In [1]:
import os
import sys
import argparse
import re
import time
import random
from datetime import datetime
from typing import Any, Tuple, Dict, List
from copy import deepcopy
import copy
import math
import shutil


from tqdm import tqdm
from sklearn import linear_model, model_selection
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset, dataset
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR100, CIFAR10, ImageFolder
from torchvision.models import resnet18
from torch.autograd import Variable
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
from transformers import ViTModel, ViTFeatureExtractor
import seaborn as sns
import scipy.stats as stats
import pandas as pd
from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from transformers.data.processors import SingleSentenceClassificationProcessor, InputFeatures
from transformers import AutoModel, AutoTokenizer , AutoModelForSequenceClassification, AutoConfig



In [3]:
class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34"""

    # BasicBlock and BottleNeck block
    # have different output size
    # we use class attribute expansion
    # to distinct
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        # residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels * BasicBlock.expansion,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion),
        )

        # shortcut
        self.shortcut = nn.Sequential()

        # the shortcut output dimension is not the same with residual function
        # use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels * BasicBlock.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion),
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers"""

    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                stride=stride,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels * BottleNeck.expansion,
                kernel_size=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels * BottleNeck.expansion,
                    stride=stride,
                    kernel_size=1,
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion),
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


class ResNet(nn.Module):
    def __init__(self, block, num_block, num_classes=100):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        # we use a different inputsize than the original paper
        # so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block

        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)

        return output


def resnet18():
    """return a ResNet 18 object"""
    return ResNet(BasicBlock, [2, 2, 2, 2])


def resnet34():
    """return a ResNet 34 object"""
    return ResNet(BasicBlock, [3, 4, 6, 3])


def resnet50():
    """return a ResNet 50 object"""
    return ResNet(BottleNeck, [3, 4, 6, 3])


def resnet101():
    """return a ResNet 101 object"""
    return ResNet(BottleNeck, [3, 4, 23, 3])


def resnet152():
    """return a ResNet 152 object"""
    return ResNet(BottleNeck, [3, 8, 36, 3])

In [4]:
""" helper function

author baiyu
"""
# https://github.com/weiaicunzai/pytorch-cifar100

def get_network(args):
    """return given network"""

    if args.net == "vgg16":
        from models.vgg import vgg16_bn

        net = vgg16_bn()
    elif args.net == "vgg13":
        from models.vgg import vgg13_bn

        net = vgg13_bn()
    elif args.net == "vgg11":
        from models.vgg import vgg11_bn

        net = vgg11_bn()
    elif args.net == "vgg19":
        from models.vgg import vgg19_bn

        net = vgg19_bn()
    elif args.net == "densenet121":
        from models.densenet import densenet121

        net = densenet121()
    elif args.net == "densenet161":
        from models.densenet import densenet161

        net = densenet161()
    elif args.net == "densenet169":
        from models.densenet import densenet169

        net = densenet169()
    elif args.net == "densenet201":
        from models.densenet import densenet201

        net = densenet201()
    elif args.net == "googlenet":
        from models.googlenet import googlenet

        net = googlenet()
    elif args.net == "inceptionv3":
        from models.inceptionv3 import inceptionv3

        net = inceptionv3()
    elif args.net == "inceptionv4":
        from models.inceptionv4 import inceptionv4

        net = inceptionv4()
    elif args.net == "inceptionresnetv2":
        from models.inceptionv4 import inception_resnet_v2

        net = inception_resnet_v2()
    elif args.net == "xception":
        from models.xception import xception

        net = xception()
    elif args.net == "resnet18":
        from models.resnet import resnet18

        net = resnet18()
    elif args.net == "resnet34":
        from models.resnet import resnet34

        net = resnet34()
    elif args.net == "resnet50":
        from models.resnet import resnet50

        net = resnet50()
    elif args.net == "resnet101":
        from models.resnet import resnet101

        net = resnet101()
    elif args.net == "resnet152":
        from models.resnet import resnet152

        net = resnet152()
    elif args.net == "preactresnet18":
        from models.preactresnet import preactresnet18

        net = preactresnet18()
    elif args.net == "preactresnet34":
        from models.preactresnet import preactresnet34

        net = preactresnet34()
    elif args.net == "preactresnet50":
        from models.preactresnet import preactresnet50

        net = preactresnet50()
    elif args.net == "preactresnet101":
        from models.preactresnet import preactresnet101

        net = preactresnet101()
    elif args.net == "preactresnet152":
        from models.preactresnet import preactresnet152

        net = preactresnet152()
    elif args.net == "resnext50":
        from models.resnext import resnext50

        net = resnext50()
    elif args.net == "resnext101":
        from models.resnext import resnext101

        net = resnext101()
    elif args.net == "resnext152":
        from models.resnext import resnext152

        net = resnext152()
    elif args.net == "shufflenet":
        from models.shufflenet import shufflenet

        net = shufflenet()
    elif args.net == "shufflenetv2":
        from models.shufflenetv2 import shufflenetv2

        net = shufflenetv2()
    elif args.net == "squeezenet":
        from models.squeezenet import squeezenet

        net = squeezenet()
    elif args.net == "mobilenet":
        from models.mobilenet import mobilenet

        net = mobilenet()
    elif args.net == "mobilenetv2":
        from models.mobilenetv2 import mobilenetv2

        net = mobilenetv2()
    elif args.net == "nasnet":
        from models.nasnet import nasnet

        net = nasnet()
    elif args.net == "attention56":
        from models.attention import attention56

        net = attention56()
    elif args.net == "attention92":
        from models.attention import attention92

        net = attention92()
    elif args.net == "seresnet18":
        from models.senet import seresnet18

        net = seresnet18()
    elif args.net == "seresnet34":
        from models.senet import seresnet34

        net = seresnet34()
    elif args.net == "seresnet50":
        from models.senet import seresnet50

        net = seresnet50()
    elif args.net == "seresnet101":
        from models.senet import seresnet101

        net = seresnet101()
    elif args.net == "seresnet152":
        from models.senet import seresnet152

        net = seresnet152()
    elif args.net == "wideresnet":
        from models.wideresidual import wideresnet

        net = wideresnet()
    elif args.net == "stochasticdepth18":
        from models.stochasticdepth import stochastic_depth_resnet18

        net = stochastic_depth_resnet18()
    elif args.net == "stochasticdepth34":
        from models.stochasticdepth import stochastic_depth_resnet34

        net = stochastic_depth_resnet34()
    elif args.net == "stochasticdepth50":
        from models.stochasticdepth import stochastic_depth_resnet50

        net = stochastic_depth_resnet50()
    elif args.net == "stochasticdepth101":
        from models.stochasticdepth import stochastic_depth_resnet101

        net = stochastic_depth_resnet101()

    else:
        print("the network name you have entered is not supported yet")
        sys.exit()

    if args.gpu:  # use_gpu
        net = net.cuda()

    return net


