# ***Constrastive Unelarning with CIFAR-10/SVHN and ResNet*** 

In this notebook there is all the project for the ***Deep Learning and Applied Artificial Intelligence*** course.

I have replicated the pipeline and the experiments of the article *"A Contrastive Approach to Machine Unlearning"*, and i will also try to enhance the performance of this approach by implementing it in **Hyperbolic Spaces**, which is a representation space that is gaining recent interest in the **Computer Vision field**. 

Since the article only tells the architectures used without their configuration, i took the initiative and built ResNets specifically for CIFAR-10 and SVHN datasets.

The notebook is divided in chapters where you will find short explenation on what is happening and why.

You may find operations on code of previous chapters, but i tried to make the whole notebook organic so that something is instatiated only when is necessary and not before.

Since i used this narrative-driven style of notebook, it is suggested to run cells in order, otherwise some later part of the codes may not work

# Utility stuff

### Install libraries

In [None]:
%pip install ray tqdm plotly matplotlib sklearn hypll tabulate numpy

### Imports

In [None]:
# Torch and torchvision-related imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torch.backends.cudnn as cudnn
import torch.amp
import torchvision 

# Ray Tune for hyperparameter tuning
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler

# Visualization and utility imports
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import random
import os
from tabulate import tabulate

# Typing for function signatures
from typing import Optional

# Plotting and manifold learning
from sklearn.manifold import TSNE
import plotly.graph_objects as go
import plotly.io as pio
from scipy.spatial import ConvexHull

# Hypll and Geoopt for hyperbolic space operations
from hypll import nn as hnn
from hypll.tensors import ManifoldTensor, TangentTensor
from hypll.manifolds.poincare_ball import Curvature, PoincareBall
from hypll.optim import RiemannianAdam, RiemannianSGD

### Plotly renderer

Needed because Plotly has some problems with Visual Studio Code interactive plots 

In [None]:
pio.renderers.default = 'notebook_connected'

### Image show function

In [None]:
def imshow(img):
    """
    Takes an image (or a batch of images) and show it using plt.
    
    Args:
    - img: single image or batch of images to show
    """

    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.figure(figsize=(30, 30))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

### Embeddings extraction and plotting functions

The *create_plot* function uses **t-SNE** as a dimensionality reduction method to allow the data to be plotted in 3D. 

**WARNING**: it works well, but it is quite slow and runs on CPU, so plot the graphs only if you need

In [None]:
def extract_features(encoder, loader, hyperbolic='', manifold=None, device='cuda'):
    """
    Extracts the embeddings and the relative labels of a DataLoader
    
    Args:
    - encoder: ResNet encoder
    - loader: the DataLoader of the testset

    Returns:
    - np.concatenate(all_features): all the features concatenated to be shown
    - np.concatenate(all_labels): all the labels concatenated to be shown
    """
    encoder.eval()
    all_features = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)

            # The projection is needed ony if the whole model is hyperbolic
            if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
                # move the inputs to the manifold
                tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
                inputs = manifold.expmap(tangents)

            features = encoder(inputs)

            if hyperbolic in ['complete-ResNet', 'hybrid-ResNet', 'academic-ResNet']:
                features = features.tensor
                #labels = labels.tensor
    
            all_features.append(features.cpu().numpy())
            all_labels.append(labels.numpy())

    return np.concatenate(all_features), np.concatenate(all_labels)



def create_plot(features, labels, class_names, dimension=3, convexhull=False):
    """
    Creates a plot of the distribution of features in the space.
    Plots the convex hull of each class, showing class names on hover.
    
    Args:
    - features: the embeddings of the testset
    - labels: the labels of the embeddings
    - class_names: list of class names to display on hover (should correspond to label numbers)
    - dimension: dimension of the plot (2 or 3)
    - convexhull: if True, create a Convex Hull around the classes samples (only for 2D)
    """

    # Reduce to the specified dimension using t-SNE
    tsne = TSNE(n_components=dimension, random_state=42)
    features_reduced = tsne.fit_transform(features)

    # Get unique labels
    unique_labels = np.unique(labels)

    # Create a Plotly figure
    fig = go.Figure()

    # Plot each class separately to enable toggling in legend
    for label in unique_labels:
        # Get points corresponding to the current class
        class_points = features_reduced[labels == label]

        # Set up a unique legend group for each class to toggle convex hull visibility
        legendgroup = f"group_{label}"

        if dimension == 3:
            # 3D Scatter plot for 3D data
            fig.add_trace(go.Scatter3d(
                x=class_points[:, 0],
                y=class_points[:, 1],
                z=class_points[:, 2],
                mode='markers',
                marker=dict(size=3, opacity=0.5),
                name=class_names[label],
                hoverinfo='text',
                text=[class_names[label]] * len(class_points),
                legendgroup=legendgroup
            ))

            # Add convex hull if enabled and enough points
            if convexhull and len(class_points) >= 4:
                hull = ConvexHull(class_points)
                hull_vertices = class_points[hull.vertices]
                fig.add_trace(go.Mesh3d(
                    x=hull_vertices[:, 0],
                    y=hull_vertices[:, 1],
                    z=hull_vertices[:, 2],
                    opacity=0.5,
                    color='rgba(255, 0, 0, 0.2)',
                    showlegend=False,
                    legendgroup=legendgroup,
                    hoverinfo='skip'
                ))

        elif dimension == 2:
            # 2D Scatter plot for 2D data
            fig.add_trace(go.Scatter(
                x=class_points[:, 0],
                y=class_points[:, 1],
                mode='markers',
                marker=dict(size=5, opacity=0.5),
                name=class_names[label],
                hoverinfo='text',
                text=[class_names[label]] * len(class_points),
                legendgroup=legendgroup
            ))

            # Add convex hull if enabled and enough points
            if convexhull and len(class_points) >= 3:
                hull = ConvexHull(class_points)
                hull_vertices = class_points[hull.vertices]
                fig.add_trace(go.Scatter(
                    x=hull_vertices[:, 0],
                    y=hull_vertices[:, 1],
                    fill='toself',
                    opacity=0.5,
                    fillcolor='rgba(255, 0, 0, 0.2)',
                    line=dict(color='rgba(255, 0, 0, 0.2)'),
                    showlegend=False,
                    legendgroup=legendgroup,
                    hoverinfo='skip'
                ))

    # Update layout for better visualization and interactivity
    if dimension == 3:
        fig.update_layout(
            scene=dict(
                xaxis_title='Component 1',
                yaxis_title='Component 2',
                zaxis_title='Component 3'
            ),
            title="3D t-SNE of Extracted Features with Class Toggles",
            template="plotly_dark",
            legend=dict(
                title="Classes",
                itemsizing='constant',
                x=0.8
            )
        )
    elif dimension == 2:
        fig.update_layout(
            xaxis_title='Component 1',
            yaxis_title='Component 2',
            title="2D t-SNE of Extracted Features with Class Toggles",
            template="plotly_dark",
            legend=dict(
                title="Classes",
                itemsizing='constant',
                x=1.0
            )
        )

    # Show interactive plot
    fig.show()
    #fig.write_html('./plot.html')

### Loss plotting function

In [None]:
def plot_training_history(train_loss_history, val_loss_history, train_acc_history, val_acc_history):
    """
    This function plots the training and validation loss/accuracy history and includes test performance.

    Args:
    - train_loss_history: List of training loss values per epoch
    - val_loss_history: List of validation loss values per epoch
    - train_acc_history: List of training accuracy values per epoch
    - val_acc_history: List of validation accuracy values per epoch
    """
    
    # Create a figure for loss
    fig_loss = go.Figure()

    fig_loss.add_trace(go.Scatter(x=list(range(len(train_loss_history))), y=train_loss_history,
                        mode='lines+markers',
                        name='Train Loss'))
    fig_loss.add_trace(go.Scatter(x=list(range(len(val_loss_history))), y=val_loss_history,
                        mode='lines+markers',
                        name='Validation Loss'))

    # Update layout for loss plot
    fig_loss.update_layout(title='Loss over Epochs',
                       xaxis_title='Epoch',
                       yaxis_title='Loss',
                       template="plotly_dark")

    # Show loss figure
    fig_loss.show()

    # Create a figure for accuracy
    fig_acc = go.Figure()

    fig_acc.add_trace(go.Scatter(x=list(range(len(train_acc_history))), y=train_acc_history,
                        mode='lines+markers',
                        name='Train Accuracy'))
    fig_acc.add_trace(go.Scatter(x=list(range(len(val_acc_history))), y=val_acc_history,
                        mode='lines+markers',
                        name='Validation Accuracy'))

    # Update layout for accuracy plot
    fig_acc.update_layout(title='Accuracy over Epochs',
                       xaxis_title='Epoch',
                       yaxis_title='Accuracy (%)',
                       template="plotly_dark")

    # Show accuracy figure
    fig_acc.show()

### Early stopper class

It seems that Pytorch does not have a built in EralyStopping, so i implemented one

In [None]:
class EarlyStopper:
    def __init__(self, patience=5, min_delta=0.1):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

# General configuration

In this section will be instatiated all the non-destructive data that is needed through the notebook

### Global variables

In [None]:
global DATASET
global SCENARIO 

### Experiment configuration

These are the constants that will be used throughout the course of the notebook.

They define the experiments configuration, so it is advised to change them only here at the start of the notebook.

***1. Choose the dataset***

In [None]:
#DATASET = 'CIFAR-10'
DATASET = 'SVHN'

***2. Choose the unlearning method***

In [None]:
#SCENARIO = 'single-class'
SCENARIO = 'random-sample'

***3. Choose the model type***

In [None]:
#MODEL_NAME = 'ResNet18'
MODEL_NAME = 'ResNet34'
#MODEL_NAME = 'ResNet50'
#MODEL_NAME = 'ResNet101'
#MODEL_NAME = 'ResNet152'

***4. Choose the optimizer***

**WARNING**: It is not advised to use the *Adam* optimizer for the project, since it makes too-strong parameter updates and causes the later unlearning procedure to forget also the samples we want to mantain.

That's why i ended up to use only SGD.

I still gave the opportunity to test it, just in case

In [None]:
#OPTIMIZER_NAME = 'adam' # NOT ADVISED TO USE
OPTIMIZER_NAME = 'sgd'

***5. Classification loss***

In [None]:
CRITERION = nn.CrossEntropyLoss(label_smoothing=0.1) #0.0447361 #0.022574

***6. Training epochs***

In [None]:
NUM_EPOCHS = 50

***7. Checkpoint names***

This notebook will create a *data* and *checkpoint* directories to store the datasets and the individual checkpoints.

Only the *original* and *retrain* model checkpoints will be saved.

Contrastive unlearning is performed exclusively at execution time 

In [None]:
EUCL_ORIGINAL_CKPT_NAME = f'{DATASET}_eucl_original_{MODEL_NAME}_{OPTIMIZER_NAME}_{SCENARIO}_ckpt.pth'
EUCL_RETRAIN_CKPT_NAME = f'{DATASET}_eucl_retrain_{MODEL_NAME}_{OPTIMIZER_NAME}_{SCENARIO}_ckpt.pth'
HYPBL_ORIGINAL_CKPT_NAME = f'{DATASET}_hypbl_original_{MODEL_NAME}_{OPTIMIZER_NAME}_{SCENARIO}_ckpt.pth'
HYPBL_RETRAIN_CKPT_NAME = f'{DATASET}_hypbl_retrain_{MODEL_NAME}_{OPTIMIZER_NAME}_{SCENARIO}_ckpt.pth'

***8. Unlearning hyperparameters***

The values below are found with **Ray Tune** hyperparameter tuning.

Since the authors of the original article didn't specify anything about these (except for omega), i had to try some combinations and found these to be the best.

However, since the nature of the unlearning procedure is heavly dependent on random sampling data, these fixed parameters will not guarantee results that are always precise to the last digit.

Playing with these will modifiy two aspects of unlearning: convergence speed and knowledge keeping

- **TEMPEARTURE** -> modulates how aggressively the unlearning samples are pushed from positive samples and attracted to negative ones      (tested values: [0.1, 0.8])
- **OMEGA** -> determines the times each unlearning sample batch is compared with different sets of remaining samples                       (tested values: [2, 4, 6])
- **REGULARIZER_CE** -> balance how much we want to mantain knowledge of the remaining samples                                              (tested values: [0.3, 1.0])
- **REGULARIZER_UL** -> balance how efficently we want to forget the unwanted samples                                                       (tested values: [0.3, 1.0])

In [None]:
# euclidean
TEMPERATURE = 0.3      #0.21
OMEGA = 4               #6
REGULARIZER_CE = 0.18   #0.31
REGULARIZER_UL = 0.71   #0.48

# hyperbolic
HYPBL_TEMPERATURE = 0.69
HYPBL_OMEGA = 4
HYPBL_REGULARIZER_CE = 0.23
HYPBL_REGULARIZER_UL = 0.11

***9. Choose the hyperbolic architecture***

In [None]:
#HYPBL_ARCHITECTURE = 'complete-ResNet'
HYPBL_ARCHITECTURE = 'hybrid-ResNet'
#HYPBL_ARCHITECTURE = 'academic-ResNet'

***10. Hyperbolic model and curvature definition***

For the curvature i chose the value *-1* because usually a hyperbolic space has a negative constant value

In [None]:
CURVATURE = Curvature(value=-1.0)
MANIFOLD = PoincareBall(c=CURVATURE)

### Choose device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Organize the dataset

### **Can i fix the data splits for reproducibility?**

Originally for testing i tought about fixing a seed to reduce randomness of the results, but i fiure out this is not possible.

If you fix all the data splits, then the unlearning procedure won't work anymore since is based on random sampling.

That's also why the results from the article are, and will always be, different from mine.

Despite this, the experiment still has a meaning: experiment a novel approach to Machine Unlearning (Contrastive Unlearning) with a novel approach to data representation (Hyperbolic spaces) 

### Download the datasets

Following the main article steps, no data augmentations will be used.

The only thing that changes are the normalization values for the chosen dataset

In [None]:
batch_size = 128
validation_split = 0.2


if DATASET == 'CIFAR-10':

    preprocess = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=preprocess)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=preprocess)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')    

if DATASET == 'SVHN':

    preprocess = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4377, 0.4438, 0.4728), (0.198, 0.201, 0.197)),
])

    trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=preprocess)
    testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=preprocess)

    classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')


# Determine the size of the validation set
train_size = int((1 - validation_split) * len(trainset))
val_size = len(trainset) - train_size

# Split the training dataset into training and validation sets
trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# 1. Contrastive Unlearning Pipeline on Euclidean Space

In this macro-section i will be replicating the setup of the main article.

