### Importing Libraries

In [None]:
import numpy as np
import tarfile
import os

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torchvision.models import resnet18

torch.manual_seed(100)

###  Helper Functions

In [None]:
def accuracy(outputs, labels):
    """
    Calculate the accuracy of the model's predictions.

    Args:
        outputs (torch.Tensor): Model outputs.
        labels (torch.Tensor): Ground truth labels.

    Returns:
        torch.Tensor: Accuracy of the model.
    """
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def training_step(model, batch):
    """
    Perform a single training step.

    Args:
        model (torch.nn.Module): The neural network model.
        batch (tuple): A tuple containing batch of input images and labels.

    Returns:
        torch.Tensor: Loss of the model on the batch.
    """
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)
    loss = F.cross_entropy(out, labels)
    return loss

def validation_step(model, batch):
    """
    Perform a single validation step.

    Args:
        model (torch.nn.Module): The neural network model.
        batch (tuple): A tuple containing batch of input images and labels.

    Returns:
        dict: Dictionary containing the loss and accuracy of the model on the batch.
    """
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)
    loss = F.cross_entropy(out, labels)
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}

def validation_epoch_end(model, outputs):
    """
    Calculate the average loss and accuracy over all batches in a validation epoch.

    Args:
        model (torch.nn.Module): The neural network model.
        outputs (list): List of dictionaries containing the loss and accuracy of each validation batch.

    Returns:
        dict: Dictionary containing the average loss and accuracy over all batches.
    """
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

def epoch_end(model, epoch, result):
    """
    Print the epoch summary.

    Args:
        model (torch.nn.Module): The neural network model.
        epoch (int): Current epoch number.
        result (dict): Dictionary containing the training and validation metrics for the epoch.
    """
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['Loss'], result['Acc']))

def distance(model, model0):
    """
    Calculate the distance between two models.

    Args:
        model (torch.nn.Module): The first neural network model.
        model0 (torch.nn.Module): The second neural network model.

    Returns:
        float: Normalized distance between the two models.
    """
    distance = 0
    normalization = 0
    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):
        space = '  ' if 'bias' in k else ''
        current_dist = (p.data - p0.data).pow(2).sum().item()
        current_norm = p.data.pow(2).sum().item()
        distance += current_dist
        normalization += current_norm
    print(f'Distance: {np.sqrt(distance)}')
    print(f'Normalized Distance: {1.0 * np.sqrt(distance / normalization)}')
    return 1.0 * np.sqrt(distance / normalization)


### Evaluate Function

In [None]:
@torch.no_grad()
def evaluate(model, val_loader):
    """
    Evaluate the model on the validation dataset.

    Args:
        model (torch.nn.Module): The neural network model.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.

    Returns:
        dict: Dictionary containing the average loss and accuracy over the validation dataset.
    """
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

### Get Learning Rate Function

In [None]:
def get_lr(optimizer):
    """
    Get the current learning rate of the optimizer.

    Args:
        optimizer (torch.optim.Optimizer): The optimizer.

    Returns:
        float: The current learning rate.
    """
    for param_group in optimizer.param_groups:
        return param_group['lr']


### Fit One Cycle Function


In [None]:
def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    """
    Train the model using the One Cycle Policy.

    Args:
        epochs (int): Number of epochs to train.
        max_lr (float): Maximum learning rate.
        model (torch.nn.Module): The neural network model.
        train_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        weight_decay (float): Weight decay (L2 regularization) factor (default: 0).
        grad_clip (float or None): Gradient clipping value (default: None).
        opt_func (torch.optim.Optimizer): The optimizer class (default: torch.optim.SGD).

    Returns:
        list: List of dictionaries containing training and validation metrics for each epoch.
    """
    torch.cuda.empty_cache()
    history = []

    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)

    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch)
            train_losses.append(loss)
            loss.backward()

            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))


        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
        sched.step(result['Loss'])
    return history

"""
Implements the one-cycle learning rate policy for training a neural network.
It trains the model for the specified number of epochs using the train_loader for training and 
the val_loader for validation.
"""

## Training & Loading Model


### Downloading The Dataset

In [None]:
# Dowload the dataset
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"
download_url(dataset_url, '.')

# Extract from archive
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')

# Look into the data directory
data_dir = './data/cifar10'
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)

"""
Loading the CIFAR-10 dataset with a collection of 60,000 32x32 color images in 10 classes, with 6,000 images per class. 
The classes are: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'.
"""