def get_training_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    """return training dataloader
    Args:
        mean: mean of cifar100 training dataset
        std: std of cifar100 training dataset
        path: path to cifar100 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose(
        [
            # transforms.ToPILImage(),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ]
    )
    # cifar100_training = CIFAR100Train(path, transform=transform_train)
    cifar100_training = torchvision.datasets.CIFAR100(
        root="./data", train=True, download=True, transform=transform_train
    )
    cifar100_training_loader = DataLoader(
        cifar100_training,
        shuffle=shuffle,
        num_workers=num_workers,
        batch_size=batch_size,
    )

    return cifar100_training_loader


def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True):
    """return training dataloader
    Args:
        mean: mean of cifar100 test dataset
        std: std of cifar100 test dataset
        path: path to cifar100 test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar100_test_loader:torch dataloader object
    """

    transform_test = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)]
    )
    # cifar100_test = CIFAR100Test(path, transform=transform_test)
    cifar100_test = torchvision.datasets.CIFAR100(
        root="./data", train=False, download=True, transform=transform_test
    )
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size
    )

    return cifar100_test_loader


def compute_mean_std(cifar100_dataset):
    """compute the mean and std of cifar100 dataset
    Args:
        cifar100_training_dataset or cifar100_test_dataset
        witch derived from class torch.utils.data

    Returns:
        a tuple contains mean, std value of entire dataset
    """

    data_r = numpy.dstack(
        [cifar100_dataset[i][1][:, :, 0] for i in range(len(cifar100_dataset))]
    )
    data_g = numpy.dstack(
        [cifar100_dataset[i][1][:, :, 1] for i in range(len(cifar100_dataset))]
    )
    data_b = numpy.dstack(
        [cifar100_dataset[i][1][:, :, 2] for i in range(len(cifar100_dataset))]
    )
    mean = numpy.mean(data_r), numpy.mean(data_g), numpy.mean(data_b)
    std = numpy.std(data_r), numpy.std(data_g), numpy.std(data_b)

    return mean, std


class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """

    def __init__(self, optimizer, total_iters, last_epoch=-1):
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [
            base_lr * self.last_epoch / (self.total_iters + 1e-8)
            for base_lr in self.base_lrs
        ]


def most_recent_folder(net_weights, fmt):
    """
    return most recent created folder under net_weights
    if no none-empty folder were found, return empty folder
    """
    # get subfolders in net_weights
    folders = os.listdir(net_weights)

    # filter out empty folders
    folders = [f for f in folders if len(os.listdir(os.path.join(net_weights, f)))]
    if len(folders) == 0:
        return ""

    # sort folders by folder created time
    folders = sorted(folders, key=lambda f: datetime.datetime.strptime(f, fmt))
    return folders[-1]


def most_recent_weights(weights_folder):
    """
    return most recent created weights file
    if folder is empty return empty string
    """
    weight_files = os.listdir(weights_folder)
    if len(weights_folder) == 0:
        return ""

    regex_str = r"([A-Za-z0-9]+)-([0-9]+)-(regular|best)"

    # sort files by epoch
    weight_files = sorted(
        weight_files, key=lambda w: int(re.search(regex_str, w).groups()[1])
    )

    return weight_files[-1]


def last_epoch(weights_folder):
    weight_file = most_recent_weights(weights_folder)
    if not weight_file:
        raise Exception("no recent weights were found")
    resume_epoch = int(weight_file.split("-")[1])

    return resume_epoch


def best_acc_weights(weights_folder):
    """
    return the best acc .pth file in given folder, if no
    best acc weights file were found, return empty string
    """
    files = os.listdir(weights_folder)
    if len(files) == 0:
        return ""

    regex_str = r"([A-Za-z0-9]+)-([0-9]+)-(regular|best)"
    best_files = [w for w in files if re.search(regex_str, w).groups()[2] == "best"]
    if len(best_files) == 0:
        return ""

    best_files = sorted(
        best_files, key=lambda w: int(re.search(regex_str, w).groups()[1])
    )
    return best_files[-1]