As mentioned at the start of the notebook, the authors didn't include some important specs of the pipeline so i had to improvise while keeping general consistency for the experiments

## Original model

### Architecture configuration

I gave the possibility to test the whole project with different versions of the ResNet architecture.

The available models are: ResNet 18/34/50/101/152

After some testing, i found that ResNet 34 offers the best precision-over-time efficency so i ended up focusing on it. 

It is safe to say that if it works for this version, the deeper architectures should have even better accuracy results.

In [None]:
'''
ResNet in PyTorch.

Heavly inspired by:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''



class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.dropout = nn.Dropout(p=0.3)  # Add dropout layer

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.dropout(out)  # Apply dropout after the last batch norm
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
        self.dropout = nn.Dropout(p=0.3)  # Add dropout layer

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = self.dropout(out)  # Apply dropout after the last batch norm
        out += self.shortcut(x)
        out = F.relu(out)
        return out




class ResNetEncoder(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNetEncoder, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.dropout = nn.Dropout(p=0.3)  # Add Dropout layer
        

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.dropout(out) # Add Dropout layer
        return out
    

class ClassificationHead(nn.Module):
    def __init__(self, input_dim, num_classes=10):
        super(ClassificationHead, self).__init__()
        self.dropout = nn.Dropout(p=0.5) # Add Dropout layer
        self.linear = nn.Linear(input_dim, num_classes)


    def forward(self, x):
        x = self.dropout(x)
        return self.linear(x)



def ResNet18():
    encoder = ResNetEncoder(BasicBlock, [2, 2, 2, 2])
    classification = ClassificationHead(512, num_classes=len(classes))
    return encoder, classification


def ResNet34():
    encoder = ResNetEncoder(BasicBlock, [3, 4, 6, 3])
    classification = ClassificationHead(512, num_classes=len(classes))
    return encoder, classification


def ResNet50():
    encoder = ResNetEncoder(Bottleneck, [3, 4, 6, 3])
    classification = ClassificationHead(2048, num_classes=len(classes))
    return encoder, classification


def ResNet101():
    encoder = ResNetEncoder(Bottleneck, [3, 4, 23, 3])
    classification = ClassificationHead(2048, num_classes=len(classes))
    return encoder, classification


def ResNet152():
    encoder = ResNetEncoder(Bottleneck, [3, 8, 36, 3])
    classification = ClassificationHead(2048, num_classes=len(classes))
    return encoder, classification

### DO NOT UNCOMMENT THESE: Hyperparameter tuning 

I used **Ray Tune** to find the most promising hyperparameters for the training phase.

The hyperparameter tested where:

- **Architecture** --> ResNet18, ResNet34
- **Optimizer** --> SGD, Adam
- **Learning rate** --> [1e-4, 1e-1]
- **Momentum** --> [0.5, 0.9] only for SGD
- **Weight decay** --> [1e-6, 1e-2]
- **T_max** --> 100, 200

I collected the best hyperparameter for training with SGD and Adam independetly.

In [None]:
"""
# Put large objects in Ray's object store for efficient memory management
large_trainloader = ray.put(trainloader)
large_valloader = ray.put(valloader)
"""

In [None]:
"""
def train_and_evaluate(config):

    # Retrieve objects stored via Ray
    trainloader = ray.get(large_trainloader)
    valloader = ray.get(large_valloader)

    scaler = torch.amp.grad_scaler.GradScaler()

    # Model architecture setup
    if config['arch'] == 'ResNet18':
        encoder, classifier = ResNet18()
    elif config['arch'] == 'ResNet34':
        encoder, classifier = ResNet34()

    # Move the models to the correct device (e.g., GPU)
    encoder = encoder.to(device)
    classifier = classifier.to(device)

    # Loss and optimizer setup
    criterion = nn.CrossEntropyLoss(label_smoothing=config['label_smoothing'])

    if config['opt'] == 'sgd':
        optimizer = optim.SGD(
            list(encoder.parameters()) + list(classifier.parameters()),
            lr=config["lr"],
            momentum=config["momentum"],
            weight_decay=config["weight_decay"]
        )
    elif config['opt'] == 'adam':
        optimizer = optim.Adam(
            list(encoder.parameters()) + list(classifier.parameters()),
            lr=config["lr"],
            weight_decay=config["weight_decay"]
        )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["T_max"])

    for epoch in range(config['num_epochs']):
        encoder.train()
        classifier.train()
        train_loss = 0
        correct = 0
        total = 0

        # Training loop
        for inputs, targets in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            with torch.amp.autocast_mode.autocast(device_type=device, dtype=torch.float16):
                features = encoder(inputs)
                outputs = classifier(features)
                
            loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        # Evaluation phase
        encoder.eval()
        classifier.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, targets in valloader:
                inputs, targets = inputs.to(device), targets.to(device)
                features = encoder(inputs)
                outputs = classifier(features)
                loss = criterion(outputs, targets)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        # Calculate test loss and accuracy
        val_loss /= len(valloader)
        val_acc = 100. * val_correct / val_total

        # Update scheduler
        scheduler.step()

        # Report metrics to Ray Tune
        ray.train.report(dict(loss=val_loss, accuracy=val_acc, from_epoch=epoch))
"""

In [None]:
"""
# Custom trial directory name creator
def custom_trial_name(trial):
    return f"trial_{trial.trial_id}"

# Define the search space for hyperparameters
search_space = {
    "arch": tune.grid_search(['ResNet18', 'ResNet34']),
    "opt": tune.grid_search(['sgd', 'adam']),
    "lr": tune.loguniform(1e-4, 1e-1),
    "momentum": tune.uniform(0.5, 0.9),
    "weight_decay": tune.loguniform(1e-6, 1e-2),
    "T_max": tune.choice([100, 200]),
    "label_smoothing": tune.uniform(0.0, 0.2),
    "num_epochs": tune.grid_search([15]) 
}


# Define the scheduler for Ray Tune
scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=15,  # Maximum number of epochs for any trial
    grace_period=7,  # Minimum number of epochs before a trial can be stopped
    reduction_factor=2  # How aggressively trials are stopped (larger means more aggressive)
)

# Launch Ray Tune hyperparameter tuning
analysis = tune.run(
    train_and_evaluate,
    resources_per_trial={"cpu": 20, "gpu": 1},
    config=search_space,
    num_samples=5,
    scheduler=scheduler,
    trial_dirname_creator=custom_trial_name  # Use custom trial name
)
"""

### Model definition

The code is set to parallelize if multiple GPUs are available

In [None]:
if MODEL_NAME == 'ResNet18':
    encoder, classifier = ResNet18()
elif MODEL_NAME == 'ResNet34':
    encoder, classifier = ResNet34()
elif MODEL_NAME == 'ResNet50':
    encoder, classifier = ResNet50()
elif MODEL_NAME == 'ResNet101':
    encoder, classifier = ResNet101()
elif MODEL_NAME == 'ResNet152':
    encoder, classifier = ResNet152()


if device == 'cuda':
    encoder = torch.nn.DataParallel(encoder).to(device)
    classifier = torch.nn.DataParallel(classifier).to(device)
    cudnn.benchmark = True

 ### Optimizer and scheduler configuration

These are the optimizers and scheduler with the best hyperparameters found via **Ray Tune** for each dataset.

Again, it is not advised to use *Adam* optimizer, since the paramters updates are too strong and will compromise the Constrastive Unlearning procedure.

I still left the option to do so though, just in case

In [None]:
# BEST CONFIGURATION ADAM
if OPTIMIZER_NAME == 'adam':
    optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.000641388, weight_decay=0.567643)

# BEST CONFIGURATION SGD
elif OPTIMIZER_NAME == 'sgd':
    optimizer = optim.SGD(list(encoder.parameters()) + list(classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

### Train, Validate and Test functions

I encapsulated the whole training pipeline into functions for better understanding of what happens.

Also, they are convinient for keeping track of the changes made on the hyperbolic variants

In [None]:
def train(encoder, classifier, loader, criterion, optimizer, scaler, hyperbolic='', manifold=None, device='cuda'):
    """
    The training pipeline. It takes as input all the necessary components to perform model training
    
    Args:
    - encoder: ResNet encoder
    - classifier: ResNet classifier
    - loader: the DataLoader of the trainset
    - criterion: the loss function (CrossEntropyLoss)
    - optimizer: do the updates to the parameters (SGD or Adam)
    - scaler: it is needed to performe training with mixed precision (16 float point, instead of 32)
    - hyperbolic: if True it means that the hyperbolic ResNet is running, and so some adjustment are needed to the embeddings  

    Returns
    - average_loss: the average of the loss computed by dividing the loss accumulation by the lenght of the loader of the train set
    - acc: the accuracy of the model on the train set
    """

    progress_bar_train = tqdm(enumerate(loader), total=len(loader))

    encoder.train()
    classifier.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in progress_bar_train:
        inputs, targets = inputs.to(device), targets.to(device)

        # The projection is needed ony if the whole model is hyperbolic
        if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
            # move the inputs to the manifold
            tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
            inputs = manifold.expmap(tangents)

        optimizer.zero_grad()

        with torch.amp.autocast_mode.autocast(device_type=device, dtype=torch.float16):
            features = encoder(inputs)
            outputs = classifier(features)

            # Indipendently of the hyperbolic model, we need to extract the tensor  
            if hyperbolic in ['complete-ResNet', 'hybrid-ResNet', 'academic-ResNet']:
                # Needed for the loss computation
                outputs = outputs.tensor # This gives the underlying PyTorch tensor
            
            loss = criterion(outputs, targets)
    
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar_train.set_description('TRAIN | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    average_loss = train_loss/len(loader)
    acc = 100.*correct/total

    return average_loss, acc

The validation function save a checkpoint of the model when a new best accuracy is reached

In [None]:
def validate(epoch, encoder, classifier, loader, criterion, best_acc, hyperbolic='', manifold=None, ckpt_name='', device='cuda'):
    """
    The validation pipeline. It takes as input all the necessary components to perform model validation
    
    Args:
    - epoch: needed for finding which epoch produced the best accuracy result of the model
    - encoder: ResNet encoder
    - classifier: ResNet classifier
    - loader: the DataLoader of the validationset
    - criterion: the loss function (CrossEntropyLoss)
    - best_acc: it keeps trace of the best accuracy so far (needed to save the model checkpoint)
    - hyperbolic: if True it means that the hyperbolic ResNet is running, and so some adjustment are needed to the embeddings
    - ckpt_name: the name of the checkpoint

    Returns:
    - average_loss: the average of the loss computed by dividing the loss accumulation by the lenght of the loader of the validation set
    - best_acc: the best accuracy of the model on the validation set
    """

    progress_bar_val = tqdm(enumerate(loader), total=len(loader))

    encoder.eval()
    classifier.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in progress_bar_val:
            inputs, targets = inputs.to(device), targets.to(device)

            # The projection is needed ony if the whole model is hyperbolic
            if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
                # move the inputs to the manifold
                tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
                inputs = manifold.expmap(tangents)

            features = encoder(inputs)
            outputs = classifier(features)
            
            # Indipendently of the hyperbolic model, we need to extract the tensor  
            if hyperbolic in ['complete-ResNet', 'hybrid-ResNet', 'academic-ResNet']:
                # Needed for the loss computation
                outputs = outputs.tensor # This gives the underlying PyTorch tensor
                
            loss = criterion(outputs, targets)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            acc = 100.*correct/total

            progress_bar_val.set_description('VAL | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (val_loss/(batch_idx+1), acc, correct, total))

    # Save checkpoint.
    average_loss = val_loss / len(loader)

    if acc > best_acc:
        print('Saving checkpoint..')
        state = {
            'encoder': encoder.state_dict(),
            'classifier': classifier.state_dict(),
            'loss': average_loss,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/{}'.format(ckpt_name))
        best_acc = acc

    return average_loss, best_acc

In [None]:
def test(encoder, classifier, loader, criterion, hyperbolic='', manifold=None, device='cuda'):
    """
    The test pipeline. It takes as input all the necessary components to perform model testing
    
    Args:
    - encoder: ResNet encoder
    - classifier: ResNet classifier
    - loader: the DataLoader of the testeset
    - criterion: the loss function (CrossEntropyLoss)
    - hyperbolic: if True it means that the hyperbolic ResNet is running, and so some adjustment are needed to the embeddings

    Returns:
    - average_loss: the average of the loss computed by dividing the loss accumulation by the lenght of the loader of the test set
    - best_acc: the best accuracy of the model on the test set
    """
    
    progress_bar_test = tqdm(enumerate(loader), total=len(loader))

    encoder.eval()
    classifier.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in progress_bar_test:
            inputs, targets = inputs.to(device), targets.to(device)

            # The projection is needed ony if the whole model is hyperbolic
            if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
                # move the inputs to the manifold
                tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
                inputs = manifold.expmap(tangents)

            features = encoder(inputs)
            outputs = classifier(features)
            
            # Indipendently of the hyperbolic model, we need to extract the tensor  
            if hyperbolic in ['complete-ResNet', 'hybrid-ResNet', 'academic-ResNet']:
                # Needed for the loss computation
                outputs = outputs.tensor # This gives the underlying PyTorch tensor
                
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            acc = 100.*correct/total

            progress_bar_test.set_description('TEST | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), acc, correct, total))

    average_loss = test_loss / len(loader)

    return average_loss, acc

### Start training

Since the article don't use any data augmentations, the model tend to overfit.

Said so, in order to replicate the results and test them in hyperbolic space, i have to keep this configuration.

The EarlyStopper class is used to stop the training when the validation accuracy stop performing meaningfull improvements.

In [None]:
# Initialize dictionaries to track the loss and accuracy
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []

val_acc = 0 

scaler = torch.amp.grad_scaler.GradScaler()

early_stopper = EarlyStopper()

for epoch in range(NUM_EPOCHS):

    print('\nEpoch: %d' % epoch)

    average_train_loss, train_acc = train(encoder=encoder, 
                                          classifier=classifier, 
                                          loader=trainloader, 
                                          criterion=CRITERION, 
                                          optimizer=optimizer, 
                                          scaler=scaler, 
                                          hyperbolic='', 
                                          manifold=None, 
                                          device=device)
    
    train_loss_history.append(average_train_loss)
    train_acc_history.append(train_acc)
    
    average_val_loss, val_acc = validate(epoch=epoch, 
                                         encoder=encoder, 
                                         classifier=classifier, 
                                         loader=valloader, 
                                         criterion=CRITERION, 
                                         best_acc=val_acc, 
                                         hyperbolic='', 
                                         manifold=None, 
                                         ckpt_name=EUCL_ORIGINAL_CKPT_NAME, 
                                         device=device)
    
    val_loss_history.append(average_val_loss)
    val_acc_history.append(val_acc)

    if early_stopper.early_stop(average_val_loss):  
        print('\nEarly stopper activated')           
        break

    scheduler.step()

### Plot loss and accuracy over the epochs

In [None]:
plot_training_history(train_loss_history, val_loss_history, train_acc_history, val_acc_history)

### Test the model

Load best checkpoint.

It is needed since there is no guarantee that the latest epochs produced the most accurate parameters.

Also comes in handy for quick testing

In [None]:
print('Resuming from checkpoint...')

assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}'.format(EUCL_ORIGINAL_CKPT_NAME))
encoder.load_state_dict(checkpoint['encoder'])
classifier.load_state_dict(checkpoint['classifier'])

print('Checkpoint loaded!')
print(f'\nValidation average loss: {checkpoint['loss']:.3f}')
print(f'Validation set accuracy: {checkpoint['acc']:.3f}%')
print(f'From epoch: {checkpoint['epoch']}')

Try the model accuracy over some samples

In [None]:
dataiter = iter(testloader)
images, labels = next(dataiter)

encoder.eval()
classifier.eval()

with torch.no_grad():
    features = encoder(images)
    outputs = classifier(features)
    _, predicted = torch.max(outputs, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

Test the model overall performance

In [None]:
test_loss, test_acc = test(encoder=encoder, 
                           classifier=classifier, 
                           loader=testloader, 
                           criterion=CRITERION, 
                           hyperbolic='', 
                           manifold=None, 
                           device=device)

print(f'\nTest Set - Loss: {test_loss:.3f}, Accuracy: {test_acc:.3f}%')

### Visualize feature space

You can choose which class samples to display by toggling the class names.

**WARNING**: As said in the *Configuration* section, Plotly has some problems with Visual Studio Code. To prevent strange behaviour by the plots, it is advised to delete the cell plot output first if you need to run again the cell

In [None]:
# Extract features from the test set, and plot them
features, labels = extract_features(encoder, testloader, hyperbolic='', manifold=None, device=device)
create_plot(features, labels, classes, dimension=2, convexhull=False)

## Trainset, Valset and Testset splitted into: sample to unlearn, sample to mantain

This is a crucial step which is needed also in the unlearning procedure.

Everyone of the three original sets is splitted into the samples we want to unlearn and the ones we want to mantain.

**ATTENTION**: in the case of *random-sample* scenario, the *"remaining_testset"* is the whole original test set. This because, in this scenario the goal is to preserve the accuracy of the models even minus these random samples.

**WARNING**: in the case of *random-sample* scenario, running this cell multiple times WILL COMPROMISE the final result. This is not an error; it's expected behavior, so it's not recommended to proceed with it.

In [None]:
def remove_samples(dataset, class_to_remove, samples_to_remove):
    """
    Split samples based on scenario (class or random) into two datasets: 
    unlearning and remaining.
    
    Args:
    - dataset: the dataset to filter (e.g., CIFAR-10, SVHN)
    - class_to_remove: when the scenario is 'single-class' it indicates the class we want to unlearn
    - samples_to_remove: when the scenario is 'random-sample' it indicates the amount of random samples we want to unlearn
    
    Returns:
    - unlearning_dataset: Dataset containing the selected samples for unlearning
    - remaining_dataset: Dataset with the remaining samples
    """

    if SCENARIO == 'single-class':
        indices_to_remove = [i for i, (img, label) in enumerate(dataset) if label == class_to_remove]
        indices_to_keep = [i for i, (img, label) in enumerate(dataset) if label != class_to_remove]

    elif SCENARIO == 'random-sample':
        total_indices = list(range(len(dataset)))
        indices_to_remove = random.sample(total_indices, samples_to_remove)
        indices_to_keep = list(set(total_indices) - set(indices_to_remove))
    
    # Create the unlearning and remaining datasets
    unlearning_dataset = torch.utils.data.Subset(dataset, indices_to_remove)
    remaining_dataset = torch.utils.data.Subset(dataset, indices_to_keep)

    return unlearning_dataset, remaining_dataset, indices_to_remove
            
if SCENARIO == 'single-class': 
    # here the unlearned class is removed from every set
    class_to_remove = 5
    unlearning_trainset, remaining_trainset, _ = remove_samples(trainset, class_to_remove=class_to_remove, samples_to_remove=None)
    _, remaining_valset, _ = remove_samples(valset, class_to_remove=class_to_remove, samples_to_remove=None) 
    unlearning_testset , remaining_testset, unlearning_test_indices = remove_samples(testset, class_to_remove=class_to_remove, samples_to_remove=None) 

    # Next to each magazine there is a label indicating which part it will be used in
    remaining_trainloader = torch.utils.data.DataLoader(remaining_trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) #retrain #unlearning
    unlearning_trainloader = torch.utils.data.DataLoader(unlearning_trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) #unlearning #experiments
    remaining_valloader = torch.utils.data.DataLoader(remaining_valset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) #retrain #validation
    remaining_testloader = torch.utils.data.DataLoader(remaining_testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) #retest #experiments
    unlearning_testloader = torch.utils.data.DataLoader(unlearning_testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) #unlearning #experiments

elif SCENARIO == 'random-sample':
    samples_to_remove = 500
    unlearning_trainset, remaining_trainset, unlearning_train_indices = remove_samples(trainset, class_to_remove=None, samples_to_remove=samples_to_remove)
    remaining_testset = testset

    remaining_trainloader = torch.utils.data.DataLoader(remaining_trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) #retrain #unlearning
    unlearning_trainloader = torch.utils.data.DataLoader(unlearning_trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True) #unlearning #experiments
    remaining_valloader = valloader #retrain
    remaining_testloader = testloader   

## Retrain model

In this section, another instantiation of the model will be trained without the samples that we want to forget (from a single class or with random sampling).

Most of the functions are reused from the previous chapters, so refer to them if needed

### Choose the model

In [None]:
if MODEL_NAME == 'ResNet18':
    retrain_encoder, retrain_classifier = ResNet18()
elif MODEL_NAME == 'ResNet34':
    retrain_encoder, retrain_classifier = ResNet34()
elif MODEL_NAME == 'ResNet50':
    retrain_encoder, retrain_classifier = ResNet50()
elif MODEL_NAME == 'ResNet101':
    retrain_encoder, retrain_classifier = ResNet101()
elif MODEL_NAME == 'ResNet152':
    retrain_encoder, retrain_classifier = ResNet152()


if device == 'cuda':
    retrain_encoder = torch.nn.DataParallel(retrain_encoder).to(device)
    retrain_classifier = torch.nn.DataParallel(retrain_classifier).to(device)

### Optimizer and scheduler configuration

In [None]:
# BEST CONFIGURATION ADAM
if OPTIMIZER_NAME == 'adam':
    retrain_optimizer = optim.Adam(list(retrain_encoder.parameters()) + list(retrain_classifier.parameters()), lr=0.000641388, weight_decay=0.567643)	

# BEST CONFIGURATION SGD
elif OPTIMIZER_NAME == 'sgd':
    retrain_optimizer = optim.SGD(list(retrain_encoder.parameters()) + list(retrain_classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885) 
    
retrain_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(retrain_optimizer, T_max=200)

### Start processing

In [None]:
# Initialize dictionaries to track the loss and accuracy
retrain_train_loss_history = []
retrain_val_loss_history = []
retrain_train_acc_history = []
retrain_val_acc_history = []

retrain_val_acc = 0 

scaler = torch.amp.grad_scaler.GradScaler()

early_stopper = EarlyStopper()

for epoch in range(NUM_EPOCHS):

    print('\nEpoch: %d' % epoch)

    retrain_average_train_loss, retrain_train_acc = train(encoder=retrain_encoder, 
                                                          classifier=retrain_classifier, 
                                                          loader=remaining_trainloader, 
                                                          criterion=CRITERION, 
                                                          optimizer=retrain_optimizer, 
                                                          scaler=scaler, 
                                                          hyperbolic='', 
                                                          manifold=None, 
                                                          device=device)
    
    retrain_train_loss_history.append(retrain_average_train_loss)
    retrain_train_acc_history.append(retrain_train_acc)

    retrain_average_val_loss, retrain_val_acc = validate(epoch=epoch, 
                                                         encoder=retrain_encoder, 
                                                         classifier=retrain_classifier, 
                                                         loader=remaining_valloader, 
                                                         criterion=CRITERION, 
                                                         best_acc=retrain_val_acc, 
                                                         hyperbolic='', 
                                                         manifold=None, 
                                                         ckpt_name=EUCL_RETRAIN_CKPT_NAME, 
                                                         device=device)
    
    retrain_val_loss_history.append(retrain_average_val_loss)
    retrain_val_acc_history.append(retrain_val_acc)

    if early_stopper.early_stop(retrain_average_val_loss):
        print('\nEarly stopper activated') 
        break

    retrain_scheduler.step()

### Plot the loss

In [None]:
plot_training_history(retrain_train_loss_history, retrain_val_loss_history, retrain_train_acc_history, retrain_val_acc_history)

### Test the model

Load best checkpoint.

It is needed since there is no guarantee that the latest epochs produced the most accurate parameters.

Also comes in handy for quick testing

In [None]:
print('Resuming from checkpoint...')

assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}'.format(EUCL_RETRAIN_CKPT_NAME))
retrain_encoder.load_state_dict(checkpoint['encoder'])
retrain_classifier.load_state_dict(checkpoint['classifier'])

print('Checkpoint loaded!')
print(f'\nValidation average loss: {checkpoint['loss']:.3f}')
print(f'Validation set accuracy: {checkpoint['acc']:.3f}%')
print(f'From epoch: {checkpoint['epoch']}')

Try the model precision over some samples of the *remaining test set*

In [None]:
dataiter = iter(remaining_testloader)
images, labels = next(dataiter)

retrain_encoder.eval()
retrain_classifier.eval()

with torch.no_grad():
    features = retrain_encoder(images)
    outputs = retrain_classifier(features)

    _, predicted = torch.max(outputs, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

Test the model overall performance on *remaining samples* and *unlearned samples*

In [None]:
retrain_average_remaining_test_loss, retrain_best_remaining_test_acc = test(encoder=retrain_encoder, 
                                                                            classifier=retrain_classifier, 
                                                                            loader=remaining_testloader, 
                                                                            criterion=CRITERION, 
                                                                            hyperbolic='', 
                                                                            manifold=None, 
                                                                            device=device)

retrain_average_unl_train_loss, retrain_best_unl_train_acc = test(encoder=retrain_encoder, 
                                                                classifier=retrain_classifier, 
                                                                loader=unlearning_trainloader, 
                                                                criterion=CRITERION, 
                                                                hyperbolic='', 
                                                                manifold=None, 
                                                                device=device)

if SCENARIO == 'single-class':

    retrain_average_unl_test_loss, retrain_best_unl_test_acc = test(encoder=retrain_encoder, 
                                                                    classifier=retrain_classifier, 
                                                                    loader=unlearning_testloader, 
                                                                    criterion=CRITERION, 
                                                                    hyperbolic='', 
                                                                    manifold=None, 
                                                                    device=device)
    
print(f'\nRemaining Test Set Loss: {retrain_average_remaining_test_loss:.3f}, Accuracy: {retrain_best_remaining_test_acc:.3f}%')
print(f'Unlearning Train Set Loss: {retrain_average_unl_train_loss:.3f}, Accuracy: {retrain_best_unl_train_acc:.3f}%')
print(f'Unlearning Test Set Loss: {retrain_average_unl_test_loss:.3f}, Accuracy: {retrain_best_unl_test_acc:.3f}%') if SCENARIO == 'single-class' else None

### Visualize feature space

You can choose which class samples to display by toggling the class names 

In [None]:
if SCENARIO == 'single-class':

    # Extract features from the test set
    remaining_features, remaining_labels = extract_features(retrain_encoder, remaining_testloader, hyperbolic='', manifold=None, device=device)
    unlearning_features, unlearning_labels = extract_features(retrain_encoder, unlearning_testloader, hyperbolic='', manifold=None, device=device)

    # Combine for t-SNE
    features = np.concatenate((unlearning_features, remaining_features), axis=0)
    labels = np.concatenate((unlearning_labels, remaining_labels), axis=0)

# In this case, there is no specific class to isolate, so the plotting goal is for all the models (original, retrain and unlearned) to plot a similar distribution of the whole test set
elif SCENARIO == 'random-sample':

    # Extract features from the test set
    features, labels = extract_features(retrain_encoder, remaining_testloader, hyperbolic='', manifold=None, device=device)


create_plot(features, labels, classes, dimension=2, convexhull=False)

## Unlearning procedure

### Unlearning validation set configuration

The unlearning algorithm takes as input:

1. The unlearning samples from the trainset
2. The reamining samples from the trainset
3. A validation set built like this:
    - for *single-class* unlearning -> unlearning trainset samples
    - for *random-sample* unlearning -> **subset**(unlearning trainset samples) + **subset**(original testset)

For the contrastive unlearning procedure, the *unlearning validation set* is built following these rules, and it differs from the previous validation set

In [None]:
# Used only for the random-sample scenario
def sample_validation_set(unlearning_trainset, testset, ratio=0.5):
    """
    Create a two distinct validation sets for the sample unlearning case from a subset of 
    unlearning samples in the unlearning trainset and a subset of samples from the original test set.
    
    Args:
    - unlearning_trainset: the dataset of unlearning samples (will take a Subset)
    - testset: the original test set (will take a Subset)
    - unlearning_trainset_indices: indices of the unlearning samples in the original trainset
    - ratio: proportion of samples to use for the validation set (default is 10%)
    
    Returns:
    - two eval_set: one for unlearning samples and one for test set samples
    """

    # Randomly select a subset from the unlearning samples
    total_indices = list(range(len(unlearning_trainset)))
    eval_unlearning_trainset_indices = random.sample(total_indices, int(len(unlearning_trainset) * ratio))
    
    # Randomly select a subset from the test set
    total_test_indices = list(range(len(testset)))
    eval_test_indices = random.sample(total_test_indices, int(len(total_test_indices) * ratio))

    return torch.utils.data.Subset(unlearning_trainset, eval_unlearning_trainset_indices), torch.utils.data.Subset(testset, eval_test_indices)

if SCENARIO == 'single-class':
    unl_valset_test = unlearning_testset
    unl_valloader_test = unlearning_testloader

elif SCENARIO == 'random-sample':
    unl_subset_trainset, unl_subset_testset = sample_validation_set(unlearning_trainset, testset, ratio=0.5)
    unl_valloader_train = torch.utils.data.DataLoader(unl_subset_trainset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    unl_valloader_test = torch.utils.data.DataLoader(unl_subset_testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

### Load original checkpoint

**WARNING**: the encoder and classifier here are the original ones (instantiated in the first part of the notebook).

During the unlearning procedure, the original model parameters will be overwritten.

So if you go back and try inference on the original model, the output will change (for worse, obviously).

*This checkpoint cell is just for quick testing*

In [None]:
print('Resuming from checkpoint...')

assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}'.format(EUCL_ORIGINAL_CKPT_NAME))
encoder.load_state_dict(checkpoint['encoder'])
classifier.load_state_dict(checkpoint['classifier'])

print('Checkpoint loaded!')
print(f'\nValidation average loss: {checkpoint['loss']:.3f}')
print(f'Validation set accuracy: {checkpoint['acc']:.3f}%')
print(f'From epoch: {checkpoint['epoch']}')

### Contrastive unlearning loss

This is the custom ***Constrastive Unlearning*** loss as described in the original article

In [None]:
class ContrastiveUnlearningLoss(nn.Module):
    def __init__(self, temperature):
        super(ContrastiveUnlearningLoss, self).__init__()
        self.temperature = temperature

    def forward(self, embeddings_u, embeddings_r, labels_u, labels_r):
        unlearning_batch_size = embeddings_u.size(0)
        loss = 0.0

        # Normalize embeddings
        embeddings_u = F.normalize(embeddings_u, p=2, dim=1)
        embeddings_r = F.normalize(embeddings_r, p=2, dim=1)

        for i in range(unlearning_batch_size):
            # Anchor: embedding of unlearning sample
            z_u = embeddings_u[i]

            if SCENARIO == 'single-class':
                negatives = embeddings_r  # All remaining samples are negative
                neg_sim = torch.exp(torch.matmul(z_u, negatives.T) / self.temperature)

                denominator = neg_sim.size(0) + 1e-8

                inner_loss = torch.log(neg_sim / denominator).sum() if len(neg_sim) > 0 else torch.tensor(0.0)
                loss += (-1 / (neg_sim.size(0) + 1e-8)) * inner_loss

                if len(neg_sim) == 0:
                    print("Warning: neg_sim is zero or too small.")

            elif SCENARIO == 'random-sample':
                pos_mask = labels_r == labels_u[i]
                neg_mask = ~pos_mask

                positives = embeddings_r[pos_mask]   # Same class as anchor
                negatives = embeddings_r[neg_mask]   # Different class from anchor

                # Calculate similarities
                pos_sim = torch.exp(torch.matmul(z_u, positives.T) / self.temperature) if len(positives) > 0 else torch.tensor(0.0)
                neg_sim = torch.exp(torch.matmul(z_u, negatives.T) / self.temperature)

                # Updated denominator with both positive and negative similarities
                denominator = pos_sim.sum() + 1e-8 

                # Calculate inner loss
                inner_loss = (torch.log(neg_sim / denominator)).sum() if len(neg_sim) > 0 else torch.tensor(0.0)

                # Add to total loss
                loss += (-1 / (neg_sim.size(0) + 1e-8)) * inner_loss  # Avoid divide-by-zero

        return loss / unlearning_batch_size  # Average loss per batch

### Unlearning optimizer and scheduler configuration

Here an assumption is made: the training hyperparamters are also good for unlearning.

After some personal tests, it seems that this holds

In [None]:
# BEST CONFIGURATION ADAM
if OPTIMIZER_NAME == 'adam':
    unl_optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.000641388, weight_decay=0.567643)	

# BEST CONFIGURATION SGD
elif OPTIMIZER_NAME == 'sgd':
    unl_optimizer = optim.SGD(list(encoder.parameters()) + list(classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885) 

unl_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(unl_optimizer, T_max=200)

### DO NOT UNCOMMENT THESE: Unlearning hyperparameters tuning

In [None]:
"""
# Put large objects in Ray's object store for efficient memory management
large_unlearning_trainloader = ray.put(unlearning_trainloader)
if SCENARIO == 'random-sample':
    large_unl_valloader_train = ray.put(unl_valloader_train)
large_unl_valloader_test = ray.put(unl_valloader_test)
large_remaining_trainset = ray.put(remaining_trainset)

checkpoint = torch.load('./checkpoint/{}'.format(EUCL_ORIGINAL_CKPT_NAME))
large_checkpoint = ray.put(checkpoint)
"""

In [None]:
"""
# Calculate loss and accuracy of a data loader
def get_loss_and_accuracy(encoder, classifier, data_loader, criterion):
        
        correct = 0
        total = 0
        total_loss = 0.0
        encoder.eval()
        classifier.eval()
        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = classifier(encoder(inputs))
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        accuracy = correct / total if total > 0 else 0
        avg_loss = total_loss / len(data_loader) if len(data_loader) > 0 else 0
        
        return avg_loss, accuracy


# Validation function to evaluate the unlearned efficency
def validate_unlearning(encoder, classifier, valloader_train, valloader_test, classification_loss, class_count):

    if SCENARIO == 'single-class':
        # For single-class unlearning, accuracy on the unlearning class should be <= 1/class_count
        unlearning_loss, unlearning_accuracy = get_loss_and_accuracy(encoder, classifier, valloader_test, classification_loss)
        verdict = unlearning_accuracy <= 1 / class_count
        #print('Unlearning class Loss: {:.4f}, Unlearning Accuracy: {:.4f}, Termination condition reached: {}'.format(unlearning_loss, unlearning_accuracy, verdict))
    
    elif SCENARIO == 'random-sample':
        # For sample-unlearning, accuracy on unlearning samples <= accuracy on the test samples
        unlearning_loss, unlearning_accuracy = get_loss_and_accuracy(encoder, classifier, valloader_train, classification_loss)
        test_loss, test_accuracy = get_loss_and_accuracy(encoder, classifier, valloader_test, classification_loss)
        verdict = unlearning_accuracy <= test_accuracy
        #print('Unlearning sample Loss: {:.4f}, Unlearning Accuracy: {:.4f}, Test loss: {:.4f}, Test accuracy: {:.4f}, Termination condition reached: {}'.format(unlearning_loss, unlearning_accuracy, test_loss, test_accuracy, verdict))

    return unlearning_accuracy, verdict
    


# Validation function to evaluate the remaining accuracy
def validate_containment(encoder, classifier, remaining_testloader, classification_loss, class_count):

    # For single-class unlearning, accuracy on the unlearning class should be <= 1/class_count
    remaining_loss, remaining_accuracy = get_loss_and_accuracy(encoder, classifier, remaining_testloader, classification_loss)
    #print('Remaining Test Set Loss: {:.4f}, Accuracy: {:.4f}'.format(remaining_loss, remaining_accuracy))
    return remaining_accuracy
"""

In [None]:
"""
# Main training loop with validation
def train_contrastive_unlearning(config):

    # DA TOGLIERE INSIEME A DATA PARALLEL
    def remove_module_prefix(state_dict):
        return {k.replace("module.", ""): v for k, v in state_dict.items()}
    
    scaler = torch.amp.grad_scaler.GradScaler()

    encoder, classifier = ResNet34()

    checkpoint = ray.get(large_checkpoint)

    encoder_state = remove_module_prefix(checkpoint['encoder'])
    classifier_state = remove_module_prefix(checkpoint['classifier'])

    encoder.load_state_dict(encoder_state)
    classifier.load_state_dict(classifier_state)

    encoder = encoder.to(device)
    classifier = classifier.to(device)

    unlearning_trainloader = ray.get(large_unlearning_trainloader)
    remaining_trainset = ray.get(large_remaining_trainset)
    if SCENARIO == 'random-sample':
        unl_valloader_train = ray.get(large_unl_valloader_train)
    unl_valloader_test = ray.get(large_unl_valloader_test)
    
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    if config['opt'] == 'sgd':
        optimizer = optim.SGD(list(encoder.parameters()) + list(classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885)
    elif config['opt'] == 'adam':
        optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.000641388, weight_decay=0.567643)
    
    class_count = 10

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


    for epoch in range(config["num_epochs"]):

        encoder.train()
        classifier.train()

        total_loss_epoch = 0.0  # To accumulate loss over the epoch
        total_batches = 0
        
        # Iterate over unlearning dataset (D_u)
        for x_u, y_u in unlearning_trainloader:
            x_u, y_u = x_u.to(device), y_u.to(device)
            optimizer.zero_grad()
            batch_loss = 0.0  # To accumulate loss over the batch
            
            # Perform ω iterations over the remaining dataset for each unlearning batch
            for _ in range(config["omega"]):  # Omega loop

                # Create a DataLoader with a RandomSampler
                random_sampler = torch.utils.data.RandomSampler(remaining_trainset)
                remaining_trainset_loader = torch.utils.data.DataLoader(remaining_trainset, sampler=random_sampler, batch_size=128, pin_memory=True)

                for remaining_data in remaining_trainset_loader:
                    x_r, y_r = remaining_data
                    break

                x_r, y_r = x_r.to(device), y_r.to(device)

                # Use mixed precision
                with torch.amp.autocast_mode.autocast(device_type=device, dtype=torch.float16):
                    # Forward pass for remaining samples (classification loss)
                    z_r = encoder(x_r) 
                    logits_r = classifier(z_r)

                    # Forward pass for unlearning samples
                    z_u = encoder(x_u)
                    logits_u = classifier(z_u)

                ce_loss = criterion(logits_r, y_r)
                # Compute contrastive unlearning loss
                loss_fn = ContrastiveUnlearningLoss(temperature=config["temperature"])
                ul_loss = loss_fn(z_u, z_r, y_u, y_r)

                # Total loss: classification loss + contrastive unlearning loss
                total_loss = config["regularizer_ul"] * ul_loss + config["regularizer_ce"] * ce_loss
                batch_loss += total_loss.item()

                scaler.scale(total_loss).backward()
                
            scaler.step(optimizer)
            scaler.update()

            total_batches += 1
            total_loss_epoch += batch_loss    
            
        # Compute validation for termination criteria
        if SCENARIO == 'single-class':
            unlearned_acc, verdict = validate_unlearning(encoder, classifier, None, unl_valloader_test, CRITERION, class_count=class_count)
        elif SCENARIO == 'random-sample':
            unlearned_acc, verdict = validate_unlearning(encoder, classifier, unl_valloader_train, unl_valloader_test, CRITERION, class_count=class_count)

        # Compute accuracy retain over the remaining samples
        remaining_acc = validate_containment(encoder, classifier, remaining_testloader, CRITERION, class_count=class_count)

        scheduler.step()

        # report results to Ray Tune
        ray.train.report(dict(unlearned_accuracy=unlearned_acc, remaining_accuracy=remaining_acc))       

        # Termination criteria
        if verdict: 
            print("\nTerminating unlearning procedure")
            break
"""

In [None]:
"""
def custom_trial_name(trial):
    return f"trial_unlearning{trial.trial_id}"

# Define the hyperparameter search space
search_space = {
    "num_epochs": tune.grid_search([NUM_EPOCHS]),
    "opt": tune.grid_search(['sgd']),
    "temperature": tune.uniform(0.1, 1.0),
    "regularizer_ce": tune.uniform(0.1, 1.0),
    "regularizer_ul": tune.uniform(0.1, 1.0),
    "omega": tune.choice([2, 4, 6]),
}

# Use partial to pass fixed_params
analysis = tune.run(
    train_contrastive_unlearning,
    resources_per_trial={"cpu": 20, "gpu": 1},
    config=search_space,
    num_samples=50,
    scheduler=ASHAScheduler(metric="remaining_accuracy", mode="max"),    
    trial_dirname_creator=custom_trial_name
)
"""

### Validate Unlearning

There are two validation functions:

- *validate_unlearning* test how efficient the unlearning is
- *validate_containment* test how efficiently the model mantain knowledge on the rest of the data

In [None]:
# Calculate loss and accuracy
def get_loss_and_accuracy(encoder, classifier, data_loader, criterion, set_name='', hyperbolic='', manifold=None, device='cuda'):
    """
    Validate the unlearning process based on the scenario.
    Args:
        encoder: The encoder model to extract features.
        classifier: The classifier head model for predictions.
        data_loader: The loader of a specific dataset. 
        criterion: Loss function (e.g., CrossEntropyLoss).
        set_name: the name of the set it is running (for the tqdm bar)
    
    Returns:
        avg_loss: The average loss of the model on the data_loader
        accuracy: The accuracy of the model on the data_loader
    """

    progress_bar_val = tqdm(enumerate(data_loader), total=len(data_loader))
    
    correct = 0
    total = 0
    total_loss = 0.0
    encoder.eval()
    classifier.eval()
    with torch.no_grad():
        for batch_idx, (inputs, labels) in progress_bar_val:
            inputs, labels = inputs.to(device), labels.to(device)

            # The projection is needed ony if the whole model is hyperbolic
            if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
                # move the inputs to the manifold
                tangents = TangentTensor(data=inputs, man_dim=1, manifold=manifold)
                inputs = manifold.expmap(tangents)

            outputs = classifier(encoder(inputs))

            # Indipendently of the hyperbolic model, we need to extract the tensor  
            if hyperbolic in ['complete-ResNet', 'hybrid-ResNet', 'academic-ResNet']:
                # Needed for the loss computation
                outputs = outputs.tensor # This gives the underlying PyTorch tensor
                
            loss = criterion(outputs, labels)
    
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            progress_bar_val.set_description('UNLEARNING VAL | %s | Loss: %.3f, Accuracy: %.3f%% (%d/%d)' % (set_name, total_loss/(batch_idx+1), 100. * correct/total, correct, total))

    accuracy = correct / total if total > 0 else 0
    avg_loss = total_loss / len(data_loader) if len(data_loader) > 0 else 0
    
    return avg_loss, accuracy



# Validation function to evaluate the unlearned efficency
def validate_unlearning(encoder, classifier, valloader_train, valloader_test, classification_loss, class_count, hyperbolic='', manifold=None, device='cuda'):
    """
    Validate the unlearning process based on the scenario.
    Args:
        encoder: The encoder model to extract features.
        classifier: The classifier head model for predictions.
        valloader_train: the unlearning trainset subset DataLoader.
        valloader_test: the testset subset DataLoader.
        classification_loss: Loss function (e.g., CrossEntropyLoss).
        class_count: Total number of classes in the dataset (for random guess accuracy threshold).
    
    Returns:
        unlearning_accuracy: the accuracy of the model on the unlearning validation set
        verdict: A boolean value indicating if the termination condition is met.
    """

    if SCENARIO == 'single-class':
        # For single-class unlearning, accuracy on the unlearning class should be <= 1/class_count
        unlearning_loss, unlearning_accuracy = get_loss_and_accuracy(encoder, classifier, valloader_test, classification_loss, set_name='Unlearning Test Set', hyperbolic=hyperbolic, manifold=manifold, device=device)
        verdict = unlearning_accuracy <= 1 / class_count
        
    
    elif SCENARIO == 'random-sample':
        # For sample-unlearning, accuracy on unlearning samples <= accuracy on the test samples
        unlearning_loss, unlearning_accuracy = get_loss_and_accuracy(encoder, classifier, valloader_train, classification_loss, set_name='Unlearning Train Subset', hyperbolic=hyperbolic, manifold=manifold, device=device)
        test_loss, test_accuracy = get_loss_and_accuracy(encoder, classifier, valloader_test, classification_loss, set_name='Test Subset', hyperbolic=hyperbolic, manifold=manifold, device=device)
        verdict = unlearning_accuracy <= test_accuracy
        
    return unlearning_accuracy, verdict
    


# Validation function to evaluate the remaining accuracy
def validate_containment(encoder, classifier, remaining_testloader, classification_loss, class_count, hyperbolic='', manifold=None, device='cuda'):
    """
    Validate the accuracy on the remaining samples based on the scenario.
    Args:
        encoder: The encoder model to extract features.
        classifier: The classifier head model for predictions.
        remaining_testloader: the loader of the remaining samples of the test set
        classification_loss: Loss function (e.g., CrossEntropyLoss).
        class_count: Total number of classes in the dataset.
    
    Returns:
        A boolean value indicating if the termination condition is met.
    """

    # For single-class unlearning, accuracy on the unlearning class should be <= 1/class_count
    remaining_loss, remaining_accuracy = get_loss_and_accuracy(encoder, classifier, remaining_testloader, classification_loss, set_name='Remaining Test Set', hyperbolic=hyperbolic, manifold=manifold, device=device)
    
    return remaining_accuracy

### Unlearning algorithm

In [None]:
# Main training loop with validation
def train_contrastive_unlearning(encoder, classifier, unlearning_trainloader, remaining_trainset, classification_loss, unlearning_loss, 
                                 optimizer, omega, lambda_ce, lambda_ul, scaler, hyperbolic='', manifold=None, device='cuda'):

    encoder.train()
    classifier.train()

    progress_bar_unlearn = tqdm(enumerate(unlearning_trainloader), total=len(unlearning_trainloader))

    total_loss_epoch = 0.0  # To accumulate loss over the epoch
    total_batches = 0
    
    # Iterate over unlearning dataset (D_u)
    for batch_idx, (x_u, y_u) in progress_bar_unlearn:
        x_u, y_u = x_u.to(device), y_u.to(device)

        # The projection is needed ony if the whole model is hyperbolic
        if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
            # move the inputs to the manifold
            tangents = TangentTensor(data=x_u, man_dim=1, manifold=manifold)
            x_u = manifold.expmap(tangents)

        optimizer.zero_grad()
        batch_loss = 0.0  # To accumulate loss over the batch
        
        # Perform ω iterations over the remaining dataset for each unlearning batch
        for _ in range(omega):  # Omega loop

            # Create a DataLoader with a RandomSampler
            random_sampler = torch.utils.data.RandomSampler(remaining_trainset)
            remaining_trainset_loader = torch.utils.data.DataLoader(remaining_trainset, sampler=random_sampler, batch_size=batch_size, pin_memory=True)

            for remaining_data in remaining_trainset_loader:
                x_r, y_r = remaining_data
                break

            x_r, y_r = x_r.to(device), y_r.to(device)

            # The projection is needed ony if the whole model is hyperbolic
            if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
                # move the inputs to the manifold
                tangents = TangentTensor(data=x_r, man_dim=1, manifold=manifold)
                x_r = manifold.expmap(tangents)

            # Use mixed precision
            with torch.amp.autocast_mode.autocast(device_type=device, dtype=torch.float16):
                # Forward pass for remaining samples (classification loss)
                z_r = encoder(x_r) 
                logits_r = classifier(z_r)

                # Forward pass for unlearning samples
                z_u = encoder(x_u)
                logits_u = classifier(z_u)

            # Compute classification loss
            # Indipendently of the hyperbolic model, we need to extract the tensor  
            if hyperbolic in ['complete-ResNet', 'hybrid-ResNet', 'academic-ResNet']:
                # Needed for the loss computation
                logits_r = logits_r.tensor # This gives the underlying PyTorch tensor
                
            ce_loss = classification_loss(logits_r, y_r)
            
            # Compute contrastive unlearning loss
            #loss_fn = ContrastiveUnlearningLoss(temperature=temperature)
            ul_loss = unlearning_loss(z_u, z_r, y_u, y_r)

            # Total loss: classification loss + contrastive unlearning loss
            total_loss = lambda_ul * ul_loss + lambda_ce * ce_loss
            batch_loss += total_loss.item()

            # Backward pass
            scaler.scale(total_loss).backward()
            
        scaler.step(optimizer)
        scaler.update()

        total_batches += 1
        total_loss_epoch += batch_loss
            
        # Update progress bar description
        progress_bar_unlearn.set_description('UNLEARNING TRAIN | Batch Loss: %.4f | Average Loss: %.4f' % (batch_loss, total_loss_epoch / len(unlearning_trainloader)))

### Start unlearning

In [None]:
scaler = torch.amp.grad_scaler.GradScaler()


for epoch in range(NUM_EPOCHS):

    print('\nEpoch: %d' % epoch)

    train_contrastive_unlearning(encoder=encoder, 
                                 classifier=classifier, 
                                 unlearning_trainloader=unlearning_trainloader, 
                                 remaining_trainset=remaining_trainset, 
                                 classification_loss=CRITERION, 
                                 unlearning_loss=ContrastiveUnlearningLoss(temperature=TEMPERATURE),
                                 optimizer=unl_optimizer, 
                                 omega=OMEGA, 
                                 lambda_ce=REGULARIZER_CE, 
                                 lambda_ul=REGULARIZER_UL,  
                                 scaler=scaler,
                                 hyperbolic='',
                                 manifold=None,
                                 device=device
                                 )

    if SCENARIO == 'single-class':
        unlearned_acc, verdict = validate_unlearning(encoder=encoder, 
                                                     classifier=classifier, 
                                                     valloader_train=None, 
                                                     valloader_test=unl_valloader_test, 
                                                     classification_loss=CRITERION, 
                                                     class_count=len(classes),
                                                     hyperbolic='',
                                                     manifold=None,
                                                     device=device
                                                     )
        
        remaining_acc = validate_containment(encoder=encoder, 
                                             classifier=classifier, 
                                             remaining_testloader=remaining_testloader, 
                                             classification_loss=CRITERION, 
                                             class_count=len(classes), 
                                             device=device
                                             )

    elif SCENARIO == 'random-sample':
        unlearned_acc, verdict = validate_unlearning(encoder=encoder, 
                                                     classifier=classifier, 
                                                     valloader_train=unl_valloader_train, 
                                                     valloader_test=unl_valloader_test, 
                                                     classification_loss=CRITERION, 
                                                     class_count=len(classes),
                                                     hyperbolic='',
                                                     manifold=None,
                                                     device=device
                                                     )
    
    print(f"\nTermination condition reached: {verdict}")

    # Termination criteria
    if verdict: 
        print("\nTerminating unlearning procedure")
        break

    unl_scheduler.step()

### Results of unlearning

The unlearning efficiency is tested on the same subsets used by the authors for the experiments:

- *single-class* scenario --> remaining test set, unlearning train set, unleraning test set
- *random-sample* scenario --> remaining test set, unlearning train set

In [None]:
unlearned_average_remaining_test_loss, unlearned_best_remaining_test_acc = test(encoder=encoder, 
                                                                                classifier=classifier, 
                                                                                loader=remaining_testloader, 
                                                                                criterion=CRITERION, 
                                                                                hyperbolic='', 
                                                                                manifold=None, 
                                                                                device=device)

unlearned_average_unl_train_loss, unlearned_best_unl_train_acc = test(encoder=encoder, 
                                                                        classifier=classifier, 
                                                                        loader=unlearning_trainloader, 
                                                                        criterion=CRITERION, 
                                                                        hyperbolic='', 
                                                                        manifold=None, 
                                                                        device=device)

if SCENARIO == 'single-class':
    
    unlearned_average_unl_test_loss, unlearned_best_unl_test_acc = test(encoder=encoder, 
                                                                        classifier=classifier, 
                                                                        loader=unlearning_testloader, 
                                                                        criterion=CRITERION, 
                                                                        hyperbolic='', 
                                                                        manifold=None, 
                                                                        device=device)

print(f'\nUnlearned Remaining Test Set Loss: {unlearned_average_remaining_test_loss:.3f}, Accuracy: {unlearned_best_remaining_test_acc:.3f}%')
print(f'Unlearned Unlearning Train Set Loss: {unlearned_average_unl_train_loss:.3f}, Accuracy: {unlearned_best_unl_train_acc:.3f}%')
print(f'Unlearned Unlearning Test Set Loss: {unlearned_average_unl_test_loss:.3f}, Accuracy: {unlearned_best_unl_test_acc:.3f}%') if SCENARIO == 'single-class' else None

### Get predictions over batch sample of the *remaining test set*

In [None]:
dataiter = iter(remaining_testloader)
images, labels = next(dataiter)

encoder.eval()
classifier.eval()

with torch.no_grad():
    features = encoder(images)
    outputs = classifier(features)

    _, predicted = torch.max(outputs, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

### Get predictions over batch sample of the *unlearned set*

In [None]:
if SCENARIO == 'single-class':

    dataiter = iter(unlearning_testloader)
    images, labels = next(dataiter)

elif SCENARIO == 'random-sample':

    dataiter = iter(unlearning_trainloader)
    images, labels = next(dataiter)

encoder.eval()
classifier.eval()

with torch.no_grad():
    features = encoder(images)
    outputs = classifier(features)

    _, predicted = torch.max(outputs, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

### Visualize feature space

In [None]:
if SCENARIO == 'single-class':

    # Extract features from the test set
    remaining_features, remaining_labels = extract_features(encoder, remaining_testloader, hyperbolic='', manifold=None, device=device)
    unlearned_features, unlearned_labels = extract_features(encoder, unlearning_testloader, hyperbolic='', manifold=None, device=device)

    # Combine for t-SNE
    features = np.concatenate((unlearned_features, remaining_features), axis=0)
    labels = np.concatenate((unlearned_labels, remaining_labels), axis=0)

# In this case, there is no specific class to isolate, so the plotting goal is for all the models (original, retrain and unlearned) to plot a similar distribution of the whole test set
elif SCENARIO == 'random-sample':

    features, labels = extract_features(encoder, remaining_testloader, hyperbolic='', manifold=None, device=device)
    

create_plot(features, labels, classes, dimension=2, convexhull=False)

# 2. Contrastive Unlearning Pipeline on Hyperbolic Space

In this macro-section i will repeat the whole setup buth using **hyperbolic spaces**.

The library used for experimenting with hyperbolic spaces is [***Hypll***](https://github.com/maxvanspengler/hyperbolic_learning_library)

The first one is mainly used for building the hyperbolic architecture, while the second one for plotting

**WARNING**: to run this second chapter, you NEED to have the *Chapter 1* executed, because the data and most of the functions used here are defined in previous cells

## Hyperbolic original model

Here the most logical thing to do is to take the original architecture and make it hyperbolic.

However hyperbolic operations are exceptionaly **slow** and **heavy**, so i needed to figure out something else...

I ended up with 3 different architectures. You will find explenation for each one of them down below

### Complete hyperbolic version

This is the original architecture (defined in the *Original model* section of chapter 1) made completely hyperbolic.

This means that every single operation is defined to work in hyperbolic spaces.

This is the slowest and heaviest of the 3 architecture. 

On my computer i had to decrease the batch size from 128 to 32 and it takes around 1 hour to make a single epoch 

In [None]:
class HyperbolicBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, manifold, stride=1):
        super(HyperbolicBasicBlock, self).__init__()
        self.manifold = manifold
        
        self.conv1 = hnn.HConvolution2d(in_planes, planes, kernel_size=3,stride=stride, 
                                        padding=1, bias=False, manifold=manifold)
        self.bn1 = hnn.HBatchNorm2d(planes, manifold=manifold)
        self.relu = hnn.HReLU(manifold=manifold)
        self.conv2 = hnn.HConvolution2d(planes, planes, kernel_size=3, stride=1, 
                                        padding=1, bias=False, manifold=manifold)
        self.bn2 = hnn.HBatchNorm2d(planes, manifold=manifold)
        #self.dropout = nn.Dropout(p=0.3)  # Add dropout layer
        
        # Shortcut layer for downsampling when needed
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                hnn.HConvolution2d(in_planes, self.expansion * planes, kernel_size=1, 
                                   stride=stride, bias=False, manifold=manifold),
                hnn.HBatchNorm2d(features=self.expansion * planes, manifold=manifold)
            )

    def forward(self, x: ManifoldTensor) -> ManifoldTensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)

        #out = self.dropout(out)  # Apply dropout after the last batch norm
        
        # Apply Möbius addition for the residual connection
        out = self.manifold.mobius_add(out, self.shortcut(x))
        out = self.relu(out)
        return out
    


class HyperbolicBottleneck(nn.Module):
    expansion = 4

    def __init__(self, manifold, in_planes, planes, stride=1):
        super(HyperbolicBottleneck, self).__init__()
        self.manifold = manifold
        self.conv1 = hnn.HConvolution2d(in_planes, planes, kernel_size=1, 
                                        bias=False, manifold=manifold)
        self.bn1 = hnn.HBatchNorm2d(planes, manifold)
        self.relu = hnn.HReLU(manifold=manifold)
        self.conv2 = hnn.HConvolution2d(planes, planes, kernel_size=3, stride=stride, 
                                        padding=1, bias=False, manifold=manifold)
        self.bn2 = hnn.HBatchNorm2d(planes, manifold)
        self.conv3 = hnn.HConvolution2d(planes, self.expansion * planes, kernel_size=1, 
                                        bias=False, manifold=manifold)
        self.bn3 = hnn.HBatchNorm2d(self.expansion * planes, manifold=manifold)
        #self.dropout = nn.Dropout(p=0.3)  # Add dropout layer

        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                hnn.HConvolution2d(in_planes, self.expansion * planes, kernel_size=1, 
                                   stride=stride, bias=False, manifold=manifold),
                hnn.HBatchNorm2d(self.expansion * planes, manifold=manifold)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        #out = self.dropout(out)  # Apply dropout after the last batch norm

        # Apply Möbius addition for the residual connection
        out = self.manifold.mobius_add(out, self.shortcut(x))
        out = self.relu(out)
        return out



class HyperbolicResNetEncoder(nn.Module):
    def __init__(self, block, num_blocks, manifold, num_classes=10):
        super(HyperbolicResNetEncoder, self).__init__()
        self.in_planes = 64
        self.manifold = manifold

        self.conv1 = hnn.HConvolution2d(3, 64, kernel_size=3, stride=1, 
                                        padding=1, bias=False, manifold=manifold)
        self.bn1 = hnn.HBatchNorm2d(64, manifold=manifold)
        self.relu = hnn.HReLU(manifold=manifold)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.dropout = nn.Dropout(p=0.3)  # Add Dropout layer


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, self.manifold, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)


    def forward(self, x: ManifoldTensor) -> ManifoldTensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = hnn.HAvgPool2d(kernel_size=4, manifold=self.manifold)(out)
        out = hnn.HFlatten()(out)
        #out = self.dropout(out) # Add Dropout layer
        return out



class HyperbolicClassificationHead(nn.Module):
    def __init__(self, input_dim, manifold, num_classes=10):
        super(HyperbolicClassificationHead, self).__init__()
        #self.dropout = nn.Dropout(p=0.5) # Add Dropout layer
        self.fc = hnn.HLinear(input_dim, num_classes, manifold=manifold)


    def forward(self, x: ManifoldTensor) -> ManifoldTensor:
        #x = self.dropout(x)
        return self.fc(x)



def HyperbolicResNet18(manifold):
    encoder = HyperbolicResNetEncoder(HyperbolicBasicBlock, [2, 2, 2, 2], manifold)
    classification = HyperbolicClassificationHead(512, manifold)
    return encoder, classification


def HyperbolicResNet34(manifold):
    encoder = HyperbolicResNetEncoder(HyperbolicBasicBlock, [3, 4, 6, 3], manifold)
    classification = HyperbolicClassificationHead(512, manifold)
    return encoder, classification


def HyperbolicResNet50(manifold):
    encoder = HyperbolicResNetEncoder(HyperbolicBottleneck, [3, 4, 6, 3], manifold)
    classification = HyperbolicClassificationHead(2048, manifold)
    return encoder, classification


def HyperbolicResNet101(manifold):
    encoder = HyperbolicResNetEncoder(HyperbolicBottleneck, [3, 4, 23, 3], manifold)
    classification = HyperbolicClassificationHead(2048, manifold)
    return encoder, classification


def HyperbolicResNet152(manifold):
    encoder = HyperbolicResNetEncoder(HyperbolicBottleneck, [3, 8, 36, 3], manifold)
    classification = HyperbolicClassificationHead(2048, manifold)
    return encoder, classification

### Hybrid hyperbolic (final embedding layer + classifier)

This is the original architecture (defined in the *Original model* section of chapter 1) with hyperbolic mapping for the final embeddings.

This means the features are computed using standard Euclidean operations and only at the end they are projected into an hyperbolic space (*Poincaré ball*).

Also the classification head is hyperbolic.

This is the fastest of the 3 architectures, and the one that i used for testing

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.dropout = nn.Dropout(p=0.3)  # Add dropout layer

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.dropout(out)  # Apply dropout after the last batch norm
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
        self.dropout = nn.Dropout(p=0.3)  # Add dropout layer

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = self.dropout(out)  # Apply dropout after the last batch norm
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class ResNetEncoder(nn.Module):
    def __init__(self, block, num_blocks, manifold, num_classes=10):
        super(ResNetEncoder, self).__init__()
        self.in_planes = 64
        self.manifold = manifold

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.dropout = nn.Dropout(p=0.3)  # Add Dropout layer
        

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.dropout(out) # Add Dropout layer

        # Projects the embeddings in the Poincaré manifold
        out = TangentTensor(data=out, man_dim=1, manifold=self.manifold)
        out = self.manifold.expmap(out)
        return out
    

class HyperbolicClassificationHead(nn.Module):
    def __init__(self, input_dim, manifold, num_classes=10):
        super(HyperbolicClassificationHead, self).__init__()
        self.manifold = manifold
        self.dropout = nn.Dropout(p=0.5) # Add Dropout layer
        self.fc = hnn.HLinear(input_dim, num_classes, manifold=manifold)
    
    def forward(self, x):
        x = self.dropout(x.tensor)
        # Projects the embeddings back in the Poincaré manifold
        x = TangentTensor(data=x, man_dim=1, manifold=self.manifold)
        x = self.manifold.expmap(x)
        x = self.fc(x)
        return x



def HybridHyperbolicResNet18(manifold):
    encoder = ResNetEncoder(BasicBlock, [2, 2, 2, 2], manifold)
    classification = HyperbolicClassificationHead(512, manifold, num_classes=len(classes))
    return encoder, classification


def HybridHyperbolicResNet34(manifold):
    encoder = ResNetEncoder(BasicBlock, [3, 4, 6, 3], manifold)
    classification = HyperbolicClassificationHead(512, manifold, num_classes=len(classes))
    return encoder, classification


def HybridHyperbolicResNet50(manifold):
    encoder = ResNetEncoder(Bottleneck, [3, 4, 6, 3], manifold)
    classification = HyperbolicClassificationHead(2048, manifold, num_classes=len(classes))
    return encoder, classification


def HybridHyperbolicResNet101(manifold):
    encoder = ResNetEncoder(Bottleneck, [3, 4, 23, 3], manifold)
    classification = HyperbolicClassificationHead(2048, manifold, num_classes=len(classes))
    return encoder, classification


def HybridHyperbolicResNet152(manifold):
    encoder = ResNetEncoder(Bottleneck, [3, 8, 36, 3], manifold)
    classification = HyperbolicClassificationHead(2048, manifold, num_classes=len(classes))
    return encoder, classification

### Suggested hypll model

This is the Hypll library implementation of a hyperbolic classifier.

However, this is not a random classifier. 

It is based on the [*Poincaré ResNet paper*](https://arxiv.org/abs/2303.14027) which, in turn, is based on the original Euclidean implementation described in the paper [*Deep Residual Learning for Image Recognition*](https://arxiv.org/abs/1512.03385).

It was already built for the CIFAR-10 dataset, so i simply took it and devided it into encoder and classifier.

This model has a decent speed, but still nowhere near close to the Hybrid version 

In [None]:
class PoincareResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        manifold: PoincareBall,
        stride: int = 1,
        downsample: Optional[nn.Sequential] = None,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.manifold = manifold
        self.stride = stride
        self.downsample = downsample

        self.conv1 = hnn.HConvolution2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            manifold=manifold,
            stride=stride,
            padding=1,
        )
        self.bn1 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold)
        self.relu = hnn.HReLU(manifold=self.manifold)
        self.conv2 = hnn.HConvolution2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            manifold=manifold,
            padding=1,
        )
        self.bn2 = hnn.HBatchNorm2d(features=out_channels, manifold=manifold)

    def forward(self, x: ManifoldTensor) -> ManifoldTensor:
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            residual = self.downsample(residual)

        x = self.manifold.mobius_add(x, residual)
        x = self.relu(x)

        return x

class PoincareEncoder(nn.Module):
    def __init__(
        self,
        channel_sizes: list[int],
        group_depths: list[int],
        manifold: PoincareBall,
    ):
        super().__init__()
        self.channel_sizes = channel_sizes
        self.group_depths = group_depths
        self.manifold = manifold

        self.conv = hnn.HConvolution2d(
            in_channels=3,
            out_channels=channel_sizes[0],
            kernel_size=3,
            manifold=manifold,
            padding=1,
        )
        self.bn = hnn.HBatchNorm2d(features=channel_sizes[0], manifold=manifold)
        self.relu = hnn.HReLU(manifold=manifold)
        self.group1 = self._make_group(
            in_channels=channel_sizes[0],
            out_channels=channel_sizes[0],
            depth=group_depths[0],
        )
        self.group2 = self._make_group(
            in_channels=channel_sizes[0],
            out_channels=channel_sizes[1],
            depth=group_depths[1],
            stride=2,
        )
        self.group3 = self._make_group(
            in_channels=channel_sizes[1],
            out_channels=channel_sizes[2],
            depth=group_depths[2],
            stride=2,
        )

    def forward(self, x: ManifoldTensor) -> ManifoldTensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.group1(x)
        x = self.group2(x)
        x = self.group3(x)
        return x

    def _make_group(
        self,
        in_channels: int,
        out_channels: int,
        depth: int,
        stride: int = 1,
    ) -> nn.Sequential:
        if stride == 1:
            downsample = None
        else:
            downsample = hnn.HConvolution2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                manifold=self.manifold,
                stride=stride,
            )

        layers = [
            PoincareResidualBlock(
                in_channels=in_channels,
                out_channels=out_channels,
                manifold=self.manifold,
                stride=stride,
                downsample=downsample,
            )
        ]

        for _ in range(1, depth):
            layers.append(
                PoincareResidualBlock(
                    in_channels=out_channels,
                    out_channels=out_channels,
                    manifold=self.manifold,
                )
            )

        return nn.Sequential(*layers)

class PoincareClassifier(nn.Module):
    def __init__(self, in_features: int, num_classes: int, manifold: PoincareBall):
        super().__init__()
        self.manifold = manifold
        self.avg_pool = hnn.HAvgPool2d(kernel_size=8, manifold=manifold)
        self.fc = hnn.HLinear(in_features=in_features, out_features=num_classes, manifold=manifold)

    def forward(self, x: ManifoldTensor) -> ManifoldTensor:
        x = self.avg_pool(x)
        x = self.fc(x.squeeze())
        return x
    
    

def AcademicHyperbolicResNet(manifold):
    encoder = PoincareEncoder(channel_sizes=[4, 8, 16], group_depths=[3, 3, 3], manifold=manifold)#.to(device)
    classifier = PoincareClassifier(in_features=[4, 8, 16][-1], num_classes=10, manifold=manifold)#.to(device)
    return encoder, classifier
    

### Choose the architecture

In [None]:
if HYPBL_ARCHITECTURE == 'complete-ResNet':
    if MODEL_NAME == 'ResNet18':
        hypbl_encoder, hypbl_classifier = HyperbolicResNet18(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet34':
        hypbl_encoder, hypbl_classifier = HyperbolicResNet34(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet50':
        hypbl_encoder, hypbl_classifier = HyperbolicResNet50(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet101':
        hypbl_encoder, hypbl_classifier = HyperbolicResNet101(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet152':
        hypbl_encoder, hypbl_classifier = HyperbolicResNet152(manifold=MANIFOLD)

elif HYPBL_ARCHITECTURE == 'hybrid-ResNet':
    if MODEL_NAME == 'ResNet18':
        hypbl_encoder, hypbl_classifier = HybridHyperbolicResNet18(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet34':
       hypbl_encoder, hypbl_classifier = HybridHyperbolicResNet34(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet50':
        hypbl_encoder, hypbl_classifier = HybridHyperbolicResNet50(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet101':
        hypbl_encoder, hypbl_classifier = HybridHyperbolicResNet101(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet152':
        hypbl_encoder, hypbl_classifier = HybridHyperbolicResNet152(manifold=MANIFOLD)

elif HYPBL_ARCHITECTURE == 'academic-ResNet':
    hypbl_encoder, hypbl_classifier = AcademicHyperbolicResNet(manifold=MANIFOLD)


if device == 'cuda':
    MANIFOLD = MANIFOLD.to(device)
    hypbl_encoder = torch.nn.DataParallel(hypbl_encoder).to(device)
    hypbl_classifier = torch.nn.DataParallel(hypbl_classifier).to(device)

### Optimizer and scheduler definition

Even in hyperbolic spaces it is not advised to use the *Adam* optimizer

In [None]:
# BEST CONFIGURATION ADAM
if OPTIMIZER_NAME == 'adam':
    hypbl_optimizer = RiemannianAdam(list(hypbl_encoder.parameters()) + list(hypbl_classifier.parameters()), lr=0.000641388, weight_decay=0.567643)

# BEST CONFIGURATION SGD
elif OPTIMIZER_NAME == 'sgd':
    hypbl_optimizer = RiemannianSGD(list(hypbl_encoder.parameters()) + list(hypbl_classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885)

hypbl_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(hypbl_optimizer, T_max=200)

### Start training

In [None]:
# Initialize dictionaries to track the loss and accuracy
hypbl_train_loss_history = []
hypbl_val_loss_history = []
hypbl_train_acc_history = []
hypbl_val_acc_history = []

hypbl_val_acc = 0 

scaler = torch.amp.grad_scaler.GradScaler()

early_stopper = EarlyStopper()

for epoch in range(NUM_EPOCHS):

    print('\nEpoch: %d' % epoch)

    hypbl_average_train_loss, hypbl_train_acc = train(encoder=hypbl_encoder, 
                                                      classifier=hypbl_classifier, 
                                                      loader=trainloader, 
                                                      criterion=CRITERION, 
                                                      optimizer=hypbl_optimizer, 
                                                      scaler=scaler, 
                                                      hyperbolic=HYPBL_ARCHITECTURE, 
                                                      manifold=MANIFOLD, 
                                                      device=device)
    
    hypbl_train_loss_history.append(hypbl_average_train_loss)
    hypbl_train_acc_history.append(hypbl_train_acc)
    
    hypbl_average_val_loss, hypbl_val_acc = validate(epoch=epoch, 
                                                     encoder=hypbl_encoder, 
                                                     classifier=hypbl_classifier, 
                                                     loader=valloader, 
                                                     criterion=CRITERION, 
                                                     best_acc=hypbl_val_acc, 
                                                     hyperbolic=HYPBL_ARCHITECTURE, 
                                                     manifold=MANIFOLD, 
                                                     ckpt_name=HYPBL_ORIGINAL_CKPT_NAME, 
                                                     device=device)
    
    hypbl_val_loss_history.append(hypbl_average_val_loss)
    hypbl_val_acc_history.append(hypbl_val_acc)

    if early_stopper.early_stop(hypbl_average_val_loss):  
        print('\nEarly stopper activated')           
        break

    hypbl_scheduler.step()

### Plot loss and accuracy over the epochs

In [None]:
plot_training_history(hypbl_train_loss_history, hypbl_val_loss_history, hypbl_train_acc_history, hypbl_val_acc_history)

### Test the model

Load best checkpoint.

It is needed since there is no guarantee that the latest epochs produced the most accurate parameters.

Also comes in handy for quick testing

In [None]:
print('Resuming from checkpoint...')

assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}'.format(HYPBL_ORIGINAL_CKPT_NAME))
hypbl_encoder.load_state_dict(checkpoint['encoder'])
hypbl_classifier.load_state_dict(checkpoint['classifier'])

print('Checkpoint loaded!')
print(f'\nValidation average loss: {checkpoint['loss']:.3f}')
print(f'Validation set accuracy: {checkpoint['acc']:.3f}%')
print(f'From epoch: {checkpoint['epoch']}')

Try the model accuracy over some samples

In [None]:
dataiter = iter(testloader)
images, labels = next(dataiter)

hypbl_encoder.eval()
hypbl_classifier.eval()

with torch.no_grad():
    features = hypbl_encoder(images)
    outputs = hypbl_classifier(features)
    
    _, predicted = torch.max(outputs.tensor, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

Test the model overall performance

In [None]:
hypbl_test_loss, hypbl_test_acc = test(encoder=hypbl_encoder, 
                           classifier=hypbl_classifier, 
                           loader=testloader, 
                           criterion=CRITERION, 
                           hyperbolic='hybrid-ResNet', 
                           manifold=MANIFOLD, 
                           device=device)

print(f'\nTest Set - Loss: {hypbl_test_loss:.3f}, Accuracy: {hypbl_test_acc:.3f}%')

### Visualize feature space

You can choose which class samples to display by toggling the class names.

**WARNING**: As said in the *Configuration* section, Plotly has some problems with Visual Studio Code. To prevent strange behaviour by the plots, it is advised to delete the cell plot output first if you need to run again the cell

In [None]:
# Extract features from the test set, and plot them
features, labels = extract_features(hypbl_encoder, testloader, hyperbolic=HYPBL_ARCHITECTURE, manifold=MANIFOLD, device=device)
create_plot(features, labels, classes, dimension=2, convexhull=False)

## Hyperbolic retrain model

In this section, another instantiation of the model will be trained without the samples that we want to forget (from a single class or with random sampling).

Most of the functions are reused from the previous chapters, so refer to them if needed

### Choose the model

In [None]:
if HYPBL_ARCHITECTURE == 'complete-ResNet':
    if MODEL_NAME == 'ResNet18':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HyperbolicResNet18(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet34':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HyperbolicResNet34(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet50':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HyperbolicResNet50(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet101':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HyperbolicResNet101(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet152':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HyperbolicResNet152(manifold=MANIFOLD)

elif HYPBL_ARCHITECTURE == 'hybrid-ResNet':
    if MODEL_NAME == 'ResNet18':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HybridHyperbolicResNet18(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet34':
       retrain_hypbl_encoder, retrain_hypbl_classifier = HybridHyperbolicResNet34(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet50':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HybridHyperbolicResNet50(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet101':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HybridHyperbolicResNet101(manifold=MANIFOLD)
    elif MODEL_NAME == 'ResNet152':
        retrain_hypbl_encoder, retrain_hypbl_classifier = HybridHyperbolicResNet152(manifold=MANIFOLD)

elif HYPBL_ARCHITECTURE == 'academic-ResNet':
    retrain_hypbl_encoder, retrain_hypbl_classifier = AcademicHyperbolicResNet(manifold=MANIFOLD)


if device == 'cuda':
    manifold = MANIFOLD.to(device)
    retrain_hypbl_encoder = torch.nn.DataParallel(retrain_hypbl_encoder).to(device)
    retrain_hypbl_classifier = torch.nn.DataParallel(retrain_hypbl_classifier).to(device)

### Optimizer and scheduler configuration

In [None]:
# BEST CONFIGURATION ADAM
if OPTIMIZER_NAME == 'adam':
    retrain_hypbl_optimizer = RiemannianAdam(list(retrain_hypbl_encoder.parameters()) + list(retrain_hypbl_classifier.parameters()), lr=0.000641388, weight_decay=0.567643)	

# BEST CONFIGURATION SGD
elif OPTIMIZER_NAME == 'sgd':
    retrain_hypbl_optimizer = RiemannianSGD(list(retrain_hypbl_encoder.parameters()) + list(retrain_hypbl_classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885) 
    
retrain_hypbl_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(retrain_hypbl_optimizer, T_max=200)

### Start processing

In [None]:
# Initialize dictionaries to track the loss and accuracy
retrain_hypbl_train_loss_history = []
retrain_hypbl_val_loss_history = []
retrain_hypbl_train_acc_history = []
retrain_hypbl_val_acc_history = []

retrain_hypbl_val_acc = 0 

scaler = torch.amp.grad_scaler.GradScaler()

early_stopper = EarlyStopper()

for epoch in range(NUM_EPOCHS):

    print('\nEpoch: %d' % epoch)

    retrain_hypbl_average_train_loss, retrain_hypbl_train_acc = train(encoder=retrain_hypbl_encoder, 
                                                                      classifier=retrain_hypbl_classifier, 
                                                                      loader=remaining_trainloader, 
                                                                      criterion=CRITERION, 
                                                                      optimizer=retrain_hypbl_optimizer, 
                                                                      scaler=scaler, 
                                                                      hyperbolic=HYPBL_ARCHITECTURE, 
                                                                      manifold=MANIFOLD, 
                                                                      device=device)
    
    retrain_hypbl_train_loss_history.append(retrain_hypbl_average_train_loss)
    retrain_hypbl_train_acc_history.append(retrain_hypbl_train_acc)

    retrain_hypbl_average_val_loss, retrain_hypbl_val_acc = validate(epoch=epoch, 
                                                                     encoder=retrain_hypbl_encoder, 
                                                                     classifier=retrain_hypbl_classifier, 
                                                                     loader=remaining_valloader, 
                                                                     criterion=CRITERION, 
                                                                     best_acc=retrain_hypbl_val_acc, 
                                                                     hyperbolic=HYPBL_ARCHITECTURE, 
                                                                     manifold=MANIFOLD, 
                                                                     ckpt_name=HYPBL_RETRAIN_CKPT_NAME, 
                                                                     device=device)
    
    retrain_hypbl_val_loss_history.append(retrain_hypbl_average_val_loss)
    retrain_hypbl_val_acc_history.append(retrain_hypbl_val_acc)

    if early_stopper.early_stop(retrain_hypbl_average_val_loss):
        print('\nEarly stopper activated') 
        break

    retrain_hypbl_scheduler.step()

### Plot the loss

In [None]:
plot_training_history(retrain_hypbl_train_loss_history, retrain_hypbl_val_loss_history, retrain_hypbl_train_acc_history, retrain_hypbl_val_acc_history)

### Test the model

Load best checkpoint.

It is needed since there is no guarantee that the latest epochs produced the most accurate parameters.

Also comes in handy for quick testing

In [None]:
print('Resuming from checkpoint...')

assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}'.format(HYPBL_RETRAIN_CKPT_NAME))
retrain_hypbl_encoder.load_state_dict(checkpoint['encoder'])
retrain_hypbl_classifier.load_state_dict(checkpoint['classifier'])

print('Checkpoint loaded!')
print(f'\nValidation average loss: {checkpoint['loss']:.3f}')
print(f'Validation set accuracy: {checkpoint['acc']:.3f}%')
print(f'From epoch: {checkpoint['epoch']}')

Try the model precision over some samples of the *remaining test set*

In [None]:
dataiter = iter(remaining_testloader)
images, labels = next(dataiter)

retrain_hypbl_encoder.eval()
retrain_hypbl_classifier.eval()

with torch.no_grad():
    features = retrain_hypbl_encoder(images)
    outputs = retrain_hypbl_classifier(features)

    _, predicted = torch.max(outputs.tensor, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

Test the model overall performance on *remaining samples* and *unlearned samples*

In [None]:
retrain_hypbl_average_remaining_test_loss, retrain_hypbl_best_remaining_test_acc = test(encoder=retrain_hypbl_encoder, 
                                                                                        classifier=retrain_hypbl_classifier, 
                                                                                        loader=remaining_testloader, 
                                                                                        criterion=CRITERION, 
                                                                                        hyperbolic=HYPBL_ARCHITECTURE, 
                                                                                        manifold=MANIFOLD, 
                                                                                        device=device)

retrain_hypbl_average_unl_train_loss, retrain_hypbl_best_unl_train_acc = test(encoder=retrain_hypbl_encoder, 
                                                                              classifier=retrain_hypbl_classifier, 
                                                                              loader=unlearning_trainloader, 
                                                                              criterion=CRITERION, 
                                                                              hyperbolic=HYPBL_ARCHITECTURE, 
                                                                              manifold=MANIFOLD, 
                                                                              device=device)

if SCENARIO == 'single-class':

    retrain_hypbl_average_unl_test_loss, retrain_hypbl_best_unl_test_acc = test(encoder=retrain_hypbl_encoder, 
                                                                                classifier=retrain_hypbl_classifier, 
                                                                                loader=unlearning_testloader, 
                                                                                criterion=CRITERION, 
                                                                                hyperbolic=HYPBL_ARCHITECTURE, 
                                                                                manifold=MANIFOLD, 
                                                                                device=device)

print(f'\nRemaining Test Set Loss: {retrain_hypbl_average_remaining_test_loss:.3f}, Accuracy: {retrain_hypbl_best_remaining_test_acc:.3f}%')
print(f'Unlearning Train Set Loss: {retrain_hypbl_average_unl_train_loss:.3f}, Accuracy: {retrain_hypbl_best_unl_train_acc:.3f}%')
print(f'Unlearning Test Set Loss: {retrain_hypbl_average_unl_test_loss:.3f}, Accuracy: {retrain_hypbl_best_unl_test_acc:.3f}%') if SCENARIO == 'single-class' else None

### Visualize feature space

You can choose which class samples to display by toggling the class names 

In [None]:
if SCENARIO == 'single-class':

    # Extract features from the test set
    remaining_features, remaining_labels = extract_features(retrain_hypbl_encoder, remaining_testloader, hyperbolic=HYPBL_ARCHITECTURE, manifold=MANIFOLD, device=device)
    unlearning_features, unlearning_labels = extract_features(retrain_hypbl_encoder, unlearning_testloader, hyperbolic=HYPBL_ARCHITECTURE, manifold=MANIFOLD, device=device)

    # Combine for t-SNE
    features = np.concatenate((unlearning_features, remaining_features), axis=0)
    labels = np.concatenate((unlearning_labels, remaining_labels), axis=0)

# In this case, there is no specific class to isolate, so the plotting goal is for all the models (original, retrain and unlearned) to plot a similar distribution of the whole test set
elif SCENARIO == 'random-sample':

    features, labels = extract_features(retrain_hypbl_encoder, remaining_testloader, hyperbolic=HYPBL_ARCHITECTURE, manifold=MANIFOLD, device=device)
    

create_plot(features, labels, classes, dimension=2, convexhull=False)

## Unlearning model

In this section i will replicate the unlearning procedure but on the new representation space.

You will find details about it on the next cells 

### Load checkpoint

**WARNING**: the encoder and classifier here are the original ones (instantiated in the first part of the *Chapter 2*).

During the unlearning procedure, the original model parameters will be overwritten.

So if you go back and try inference on the original model, the output will change (for worse, obviously).

*This checkpoint cell is just for quick testing*

In [None]:
print('Resuming from checkpoint...')

assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}'.format(HYPBL_ORIGINAL_CKPT_NAME))
hypbl_encoder.load_state_dict(checkpoint['encoder'])
hypbl_classifier.load_state_dict(checkpoint['classifier'])

print('Checkpoint loaded!')
print(f'\nValidation average loss: {checkpoint['loss']:.3f}')
print(f'Validation set accuracy: {checkpoint['acc']:.3f}%')
print(f'From epoch: {checkpoint['epoch']}')

### Hyperbolic contrastive unlearning loss

This is the ***Constrastive Unlearning*** loss variant to work with the Poincaré ball manifold.

The code is essentially the same, but it changes how the distances between samples are computed 

In [None]:
class HyperbolicContrastiveUnlearningLoss(nn.Module):
    def __init__(self, temperature, manifold):
        super(HyperbolicContrastiveUnlearningLoss, self).__init__()
        self.temperature = temperature
        self.manifold = manifold

    def forward(self, embeddings_u, embeddings_r, labels_u, labels_r):
        unlearning_batch_size = embeddings_u.size(0)
        loss = 0.0

        # Normalize and project embeddings in hyperbolic space
        embeddings_u = F.normalize(embeddings_u.tensor, p=2, dim=1)
        projection_embeddings_u = TangentTensor(data=embeddings_u, man_dim=1, manifold=self.manifold)
        embeddings_u = self.manifold.expmap(projection_embeddings_u)

        embeddings_r = F.normalize(embeddings_r.tensor, p=2, dim=1)
        projection_embeddings_r = TangentTensor(data=embeddings_r, man_dim=1, manifold=self.manifold)
        embeddings_r = self.manifold.expmap(projection_embeddings_r)
        

        for i in range(unlearning_batch_size):
            # Anchor: embedding of unlearning sample
            z_u = embeddings_u[i]

            if SCENARIO == 'single-class':
                negatives = embeddings_r
                neg_sim = torch.exp(self.manifold.dist(z_u, negatives) / self.temperature)

                denominator = neg_sim.size(0) + 1e-8

                inner_loss = torch.log(neg_sim / denominator).sum() if len(neg_sim) > 0 else torch.tensor(0.0)
                loss += (-1 / (neg_sim.size(0) + 1e-8)) * inner_loss

                if len(neg_sim) == 0:
                    print("Warning: neg_sim is zero or too small.")

            elif SCENARIO == 'random-sample':
                pos_mask = labels_r == labels_u[i]
                neg_mask = ~pos_mask

                positives = embeddings_r[pos_mask]
                negatives = embeddings_r[neg_mask]

                #print(positives.dim())

                pos_sim = torch.exp(self.manifold.dist(z_u, positives) / self.temperature) if len(positives.tensor) > 0 else torch.tensor(0.0)
                neg_sim = torch.exp(self.manifold.dist(z_u, negatives) / self.temperature)

                denominator = pos_sim.sum() + 1e-8

                inner_loss = torch.log(neg_sim / denominator).sum() if len(neg_sim) > 0 else torch.tensor(0.0)
                loss += (-1 / (neg_sim.size(0) + 1e-8)) * inner_loss

        return loss / unlearning_batch_size

### DO NOT UNCOMMENT THESE: Hyperbolic unlearning hyperparameters tuning

In [None]:
"""
# Put large objects in Ray's object store for efficient memory management
large_unlearning_trainloader = ray.put(unlearning_trainloader)
if SCENARIO == 'random-sample':
    large_unl_valloader_train = ray.put(unl_valloader_train)
large_unl_valloader_test = ray.put(unl_valloader_test)
large_remaining_trainset = ray.put(remaining_trainset)

checkpoint = torch.load('./checkpoint/{}'.format(HYPBL_ORIGINAL_CKPT_NAME))
large_checkpoint = ray.put(checkpoint)
"""

In [None]:
"""
# Main training loop with validation
def tune_train_contrastive_unlearning(config):

    # DA TOGLIERE INSIEME A DATA PARALLEL
    def remove_module_prefix(state_dict):
        return {k.replace("module.", ""): v for k, v in state_dict.items()}
    
    scaler = torch.amp.grad_scaler.GradScaler()

    hyperbolic = HYPBL_ARCHITECTURE

    curvature = Curvature(value=-1.0)
    manifold = PoincareBall(c=curvature).to(device)

    encoder, classifier = HybridHyperbolicResNet34(manifold=manifold)

    checkpoint = ray.get(large_checkpoint)

    encoder_state = remove_module_prefix(checkpoint['encoder'])
    classifier_state = remove_module_prefix(checkpoint['classifier'])

    encoder.load_state_dict(encoder_state)
    classifier.load_state_dict(classifier_state)

    encoder = encoder.to(device)
    classifier = classifier.to(device)

    unlearning_trainloader = ray.get(large_unlearning_trainloader)
    remaining_trainset = ray.get(large_remaining_trainset)
    if SCENARIO == 'random-sample':
        unl_valloader_train = ray.get(large_unl_valloader_train)
    unl_valloader_test = ray.get(large_unl_valloader_test)
    
    criterion = CRITERION

    if config['opt'] == 'sgd':
        optimizer = RiemannianSGD(list(encoder.parameters()) + list(classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885)
    elif config['opt'] == 'adam':
        optimizer = RiemannianAdam(list(encoder.parameters()) + list(classifier.parameters()), lr=0.000641388, weight_decay=0.567643)
    
    class_count = 10

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


    for epoch in range(config["num_epochs"]):

        encoder.train()
        classifier.train()

        total_loss_epoch = 0.0  # To accumulate loss over the epoch
        total_batches = 0
        
        # Iterate over unlearning dataset (D_u)
        for x_u, y_u in unlearning_trainloader:
            x_u, y_u = x_u.to(device), y_u.to(device)

            # The projection is needed ony if the whole model is hyperbolic
            if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
                # move the inputs to the manifold
                tangents = TangentTensor(data=x_u, man_dim=1, manifold=manifold)
                x_u = manifold.expmap(tangents)

            optimizer.zero_grad()
            batch_loss = 0.0  # To accumulate loss over the batch
            
            # Perform ω iterations over the remaining dataset for each unlearning batch
            for _ in range(config["omega"]):  # Omega loop

                # Create a DataLoader with a RandomSampler
                random_sampler = torch.utils.data.RandomSampler(remaining_trainset)
                remaining_trainset_loader = torch.utils.data.DataLoader(remaining_trainset, sampler=random_sampler, batch_size=128, pin_memory=True)

                for remaining_data in remaining_trainset_loader:
                    x_r, y_r = remaining_data
                    break

                x_r, y_r = x_r.to(device), y_r.to(device)

                # The projection is needed ony if the whole model is hyperbolic
                if hyperbolic in ['complete-ResNet', 'academic-ResNet']:
                    # move the inputs to the manifold
                    tangents = TangentTensor(data=x_r, man_dim=1, manifold=manifold)
                    x_r = manifold.expmap(tangents)

                # Use mixed precision
                with torch.amp.autocast_mode.autocast(device_type=device, dtype=torch.float16):
                    # Forward pass for remaining samples (classification loss)
                    z_r = encoder(x_r) 
                    logits_r = classifier(z_r)

                    # Forward pass for unlearning samples
                    z_u = encoder(x_u)
                    logits_u = classifier(z_u)

                # Compute classification loss
                # Indipendently of the hyperbolic model, we need to extract the tensor  
                if hyperbolic in ['complete-ResNet', 'hybrid-ResNet', 'academic-ResNet']:
                    # Needed for the loss computation
                    logits_r = logits_r.tensor # This gives the underlying PyTorch tensor

                ce_loss = criterion(logits_r, y_r)
                # Compute contrastive unlearning loss
                loss_fn = HyperbolicContrastiveUnlearningLoss(temperature=config["temperature"], manifold=manifold)
                ul_loss = loss_fn(z_u, z_r, y_u, y_r)

                # Total loss: classification loss + contrastive unlearning loss
                total_loss = config["regularizer_ul"] * ul_loss + config["regularizer_ce"] * ce_loss
                batch_loss += total_loss.item()

                scaler.scale(total_loss).backward()
                
            scaler.step(optimizer)
            scaler.update()

            total_batches += 1
            total_loss_epoch += batch_loss    
            
        # Compute validation for termination criteria
        if SCENARIO == 'single-class':
            unlearned_acc, verdict = validate_unlearning(encoder, classifier, None, unl_valloader_test, CRITERION, class_count=class_count, hyperbolic=hyperbolic, manifold=manifold)
            # Compute accuracy retain over the remaining samples
            remaining_acc = validate_containment(encoder, classifier, remaining_testloader, CRITERION, class_count=class_count, hyperbolic=hyperbolic, manifold=manifold)
        elif SCENARIO == 'random-sample':
            unlearned_acc, verdict = validate_unlearning(encoder, classifier, unl_valloader_train, unl_valloader_test, CRITERION, class_count=class_count, hyperbolic=hyperbolic, manifold=manifold)


        scheduler.step()

        # report results to Ray Tune
        ray.train.report(dict(unlearned_accuracy=unlearned_acc, remaining_accuracy=remaining_acc))       

        # Termination criteria
        if verdict: 
            print("\nTerminating unlearning procedure")
            break
"""

In [None]:
"""
def custom_trial_name(trial):
    return f"trial_unlearning{trial.trial_id}"

# Define the hyperparameter search space
search_space = {
    "num_epochs": tune.grid_search([NUM_EPOCHS]),
    "opt": tune.grid_search(['sgd']),
    "temperature": tune.uniform(0.1, 1.0),
    "regularizer_ce": tune.uniform(0.1, 1.0),
    "regularizer_ul": tune.uniform(0.1, 1.0),
    "omega": tune.choice([2, 4, 6]),
}

# Use partial to pass fixed_params
analysis = tune.run(
    tune_train_contrastive_unlearning,
    resources_per_trial={"cpu": 20, "gpu": 1},
    config=search_space,
    num_samples=50,
    scheduler=ASHAScheduler(metric="remaining_accuracy", mode="max"),    
    trial_dirname_creator=custom_trial_name
)
"""

### Unlearning optimizer

Again, we assume that the training hyperparamters of the Euclidean model are also good for unlearning in Hyperbolic space

In [None]:
# BEST CONFIGURATION ADAM
if OPTIMIZER_NAME == 'adam':
    unl_hypbl_optimizer = RiemannianAdam(list(hypbl_encoder.parameters()) + list(hypbl_classifier.parameters()), lr=0.000641388, weight_decay=0.567643)	

# BEST CONFIGURATION SGD
elif OPTIMIZER_NAME == 'sgd':
    unl_hypbl_optimizer = RiemannianSGD(list(hypbl_encoder.parameters()) + list(hypbl_classifier.parameters()), lr=0.015254, momentum=0.582486, weight_decay=0.00366885) 

unl_hypbl_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(unl_hypbl_optimizer, T_max=200)

### Start unlearning

In [None]:
scaler = torch.amp.grad_scaler.GradScaler()


for epoch in range(NUM_EPOCHS):

    print('\nEpoch: %d' % epoch)

    train_contrastive_unlearning(encoder=hypbl_encoder, 
                                 classifier=hypbl_classifier, 
                                 unlearning_trainloader=unlearning_trainloader, 
                                 remaining_trainset=remaining_trainset, 
                                 classification_loss=CRITERION, 
                                 unlearning_loss=HyperbolicContrastiveUnlearningLoss(temperature=HYPBL_TEMPERATURE, manifold=MANIFOLD),
                                 optimizer=unl_hypbl_optimizer, 
                                 omega=HYPBL_OMEGA, 
                                 lambda_ce=HYPBL_REGULARIZER_CE, 
                                 lambda_ul=HYPBL_REGULARIZER_UL,  
                                 scaler=scaler,
                                 hyperbolic=HYPBL_ARCHITECTURE,
                                 manifold=MANIFOLD,
                                 device=device
                                 )

    if SCENARIO == 'single-class':
        unlearned_hypbl_acc, verdict = validate_unlearning(encoder=hypbl_encoder, 
                                                           classifier=hypbl_classifier, 
                                                           valloader_train=None, 
                                                           valloader_test=unl_valloader_test, 
                                                           classification_loss=CRITERION, 
                                                           class_count=len(classes),
                                                           hyperbolic=HYPBL_ARCHITECTURE,
                                                           manifold=MANIFOLD,
                                                           device=device
                                                           )
        
        remaining_hypbl_acc = validate_containment(encoder=hypbl_encoder, 
                                                   classifier=hypbl_classifier, 
                                                   remaining_testloader=remaining_testloader, 
                                                   classification_loss=CRITERION, 
                                                   class_count=len(classes), 
                                                   hyperbolic=HYPBL_ARCHITECTURE, 
                                                   manifold=MANIFOLD, 
                                                   device=device
                                                   )

    elif SCENARIO == 'random-sample':
        unlearned_hypbl_acc, verdict = validate_unlearning(encoder=hypbl_encoder, 
                                                           classifier=hypbl_classifier, 
                                                           valloader_train=unl_valloader_train, 
                                                           valloader_test=unl_valloader_test, 
                                                           classification_loss=CRITERION, 
                                                           class_count=len(classes), 
                                                           hyperbolic=HYPBL_ARCHITECTURE,
                                                           manifold=MANIFOLD,
                                                           device=device
                                                           )
    
    print(f"\nTermination condition reached: {verdict}")

    # Termination criteria
    if verdict: 
        print("\nTerminating unlearning procedure")
        break

    unl_hypbl_scheduler.step()

### Results of unlearning

The unlearning efficiency is tested on the same subsets used by the authors for the experiments:

- *single-class* scenario --> remaining test set, unlearning train set, unleraning test set
- *random-sample* scenario --> remaining test set, unlearning train set

In [None]:
unlearned_hypbl_average_remaining_test_loss, unlearned_hypbl_best_remaining_test_acc = test(encoder=hypbl_encoder, 
                                                                                            classifier=hypbl_classifier, 
                                                                                            loader=remaining_testloader, 
                                                                                            criterion=CRITERION, 
                                                                                            hyperbolic=HYPBL_ARCHITECTURE, 
                                                                                            manifold=MANIFOLD, 
                                                                                            device=device)

unlearned_hypbl_average_unl_train_loss, unlearned_hypbl_best_unl_train_acc = test(encoder=hypbl_encoder, 
                                                                                    classifier=hypbl_classifier, 
                                                                                    loader=unlearning_trainloader,
                                                                                    criterion=CRITERION,
                                                                                    hyperbolic=HYPBL_ARCHITECTURE,
                                                                                    manifold=MANIFOLD,
                                                                                    device=device)

if SCENARIO == 'single-class':

    unlearned_hypbl_average_unl_test_loss, unlearned_hypbl_best_unl_test_acc = test(encoder=hypbl_encoder,
                                                                                    classifier=hypbl_classifier,
                                                                                    loader=unlearning_testloader,
                                                                                    criterion=CRITERION,
                                                                                    hyperbolic=HYPBL_ARCHITECTURE,
                                                                                    manifold=MANIFOLD,
                                                                                    device=device)

print(f'\nUnlearned Remaining Test Set Loss: {unlearned_hypbl_average_remaining_test_loss:.3f}, Accuracy: {unlearned_hypbl_best_remaining_test_acc:.3f}%')
print(f'Unlearned Unlearning Train Set Loss: {unlearned_hypbl_average_unl_train_loss:.3f}, Accuracy: {unlearned_hypbl_best_unl_train_acc:.3f}%')
print(f'Unlearned Unlearning Test Set Loss: {unlearned_hypbl_average_unl_test_loss:.3f}, Accuracy: {unlearned_hypbl_best_unl_test_acc:.3f}%') if SCENARIO == 'single-class' else None

### Get predictions over batch sample of the *remaining test set*

In [None]:
dataiter = iter(remaining_testloader)
images, labels = next(dataiter)

hypbl_encoder.eval()
hypbl_classifier.eval()

with torch.no_grad():
    features = hypbl_encoder(images)
    outputs = hypbl_classifier(features)

    _, predicted = torch.max(outputs.tensor, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

### Get predictions over batch sample of the *unlearned set*

In [None]:
if SCENARIO == 'single-class':
    
    dataiter = iter(unlearning_testloader)
    images, labels = next(dataiter)

elif SCENARIO == 'random-sample':

    dataiter = iter(unlearning_trainloader)
    images, labels = next(dataiter)
    

hypbl_encoder.eval()
hypbl_classifier.eval()

with torch.no_grad():
    features = hypbl_encoder(images)
    outputs = hypbl_classifier(features)

    _, predicted = torch.max(outputs.tensor, 1)

# print images
imshow(torchvision.utils.make_grid(images[0:16]))

# Prepare rows with GroundTruth and Predicted labels
rows = [["GroundTruth", *[classes[labels[j]] for j in range(16)]],
        ["Predicted", *[classes[predicted[j]] for j in range(16)]]]

# Display the table
print(tabulate(rows, headers=[""] + [f"Image {i+1}" for i in range(16)], tablefmt="grid"))

### Visualize feature space

In [None]:
if SCENARIO == 'single-class':

    # Extract features from the test set
    remaining_features, remaining_labels = extract_features(hypbl_encoder, remaining_testloader, hyperbolic=HYPBL_ARCHITECTURE, manifold=MANIFOLD, device=device)
    unlearned_features, unlearned_labels = extract_features(hypbl_encoder, unlearning_testloader, hyperbolic=HYPBL_ARCHITECTURE, manifold=MANIFOLD, device=device)

    # Combine for t-SNE
    features = np.concatenate((unlearned_features, remaining_features), axis=0)
    labels = np.concatenate((unlearned_labels, remaining_labels), axis=0)

# In this case, there is no specific class to isolate, so the plotting goal is for all the models (original, retrain and unlearned) to plot a similar distribution of the whole test set
elif SCENARIO == 'random-sample':

    features, labels = extract_features(hypbl_encoder, remaining_testloader, hyperbolic=HYPBL_ARCHITECTURE, manifold=MANIFOLD, device=device)

create_plot(features, labels, classes, dimension=2, convexhull=False)