# MNIST EXAMPLE

# Instaling libraries

In [None]:
!pip install shap --upgrade

In [None]:
import os
import random
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18
import torch.nn.functional as F
import torch.nn.utils.prune as prune
from copy import deepcopy
from math import sqrt
from torch.utils.data import Subset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
RNG = torch.Generator().manual_seed(42)

#  Download dataset

In [None]:
# download and pre-process MNIST
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_set = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)

held_out = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=normalize
)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)

# Training original model

In [None]:
model = resnet18(pretrained=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dostosowanie klasyfikatora do 10 klas (MNIST)
model = resnet18(pretrained=False)
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Zmiana liczby wejściowych kanałów na 1
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

# 3. Definicja funkcji straty i optymalizatora
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 4. Trenowanie modelu
num_epochs = 5
train_losses = []
test_accuracies = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)
    print(f"Epoka [{epoch+1}/{num_epochs}], Strata: {avg_loss:.4f}")

    # Ewaluacja na zbiorze testowym
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f"Dokładność na zbiorze testowym: {accuracy:.2f}%")

# Retrain function and accuracy

In [None]:
def full_retrain(train_loader):
  model = resnet18(pretrained=False)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  model = resnet18(pretrained=False)
  model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
  model.fc = nn.Linear(model.fc.in_features, 10)
  model = model.to(device)

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)

  num_epochs = 5
  train_losses = []
  test_accuracies = []

  for epoch in range(num_epochs):
      model.train()
      running_loss = 0.0

      for images, labels in train_loader:
          images, labels = images.to(device), labels.to(device)

          optimizer.zero_grad()
          outputs = model(images)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          running_loss += loss.item()

      avg_loss = running_loss / len(train_loader)
      train_losses.append(avg_loss)
      print(f"Epoka [{epoch+1}/{num_epochs}], Strata: {avg_loss:.4f}")

      # Ewaluacja na zbiorze testowym
  model.eval()
  return model

In [None]:
def accuracy(net, loader):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total


print(f"Train set accuracy: {100.0 * accuracy(model, train_loader):0.1f}%")
print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%")

# Forgetset split

In [None]:
def data_to_forget(train, class_indices, label, size):
  sampled_class = []
  if label == -1:
    for i in range(10):
      print("*")
      tmp = random.sample(class_indices[i], int(size/10))
      class_indices[i] = list(set(class_indices[i]) - set(tmp))
      sampled_class.extend(tmp)
  else:
    sampled_class = random.sample(class_indices[label], size)  # Wylosowanie 50 próbek
    class_indices[label] = list(set(class_indices[label]) - set(sampled_class))

  flat_class_indices = [idx for sublist in class_indices for idx in sublist]
  forget_set = torch.utils.data.Subset(train_set, sampled_class)
  retain_set = torch.utils.data.Subset(train_set, flat_class_indices)

  forget_loader = torch.utils.data.DataLoader(
      forget_set, batch_size=128, shuffle=True, num_workers=2
  )
  retain_loader = torch.utils.data.DataLoader(
      retain_set, batch_size=128, shuffle=True, num_workers=2, generator=RNG
  )
  return forget_loader, retain_loader, class_indices

# Unlearning alghoritms
Source: https://www.kaggle.com/competitions/neurips-2023-machine-unlearning

#  Fauchan - 1st place



In [None]:
def evaluation(net, dataloader, criterion, device = 'cuda'): ##evaluation function
    net.eval()
    total_samp = 0
    total_acc = 0
    total_loss = 0.0
    for sample in dataloader:
        images, labels = sample['image'].to(device), sample['age_group'].to(device)
        _pred = net(images)
        total_samp+=len(labels)
        #print(f'total_samp={total_samp}')
        loss = criterion(_pred, labels)
        total_loss += loss.item()
        total_acc+=(_pred.max(1)[1] == labels).float().sum().item()
        #print(f'total_acc={total_acc}')
    #print(f'total_sample={total_samp}')
    mean_loss = total_loss / len(dataloader)
    mean_acc = total_acc/total_samp
    print(f'loss={mean_loss}')
    print(f'acc={mean_acc}')
    return

