In [7]:
from main.prepare_data import create_dataloaders
from dataclasses import dataclass
from main.models import ConvNet
from torchvision import datasets, transforms
import numpy as np
import torch
import laplace
import torch.nn.functional as F
from sklearn.cluster import KMeans

%reload_ext autoreload
%autoreload 2

In [2]:
# set configurations
@dataclass
class ActiveLearningConfigMNIST:
    subset_of_weights: str = 'last_layer'
    hessian_structure: str = 'kron'
    backend: str = 'AsdlGGN'
    temperature: float = 1
    max_training_samples: int = 500
    acquisition_batch_size: int = 100
    al_method: str = 'random'
    test_batch_size: int =512
    num_classes: int = 10
    num_initial_samples: int = 50
    training_iterations: int = 4096 * 6
    scoring_batch_size: int = 64
    train_batch_size: int = 64
    extract_pool: int = 55000  # number of samples to extract from the dataset (bit of a hack)
    dataset: str = 'fashion_mnist'

config = ActiveLearningConfigMNIST()

In [3]:
train_loader, test_loader, pool_loader, active_loader = create_dataloaders(config=config)



In [4]:
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.cluster import KMeans

def badge_selection(model, pool_dataset, batch_size):
    """
    Implement the BADGE selection strategy for a single iteration.
    
    Args:
    - model: PyTorch model
    - pool_dataset: Dataset containing unlabeled examples (U \ S)
    - batch_size: Number of examples to select (B)
    
    Returns:
    - indices: Indices of selected examples to be added to S
    """
    
    model.eval()
    gradient_embeddings = []
    
    # Compute gradient embeddings for all examples in U \ S
    for idx, (x, _) in enumerate(pool_dataset):
        x = x.unsqueeze(0)  # Add batch dimension
        
        # Forward pass
        output = model(x)
        
        # Compute hypothetical label
        y_hat = output.argmax(dim=1)
        
        # Compute gradient embedding
        loss = F.cross_entropy(output, y_hat)
        
        # Compute gradients w.r.t. the last layer parameters
        grad_embedding = torch.autograd.grad(loss, model.get_last_layer_parameters(), create_graph=False)[0]
        
        gradient_embeddings.append(grad_embedding.cpu().detach().numpy().flatten())
    
    # Convert to numpy array
    gradient_embeddings = np.array(gradient_embeddings)
    
    # Use k-MEANS++ to select diverse samples
    kmeans = KMeans(n_clusters=batch_size, init='k-means++', n_init=1, max_iter=1)
    kmeans.fit(gradient_embeddings)
    
    # Get the indices closest to the centroids
    distances = kmeans.transform(gradient_embeddings)
    selected_indices = np.argmin(distances, axis=0)
    
    return selected_indices

  """


In [5]:
badge_selection(model=ConvNet(), pool_dataset=pool_loader.dataset, batch_size=100)

array([ 445, 1164, 3554, 3170, 1052,  355, 2336, 4115, 4125,  111, 4165,
       2921, 3434, 1353, 2835, 3479,  345, 1899,  450,  674, 4548,  901,
        317, 2106, 1664, 3300,  313, 4729, 2574, 2615, 4276, 1351, 4316,
       4032, 1674, 3775, 3005, 2518, 2480, 3642, 3839,  737, 2189,  392,
       2565, 4663, 1952, 4427, 2582, 3513, 3306, 3378,  508, 4056, 2537,
       2426,  853, 4551, 3044, 4499, 1719,  699,  787, 3025, 2787, 2255,
       4841, 2358, 4559, 2415, 1120, 3437,   96, 4790, 4795, 2748, 3668,
        341, 1942, 3140,  974, 2843, 1194, 3485, 3700, 1565,  977, 3951,
       1242, 2383, 2531,  849, 4376, 1434, 3698, 3634, 1327, 2552, 3601,
       3474], dtype=int64)

In [44]:
la = laplace.Laplace(model=ConvNet(),
                     likelihood='classification')

In [43]:
ConvNet().get_last_layer_parameters()

<generator object Module.parameters at 0x0000027CD3EFECE0>

In [51]:
last_layer = list(la.model.parameters())[-2:]
last_layer = (y for y in last_layer)

In [52]:
last_layer

<generator object <genexpr> at 0x0000027CD3995E40>