# Stochastic Unlearning with Knowledge Preserving Loss

This notebook describes the solution of 9th place on private leaderboard. It ranked 3rd place on public leaderboard.
The detailed explanation of our methods can be found in pdf below:
https://www.dropbox.com/scl/fi/izmzhmj3ktqk3ze6rhjv6/Kaggle_Unlearning_Challenge_Solution.pdf?rlkey=9pqw47izw4nuuthanw3czu1gi&dl=0

In [None]:
import os
import subprocess

import pandas as pd
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
# It's really important to add an accelerator to your notebook, as otherwise the submission will fail.
# We recomment using the P100 GPU rather than T4 as it's faster and will increase the chances of passing the time cut-off threshold.

if DEVICE != 'cuda':
    raise RuntimeError('Make sure you have added an accelerator to your notebook; the submission will fail otherwise!')

In [None]:
# Helper functions for loading the hidden dataset.

class RandomGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.mean = mean
        self.std = std
        
    def __call__(self, image):
        return image + torch.randn(image.size()) * self.std + self.mean

def load_example(df_row):
    image = torchvision.io.read_image(df_row['image_path'])
    
    result = {
        'image': image,
        'image_id': df_row['image_id'],
        'age_group': df_row['age_group'],
        'age': df_row['age'],
        'person_id': df_row['person_id']
    }
    return result


class HiddenDataset(Dataset):
    '''The hidden dataset.'''
    def __init__(self, split='train'):
        super().__init__()
        self.examples = []
        self.transform = None
        if split == 'train' or split == 'retain' or split == 'forget':
            from torchvision import transforms
            self.transform = transforms.Compose([
                RandomGaussianNoise(mean=0.0, std=0.1), # Add Data Augmentation (Random Gaussian Noise)
            ])
        df = pd.read_csv(f'/kaggle/input/neurips-2023-machine-unlearning/{split}.csv')
        df['image_path'] = df['image_id'].apply(
            lambda x: os.path.join('/kaggle/input/neurips-2023-machine-unlearning/', 'images', x.split('-')[0], x.split('-')[1] + '.png'))
        df = df.sort_values(by='image_path')
        df.apply(lambda row: self.examples.append(load_example(row)), axis=1)
        if len(self.examples) == 0:
            raise ValueError('No examples.')

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        image = example['image']
        image = image.to(torch.float32)
        if self.transform is not None:
            image = self.transform(image)
        example['image'] = image
        return example


def get_dataset(batch_size):
    '''Get the dataset.'''
    retain_ds = HiddenDataset(split='retain')
    forget_ds = HiddenDataset(split='forget')
    val_ds = HiddenDataset(split='validation')

    retain_loader = DataLoader(retain_ds, batch_size=batch_size, shuffle=True)
    forget_loader = DataLoader(forget_ds, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)

    return retain_loader, forget_loader, validation_loader

In [None]:
def unlearning(
    net,
    retain_loader,
    forget_loader,
    val_loader):
    """Simple unlearning by finetuning."""
    cycles = 4              # number of cycles
    epochs = [2, 1, 1, 1]   # epochs for each cycle
    num_init = [6, 6, 6, 6] # number of re-initializing parameters for each cycle, 1/10 of the model parameters (layer level)
    
    # Clone the original model with requires_grad=False
    from copy import deepcopy
    cloned_net = deepcopy(net)
    for param in cloned_net.parameters():
        param.requires_grad = False
        
    # Knowledge Preserving Loss (MSE loss) used in Remembering Phase
    criterion = nn.MSELoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
    
    # Custom lr scheduling
    def adjust_lr(epoch):
        # base_lr * lambda
        # last additional element 0 for preventing indexing error at the last scheduler.step()
        lambda_list = [0.5, 1, 1, 1, 1, 0] # make lr=[0.0005, 0.001, 0.001, 0.001, 0.001] for each epoch
        return lambda_list[epoch]
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=adjust_lr)
    net.train()

    # Setting seleciton pool, which will be selected to be re-initialized in Forgetting Phase
    params = [p for p in net.named_parameters() if 'downsample' not in p[0] and 'fc' not in p[0]] # do not re-initialize FC and Downsample

    import random
    for cycle in range(cycles):
        ### Forgetting Phase: Stochastic Re-initialization ###
        # Randomly reinitialize part of the network parameters (layer level)
        with torch.no_grad():
            for _ in range(num_init[cycle]):
                name, param = random.choice(params) # Sampling with replacement
                if 'weight' in name:
                    nn.init.normal_(param, mean=0, std=0.01) # Using normal distribution for initialization
                elif 'bias' in name:
                    nn.init.zeros_(param) # Initializing biases to zero

        ### Remembering Phase: Knowledge Preserving Loss ###
        # Remind the model about retain set
        # Only retain set is utilized
        for ep in range(epochs[cycle]):
            for i, sample in enumerate(retain_loader):
                inputs = sample["image"]
                inputs = inputs.to(DEVICE)
            
                optimizer.zero_grad()
                targets = cloned_net(inputs)
                outputs = net(inputs)
                loss = criterion(outputs, targets) # Calculate loss between original model output and unlearned model output
                loss.backward()
                optimizer.step()
            scheduler.step()
                
    net.eval()

In [None]:
if os.path.exists('/kaggle/input/neurips-2023-machine-unlearning/empty.txt'):
    # mock submission
    subprocess.run('touch submission.zip', shell=True)
else:
    
    # Note: it's really important to create the unlearned checkpoints outside of the working directory 
    # as otherwise this notebook may fail due to running out of disk space.
    # The below code saves them in /kaggle/tmp to avoid that issue.
    
    os.makedirs('/kaggle/tmp', exist_ok=True)
    retain_loader, forget_loader, validation_loader = get_dataset(64)
    net = resnet18(weights=None, num_classes=10)
    net.to(DEVICE)
    for i in range(512):
        net.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
        unlearning(net, retain_loader, forget_loader, validation_loader)
        state = net.state_dict()
        torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
        
    # Ensure that submission.zip will contain exactly 512 checkpoints 
    # (if this is not the case, an exception will be thrown).
    unlearned_ckpts = os.listdir('/kaggle/tmp')
    if len(unlearned_ckpts) != 512:
        raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')
        
    subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)
