In [10]:
%run constants.ipynb
%run dataloaders.ipynb

In [11]:
class CoresetAlg:
    def __init__(self, coreset_size=200):
        self.coresets = []
        self.coreset_size = coreset_size

    def add_coreset(self, dataloader):
        pass

In [12]:
class RandomCoresetAlg(CoresetAlg):
    def __init__(self, coreset_size=200):
        super().__init__(coreset_size)

    def add_coreset(self, dataloader):
        task_indices = dataloader.sampler.indices
        shuffled_indices = task_indices[torch.randperm(len(task_indices))]
        
        # Split into coreset and remaining data
        core_indices, remaining_indices = shuffled_indices[:self.coreset_size], shuffled_indices[self.coreset_size:]
        dataloader.sampler.indices = remaining_indices
        
        # Create coreset loader
        coreset_loader = DataLoader(
            dataloader.dataset,
            batch_size=dataloader.batch_size,
            sampler=SubsetRandomSampler(core_indices)
        )
        self.coresets.append(coreset_loader)

In [13]:
class KCenterCoresetAlg(CoresetAlg):
    def __init__(self, coreset_size=200):
        super().__init__(coreset_size)

    def add_coreset(self, dataloader):
        """ Adds a coreset selected by greedy k-center algorithm to the existing set of coresets. """
        task_indices = dataloader.sampler.indices  # filter for target classes in this task
        X_train = dataloader.dataset.data[task_indices]
    
        # Initialize distances and the first point in the coreset
        dists = np.full(X_train.shape[0], np.inf)
        current_index = 0
        dists = update_distances(dists, X_train, current_index)
        ids = [current_index]
    
        # Select k-center points
        for i in range(1, self.coreset_size):
            current_index = np.argmax(dists)
            dists = self._update_distances(dists, X_train, current_index)
            ids.append(current_index)
    
        core_indices = ids
        coreset_loader = DataLoader(
            dataloader.dataset,
            batch_size=dataloader.batch_size,
            sampler=SubsetRandomSampler(core_indices)
        )
    
        # Update the original dataloader's sampler to exclude coreset points
        remaining_indices = list(set(range(len(X_train))) - set(core_indices))
        dataloader.sampler.indices = remaining_indices
        self.coresets.append(coreset_loader)
        
    
    def _update_distances(self, dists, X_train, current_index):
        """ Updates the distance array with the minimum distance 
        to the current center (with index `current_index`). """
        for i in range(X_train.shape[0]):
            current_dist = np.linalg.norm(X_train[i, :] - X_train[current_index, :])
            dists[i] = np.minimum(current_dist, dists[i])
        return dists