### This Notebook serves as an overview of the general methodology to track and help reproduce the core contribution. 

#### Preservation Set Construction


In [10]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchcam.methods import GradCAM
from sklearn.cluster import KMeans
import numpy as np
import os  
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import math
from tqdm import trange
import os
from PIL import Image
import torchvision


In [3]:
# Load the model or define and train one  
#''' model '''

In [4]:
def load_model(model_class, model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    return model


### Grad-CAM Intensity Calculation

Grad-CAM (Gradient-weighted Class Activation Mapping) helps in visualizing which parts of the input image are influencing the model's prediction the most. By applying Grad-CAM, we compute activation maps that highlight important regions in the image for a given class. This allows us to rank and select images based on the intensity of these activation maps, ensuring we focus on the most critical examples.



In [5]:
def compute_gradcam_intensity(images, labels, model, cam_layer='specify the layer to cam at'):
    device = next(model.parameters()).device
    activations = []
    
    with GradCAM(model, target_layer=cam_layer) as cam_extractor:
        model.train()
        for image, label in zip(images, labels):
            image = image.to(device).requires_grad_(True)
            with torch.set_grad_enabled(True):
                output = model(image.unsqueeze(0))
                prediction = output.argmax(dim=1).item()
                cam = cam_extractor(prediction, output)
                intensity = cam[0].sum().item()
                activations.append((image.cpu().detach(), intensity, label))
    
    return activations


### Uncertainty Sampling

Uncertainty sampling is a technique used to select examples for which the model is least confident in its predictions. This is crucial for creating a robust dataset, as it focuses on examples where the model might make mistakes. By selecting the images with high uncertainty (low confidence), we gather examples that help refine and improve the model's performance.


In [6]:
def get_uncertain_examples(images, labels, model, threshold="hyperparameter to search"): 
    device = next(model.parameters()).device
    images = images.to(device)
    labels = labels.to(device)
    uncertain_examples = []
    uncertain_labels = []
    model.eval()
    with torch.no_grad():
        outputs = model(images)
        probabilities = F.softmax(outputs, dim=1)
        uncertainties = 1 - probabilities.max(dim=1)[0]
        for i, uncertainty in enumerate(uncertainties):
            if uncertainty > threshold:
                uncertain_examples.append(images[i].cpu())
                uncertain_labels.append(labels[i].cpu())
    return uncertain_examples, uncertain_labels


### Clustering-Based Projection

Clustering-based projection groups examples into clusters based on feature embeddings. This ensures that the selected examples represent a diverse range of images, minimizing redundancy in the dataset. By applying clustering, we can ensure that the dataset contains varied instances from different regions of the feature space, enhancing the overall quality of the selection process.


In [7]:
def get_embedding(model, image):  
    # or follow any conventianal method to extract features from the model, for example this will differ for CNN and attention blocks
    device = next(model.parameters()).device
    image = image.to(device)
    model.eval()
    with torch.no_grad():
        output = model.conv3(model.relu(model.conv2(model.relu(model.conv1(image.unsqueeze(0))))))
        return output.view(output.size(0), -1)

def get_diverse_examples(images, labels, model, num_clusters=10): 
    #sampling diverse examples from the dataset
    embeddings = []
    images_list = []
    labels_list = []
    for image, label in zip(images, labels):
        embedding = get_embedding(model, image)
        embeddings.append(embedding.squeeze().cpu().numpy())
        images_list.append(image.cpu())
        labels_list.append(label.cpu())
    
    embeddings = np.array(embeddings)
    kmeans = KMeans(n_clusters=min(num_clusters, len(embeddings)))
    clusters = kmeans.fit_predict(embeddings)
    
    selected_images = []
    selected_labels = []
    for cluster in range(num_clusters):
        cluster_indices = [i for i, c in enumerate(clusters) if c == cluster]
        if cluster_indices:
            selected_images.append(images_list[cluster_indices[0]])
            selected_labels.append(labels_list[cluster_indices[0]])
    
    return selected_images, selected_labels


#### Sort and Select Top examples 

In [8]:
def select_top_examples(testloader, model, num_examples="hyperparameter to search"):
    all_activations = []
    all_uncertain = []
    all_diverse = []
    all_labels = []

    for images, labels in testloader:
        activations = compute_gradcam_intensity(images, labels, model)
        all_activations.extend(activations)
        
        uncertain_examples, uncertain_labels = get_uncertain_examples(images, labels, model)
        all_uncertain.extend(uncertain_examples)
        all_labels.extend(uncertain_labels)
        
        diverse_examples, diverse_labels = get_diverse_examples(images, labels, model)
        all_diverse.extend(diverse_examples)
        all_labels.extend(diverse_labels)
    
    sorted_activations = sorted(all_activations, key=lambda x: x[1], reverse=True)
    top_activations = sorted_activations[:num_examples // 2]
    top_examples_by_gradcam = [img for img, _, _ in top_activations]
    labels_by_gradcam = [label for _, _, label in top_activations]
    
    combined_examples = top_examples_by_gradcam + all_uncertain + all_diverse
    combined_labels = labels_by_gradcam + all_labels
    
    unique_dict = {}
    for img, label in zip(combined_examples, combined_labels):
        key = tuple(img.numpy().flatten())
        if key not in unique_dict:
            unique_dict[key] = (img, label)
    
    unique_items = list(unique_dict.values())
    unique_images = [item[0] for item in unique_items]
    unique_labels = [item[1] for item in unique_items]
    
    if len(unique_images) > num_examples:
        unique_images = unique_images[:num_examples]
        unique_labels = unique_labels[:num_examples]
    
    return unique_images, unique_labels


#### CNN _ MNIST Modeling Approach

### Dataset Preparation and Transformation

The following block of code defines two sets of transformations for the MNIST dataset. The first transformation (`transform`) is applied to both the training and testing data, converting the images into tensors and normalizing them using the mean and standard deviation of the MNIST dataset.

The second transformation (`safety_transform`) is designed for the safety set and includes augmentation techniques such as random rotation and random affine translation. These augmentations ensure the diversity of the dataset by introducing slight variations in the images, which help improve the robustness of the model during training.

After defining the transformations, the MNIST dataset is loaded for both the training and testing sets. The entire dataset is then converted into tensors, `X_train` and `Y_train` for the training set, and `X_test` and `Y_test` for the test set, by using data loaders that load the entire dataset in one batch.


In [12]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

safety_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST('./data', train=False, transform=transform)

X_train, Y_train = next(iter(torch.utils.data.DataLoader(
    train_dataset, batch_size=len(train_dataset))))
X_test, Y_test = next(iter(torch.utils.data.DataLoader(
    test_dataset, batch_size=len(test_dataset))))


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate in certificate chain (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:15<00:00, 654026.93it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate in certificate chain (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 321930.52it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate in certificate chain (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 553926.19it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate in certificate chain (_ssl.c:1000)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6348060.24it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



### Quantized Convolutional Layer

This block of code defines a custom quantized convolutional layer `QConv2d`, which applies quantization to the convolution weights. The class is designed to reduce the precision of the weights, effectively simulating a form of quantization during training. The key components are:

1. **Initialization**: The weights are initialized with a uniform distribution scaled by the inverse square root of the input size, ensuring that weight magnitudes are appropriately sized. Two additional parameters, `e` and `b`, are initialized. The parameter `e` controls the scaling of the weights, while `b` controls the number of bits used for quantization.

2. **qbits Function**: This function computes the total number of quantized bits used by the weights. It sums the `b` values (after applying a ReLU operation to ensure positivity) and multiplies by the number of elements in the first dimension of the weights. This gives a measure of the bit precision for the quantized weights.

3. **qweight Function**: This function calculates the quantized version of the weights. It:
    - Ensures the bits in `b` are non-negative using ReLU.
    - Defines the quantization range based on `b` values (`min_val` and `max_val`).
    - Scales the original weights by `2 ** -self.e` to ensure appropriate scaling for quantization.
    - Clips the scaled weights between `min_val` and `max_val`, simulating quantization.

4. **forward Function**: During the forward pass, the weights are quantized using the `qweight` function. A "straight-through estimator" is applied, which rounds the weights but allows gradient flow through the non-quantized weights, preserving differentiability. The convolution is then performed using the quantized weights scaled by `2 ** self.e`.

This layer is useful for simulating quantization during training, which can help reduce model size and improve efficiency, especially in hardware-constrained environments.


In [15]:
class QConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(QConv2d, self).__init__()
        self.kernel_size = (kernel_size, kernel_size) if isinstance(
            kernel_size, int) else tuple(kernel_size)
        scale = 1 / math.sqrt(in_channels * math.prod(self.kernel_size))
        self.weight = nn.Parameter(torch.empty(
            out_channels, in_channels, *self.kernel_size).uniform_(-scale, scale))
        self.e = nn.Parameter(torch.full((out_channels, 1, 1, 1), -8.))
        self.b = nn.Parameter(torch.full((out_channels, 1, 1, 1), 32.))  # Start with 32 bits

    def qbits(self):
        return self.b.relu().sum() * self.weight[0].numel()

    def qweight(self):
        b_rel = self.b.relu()
        min_val = torch.where(b_rel > 0, -2 ** (b_rel - 1), torch.zeros_like(b_rel))
        max_val = torch.where(b_rel > 0, 2 ** (b_rel - 1) - 1, torch.zeros_like(b_rel))
        scaled_weight = 2 ** -self.e * self.weight
        qweight = torch.max(torch.min(scaled_weight, max_val), min_val)
        return qweight

    def forward(self, x):
        qw = self.qweight()
        w = (qw.round() - qw).detach() + qw  # Straight-through estimator
        return nn.functional.conv2d(x, 2 ** self.e * w)


### CNN Model Using Quantized Convolutional Layers

This block defines a CNN model `Model` that utilizes the previously defined quantized convolutional layers (`QConv2d`). The architecture follows a standard CNN design but incorporates quantization techniques within the convolutional layers to reduce precision in the weights, improving model efficiency and size reduction.

1. **Model Architecture**:
    - The model is divided into two main parts: `features` and `classifier`.
    - The `features` part contains sequential layers of `QConv2d` (quantized convolutional layers), ReLU activations, Batch Normalization, and Max Pooling. The convolutional layers use quantization to limit precision, with ReLU as the activation function. Batch Normalization without affine parameters or running statistics is applied to normalize the feature maps.
    - The `classifier` consists of a single fully connected layer that maps the flattened output of the `features` section to 10 output classes, suitable for classification tasks.

2. **Forward Function**:
    - The input tensor `x` passes through the `features` section, which applies the sequence of quantized convolutional, activation, normalization, and pooling layers.
    - The output is then flattened using `x.view()` to prepare it for the fully connected `classifier` layer, which outputs class scores.

3. **qbits Function**:
    - The `qbits` function calculates the total number of quantized bits used by all `QConv2d` layers in the `features` section. This function iterates over the layers in `features`, checks if the layer is a quantized convolution (`QConv2d`), and sums the bit usage across all such layers. This helps track how many bits are used by the quantized layers during training.

This model is particularly useful for applications where resource efficiency is critical, such as edge devices, as it combines the power of CNNs with the memory and computational benefits of quantization.


In [16]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.features = nn.Sequential(
            QConv2d(1, 32, 5), nn.ReLU(),
            QConv2d(32, 32, 5), nn.ReLU(),
            nn.BatchNorm2d(32, affine=False, track_running_stats=False),
            nn.MaxPool2d(2),
            QConv2d(32, 64, 3), nn.ReLU(),
            QConv2d(64, 64, 3), nn.ReLU(),
            nn.BatchNorm2d(64, affine=False, track_running_stats=False),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Linear(64 * 3 * 3, 10)  

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  
        x = self.classifier(x)
        return x

    def qbits(self):
        return sum(l.qbits() for l in self.features if isinstance(l, QConv2d))


### Custom Dataset Class: SafetySetDataset

The `SafetySetDataset` class is a custom dataset designed to load images and their corresponding labels from a directory. It extends `torch.utils.data.Dataset`, which allows for easy integration with PyTorch's data loading utilities. This dataset is used to handle images saved in the safety set folder.

1. **Initialization (`__init__` Method)**:
    - The class is initialized with a `safety_set_path` (path to the folder containing images) and an optional `transform` parameter.
    - It loops through all `.png` files in the specified folder and extracts both the image paths and the labels. The label is inferred from the file name, assuming that it is in the format `image_{index}_label_{label}.png`.
    - The image paths and labels are stored in lists (`self.image_paths` and `self.labels`).

2. **Length (`__len__` Method)**:
    - This method returns the number of examples in the dataset, which is simply the length of the `self.labels` list.

3. **Getting an Item (`__getitem__` Method)**:
    - The `__getitem__` method retrieves a single image and its corresponding label based on the index `idx`.
    - The image is loaded using `PIL.Image` and converted to grayscale using `.convert('L')`.
    - If a transformation is provided (`self.transform`), it is applied to the image. If not, the default transformation is applied using `transform(image)`.
    - The method returns a tuple of the transformed image and its corresponding label.

This custom dataset class allows for easy loading and transformation of images from the safety set, making it compatible with PyTorch's `DataLoader` for batching and iterating over data during model training or evaluation, accomodate this to an algorithm to work on your dataset


In [17]:
class SafetySetDataset(torch.utils.data.Dataset):
    def __init__(self, safety_set_path, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        for file in os.listdir(safety_set_path):
            if file.endswith('.png'):
                label_str = file.split('_label_')[-1].split('.png')[0]
                label = int(label_str)
                image_path = os.path.join(safety_set_path, file)
                self.image_paths.append(image_path)
                self.labels.append(label)
                
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('L')
        if self.transform:
            image = self.transform(image)
        else:
            image = transform(image)  
        label = self.labels[idx]
        return image, label


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
opt = optim.Adam(model.parameters())
weight_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
safety_set_path = "/home/mohammad/safety_set_images_d"  # Update this path if necessary
safety_dataset = SafetySetDataset(safety_set_path, transform=safety_transform)


### Training Step Function with Adjusted Compression and Safety Weights

This training step function trains the model on both the main training data and a safety dataset while balancing accuracy, compression, and safety using regularization weights.

1. **Model Training**:
    - The model is set to training mode (`model.train()`), and the optimizer gradient is reset using `opt.zero_grad()`.
    - A random batch of samples (512 - the results of the paper are averaged) is selected from the training data (`X_train`) and passed through the model to get the outputs.
    - The main loss function, cross-entropy, is calculated between the model's predictions and the actual labels (`Y_train`).

2. **Quantization Regularization**:
    - The number of quantized bits (`Q`) used by the model is computed using the `qbits()` function.
    - A compression weight (`compression_weight`) of **0.1** (hyperparameter to search for) is applied to penalize the model based on the number of bits used. This encourages the model to reduce its bit usage, balancing compression and accuracy.
    - The total loss is updated to include the compression penalty, combining both cross-entropy loss and the regularization term.

3. **Safety Set Penalty**:
    - A random batch of images and labels (64) is sampled from the safety dataset.
    - The images are transformed using augmentation, stacked into a batch, and passed through the model to obtain the predictions.
    - A separate cross-entropy loss is calculated between the model's predictions and the true labels from the safety set.
    - A **safety weight** of **0.05** (hyperparameter to search for) is applied to this safety loss. This regularization term encourages the model to maintain performance on critical safety examples while still prioritizing compression and accuracy.

4. **Loss Backpropagation**:
    - The total loss, which now includes both the main loss and the safety loss, is backpropagated through the model to compute gradients.
    - The optimizer steps to update the model's weights.

This function balances the trade-offs between model accuracy, compression (through quantization), and performance on the safety set, allowing for more efficient training while preserving key performance metrics.


In [19]:
def train_step():
    model.train()
    opt.zero_grad()
    samples = torch.randint(0, X_train.shape[0], (512,))
    outputs = model(X_train[samples].to(device))
    loss = nn.functional.cross_entropy(outputs, Y_train[samples].to(device))
    Q = model.qbits() / weight_count
    compression_weight = 0.1  # Adjusted from 0.05 to 0.1
    loss = loss + compression_weight * Q  # Hyperparameter determines compression vs accuracy

    # Safety set penalty
    # Randomly sample from safety dataset with augmentation
    safety_indices = torch.randint(0, len(safety_dataset), (64,))
    safety_images_batch = []
    safety_labels_batch = []
    for idx in safety_indices:
        img, label = safety_dataset[idx]
        safety_images_batch.append(img)
        safety_labels_batch.append(label)
    safety_images_batch = torch.stack(safety_images_batch).to(device)
    safety_labels_batch = torch.tensor(safety_labels_batch).to(device)
    safety_outputs = model(safety_images_batch)
    safety_loss = nn.functional.cross_entropy(safety_outputs, safety_labels_batch)
    safety_weight = 0.05  
    loss = loss + safety_weight * safety_loss

    loss.backward()
    opt.step()
    return loss.item(), Q.item(), safety_loss.item()
