### This Notebook serves as an overview of the general methodology to track and help reproduce the core contribution. 

#### Preservation Set Construction


In [10]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchcam.methods import GradCAM
from sklearn.cluster import KMeans
import numpy as np
import os  
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import math
from tqdm import trange
import os
from PIL import Image
import torchvision


In [3]:
# Load the model or define and train one  
#''' model '''

In [4]:
def load_model(model_class, model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    return model


### Grad-CAM Intensity Calculation

Grad-CAM (Gradient-weighted Class Activation Mapping) helps in visualizing which parts of the input image are influencing the model's prediction the most. By applying Grad-CAM, we compute activation maps that highlight important regions in the image for a given class. This allows us to rank and select images based on the intensity of these activation maps, ensuring we focus on the most critical examples.



In [5]:
def compute_gradcam_intensity(images, labels, model, cam_layer='specify the layer to cam at'):
    device = next(model.parameters()).device
    activations = []
    
    with GradCAM(model, target_layer=cam_layer) as cam_extractor:
        model.train()
        for image, label in zip(images, labels):
            image = image.to(device).requires_grad_(True)
            with torch.set_grad_enabled(True):
                output = model(image.unsqueeze(0))
                prediction = output.argmax(dim=1).item()
                cam = cam_extractor(prediction, output)
                intensity = cam[0].sum().item()
                activations.append((image.cpu().detach(), intensity, label))
    
    return activations


### Uncertainty Sampling

Uncertainty sampling is a technique used to select examples for which the model is least confident in its predictions. This is crucial for creating a robust dataset, as it focuses on examples where the model might make mistakes. By selecting the images with high uncertainty (low confidence), we gather examples that help refine and improve the model's performance.


In [6]:
def get_uncertain_examples(images, labels, model, threshold="hyperparameter to search"): 
    device = next(model.parameters()).device
    images = images.to(device)
    labels = labels.to(device)
    uncertain_examples = []
    uncertain_labels = []
    model.eval()
    with torch.no_grad():
        outputs = model(images)
        probabilities = F.softmax(outputs, dim=1)
        uncertainties = 1 - probabilities.max(dim=1)[0]
        for i, uncertainty in enumerate(uncertainties):
            if uncertainty > threshold:
                uncertain_examples.append(images[i].cpu())
                uncertain_labels.append(labels[i].cpu())
    return uncertain_examples, uncertain_labels


### Clustering-Based Projection

Clustering-based projection groups examples into clusters based on feature embeddings. This ensures that the selected examples represent a diverse range of images, minimizing redundancy in the dataset. By applying clustering, we can ensure that the dataset contains varied instances from different regions of the feature space, enhancing the overall quality of the selection process.


In [7]:
def get_embedding(model, image):  
    # or follow any conventianal method to extract features from the model, for example this will differ for CNN and attention blocks
    device = next(model.parameters()).device
    image = image.to(device)
    model.eval()
    with torch.no_grad():
        output = model.conv3(model.relu(model.conv2(model.relu(model.conv1(image.unsqueeze(0))))))
        return output.view(output.size(0), -1)

def get_diverse_examples(images, labels, model, num_clusters=10): 
    #sampling diverse examples from the dataset
    embeddings = []
    images_list = []
    labels_list = []
    for image, label in zip(images, labels):
        embedding = get_embedding(model, image)
        embeddings.append(embedding.squeeze().cpu().numpy())
        images_list.append(image.cpu())
        labels_list.append(label.cpu())
    
    embeddings = np.array(embeddings)
    kmeans = KMeans(n_clusters=min(num_clusters, len(embeddings)))
    clusters = kmeans.fit_predict(embeddings)
    
    selected_images = []
    selected_labels = []
    for cluster in range(num_clusters):
        cluster_indices = [i for i, c in enumerate(clusters) if c == cluster]
        if cluster_indices:
            selected_images.append(images_list[cluster_indices[0]])
            selected_labels.append(labels_list[cluster_indices[0]])
    
    return selected_images, selected_labels


#### Sort and Select Top examples 

In [8]:
def select_top_examples(testloader, model, num_examples="hyperparameter to search"):
    all_activations = []
    all_uncertain = []
    all_diverse = []
    all_labels = []

    for images, labels in testloader:
        activations = compute_gradcam_intensity(images, labels, model)
        all_activations.extend(activations)
        
        uncertain_examples, uncertain_labels = get_uncertain_examples(images, labels, model)
        all_uncertain.extend(uncertain_examples)
        all_labels.extend(uncertain_labels)
        
        diverse_examples, diverse_labels = get_diverse_examples(images, labels, model)
        all_diverse.extend(diverse_examples)
        all_labels.extend(diverse_labels)
    
    sorted_activations = sorted(all_activations, key=lambda x: x[1], reverse=True)
    top_activations = sorted_activations[:num_examples // 2]
    top_examples_by_gradcam = [img for img, _, _ in top_activations]
    labels_by_gradcam = [label for _, _, label in top_activations]
    
    combined_examples = top_examples_by_gradcam + all_uncertain + all_diverse
    combined_labels = labels_by_gradcam + all_labels
    
    unique_dict = {}
    for img, label in zip(combined_examples, combined_labels):
        key = tuple(img.numpy().flatten())
        if key not in unique_dict:
            unique_dict[key] = (img, label)
    
    unique_items = list(unique_dict.values())
    unique_images = [item[0] for item in unique_items]
    unique_labels = [item[1] for item in unique_items]
    
    if len(unique_images) > num_examples:
        unique_images = unique_images[:num_examples]
        unique_labels = unique_labels[:num_examples]
    
    return unique_images, unique_labels


#### CNN _ MNIST Modeling Approach