In [5]:
conf = {
    "CHECKPOINT_PATH": "checkpoint",

    # Class correspondence as done in https://github.com/vikram2000b/bad-teaching-unlearning
    "class_dict": {
        "rocket": 69,
        "vehicle2": 19,
        "veg": 4,
        "mushroom": 51,
        "people": 14,
        "baby": 2,
        "electrical_devices": 5,
        "lamp": 40,
        "natural_scenes": 10,
        "sea": 71,
        "42": 42,
        "1": 1,
        "10": 10,
        "20": 20,
        "30": 30,
        "40": 40,
        "lion": 43,
    },

    # Classes from https://github.com/vikram2000b/bad-teaching-unlearning
    "cifar20_classes": {"vehicle2", "veg", "people", "electrical_devices", "natural_scenes"},

    # Classes from https://github.com/vikram2000b/bad-teaching-unlearning
    "cifar100_classes": {"rocket", "mushroom", "baby", "lamp", "sea"},

    # total training epochs

    # Training parameters for the tasks; milestones are when the learning rate gets lowered

    "Cifar100_EPOCHS": 200,
    "Cifar100_MILESTONES": [60, 120, 160],

    "Cifar10_EPOCHS": 20,
    "Cifar10_MILESTONES": [8, 12, 16],

    "Cifar20_EPOCHS": 40,
    "Cifar20_MILESTONES": [15, 30, 35],

    "Cifar100_EPOCHS": 200,
    "Cifar100_MILESTONES":  [60, 120, 160],


    "Cifar10_ViT_EPOCHS": 8,
    "Cifar10_ViT_MILESTONES": [7],

    "Cifar20_ViT_EPOCHS": 9,
    "Cifar20_ViT_MILESTONES": [8],

    "Cifar100_ViT_EPOCHS": 8,
    "Cifar100_ViT_MILESTONES": [7],

    # log dir
    "LOG_DIR": "runs",

    # save weights file per SAVE_EPOCH epoch
    "SAVE_EPOCH": 10
}

DATE_FORMAT = "%A_%d_%B_%Y_%Hh_%Mm_%Ss"

# time of script run
TIME_NOW = datetime.now().strftime(DATE_FORMAT)

In [6]:
"""
Datasets used for the experiments (CIFAR and HARD)
"""

# Improves model performance (https://github.com/weiaicunzai/pytorch-cifar100)
CIFAR_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

# Cropping etc. to improve performance of the model (details see https://github.com/weiaicunzai/pytorch-cifar100)
transform_train_from_scratch = [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

transform_unlearning = [
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

transform_test = [
    # transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
]

active_transform = None


class Cifar100(CIFAR100):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        transform = transforms.Compose(transform)
        global active_transform
        if img_size > 32:
            active_transform = transforms.Compose([transforms.Resize((img_size,img_size)), transforms.ConvertImageDtype(torch.float16)])
        else:
            active_transform = transforms.Resize((img_size,img_size))
        self.imgs = {}
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        if index in self.imgs:
            return self.imgs[index]
        x, y = super().__getitem__(index)
        self.imgs[index] = x, torch.Tensor([]), y
        return self.imgs[index]


class Cifar20(CIFAR100):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        global active_transform
        if img_size > 32:
            active_transform = transforms.Compose([transforms.Resize((img_size,img_size)), transforms.ConvertImageDtype(torch.float16)])
        else:
            active_transform = transforms.Resize((img_size,img_size))
        transform = transforms.Compose(transform)

        super().__init__(root=root, train=train, download=download, transform=transform)

        # This map is for the matching of subclases to the superclasses. E.g., rocket (69) to Vehicle2 (19:)
        # Taken from https://github.com/vikram2000b/bad-teaching-unlearning
        self.coarse_map = {
            0: [4, 30, 55, 72, 95],
            1: [1, 32, 67, 73, 91],
            2: [54, 62, 70, 82, 92],
            3: [9, 10, 16, 28, 61],
            4: [0, 51, 53, 57, 83],
            5: [22, 39, 40, 86, 87],
            6: [5, 20, 25, 84, 94],
            7: [6, 7, 14, 18, 24],
            8: [3, 42, 43, 88, 97],
            9: [12, 17, 37, 68, 76],
            10: [23, 33, 49, 60, 71],
            11: [15, 19, 21, 31, 38],
            12: [34, 63, 64, 66, 75],
            13: [26, 45, 77, 79, 99],
            14: [2, 11, 35, 46, 98],
            15: [27, 29, 44, 78, 93],
            16: [36, 50, 65, 74, 80],
            17: [47, 52, 56, 59, 96],
            18: [8, 13, 48, 58, 90],
            19: [41, 69, 81, 85, 89],
        }

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        coarse_y = None
        for i in range(20):
            for j in self.coarse_map[i]:
                if y == j:
                    coarse_y = i
                    break
            if coarse_y != None:
                break
        if coarse_y == None:
            print(y)
            assert coarse_y != None
        return x, y, coarse_y


class Cifar10(CIFAR10):
    def __init__(self, root, train, unlearning, download, img_size=32):
        if train:
            if unlearning:
                transform = transform_unlearning
            else:
                transform = transform_train_from_scratch
        else:
            transform = transform_test
        global active_transform
        if img_size > 32:
            active_transform = transforms.Compose([transforms.Resize((img_size,img_size)), transforms.ConvertImageDtype(torch.float16)])
        else:
            active_transform = transforms.Resize((img_size,img_size))
        transform = transforms.Compose(transform)
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        x, y = super().__getitem__(index)
        return x, torch.Tensor([]), y

#https://github.com/elnagara/HARD-Arabic-Dataset
class HARD():
    def __init__(self, root, download, train, unlearning, img_size,file_path):
        df = pd.read_csv(file_path)
        if(train):
            df = df.loc[:40000]
        else:
            df = df.loc[40000:]
        
        self.data = df["review"].tolist()
        self.targets = df["rating"].tolist()
        model_dir = "/kaggle/input/marbert-hard-data"
        
        tokenizer = AutoTokenizer.from_pretrained(model_dir)
        
        dataset = SingleSentenceClassificationProcessor(mode='classification')
        dataset.add_examples(texts_or_text_and_labels=self.data,overwrite_examples = True)
        
        tokenizer.max_len = 512
        self.data = dataset.get_features(tokenizer = tokenizer, max_length =512)
        
    def __getitem__(self, index):
        review = self.data[index]
        review_dict = {"input_ids":torch.tensor(review.input_ids), "attention_mask": torch.tensor(review.attention_mask)}
        return (review_dict,self.targets[index])
    
    def __len__(self):
        # Assuming 'data' is a list attribute, return its length
        return len(self.data)


class UnLearningData(Dataset):
    def __init__(self, forget_data, retain_data):
        super().__init__()
        self.forget_data = forget_data
        self.retain_data = retain_data
        self.forget_len = len(forget_data)
        self.retain_len = len(retain_data)

    def __len__(self):
        return self.retain_len + self.forget_len

    def __getitem__(self, index):
        if index < self.forget_len:
            x = self.forget_data[index][0]
            y = 1
            return x, y
        else:
            x = self.retain_data[index - self.forget_len][0]
            y = 0
            return x, y

In [7]:
"""
From https://github.com/vikram2000b/bad-teaching-unlearning
And https://github.com/weiaicunzai/pytorch-cifar100 (better performance) <- Refer to this for comments
"""

def ResNet18(num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class ConvStandard(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=None,
        output_padding=0,
        w_sig=np.sqrt(1.0),
    ):
        super(ConvStandard, self).__init__(in_channels, out_channels, kernel_size)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.w_sig = w_sig
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.normal_(
            self.weight,
            mean=0,
            std=self.w_sig / (self.in_channels * np.prod(self.kernel_size)),
        )
        if self.bias is not None:
            torch.nn.init.normal_(self.bias, mean=0, std=0)

    def forward(self, input):
        return F.conv2d(input, self.weight, self.bias, self.stride, self.padding)


class Conv(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=None,
        output_padding=0,
        activation_fn=nn.ReLU,
        batch_norm=True,
        transpose=False,
    ):
        if padding is None:
            padding = (kernel_size - 1) // 2
        model = []
        if not transpose:
            #             model += [ConvStandard(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding
            #                                 )]
            model += [
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=not batch_norm,
                )
            ]
        else:
            model += [
                nn.ConvTranspose2d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=stride,
                    padding=padding,
                    output_padding=output_padding,
                    bias=not batch_norm,
                )
            ]
        if batch_norm:
            model += [nn.BatchNorm2d(out_channels, affine=True)]
        model += [activation_fn()]
        super(Conv, self).__init__(*model)


class AllCNN(nn.Module):
    def __init__(
        self,
        filters_percentage=1.0,
        n_channels=3,
        num_classes=10,
        dropout=False,
        batch_norm=True,
    ):
        super(AllCNN, self).__init__()
        n_filter1 = int(96 * filters_percentage)
        n_filter2 = int(192 * filters_percentage)
        self.features = nn.Sequential(
            Conv(n_channels, n_filter1, kernel_size=3, batch_norm=batch_norm),
            Conv(n_filter1, n_filter1, kernel_size=3, batch_norm=batch_norm),
            Conv(
                n_filter1,
                n_filter2,
                kernel_size=3,
                stride=2,
                padding=1,
                batch_norm=batch_norm,
            ),
            nn.Dropout(inplace=True) if dropout else Identity(),
            Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm),
            Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm),
            Conv(
                n_filter2,
                n_filter2,
                kernel_size=3,
                stride=2,
                padding=1,
                batch_norm=batch_norm,
            ),  # 14
            nn.Dropout(inplace=True) if dropout else Identity(),
            Conv(n_filter2, n_filter2, kernel_size=3, stride=1, batch_norm=batch_norm),
            Conv(n_filter2, n_filter2, kernel_size=1, stride=1, batch_norm=batch_norm),
            nn.AvgPool2d(8),
            Flatten(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(n_filter2, num_classes),
        )

    def forward(self, x):
        features = self.features(x)
        output = self.classifier(features)
        return output