In [None]:
USE_MOCK: bool = False

from torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLR
def kl_loss_sym(x,y):
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    return kl_loss(nn.LogSoftmax(dim=-1)(x),y)
def fauchan(
        net,
        retain_loader,
        forget_loader,
        val_loader,
):
    """Simple unlearning by finetuning."""
    print('-----------------------------------')
    epochs = 8
    retain_bs = 256
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.005,
                          momentum=0.9, weight_decay=0)
    optimizer_retain = optim.SGD(net.parameters(), lr=0.001*retain_bs/64, momentum=0.9, weight_decay=1e-2)
    ##the learning rate is associated with the batchsize we used
    optimizer_forget = optim.SGD(net.parameters(), lr=3e-4, momentum=0.9, weight_decay=0)
    total_step = int(len(forget_loader)*epochs)
    retain_ld = DataLoader(retain_loader.dataset, batch_size=retain_bs, shuffle=True)
    retain_ld4fgt = DataLoader(retain_loader.dataset, batch_size=256, shuffle=True)
    scheduler = CosineAnnealingLR(optimizer_forget, T_max=total_step, eta_min=1e-6)
    """if USE_MOCK: ##Use some Local Metric as reference
        net.eval()
        print('Forget')
        evaluation(net, forget_loader, criterion)
        print('Valid')
        evaluation(net, validation_loader, criterion)"""
    net.train()
    for inputs, targets in forget_loader: ##First Stage
        inputs = inputs.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(inputs)
        uniform_label = torch.ones_like(outputs).to(DEVICE) / outputs.shape[1] ##uniform pseudo label
        loss = kl_loss_sym(outputs, uniform_label) ##optimize the distance between logits and pseudo labels
        loss.backward()
        optimizer.step()
    """if USE_MOCK:
        print('Forget')
        evaluation(net,forget_loader,criterion)
        print('Valid')
        evaluation(net, validation_loader,criterion)
        print(f'epoch={epochs} and retain batch_sz={retain_bs}')"""
    net.train()
    for ep in range(epochs): ##Second Stage
        net.train()
        for (inputs_forget, outputs_forget), (inputs_retain, outputs_retain) in zip(forget_loader, retain_ld4fgt):##Forget Round
            t = 1.15 ##temperature coefficient
            inputs_forget, inputs_retain = inputs_forget.to(DEVICE), inputs_retain.to(DEVICE)
            optimizer_forget.zero_grad()
            outputs_forget,outputs_retain = net(inputs_forget),net(inputs_retain).detach()
            loss = (-1 * nn.LogSoftmax(dim=-1)(outputs_forget @ outputs_retain.T/t)).mean() ##Contrastive Learning loss
            loss.backward()
            optimizer_forget.step()
            scheduler.step()
        for inputs, labels in retain_ld: ##Retain Round
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer_retain.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer_retain.step()
        """if USE_MOCK:
            print(f'epoch {ep}:')
            print('Retain')
            evaluation(net, retain_ld, criterion)
            print('Forget')
            evaluation(net, forget_loader, criterion)
            print('Valid')
            evaluation(net, validation_loader, criterion)"""
    print('-----------------------------------')
    return net

# Kookmin - 2 place

In [None]:
sch = 'linear'
init_rate = 0.3
init_method = 'snip_little_grad'
lr = 0.001
epoch = 5
weight_decay = 5e-4

In [None]:
from torch.optim.lr_scheduler import _LRScheduler

class LinearAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, num_annealing_steps, num_total_steps):
        self.num_annealing_steps = num_annealing_steps
        self.num_total_steps = num_total_steps

        super().__init__(optimizer)

    def get_lr(self):
        if self._step_count <= self.num_annealing_steps:
            return [base_lr * self._step_count / self.num_annealing_steps for base_lr in self.base_lrs]
        else:
            return [base_lr * (self.num_total_steps - self._step_count) / (self.num_total_steps - self.num_annealing_steps) for base_lr in self.base_lrs]

In [None]:
class Masker(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, mask):
        ctx.save_for_backward(mask)
        return x

    @staticmethod
    def backward(ctx, grad):
        mask, = ctx.saved_tensors
        return grad * mask, None


