# 🚀 **NeurIPS 2023 Unlearning Challenge Notebook**


Welcome to this Jupyter notebook designed for the NeurIPS 2023 Unlearning Challenge. This notebook demonstrates an unlearning approach using fine-tuning and checkpoint creation.



## Let's get started! 🎉


## 📦 **Cell 1: Importing Libraries and Setting Device**
In this cell, we import necessary libraries, check for GPU availability, and set the device accordingly.

In [1]:
import os
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

DEVICE = device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())
import timm

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
RNG = torch.Generator().manual_seed(42)



Running on device: CUDA


In [2]:
import os  # 📂 Import the 'os' module to interact with the operating system.
import subprocess  # 🚀 Import the 'subprocess' module for running external commands.

import pandas as pd  # 📊 Import 'pandas' library for data manipulation.
import torch  # 🔥 Import 'torch' library for PyTorch functionalities.
import torchvision  # 🖼️ Import 'torchvision' for computer vision utilities.
import torch.nn as nn  # 🧠 Import 'nn' module from 'torch' for neural network components.
import torch.optim as optim  # ⚙️ Import 'optim' module for optimization algorithms.
from torchvision.models import resnet18  # 📸 Import the ResNet-18 model from torchvision.
from torch.utils.data import DataLoader, Dataset  # 🧾 Import data-related components.

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'  # 🧭 Check and set the device to GPU if available, otherwise use CPU.


## 🚧 **Cell 2: Accelerator Check**
This cell checks if a GPU accelerator is available and raises an error if not.

In [3]:
#use p100
if DEVICE != 'cuda':
    raise RuntimeError('❗ Make sure you have added an accelerator (e.g., GPU) to your notebook; the submission will fail otherwise! 🚀')


## 📂 **Cell 3: Dataset Loading and Preparation**
Here, we define functions and classes for loading and preparing the dataset for training and validation.

In [4]:
def load_example(df_row):
    '''Load an example from the dataset.'''
    image = torchvision.io.read_image(df_row['image_path'])  # 📷 Read the image using torchvision.
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset for training and validation.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []

        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')  # 📊 Read the dataset CSV.
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sort_values(by='image_path')  # 📂 Sort the dataset by image path.
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        '''Get an example from the dataset.'''
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)  # 🌟 Convert image to float32.
        example['image'] = image
        return example['image'],example['age_group']


def get_dataset(batch_size):
    '''Get dataloaders for different dataset splits.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)  # 📦 Create a DataLoader for 'retain' split.
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)  # 📦 Create a DataLoader for 'forget' split.
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)  # 📦 Create a DataLoader for 'validation' split.

    return retain_loader, forget_loader, validation_loader


## 🧠 **Cell 4: Unlearning Function**
This cell contains the unlearning function, which fine-tunes the model using a provided dataset.

In [5]:
class Discriminator(nn.Module):
    def __init__(self,input_dim = 512):
        super(Discriminator, self).__init__()
        self.dics = nn.Sequential(
            nn.Linear(input_dim, input_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(input_dim * 4, input_dim * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(input_dim * 8, input_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(input_dim * 2, input_dim),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(input_dim , 1),
    )
    def forward(self, x):
        return self.dics(x)
    
    def get_loss(self,gen,   forget, test):
        
        fake = gen.forward_features(forget).detach()
        real = gen.forward_features(test).detach()
        
        disc_fake_pred = self.forward(fake)
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        
        disc_real_pred = self.forward(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        
        return disc_loss

In [6]:
class Generator(nn.Module):
    def __init__(self,net):
        super(Generator, self).__init__()
        self.model = net
        self.head = self.model.fc
        self.model.fc = nn.Identity()
        
    def forward(self,x):
        x = self.model(x)
        return self.head(x)
    def forward_features(self,x):
        return self.model(x)
    def get_loss(self, disc, forget,class_labels  ):

        logits = self.model(forget)
        disc_fake_pred = disc(logits)

        #classification_loss = clas_criterion(self.head(logits),  class_labels)

        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred)) 

        return gen_loss#(classification_loss + gen_loss) / 2

In [7]:
normalize =transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

In [8]:
criterion = nn.BCEWithLogitsLoss()
clas_criterion = nn.CrossEntropyLoss()

def unlearning(net,retain_loader,forget_loader,val_loader):

    lr = 0.0005
    n_epochs = 25

    gen = Generator(net).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc = Discriminator().to(device) 
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    gscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(gen_opt, T_max=n_epochs)
    dscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(disc_opt, T_max=n_epochs)
    for epoch in range(n_epochs):
        for (forget_images,forget_label),(retain_images,retain_labels) in zip(forget_loader,retain_loader):
            cur_batch_size = forget_images.size(0)

            forget_images = forget_images.to(device, dtype=torch.float)
            forget_label = forget_label.to(device, dtype=torch.long)
            
            retain_images = retain_images.to(device, dtype=torch.float)
            retain_labels = retain_labels.to(device, dtype=torch.long)
            
            test_image = normalize(torch.rand_like(forget_images))
            test_image = test_image.to(device, dtype=torch.float)



            gen.eval()
            disc.train()
            disc_opt.zero_grad()
            disc_loss = disc.get_loss(gen,   forget_images , test_image)

            disc_loss.backward(retain_graph=True)


            disc_opt.step()


            gen.train()
            disc.eval()
            gen_opt.zero_grad()


            outputs = gen(retain_images)
            loss = clas_criterion(outputs, retain_labels)
            gen_loss =  gen.get_loss(disc, forget_images,forget_label)
            (gen_loss + loss).backward()
            
            gen_opt.step()
        gscheduler.step()
        dscheduler.step()
    net.fc = gen.head
    net.eval()  # 🧪 Set the network to evaluation mode when done training.


## 💾 **Cell 5: Checkpoint Generation & Submission**
In this cell, we create unlearned model checkpoints and ensure there are exactly 512 checkpoints.and handles the submission process, including creating the submission.zip file.

In [9]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # 📦 Mock submission: Create an empty submission.zip file.
    subprocess.run('touch submission.zip', shell=True)
else:
    # 🚀 Create unlearned checkpoints outside of the working directory to avoid disk space issues.
    os.makedirs('/kaggle/tmp', exist_ok=True)

    # 🧾 Load the datasets and initialize the model.
    retain_loader, forget_loader, validation_loader = get_dataset(128 * 2)
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)

    # 🔄 Iterate to create unlearned checkpoints.
    for i in range(512):
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))  # 📂 Load the original model.
        unlearning(net, retain_loader, forget_loader, validation_loader)  # 🧠 Perform unlearning.
        state = net.state_dict()  # 📄 Get the model's state.
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')  # 💾 Save unlearned checkpoint.

    # 📏 Check the number of unlearned checkpoints to ensure it's 512.
    unlearned_ckpts = os.listdir('/kaggle/tmp')
    if len(unlearned_ckpts) != 512:
        raise RuntimeError('❗ Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')

    # 📦 Create the submission.zip file containing the unlearned checkpoints.
    subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)