class ViT(nn.Module):
    def __init__(self, num_classes=20, **kwargs):
        super(ViT, self).__init__()
        self.base = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.final = nn.Linear(self.base.config.hidden_size, num_classes)
        self.num_classes = num_classes
        self.relu = nn.ReLU()

    def forward(self, pixel_values):
        outputs = self.base(pixel_values=pixel_values)
        logits = self.final(outputs.last_hidden_state[:, 0])

        return logits
    
class Marbert(nn.Module):
    def __init__(self, num_classes=5, model_dir="", **kwargs):
        super(Marbert, self).__init__()
       
        config = AutoConfig.from_pretrained(model_dir, num_labels=num_classes)
        
        self.base = AutoModelForSequenceClassification.from_pretrained(model_dir, config = config)
        self.num_classes = num_classes
        self.config = config
    
    def forward(self, input_ids, attention_mask):
        # Forward pass computation
        outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits
    

    
    def __call__(self, reviews):
        return super(Marbert, self).__call__(reviews['input_ids'].to("cuda"), reviews['attention_mask'].to("cuda")).to("cuda")

In [8]:
# Original code from https://github.com/weiaicunzai/pytorch-cifar100 <- refer to this repo for comments


def train(epochs):
    start = time.time()
    net.train()
    for batch_index, (images, _, labels) in enumerate(trainloader):
        if args["gpu"]:
            labels = labels.cuda()
            images = images.cuda()

        optimizer.zero_grad()
        outputs = net(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        # print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format(
        #     loss.item(),
        #     optimizer.param_groups[0]['lr'],
        #     epoch=epoch,
        #     trained_samples=batch_index * args.b + len(images),
        #     total_samples=len(trainloader.dataset)
        # ))

        if epoch <= args["warm"]:
            warmup_scheduler.step()

    finish = time.time()

    print("epoch {} training time consumed: {:.2f}s".format(epoch, finish - start))


@torch.no_grad()
def eval_training(epoch=0, tb=True):
    start = time.time()
    net.eval()

    test_loss = 0.0  # cost function error
    correct = 0.0

    for images, _, labels in testloader:
        if args["gpu"]:
            images = images.cuda()
            labels = labels.cuda()

        outputs = net(images)
        loss = loss_function(outputs, labels)

        test_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum()

    finish = time.time()
    if args["gpu"]:
        print("GPU INFO.....")
        print(torch.cuda.memory_summary(), end="")
    print("Evaluating Network.....")
    print(
        "Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s".format(
            epoch,
            test_loss / len(testloader.dataset),
            correct.float() / len(testloader.dataset),
            finish - start,
        )
    )
    print()

    return correct.float() / len(testloader.dataset)

In [9]:
datasets_dict = {
    "Cifar10": Cifar10,
    "Cifar20": Cifar20,
    "Cifar100": Cifar100,
    "HARD": HARD
}

In [13]:
# From https://github.com/vikram2000b/bad-teaching-unlearning

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds)) * 100