class MaskConv2d(nn.Conv2d):
    def __init__(self, mask, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device='cpu'):
        super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                     padding, dilation, groups, bias, padding_mode, device=device)
        self.mask = mask

    def forward(self, input):
        masked_weight = Masker.apply(self.weight, self.mask)
        return super(MaskConv2d, self)._conv_forward(input, masked_weight, self.bias)

@torch.no_grad()
def re_init_model_random_zero_grad_Maskconv(model, px):
    print("Apply Unstructured random re init no grads Globally (all conv layers)")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            new_conv = MaskConv2d(mask, m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            #nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
            new_conv.weight.data[~mask] = m.weight[~mask]
            setattr(model, name, new_conv)

@torch.no_grad()
def re_init_model_random_little_grad_Maskconv_fanout(model, px):
    print("Apply Unstructured re_init_model_random_little_grad_Maskconv_fanout (all conv layers)")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True
            grad_mask = mask.clone().float()
            grad_mask[grad_mask==0] += 0.1

            new_conv = MaskConv2d(grad_mask, m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
            new_conv.weight.data[~mask] = m.weight[~mask]
            setattr(model, name, new_conv)

In [None]:
@torch.no_grad()
def re_init_model_snip_ver2_little_grad(model, px): # re init smallest gradients
    print("Apply Unstructured re_init_model_snip_ver2_little_grad Globally (all conv layers)")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = -m.weight.grad.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True
            grad_mask = mask.clone().float()
            grad_mask[grad_mask==0] += 0.1

            new_conv = MaskConv2d(grad_mask, m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")

            new_conv.weight.data[~mask] = m.weight[~mask]

            set_layer(model, name, new_conv)

def set_layer(model, layer_name, layer):
    splited = layer_name.split('.')
    if len(splited) == 1:
        setattr(model, splited[0], layer)
    elif len(splited) == 3:
        setattr(getattr(model, splited[0])[int(splited[1])], splited[2], layer)
    elif len(splited) == 4:
        getattr(getattr(model, splited[0])[int(splited[1])], splited[2])[int(splited[3])] = layer

@torch.no_grad()
def replace_maskconv(model):
    print("Remove Maskconv")
    for name, m in list(model.named_modules()):
        if isinstance(m, MaskConv2d):
            conv = nn.Conv2d(m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            conv.weight.data = m.weight
            conv.bias = m.bias
            set_layer(model, name, conv)

In [None]:
@torch.no_grad()
def re_init_model_snip_ver2(model, px): # re init smallest gradients
    print("Apply Unstructured snip ve2 re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = -m.weight.grad.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]

def get_grads_for_snip(model, retain_loader, forget_loader):
    indices = torch.randperm(len(retain_loader.dataset), dtype=torch.int32, device='cpu')[:len(forget_loader.dataset)]

    model.zero_grad()
    for inputs, targets in retain_loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        outputs = model(inputs)
        loss = F.cross_entropy(outputs, targets)
        loss.backward()

    for inputs, targets in  forget_loader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        outputs = model(inputs)
        loss = -F.cross_entropy(outputs, targets)
        loss.backward()

In [None]:
def kookmin(
    net,
    retain_loader,
    forget_loader,
    val_loader):
    if init_method=='snip_little_grad':
        replace_maskconv(net)
        get_grads_for_snip(net, retain_loader, forget_loader)
        re_init_model_snip_ver2_little_grad(net, init_rate)
    else:
        raise "not implemented"

    """Simple unlearning by finetuning."""
    epochs = epoch
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=0.9, weight_decay=weight_decay)
    if sch=='cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    elif sch=='linear':
        scheduler = LinearAnnealingLR(optimizer, num_annealing_steps=(epochs+1)//2, num_total_steps=epochs+1)
    elif sch=='decrease':
        print('decrease')
        scheduler = LinearAnnealingLR(optimizer, num_annealing_steps=1, num_total_steps=epochs+1)

    net.train()

    for ep in range(epochs):
        for inputs, targets in retain_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

        scheduler.step()
    #remove_prune(net)
    net.eval()
    return net

# Seif - 3rd place

In [None]:
def seif(
    net,
    retain_loader,
    forget_loader,
    val_loader):
    """Simple unlearning by finetuning."""



    class CustomCrossEntropyLoss(nn.Module):
        def __init__(self, class_weights=None):
            super(CustomCrossEntropyLoss, self).__init__()
            self.class_weights = class_weights

        def forward(self, input, target):
            # Compute the standard cross-entropy loss
            ce_loss = nn.functional.cross_entropy(input, target)

            # Apply class weights to the loss if provided
            if self.class_weights is not None:
                # Calculate the weights for each element in the batch based on the target
                weights = torch.tensor([self.class_weights[i] for i in target], device=input.device)
                ce_loss = torch.mean(ce_loss * weights)

            return ce_loss



    # Define the vision_confuser function
    def vision_confuser(model, std = 0.6):
        for name, module in model.named_children():
            if hasattr(module, 'weight'):
                if 'conv' in name:

                    actual_value = module.weight.clone().detach()
                    new_values = torch.normal(mean=actual_value, std=std)
                    module.weight.data.copy_(new_values)

    vision_confuser(net)

    epochs = 4

    w = 0.05

    class_weights = [1, w, w, w, w, w, w, w, w, w]
    criterion = CustomCrossEntropyLoss(class_weights)


    optimizer = optim.SGD(net.parameters(), lr=0.0007,
                      momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs)

    net.train()
    i=0

    for ep in range(epochs):
        i=0
        net.train()
        for inputs, targets in retain_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        if (ep == epochs-2):
            vision_confuser(net , 0.005) # increase model robustness before last training epoch

        scheduler.step()

    net.eval()
    return net

# Sebastian - 4th place

In [None]:
def kl_loss_fn(outputs, dist_target):
    kl_loss = F.kl_div(torch.log_softmax(outputs, dim=1), dist_target, log_target=True, reduction='batchmean')
    return kl_loss

def entropy_loss_fn(outputs, labels, dist_target, class_weights):
    ce_loss = F.cross_entropy(outputs, labels, weight=class_weights)
    entropy_dist_target = torch.sum(-torch.exp(dist_target) * dist_target, dim=1)
    entropy_outputs = torch.sum(-torch.softmax(outputs, dim=1) * torch.log_softmax(outputs, dim=1), dim=1)
    entropy_loss = F.mse_loss(entropy_outputs, entropy_dist_target)
    return ce_loss + entropy_loss

In [None]:
def sebastian(
    net,
    retain_loader,
    forget_loader,
    val_loader,
    class_weights=None,
):
    """Simple unlearning by finetuning."""
    epochs = 3.2
    max_iters = int(len(retain_loader) * epochs)
    optimizer = optim.SGD(net.parameters(), lr=0.0005,
                      momentum=0.9, weight_decay=5e-4)
    initial_net = deepcopy(net)

    net.train()
    initial_net.eval()

    def prune_model(net, amount=0.95, rand_init=True):
        # Modules to prune
        modules = list()
        for k, m in enumerate(net.modules()):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                modules.append((m, 'weight'))
                if m.bias is not None:
                    modules.append((m, 'bias'))

        # Prune criteria
        prune.global_unstructured(
            modules,
            #pruning_method=prune.RandomUnstructured,
            pruning_method=prune.L1Unstructured,
            amount=amount,
        )

        # Perform the prune
        for k, m in enumerate(net.modules()):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                prune.remove(m, 'weight')
                if m.bias is not None:
                    prune.remove(m, 'bias')

        # Random initialization
        if rand_init:
            for k, m in enumerate(net.modules()):
                if isinstance(m, nn.Conv2d):
                    mask = m.weight == 0
                    c_in = mask.shape[1]
                    k = 1/(c_in*mask.shape[2]*mask.shape[3])
                    randinit = (torch.rand_like(m.weight)-0.5)*2*sqrt(k)
                    m.weight.data[mask] = randinit[mask]
                if isinstance(m, nn.Linear):
                    mask = m.weight == 0
                    c_in = mask.shape[1]
                    k = 1/c_in
                    randinit = (torch.rand_like(m.weight)-0.5)*2*sqrt(k)
                    m.weight.data[mask] = randinit[mask]

    num_iters = 0
    running = True
    prune_amount = 0.99
    prune_model(net, prune_amount, True)
    while running:
        net.train()
        for inputs, targets in retain_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            # Get target distribution
            with torch.no_grad():
                original_outputs = initial_net(inputs)
                preds = torch.log_softmax(original_outputs, dim=1)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = entropy_loss_fn(outputs, targets, preds, class_weights)
            loss.backward()
            optimizer.step()

            num_iters += 1
            # Stop at max iters
            if num_iters > max_iters:
                running = False
                break

    net.eval()
    return net

# Amnesiacs - 6th place

In [None]:
def kl_loss(model_logits, teacher_logits, temperature=1.):
    '''
    Calculate the Kullback-Leibler (KL) divergence loss between the output of a model (student) and the output of a teacher model.

    Args:
        model_logits (torch.Tensor): The logits from the student model. Logits are the raw outputs of the last neural network layer prior to softmax activation.
        teacher_logits (torch.Tensor): The logits from the teacher model.
        temperature (float, optional): A temperature parameter that scales the logits before applying softmax. Higher temperatures produce softer probability distributions. Defaults to 1.

    Returns:
        torch.Tensor: The KL divergence loss, averaged over the batch.
    '''
    # apply softmax to the (scaled) logits of the teacher model
    teacher_output_softmax = F.softmax(teacher_logits / temperature, dim=1)
    # apply log softmax to the (scaled) logits of the student model
    output_log_softmax = F.log_softmax(model_logits / temperature, dim=1)

    # calculate the KL divergence between the student and teacher outputs
    kl_div = F.kl_div(output_log_softmax, teacher_output_softmax, reduction='batchmean')
    return kl_div

In [None]:
def amnesiacs(
    net,
    retain_loader,
    forget_loader,
    val_loader,
    class_weights=None,
):
    """Simple unlearning by finetuning."""
    epochs = 3.2
    max_iters = int(len(retain_loader) * epochs)
    optimizer = optim.SGD(net.parameters(), lr=0.0005,
                      momentum=0.9, weight_decay=5e-4)
    initial_net = deepcopy(net)

    net.train()
    initial_net.eval()

    def prune_model(net, amount=0.95, rand_init=True):
        # Modules to prune
        modules = list()
        for k, m in enumerate(net.modules()):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                modules.append((m, 'weight'))
                if m.bias is not None:
                    modules.append((m, 'bias'))

        # Prune criteria
        prune.global_unstructured(
            modules,
            #pruning_method=prune.RandomUnstructured,
            pruning_method=prune.L1Unstructured,
            amount=amount,
        )

        # Perform the prune
        for k, m in enumerate(net.modules()):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                prune.remove(m, 'weight')
                if m.bias is not None:
                    prune.remove(m, 'bias')

        # Random initialization
        if rand_init:
            for k, m in enumerate(net.modules()):
                if isinstance(m, nn.Conv2d):
                    mask = m.weight == 0
                    c_in = mask.shape[1]
                    k = 1/(c_in*mask.shape[2]*mask.shape[3])
                    randinit = (torch.rand_like(m.weight)-0.5)*2*sqrt(k)
                    m.weight.data[mask] = randinit[mask]
                if isinstance(m, nn.Linear):
                    mask = m.weight == 0
                    c_in = mask.shape[1]
                    k = 1/c_in
                    randinit = (torch.rand_like(m.weight)-0.5)*2*sqrt(k)
                    m.weight.data[mask] = randinit[mask]

    num_iters = 0
    running = True
    prune_amount = 0.99
    prune_model(net, prune_amount, True)
    while running:
        net.train()
        for inputs, targets in retain_loader:
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            # Get target distribution
            with torch.no_grad():
                original_outputs = initial_net(inputs)
                preds = torch.log_softmax(original_outputs, dim=1)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = entropy_loss_fn(outputs, targets, preds, class_weights)
            loss.backward()
            optimizer.step()

            num_iters += 1
            # Stop at max iters
            if num_iters > max_iters:
                running = False
                break

    net.eval()
    return net

# Unlearning operation

In [None]:
import copy

#rt_model = copy.deepcopy(model)
fau_model = copy.deepcopy(model)
koo_model = copy.deepcopy(model)
sei_model = copy.deepcopy(model)
seb_model = copy.deepcopy(model)
amn_model = copy.deepcopy(model)

class_indices = [[] for _ in range(10)]

for idx, (_, label) in enumerate(train_set):
    class_indices[label].append(idx)

print(class_indices)

forget_loader, retain_loader, class_indices = data_to_forget(train_set, class_indices, -1, 6000)
print("Fauchan")
fau_model = fauchan(fau_model, retain_loader, forget_loader, val_loader)
print("Kookmin")
koo_model = kookmin(koo_model, retain_loader, forget_loader, val_loader)
print("Seif")
sei_model = seif(sei_model, retain_loader, forget_loader, val_loader)
print("Sebastian")
seb_model = sebastian(seb_model, retain_loader, forget_loader, val_loader)
print("Amnesiacs")
amn_model = amnesiacs(amn_model, retain_loader, forget_loader, val_loader)
print("Retain")
rt_model = full_retrain(retain_loader)

# SHAP Evaluation

In [None]:
import shap
from collections import defaultdict
import torch

def calculate_shap(model, test_images):
  images = test_images
  images = images.to(DEVICE)

  explainer = shap.GradientExplainer(model, images)

  shap_values = explainer.shap_values(images)

  shap_values_images = [np.transpose(sv, (0, 2, 3, 1)) for sv in shap_values]
  shap_sum = np.sum(np.abs(shap_values), axis=(1,2,3))

  return shap_values_images, shap_values, shap_sum


def get_one_image_per_class(dataset):
    """Pobiera po jednym obrazie z każdej klasy."""
    images_per_class = {}
    for image, label in dataset:
        if label not in images_per_class:
            images_per_class[label] = image
        if len(images_per_class) == 10:
            break
    return list(images_per_class.values())

class_indices = {}
subset_indices = []
for idx, (img, label) in enumerate(test_set):
    if label not in class_indices:
        class_indices[label] = idx
        subset_indices.append(idx)
    if len(subset_indices) == 10:
        break

subset = Subset(test_set, subset_indices)
data_loader = DataLoader(subset, batch_size=10, shuffle=False)

import torch

def extract_data(dataset):
    images = []
    labels = []
    for img, label in dataset:
        images.append(img.cpu())
        labels.append(label)

    images = torch.stack(images).cpu()
    labels = torch.tensor(labels).cpu()
    return images, labels

def get_balanced_subset(dataset, samples_per_class=100):
    class_indices = defaultdict(list)

    for idx, (img, label) in enumerate(dataset):
        class_indices[label].append(idx)

    selected_indices = []
    for label, indices in class_indices.items():
        if len(indices) < samples_per_class:
            raise ValueError(f"Klasa {label} ma tylko {len(indices)} próbek.")
        selected_indices.extend(random.sample(indices, samples_per_class))

    return Subset(dataset, selected_indices)

balanced_subset = get_balanced_subset(test_set, samples_per_class=100)

test_images, test_labels = extract_data(balanced_subset)

print(len(test_images))


fau_shap_values_images, fau_shap_values, fau_shap_sum = calculate_shap(fau_model, test_images)
koo_shap_values_images, koo_shap_values, koo_shap_sum = calculate_shap(koo_model, test_images)
sei_shap_values_images, sei_shap_values, sei_shap_sum = calculate_shap(sei_model, test_images)
seb_shap_values_images, seb_shap_values, seb_shap_sum = calculate_shap(seb_model, test_images)
amn_shap_values_images, amn_shap_values, amn_shap_sum = calculate_shap(amn_model, test_images)
rt_shap_values_images, rt_shap_values, rt_shap_sum = calculate_shap(rt_model, test_images)

# Count ϕ

In [None]:
tmp_fau = fau_shap_sum[np.arange(len(test_labels)), test_labels]
tmp_koo = koo_shap_sum[np.arange(len(test_labels)), test_labels]
tmp_sei = sei_shap_sum[np.arange(len(test_labels)), test_labels]
tmp_seb = seb_shap_sum[np.arange(len(test_labels)), test_labels]
tmp_amn = amn_shap_sum[np.arange(len(test_labels)), test_labels]
tmp_rt = rt_shap_sum[np.arange(len(test_labels)), test_labels]

In [None]:
print(np.sum(tmp_fau / np.sum(fau_shap_sum, axis = 1)) / np.sum(tmp_rt / np.sum(rt_shap_sum, axis = 1)))
print(np.sum(tmp_koo / np.sum(koo_shap_sum, axis = 1)) / np.sum(tmp_rt / np.sum(rt_shap_sum, axis = 1)))
print(np.sum(tmp_sei / np.sum(sei_shap_sum, axis = 1)) / np.sum(tmp_rt / np.sum(rt_shap_sum, axis = 1)))
print(np.sum(tmp_seb / np.sum(seb_shap_sum, axis = 1)) / np.sum(tmp_rt / np.sum(rt_shap_sum, axis = 1)))
print(np.sum(tmp_amn / np.sum(amn_shap_sum, axis = 1)) / np.sum(tmp_rt / np.sum(rt_shap_sum, axis = 1)))

# Count accuracy

In [None]:
balanced_loader = torch.utils.data.DataLoader(
    balanced_subset, batch_size=128, shuffle=True, num_workers=2
)

print(accuracy(fau_model, balanced_loader))
print(accuracy(koo_model, balanced_loader))
print(accuracy(sei_model, balanced_loader))
print(accuracy(seb_model, balanced_loader))
print(accuracy(amn_model, balanced_loader))
print(accuracy(rt_model, balanced_loader))

In [None]:
print(accuracy(fau_model, balanced_loader) / accuracy(rt_model, balanced_loader))
print(accuracy(koo_model, balanced_loader) / accuracy(rt_model, balanced_loader))
print(accuracy(sei_model, balanced_loader) / accuracy(rt_model, balanced_loader))
print(accuracy(seb_model, balanced_loader) / accuracy(rt_model, balanced_loader))
print(accuracy(amn_model, balanced_loader) / accuracy(rt_model, balanced_loader))

In [None]:
print((np.sum((fau_base_values)) - np.sum((rt_base_values))) /10)
print((np.sum((koo_base_values)) - np.sum((rt_base_values))) /10)
print((np.sum((sei_base_values)) - np.sum((rt_base_values))) /10)
print((np.sum((seb_base_values)) - np.sum((rt_base_values))) /10)
print((np.sum((amn_base_values)) - np.sum((rt_base_values))) /10)

# Compare plot on 10 random images

In [None]:
import shap

def calculate_shap(model, data_loader):
  images, _ = next(iter(data_loader))
  images = images.to(DEVICE)

  explainer = shap.GradientExplainer(model, images)

  shap_values = explainer.shap_values(images)

  shap_values_images = [np.transpose(sv, (0, 2, 3, 1)) for sv in shap_values]
  shap_sum = np.sum(np.abs(shap_values), axis=(1,2,3))

  return shap_values_images, shap_values, shap_sum


def get_one_image_per_class(dataset):
    images_per_class = {}
    for image, label in dataset:
        if label not in images_per_class:
            images_per_class[label] = image
        if len(images_per_class) == 10:
            break
    return list(images_per_class.values())

class_indices = {}
subset_indices = []

for idx, (img, label) in enumerate(test_set):
    if label not in class_indices:
        class_indices[label] = idx
        subset_indices.append((label, idx))
    if len(class_indices) == 10:
        break

subset_indices.sort(key=lambda x: x[0])

sorted_indices = [idx for _, idx in subset_indices]

subset = Subset(test_set, sorted_indices)
data_loader = DataLoader(subset, batch_size=10, shuffle=False)


fau_shap_values_images, fau_shap_values, fau_shap_sum = calculate_shap(fau_model, data_loader)
koo_shap_values_images, koo_shap_values, koo_shap_sum = calculate_shap(koo_model, data_loader)
sei_shap_values_images, sei_shap_values, sei_shap_sum = calculate_shap(sei_model, data_loader)
seb_shap_values_images, seb_shap_values, seb_shap_sum = calculate_shap(seb_model, data_loader)
amn_shap_values_images, amn_shap_values, amn_shap_sum = calculate_shap(amn_model, data_loader)
rt_shap_values_images, rt_shap_values, rt_shap_sum = calculate_shap(rt_model, data_loader)
org_shap_values_images, org_shap_values, org_shap_sum = calculate_shap(model, data_loader)

In [None]:
from IPython.display import display, HTML

images, _ = next(iter(data_loader))
images = images.to(DEVICE)
images_np = images.permute(0, 2, 3, 1).cpu().numpy()

def denormalize(img):
    img = img * 0.5 + 0.5
    img = np.clip(img, 0, 1)
    return img

images_np = denormalize(images_np)
print(np.array(fau_shap_values_images).shape)

fau_shap_values_images_fix = list(np.transpose(fau_shap_values, (4, 0, 2, 3, 1)))
koo_shap_values_images_fix = list(np.transpose(koo_shap_values, (4, 0, 2, 3, 1)))
sei_shap_values_images_fix = list(np.transpose(sei_shap_values, (4, 0, 2, 3, 1)))
seb_shap_values_images_fix = list(np.transpose(seb_shap_values, (4, 0, 2, 3, 1)))
amn_shap_values_images_fix = list(np.transpose(amn_shap_values, (4, 0, 2, 3, 1)))
rt_shap_values_images_fix = list(np.transpose(rt_shap_values, (4, 0, 2, 3, 1)))
org_shap_values_images_fix = list(np.transpose(org_shap_values, (4, 0, 2, 3, 1)))

In [None]:
import shap
import matplotlib.pyplot as plt
import numpy as np

num_samples = len(images_np)

fau_selected_shap_values = [fau_shap_values_images_fix[i][i] for i in range(num_samples)]
koo_selected_shap_values = [koo_shap_values_images_fix[i][i] for i in range(num_samples)]
sei_selected_shap_values = [sei_shap_values_images_fix[i][i] for i in range(num_samples)]
seb_selected_shap_values = [seb_shap_values_images_fix[i][i] for i in range(num_samples)]
amn_selected_shap_values = [amn_shap_values_images_fix[i][i] for i in range(num_samples)]
rt_selected_shap_values = [rt_shap_values_images_fix[i][i] for i in range(num_samples)]
org_selected_shap_values = [org_shap_values_images_fix[i][i] for i in range(num_samples)]

def ensure_channel_dim(shap_array):
    if shap_array.ndim == 3:
        return np.expand_dims(shap_array, axis=-1)
    return shap_array

fau_selected_shap_values = ensure_channel_dim(np.array(fau_selected_shap_values))
koo_selected_shap_values = ensure_channel_dim(np.array(koo_selected_shap_values))
sei_selected_shap_values = ensure_channel_dim(np.array(sei_selected_shap_values))
seb_selected_shap_values = ensure_channel_dim(np.array(seb_selected_shap_values))
amn_selected_shap_values = ensure_channel_dim(np.array(amn_selected_shap_values))
rt_selected_shap_values = ensure_channel_dim(np.array(rt_selected_shap_values))
org_selected_shap_values = ensure_channel_dim(np.array(org_selected_shap_values))

selected_shap_values = [
    org_selected_shap_values,
    rt_selected_shap_values,
    fau_selected_shap_values,
    koo_selected_shap_values,
    sei_selected_shap_values,
    seb_selected_shap_values,
    amn_selected_shap_values,
]

print("Images shape:", images_np.shape)
for idx, shap_val in enumerate(selected_shap_values):
    print(f"SHAP values shape [{idx}]:", shap_val.shape)

label = ["Original", "Retrained", "Fauchan", "Kookmin", "Seif", "Sebastian", "Amnesiacs"]

shap.image_plot(selected_shap_values, images_np)
plt.show()