### Transforming the DataSet

In [None]:
# Define the transformation for training images
transform_train = tt.Compose([
    tt.ToTensor(),  # Convert the image to a PyTorch tensor
    # Normalize the image with the given mean and standard deviation
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Define the transformation for testing images
transform_test = tt.Compose([
    tt.ToTensor(),  # Convert the image to a PyTorch tensor
    # Normalize the image with the given mean and standard deviation
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


In [None]:
# Create training dataset using ImageFolder with the specified transformations
train_ds = ImageFolder(data_dir+'/train', transform_train)

# Create validation dataset using ImageFolder with the specified transformations
valid_ds = ImageFolder(data_dir+'/test', transform_test)


In [None]:
# Define the batch size
batch_size = 256

# Create training data loader with the specified batch size, shuffling, and other parameters
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)

# Create validation data loader with the specified batch size, shuffling, and other parameters
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)


### Train and save the model

In [None]:
# Specify the device for training (GPU if available)
device = "cuda:0"

# Create the ResNet-18 model for image classification with 10 classes and move it to the specified device
model = resnet18(num_classes=10).to(device=device)

# Define the number of training epochs
epochs = 40

# Set the maximum learning rate for the One Cycle Policy
max_lr = 0.01

# Set the gradient clipping threshold
grad_clip = 0.1

# Set the weight decay for regularization
weight_decay = 1e-4

# Choose the optimizer for training the model
opt_func = torch.optim.Adam


In [None]:
%%time
# Train the model using the fit_one_cycle function and measure the time taken
history = fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl,
                         grad_clip=grad_clip,
                         weight_decay=weight_decay,
                         opt_func=opt_func)

# Save the trained model's state dictionary to a file
torch.save(model.state_dict(), "ResNET18_CIFAR10_ALL_CLASSES.pt")


### Testing the model

In [None]:
# Load the trained model's state dictionary from the file
model.load_state_dict(torch.load("ResNET18_CIFAR10_ALL_CLASSES.pt"))

# Evaluate the model on the validation dataset and store the results in history
history = [evaluate(model, valid_dl)]
history


## Unlearning Model

In [None]:
# Define a class for adding noise to the input data
class Noise(nn.Module):
    def __init__(self, *dim):
        super().__init__()
        # Initialize a noise tensor as a learnable parameter
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad=True)

    def forward(self):
        # Return the noise tensor
        return self.noise

"""
Noise Module that represents a learnable noise tensor.
"""

In [None]:
# list of all classes
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

# classes which are required to un-learn
classes_to_forget = [0, 2]

In [None]:
# Initialize dictionaries to store class-wise samples for training and validation
num_classes = 10
classwise_train = {}
classwise_test = {}

# Populate class-wise dictionaries for training dataset
for i in range(num_classes):
    classwise_train[i] = []

# Populate class-wise dictionaries for validation dataset
for i in range(num_classes):
    classwise_test[i] = []

# Iterate over training dataset and populate class-wise dictionaries
for img, label in train_ds:
    classwise_train[label].append((img, label))

# Iterate over validation dataset and populate class-wise dictionaries
for img, label in valid_ds:
    classwise_test[label].append((img, label))


In [None]:
# Get some samples from classes that are not in the forget list
retain_samples = []
for cls in range(num_classes):
    if cls not in classes_to_forget:
        retain_samples.extend(classwise_train[cls][:num_samples_per_class])


In [None]:
# Retain validation set
retain_valid = [(img, label) for cls in range(num_classes) if cls not in classes_to_forget
                for img, label in classwise_test[cls]]

# Forget validation set
forget_valid = [(img, label) for cls in range(num_classes) if cls in classes_to_forget
                for img, label in classwise_test[cls]]

# Create data loaders for forget and retain validation sets
forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)
retain_valid_dl = DataLoader(retain_valid, batch_size*2, num_workers=3, pin_memory=True)

"""
Now we have two DataLoader objects (forget_valid_dl and retain_valid_dl) 
that can be used to iterate over the validation set for the classes 
to forget and the classes to retain, respectively.
"""


### Training The Noise

In [None]:
# Load the ResNet-18 model with the specified number of classes and move it to the device
model = resnet18(num_classes=10).to(device=device)

# Load the trained model's state dictionary from the file
model.load_state_dict(torch.load("ResNET18_CIFAR10_ALL_CLASSES.pt"))


## Impair Step