def training_step(model, batch, device, nlp=False):
    if not nlp:
        images, labels, clabels = batch
        new_images = []
        for image in images:
            new_images.append(active_transform(image))
        new_images = torch.stack(new_images)
        images, clabels = new_images.to(device), clabels.to(device)
        out = model(images)  # Generate predictions
        loss = F.cross_entropy(out, clabels)  # Calculate loss
        return loss
    else:
        review, rating = batch
        rating = rating.to("cuda")
        out = model(review)  # Generate predictions
        loss = F.cross_entropy(out, rating)  # Calculate loss
        return loss


def validation_step(model, batch, device, nlp=False):
    if not nlp:
        images, labels, clabels = batch
        new_images = []
        for image in images:
            new_images.append(active_transform(image))
        new_images = torch.stack(new_images)
        images, clabels = new_images.to(device), clabels.to(device)
        out = model(images)  # Generate predictions
        loss = F.cross_entropy(out, clabels)  # Calculate loss
        acc = accuracy(out, clabels)  # Calculate accuracy
        return {"Loss": loss.detach(), "Acc": acc}
    else:
        review, rating = batch
        rating = rating.to("cuda")
        out = model(review)  # Generate predictions
        loss = F.cross_entropy(out, rating)  # Calculate loss
        acc = accuracy(out, rating)  # Calculate accuracy
        return {"Loss": loss.detach(), "Acc": acc}


def validation_epoch_end(model, outputs):
    batch_losses = [x["Loss"] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()  # Combine losses
    batch_accs = [x["Acc"] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()  # Combine accuracies
    return {"Loss": epoch_loss.item(), "Acc": epoch_acc.item()}


def epoch_end(model, epoch, result):
    print(
        "Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch,
            result["lrs"][-1],
            result["train_loss"],
            result["Loss"],
            result["Acc"],
        )
    )


@torch.no_grad()
def evaluate(model, val_loader, device, nlp=False):
    model.eval()
    outputs = [validation_step(model, batch, device, nlp) for batch in val_loader]
    return validation_epoch_end(model, outputs)


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


def fit_one_cycle(
    epochs, model, train_loader, val_loader, device, lr=0.01, milestones=None
):
    torch.cuda.empty_cache()
    history = []

    optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=5e-4)
    if milestones:
        train_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=0.2
        )  # learning rate decay
        warmup_scheduler = WarmUpLR(optimizer, len(train_loader))

    for epoch in range(epochs):
        if epoch > 1 and milestones:
            train_scheduler.step(epoch)

        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch, device)
            train_losses.append(loss)
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))

            if epoch <= 1 and milestones:
                warmup_scheduler.step()

        # Validation phase
        result = evaluate(model, val_loader, device)
        result["train_loss"] = torch.stack(train_losses).mean().item()
        result["lrs"] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
    return history

def fit_one_unlearning_cycle(epochs, model, train_loader, val_loader, lr, device, nlp=False):
    history = []

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch, device, nlp)
            loss.backward()
            train_losses.append(loss.detach().cpu())

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))

        result = evaluate(model, val_loader, device, nlp)
        result["train_loss"] = torch.stack(train_losses).mean()
        result["lrs"] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
    return history

In [14]:
"""
From https://github.com/vikram2000b/bad-teaching-unlearning / https://arxiv.org/abs/2205.08096
"""

def JSDiv(p, q):
    m = (p + q) / 2
    return 0.5 * F.kl_div(torch.log(p), m) + 0.5 * F.kl_div(torch.log(q), m)


# ZRF/UnLearningScore https://arxiv.org/abs/2205.08096
def UnLearningScore(tmodel, gold_model, forget_dl, batch_size, device, nlp=False):
    model_preds = []
    gold_model_preds = []
    with torch.no_grad():
        if not nlp:
            for batch in forget_dl:
                x, y, cy = batch
                new_x = []
                for i in x:
                    new_x.append(active_transform(i))
                new_x = torch.stack(new_x)
                x = new_x.to(device)
                model_output = tmodel(x)
                gold_model_output = gold_model(x)
                model_preds.append(F.softmax(model_output, dim=1).detach().cpu())
                gold_model_preds.append(F.softmax(gold_model_output, dim=1).detach().cpu())
        else:
            for batch in forget_dl:
                x, y = batch
                model_output = tmodel(x)
                gold_model_output = gold_model(x)
                model_preds.append(F.softmax(model_output, dim=1).detach().cpu())
                gold_model_preds.append(F.softmax(gold_model_output, dim=1).detach().cpu())

    model_preds = torch.cat(model_preds, axis=0)
    gold_model_preds = torch.cat(gold_model_preds, axis=0)
    return 1 - JSDiv(model_preds, gold_model_preds)


def entropy(p, dim=-1, keepdim=False):
    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)


def collect_prob(data_loader, model, nlp=False):
    if not nlp:
        data_loader = torch.utils.data.DataLoader(
            data_loader.dataset, batch_size=1, shuffle=False
        )
        prob = []
        with torch.no_grad():
            for batch in data_loader:
                data, _, target = batch
                new_data = []
                for i in data:
                    new_data.append(active_transform(i))
                new_data = torch.stack(new_data)
                data = new_data.to(next(model.parameters()).device)
                output = model(data)
                prob.append(F.softmax(output, dim=-1).data)
        return torch.cat(prob)
    else:
        data_loader = torch.utils.data.DataLoader(
            data_loader.dataset, batch_size=1, shuffle=False
        )
        prob = []
        with torch.no_grad():
            for batch in data_loader:
                data,target = batch
                output = model(data)
                prob.append(F.softmax(output, dim=-1).data)
        return torch.cat(prob)


