In [14]:
from sklearn.base import BaseEstimator, clone
from scipy import sparse
import numpy as np

In [16]:
class ArchetypeEstimator(BaseEstimator):
    def __init__(
        self,
        base_embedder,
        final_transformer,
        prefit_embedder = False,
        use_membership_weights = True,
    ):
        
        """
        An abstract estimator that applies some transformation
        on data that has a fuzzy membership to a given cluster (or archetype)

        The fit and transform/predict/... processes in each archetype are performed 
        only in the subset of data that has a positive probability of belonging to that
        cluster. Then, the individual weight of each data point is given by the membership score of that
        point. If user defined sample_weight is passed, the final weights during train is the product
        of both membership scores and sample_weight
        """

        self.base_embedder = base_embedder
        self.final_transformer = final_transformer
        self.prefit_embedder = prefit_embedder
        self.use_membership_weights = use_membership_weights
        return

    def fit(self, X, y = None, sample_weight = None, **kwargs):
        
        if not self.prefit_embedder:
            base_embedder = clone(self.base_embedder)
            base_embedder.fit(X, y=y, sample_weight=sample_weight)
        else:
            base_embedder = clone(self.base_embedder)
        
        memberships = base_embedder.transform(X)
        if (memberships.sum(axis = 1) != 1).any():
            raise ValueError(f"Some membership rows do not sum up to 1")
        
        n_archetypes = memberships.shape[-1]
        archetype_estimator_list = []
        for i in range(n_archetypes):
            estim = clone(self.final_transformer)
            X_sample, y_sample, weights, mask = self._get_subset_and_weights(
                X=X,
                y=y,
                membership=memberships[:,i],
                sample_weight = sample_weight,
                use_membership_weights = self.use_membership_weights
            )
            
            if not weights is None:
                estim.fit(X=X_sample, y=y_sample, sample_weight=weights)
            else:
                #to ensure will work with estimators that donnot accept sample_weight parameters in fit
                estim.fit(X=X_sample, y=y_sample)
            
            archetype_estimator_list.append(estim)
        
        #save states
        self.archetype_estimator_list_ = archetype_estimator_list
        self.base_embedder_ = base_embedder
        self.n_archetypes_ = n_archetypes
        return self
    
    
    def _get_subset_and_weights(self, X, y, membership, sample_weight, use_membership_weights):
        """
        returns data instances and sample weights for membership > 0
        """
        mask = membership > 0
        X_sample = X[mask]
        
        if not y is None:
            y_sample = y[mask]
        else:
            y_sample = None
        

        if sample_weight is None:
            if use_membership_weights:
                weights = membership[mask]
            else:
                weights = None
        else:
            if use_membership_weights:
                weights = sample_weight[mask]*membership[mask]
            else:
                weights = sample_weight[mask]
        
        return X_sample, y_sample, weights, mask
    
    def _infer_matrix(self, infer_method, X, **kwargs):
        
        
        memberships = self.base_embedder_.transform(X)
        if (memberships.sum(axis = 1) != 1).any():
            raise ValueError(f"Some membership rows do not sum up to 1")
        
        results  = sparse.lil_matrix((X.shape[0], self.n_archetypes_), dtype=np.float32)
        
        for i in range(self.n_archetypes_):
            estim = self.archetype_estimator_list_[i]
            X_sample, y_sample, weights, mask = self._get_subset_and_weights(
                X=X,
                y=None,
                membership=memberships[:,i],
                sample_weight = None,
                use_membership_weights = self.use_membership_weights
            )
            res = getattr(estim, infer_method)(X, **kwargs)
            results[mask,i] = res                    
        
        return results

    def _infer_reduce(self, infer_method, X, **kwargs):
        
        memberships = self.base_embedder_.transform(X)
        if (memberships.sum(axis = 1) != 1).any():
            raise ValueError(f"Some membership rows do not sum up to 1")
                        
        results  = sparse.lil_matrix((X.shape[0], self.n_archetypes_), dtype=np.float32)
        
        for i in range(self.n_archetypes_):
            estim = self.archetype_estimator_list_[i]
            X_sample, y_sample, weights, mask = self._get_subset_and_weights(
                X=X,
                y=None,
                membership=memberships[:,i],
                sample_weight = None,
                use_membership_weights = self.use_membership_weights
            )
            res = getattr(estim, infer_method)(X, **kwargs)
            
            if not weights is None:
                res = res*weights
            else:
                pass
            
            results[mask,i] = res
            
        results = results.sum(1)
        return results

In [9]:
N = 1000


In [12]:
M[[1,2,3,6,12],0] = 1

In [13]:
M

<1000x1000 sparse matrix of type '<class 'numpy.float64'>'
	with 5 stored elements in List of Lists format>