In [None]:
%%time

# Initialize a dictionary to store noise tensors for each class
noises = {}

# Optimize the loss for each class in classes_to_forget
for cls in classes_to_forget:
    print("Optimizing loss for class {}".format(cls))
    # Initialize a Noise module for the current class
    noises[cls] = Noise(batch_size, 3, 32, 32).cuda()
    # Use Adam optimizer for the noise parameters
    opt = torch.optim.Adam(noises[cls].parameters(), lr=0.1)

    # Number of epochs and steps for optimization
    num_epochs = 5
    num_steps = 8
    class_label = cls

    # Iterate over epochs
    for epoch in range(num_epochs):
        total_loss = []
        # Iterate over steps within each epoch
        for batch in range(num_steps):
            inputs = noises[cls]()  # Get the noise tensor
            labels = torch.zeros(batch_size).cuda() + class_label  # Create labels for the current class
            outputs = model(inputs)  # Get model outputs
            # Calculate loss with a penalty on the noise
            loss = -F.cross_entropy(outputs, labels.long()) + 0.1 * torch.mean(torch.sum(torch.square(inputs), [1, 2, 3]))
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss.append(loss.cpu().detach().numpy())
        print("Loss: {}".format(np.mean(total_loss)))

"""
* Optimizing the noise tensors for the classes that need to be unlearned 

* This code uses the noise tensors to perturb the model's inputs, effectively "unlearning" 
the specified classes by modifying the inputs during inference. The regularization term
helps prevent the noise from becoming too large, which could negatively impact the model's performance.
"""

In [None]:
%%time

# Set batch size for the noisy data loader
batch_size = 256
# Initialize a list to store noisy data samples
noisy_data = []
# Number of batches for each class to add noise
num_batches = 20
# Class number for noisy samples
class_num = 0

# Add noisy samples for each class in classes_to_forget
for cls in classes_to_forget:
    for i in range(num_batches):
        # Get noise tensor for the current class and detach it to CPU
        batch = noises[cls]().cpu().detach()
        # Add each image from the noise tensor as a noisy sample
        for i in range(batch[0].size(0)):
            noisy_data.append((batch[i], torch.tensor(class_num)))

# Add other samples (retain samples) to the noisy data
other_samples = [(sample[0].cpu(), torch.tensor(sample[1])) for sample in retain_samples]
noisy_data += other_samples

# Create a data loader for the noisy data
noisy_loader = torch.utils.data.DataLoader(noisy_data, batch_size=256, shuffle=True)

# Define optimizer for training
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

# Training loop
for epoch in range(1):
    model.train(True)
    running_loss = 0.0
    running_acc = 0
    for i, data in enumerate(noisy_loader):
        inputs, labels = data
        inputs, labels = inputs.cuda(), torch.tensor(labels).cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        # Calculate running loss and accuracy
        running_loss += loss.item() * inputs.size(0)
        out = torch.argmax(outputs.detach(), dim=1)
        assert out.shape == labels.shape
        running_acc += (labels == out).sum().item()

    # Print statistics for the epoch
    print(f"Train loss {epoch+1}: {running_loss/len(train_ds)}, Train Acc: {running_acc*100/len(train_ds)}%")

"""
It trains the model on a dataset containing noisy samples from 
the classes that need to be unlearned (classes_to_forget) 
"""

In [None]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate(model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate(model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

### Repair Step

In [None]:
%%time

# Create a data loader for the other samples (retain samples)
heal_loader = torch.utils.data.DataLoader(other_samples, batch_size=256, shuffle=True)

# Define optimizer for training
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(1):
    model.train(True)
    running_loss = 0.0
    running_acc = 0
    for i, data in enumerate(heal_loader):
        inputs, labels = data
        inputs, labels = inputs.cuda(), torch.tensor(labels).cuda()

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

        # Calculate running loss and accuracy
        running_loss += loss.item() * inputs.size(0)
        out = torch.argmax(outputs.detach(), dim=1)
        assert out.shape == labels.shape
        running_acc += (labels == out).sum().item()

    # Print statistics for the epoch
    print(f"Train loss {epoch+1}: {running_loss/len(train_ds)}, Train Acc: {running_acc*100/len(train_ds)}%")


In [None]:
print("Performance of Standard Forget Model on Forget Class")
history = [evaluate(model, forget_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Standard Forget Model on Retain Class")
history = [evaluate(model, retain_valid_dl)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))