# https://arxiv.org/abs/2205.08096
def get_membership_attack_data(retain_loader, forget_loader, test_loader, model, nlp=False):
    retain_prob = collect_prob(retain_loader, model, nlp)
    forget_prob = collect_prob(forget_loader, model, nlp)
    test_prob = collect_prob(test_loader, model, nlp)

    X_r = (
        torch.cat([entropy(retain_prob), entropy(test_prob)])
        .cpu()
        .numpy()
        .reshape(-1, 1)
    )
    Y_r = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])

    X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)
    Y_f = np.concatenate([np.ones(len(forget_prob))])
    return X_f, Y_f, X_r, Y_r


# https://arxiv.org/abs/2205.08096
def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model, nlp=False):
    X_f, Y_f, X_r, Y_r = get_membership_attack_data(
        retain_loader, forget_loader, test_loader, model, nlp
    )
    # clf = SVC(C=3,gamma='auto',kernel='rbf')
    clf = LogisticRegression(
        class_weight="balanced", solver="lbfgs", multi_class="multinomial"
    )
    clf.fit(X_r, Y_r)
    results = clf.predict(X_f)
    return results.mean()


@torch.no_grad()
def actv_dist(model1, model2, dataloader, device="cuda"):
    sftmx = nn.Softmax(dim=1)
    distances = []
    for batch in dataloader:
        x, _, _ = batch
        x = x.to(device)
        model1_out = model1(x)
        model2_out = model2(x)
        diff = torch.sqrt(
            torch.sum(
                torch.square(
                    F.softmax(model1_out, dim=1) - F.softmax(model2_out, dim=1)
                ),
                axis=1,
            )
        )
        diff = diff.detach().cpu()
        distances.append(diff)
    distances = torch.cat(distances, axis=0)
    return distances.mean()

In [15]:
"""
This file is used for the Selective Synaptic Dampening method
https://github.com/if-loops/selective-synaptic-dampening
"""

###############################################
# Clean implementation
###############################################


class ParameterPerturber:
    def __init__(
        self,
        model,
        opt,
        device="cuda" if torch.cuda.is_available() else "cpu",
        parameters=None,
    ):
        self.model = model
        self.opt = opt
        self.device = device
        self.alpha = None
        self.xmin = None

        print(parameters)
        self.lower_bound = parameters["lower_bound"]
        self.exponent = parameters["exponent"]
        self.magnitude_diff = parameters["magnitude_diff"]  # unused
        self.min_layer = parameters["min_layer"]
        self.max_layer = parameters["max_layer"]
        self.forget_threshold = parameters["forget_threshold"]
        self.dampening_constant = parameters["dampening_constant"]
        self.selection_weighting = parameters["selection_weighting"]

    def get_layer_num(self, layer_name: str) -> int:
        layer_id = layer_name.split(".")[1]
        if layer_id.isnumeric():
            return int(layer_id)
        else:
            return -1

    def zerolike_params_dict(self, model: torch.nn) -> Dict[str, torch.Tensor]:
        """
        Taken from: Avalanche: an End-to-End Library for Continual Learning - https://github.com/ContinualAI/avalanche
        Returns a dict like named_parameters(), with zeroed-out parameter valuse
        Parameters:
        model (torch.nn): model to get param dict from
        Returns:
        dict(str,torch.Tensor): dict of zero-like params
        """
        return dict(
            [
                (k, torch.zeros_like(p, device=p.device))
                for k, p in model.named_parameters()
            ]
        )

    def fulllike_params_dict(
        self, model: torch.nn, fill_value, as_tensor: bool = False
    ) -> Dict[str, torch.Tensor]:
        """
        Returns a dict like named_parameters(), with parameter values replaced with fill_value

        Parameters:
        model (torch.nn): model to get param dict from
        fill_value: value to fill dict with
        Returns:
        dict(str,torch.Tensor): dict of named_parameters() with filled in values
        """

        def full_like_tensor(fillval, shape: list) -> list:
            """
            recursively builds nd list of shape shape, filled with fillval
            Parameters:
            fillval: value to fill matrix with
            shape: shape of target tensor
            Returns:
            list of shape shape, filled with fillval at each index
            """
            if len(shape) > 1:
                fillval = full_like_tensor(fillval, shape[1:])
            tmp = [fillval for _ in range(shape[0])]
            return tmp

        dictionary = {}

        for n, p in model.named_parameters():
            _p = (
                torch.tensor(full_like_tensor(fill_value, p.shape), device=self.device)
                if as_tensor
                else full_like_tensor(fill_value, p.shape)
            )
            dictionary[n] = _p
        return dictionary

    def subsample_dataset(self, dataset: dataset, sample_perc: float) -> Subset:
        """
        Take a subset of the dataset

        Parameters:
        dataset (dataset): dataset to be subsampled
        sample_perc (float): percentage of dataset to sample. range(0,1)
        Returns:
        Subset (float): requested subset of the dataset
        """
        sample_idxs = np.arange(0, len(dataset), step=int((1 / sample_perc)))
        return Subset(dataset, sample_idxs)

    def split_dataset_by_class(self, dataset: dataset) -> List[Subset]:
        """
        Split dataset into list of subsets
            each idx corresponds to samples from that class

        Parameters:
        dataset (dataset): dataset to be split
        Returns:
        subsets (List[Subset]): list of subsets of the dataset,
            each containing only the samples belonging to that class
        """
        n_classes = len(set([target for _, target in dataset]))
        subset_idxs = [[] for _ in range(n_classes)]
        for idx, (x, y) in enumerate(dataset):
            subset_idxs[y].append(idx)

        return [Subset(dataset, subset_idxs[idx]) for idx in range(n_classes)]

    def calc_importance(self, dataloader: DataLoader, nlp=False) -> Dict[str, torch.Tensor]:
        """
        Adapated from: Avalanche: an End-to-End Library for Continual Learning - https://github.com/ContinualAI/avalanche
        Calculate per-parameter, importance
            returns a dictionary [param_name: list(importance per parameter)]
        Parameters:
        DataLoader (DataLoader): DataLoader to be iterated over
        Returns:
        importances (dict(str, torch.Tensor([]))): named_parameters-like dictionary containing list of importances for each parameter
        """
        criterion = nn.CrossEntropyLoss()
        importances = self.zerolike_params_dict(self.model)
        if not nlp:
            for batch in dataloader:
                x, _, y = batch
                new_x = []
                for image in x:
                    new_x.append(active_transform(image))
                new_x = torch.stack(new_x)
                x, y = new_x.to(self.device), y.to(self.device)
                self.opt.zero_grad()
                out = self.model(x)
                loss = criterion(out, y)
                loss.backward()

                for (k1, p), (k2, imp) in zip(
                    self.model.named_parameters(), importances.items()
                ):
                    if p.grad is not None:
                        imp.data += p.grad.data.clone().pow(2)
        else:
            for batch in dataloader:
                x, y = batch
                y = y.to(self.device)
                self.opt.zero_grad()
                out = self.model(x)
                loss = criterion(out, y)
                loss.backward()

                for (k1, p), (k2, imp) in zip(
                    self.model.named_parameters(), importances.items()
                ):
                    if p.grad is not None:
                        imp.data += p.grad.data.clone().pow(2)

        # average over mini batch length
        for _, imp in importances.items():
            imp.data /= float(len(dataloader))
        return importances

    def modify_weight(
        self,
        original_importance: List[Dict[str, torch.Tensor]],
        forget_importance: List[Dict[str, torch.Tensor]],
    ) -> None:
        """
        Perturb weights based on the SSD equations given in the paper
        Parameters:
        original_importance (List[Dict[str, torch.Tensor]]): list of importances for original dataset
        forget_importance (List[Dict[str, torch.Tensor]]): list of importances for forget sample
        threshold (float): value to multiply original imp by to determine memorization.

        Returns:
        None

        """

        with torch.no_grad():
            for (n, p), (oimp_n, oimp), (fimp_n, fimp) in zip(
                self.model.named_parameters(),
                original_importance.items(),
                forget_importance.items(),
            ):
                # Synapse Selection with parameter alpha
                oimp_norm = oimp.mul(self.selection_weighting)
                locations = torch.where(fimp > oimp_norm)

                # Synapse Dampening with parameter lambda
                weight = ((oimp.mul(self.dampening_constant)).div(fimp)).pow(
                    self.exponent
                )
                update = weight[locations]
                # Bound by 1 to prevent parameter values to increase.
                min_locs = torch.where(update > self.lower_bound)
                update[min_locs] = self.lower_bound
                p[locations] = p[locations].mul(update)


###############################################

In [16]:
def get_metric_scores(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device
):
    loss_acc_dict = evaluate(model, valid_dl, device)
    retain_acc_dict = evaluate(model, retain_valid_dl, device)
    zrf = UnLearningScore(model, unlearning_teacher, forget_valid_dl, 128, device)
    d_f = evaluate(model, forget_valid_dl, device)
    mia = get_membership_attack_prob(retain_train_dl, forget_train_dl, valid_dl, model)

    return (loss_acc_dict["Acc"], retain_acc_dict["Acc"], zrf, mia, d_f["Acc"])


def baseline(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    device,
    **kwargs,
):
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    )


def retrain(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    dataset_name,
    model_name,
    device,
    **kwargs,
):
    for layer in model.children():
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
    if model_name == "ViT":
        epochs = conf[f"{dataset_name}_{model_name}_EPOCHS"]
        milestones = conf[f"{dataset_name}_{model_name}_MILESTONES"]
    else:
        epochs = conf[f"{dataset_name}_EPOCHS"]
        milestones = conf[f"{dataset_name}_MILESTONES"]
    if "epochs" in kwargs:
        epochs = kwargs["epochs"]
    _ = fit_one_cycle(
        epochs,
        model,
        retain_train_dl,
        retain_valid_dl,
        lr=kwargs["lr"],
        milestones=milestones,
        device=device,
    )

    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    )

def amnesiac(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    num_classes,
    device,
    **kwargs,
):
    unlearninglabels = list(range(num_classes))
    unlearning_trainset = []

    for x, _, clabel in forget_train_dl.dataset:
        rnd = random.choice(unlearninglabels)
        while rnd == clabel:
            rnd = random.choice(unlearninglabels)
        unlearning_trainset.append((x, _, rnd))

    for x, _, y in retain_train_dl.dataset:
        unlearning_trainset.append((x, _, y))

    unlearning_train_set_dl = DataLoader(
        unlearning_trainset, kwargs["batch_size"], pin_memory=True, shuffle=True
    )

    _ = fit_one_unlearning_cycle(
        kwargs["epochs"], model, unlearning_train_set_dl, retain_valid_dl, device=device, lr=kwargs["lr"]
    )
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    )


def ssd_tuning(
    model,
    unlearning_teacher,
    retain_train_dl,
    retain_valid_dl,
    forget_train_dl,
    forget_valid_dl,
    valid_dl,
    dampening_constant,
    selection_weighting,
    full_train_dl,
    device,
    **kwargs,
):
    parameters = {
        "lower_bound": 1,
        "exponent": 1,
        "magnitude_diff": None,
        "min_layer": -1,
        "max_layer": -1,
        "forget_threshold": 1,
        "dampening_constant": dampening_constant,
        "selection_weighting": selection_weighting,
    }

    # load the trained model
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    ssd = ParameterPerturber(model, optimizer, device, parameters)
    model = model.eval()

    sample_importances = ssd.calc_importance(forget_train_dl)

    original_importances = ssd.calc_importance(full_train_dl)
    ssd.modify_weight(original_importances, sample_importances)
    return get_metric_scores(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        device,
    )

In [17]:
# This cell initiates the random forgetting process for the provided arguments.
# Setting the arguments to be used below changes the output correspondingly.

args = {
        # The name of the network to be used without quotations.
        # choices = [ResNet18, ViT]
        "net": ViT,
        # The name of the network to be used with quotations.
        # choices = ["ResNet18", "ViT"]
        "model_name": "ViT",
        # Path to the saved model weights file to load model weights from.
        "weight_path": "/kaggle/input/vit-cifar10/ViT-Cifar10-6-best.pth",
        # Dataset name to use for the random forgetting experiment with quotations.
        # choices = ["Cifar10"]
        "dataset": "Cifar10",
        # Number of classes predicted by the base model
        "classes": 10,
        # Percentage of the dataset to attempt to forget.
        # 0.00 <= forget_perc < 1.00 
        "forget_perc": 0.05,
        # Whether to use GPU acceleration.
        "gpu": True,
        # Batch size to use.
        "b": 64,
        # Warming parameter for optimizer scheduling.
        "warm": 1,
        # Learning rate for retrain-based methods.
        "lr": 0.0002,
        # Alpha hyperparameter value to use for Selective Synaptive Dampening.
        # Does not affect other methods.
        "alpha": 25,
        # Gamma hyperparameter value to use for Selective Synaptive Dampening.
        # Does not affect other methods.
        "gamma": 1.0,
        # Name of method to apply for unlearning without quotations.
        # choices=[baseline, retrain, amnesiac, ssd_tuning]
        "method": ssd_tuning,
        # Number of epochs for retrain-based methods.
        "epochs": 1,
        # Random seed number to use.
        "seed": 0
}

# Set seeds
torch.manual_seed(args["seed"])
np.random.seed(args["seed"])
random.seed(args["seed"])


batch_size = args["b"]
forget_perc = args["forget_perc"]

# get network
net = args["net"](num_classes=args["classes"])
net.load_state_dict(torch.load(args["weight_path"]))

unlearning_teacher = args["net"](num_classes=args["classes"])

if args["gpu"]:
    net = net.cuda()
    unlearning_teacher = unlearning_teacher.cuda()


root = "./data"


img_size = 224 if args["net"] == ViT else 32
trainset = datasets_dict[args["dataset"]](
    root=root, download=True, train=True, unlearning=True, img_size=img_size
)
validset = datasets_dict[args["dataset"]](
    root=root, download=True, train=False, unlearning=True, img_size=img_size
)

trainloader = DataLoader(trainset, num_workers=4, batch_size=args["b"], shuffle=True)
validloader = DataLoader(validset, num_workers=4, batch_size=args["b"], shuffle=False)

forget_train, retain_train = torch.utils.data.random_split(
    trainset, [forget_perc, 1 - forget_perc]
)
forget_train_dl = DataLoader(list(forget_train), batch_size=args["b"])
retain_train_dl = DataLoader(list(retain_train), batch_size=args["b"], shuffle=True)
forget_valid_dl = forget_train_dl
retain_valid_dl = validloader

model_size_scaler = args["alpha"]

full_train_dl = DataLoader(
    ConcatDataset((retain_train_dl.dataset, forget_train_dl.dataset)),
    batch_size=batch_size,
)

kwargs = {
    "model": net,
    "unlearning_teacher": unlearning_teacher,
    "retain_train_dl": retain_train_dl,
    "retain_valid_dl": retain_valid_dl,
    "forget_train_dl": forget_train_dl,
    "forget_valid_dl": forget_valid_dl,
    "full_train_dl": full_train_dl,
    "valid_dl": validloader,
    "lr": args["lr"],
    "batch_size": batch_size,
    "dampening_constant": args["gamma"],
    "selection_weighting": model_size_scaler,
    "num_classes": args["classes"],
    "dataset_name": args["dataset"],
    "device": "cuda" if args["gpu"] else "cpu",
    "model_name": args["model_name"],
    "epochs": args["epochs"]
}

start = time.time()

testacc, retainacc, zrf, mia, d_f = args["method"](
    **kwargs
)
end = time.time()
time_elapsed = end - start

res_dict = {
        "TestAcc": testacc,
        "RetainTestAcc": retainacc,
        "ZRF": zrf,
        "MIA": mia,
        "Df": d_f,
        "MethodTime": time_elapsed,  
    }

for k,v in res_dict.items():
    print(k + ": " + str(v))

'dataset = "Cifar10"\nn_classes = 10\nweight_path = "/kaggle/input/vit-cifar10/ViT-Cifar10-6-best.pth"\n\nargs = {\n        "net": ViT,\n        "model_name": "ViT",\n        "weight_path": weight_path,\n        # choices=["Cifar10", "Cifar20", "Cifar100", "PinsFaceRecognition"]\n        "dataset": dataset,\n        "classes": n_classes,\n        "gpu": True,\n        "b": 64,\n        "warm": 1,\n        "lr": 0.0002,\n        # choices=["baseline","retrain","finetune","blindspot","amnesiac","FisherForgetting",\n        #    "ssd_tuning",]\n        "method": retrain,\n        "epochs": 1,\n        "seed": 0\n}\n\n# Set seeds\ntorch.manual_seed(args["seed"])\nnp.random.seed(args["seed"])\nrandom.seed(args["seed"])\n\n\nbatch_size = args["b"]\n\n# get network\nnet = args["net"](num_classes=args["classes"])\nnet.load_state_dict(torch.load(args["weight_path"]))\n\nunlearning_teacher = args["net"](num_classes=args["classes"])\n\nif args["gpu"]:\n    net = net.cuda()\n    unlearning_teacher