In [None]:
# Necessary Imports
!git clone https://github.com/shreyasudaya/zero-shot-unlearning

%mv /kaggle/working/zero-shot-unlearning/datasets.py /kaggle/working
%mv /kaggle/working/zero-shot-unlearning/utils.py /kaggle/working
%mv /kaggle/working/zero-shot-unlearning/models.py /kaggle/working
%mv /kaggle/working/zero-shot-unlearning/metrics.py /kaggle/working
%mv /kaggle/working/zero-shot-unlearning/unlearn.py /kaggle/working
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

from datasets import *
from utils import *
from models import AllCNN
from metrics import *
from unlearn import *

torch.manual_seed(100)

In [None]:
train_ds, valid_ds = cifar10()

batch_size = 256
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=16)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=16)

In [None]:
num_classes = 10
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label in train_ds:
    classwise_train[label].append((img, label))

classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label in valid_ds:
    classwise_test[label].append((img, label))

In [None]:
device = 'cuda'

In [None]:
model = AllCNN(n_channels = 3).to(device = device)

## Creating the fully trained model

In [None]:
epochs = 50
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam

In [None]:
%%time
history = fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl,
                             grad_clip=grad_clip,
                             weight_decay=weight_decay,
                             opt_func=opt_func, device = device)
torch.save(model.state_dict(), "AllCNN_MNIST_ALL_CLASSES.pt")

In [None]:
model.load_state_dict(torch.load("AllCNN_MNIST_ALL_CLASSES.pt"))
history = [evaluate(model, valid_dl, device = device)]
history

## Forgetting Class 0 using GKT

In [None]:
# Getting the forget and retain data
forget_valid = []
forget_classes = [1]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, label in classwise_test[cls]:
            forget_valid.append((img, label))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, label in classwise_test[cls]:
            retain_valid.append((img, label))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=3, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=3, pin_memory=True)

In [None]:
n_generator_iter = 1
n_student_iter = 10
n_repeat_batch = n_generator_iter + n_student_iter

In [None]:
model = AllCNN(n_channels = 3).to(device = device)
model.load_state_dict(torch.load("AllCNN_MNIST_ALL_CLASSES.pt"))

student = AllCNN(n_channels = 3).to(device = device)
generator = LearnableLoader(n_repeat_batch=n_repeat_batch, num_channels = 3, device = device).to(device=device)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=0.001)
scheduler_generator = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_generator,
                                                               mode='min', factor=0.5, patience=2, verbose=True)
optimizer_student = torch.optim.Adam(student.parameters(), lr=0.001)
scheduler_student = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_student, \
                                    mode='min', factor=0.5, patience=2, verbose=True)

In [None]:
print("Performance of Fully Trained Model on Forget Class")
history = [evaluate(model, forget_valid_dl, device = device)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))

print("Performance of Fully Trained Model on Retain Class")
history = [evaluate(model, retain_valid_dl, device = device)]
print("Accuracy: {}".format(history[0]["Acc"]*100))
print("Loss: {}".format(history[0]["Loss"]))


history = [evaluate(student, forget_valid_dl, device = device)]
AccForget = history[0]["Acc"]*100
ErrForget = history[0]["Loss"]

history = [evaluate(student, retain_valid_dl, device = device)]
AccRetain = history[0]["Acc"]*100
ErrRetain = history[0]["Loss"]

In [None]:
generator_path = "./ckpts/mnist_allcnn/generator"
student_path = "./ckpts/mnist_allcnn/student"

#os.makedirs(generator_path)
#os.makedirs(student_path)

idx_pseudo = 0
total_n_pseudo_batches = 4000
n_pseudo_batches = 0
running_gen_loss = []
running_stu_loss = []

threshold = 0.01

In [None]:
import warnings
warnings.filterwarnings("ignore")

### Training the unlearned model

In [None]:
KL_temperature = 1
AT_beta = 250

In [None]:
# Ensure the directory exists
os.makedirs(generator_path, exist_ok=True)
# Ensure the directory exists
os.makedirs(student_path, exist_ok=True)
history_forget = [evaluate(student, forget_valid_dl, device = device)]
AccForget = history_forget[0]["Acc"]*100
ErrForget = history_forget[0]["Loss"]

history_retain = [evaluate(student, retain_valid_dl, device = device)]
AccRetain = history_retain[0]["Acc"]*100
ErrRetain = history_retain[0]["Loss"]

df = pd.DataFrame(columns = ["Epochs", "AccForget", "AccRetain", "ErrForget", "ErrRetain", "MeanGeneratorLoss", "MeanStudentLoss"])
new_row = {"Epochs": 0, 
           "AccForget": AccForget, 
           "AccRetain": AccRetain, 
           "ErrForget": ErrForget,
           "ErrRetain": ErrRetain, 
           "MeanGeneratorLoss": None, 
           "MeanStudentLoss": None}

df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)

# saving the generator
torch.save(generator.state_dict(), os.path.join(generator_path, str(0) + ".pt"))

# saving the student
torch.save(student.state_dict(), os.path.join(student_path, str(0) + ".pt"))




while n_pseudo_batches < total_n_pseudo_batches:
    x_pseudo = generator.__next__()
    preds, *_ = model(x_pseudo)
    mask = (torch.softmax(preds.detach(), dim=1)[:, 0] <= threshold)
    x_pseudo = x_pseudo[mask]
    if x_pseudo.size(0) == 0:
        zero_count += 1
        if zero_count > 100:
            print("Generator Stopped Producing datapoints corresponding to retain classes.")
            print("Resetting the generator to previous checkpoint")
            generator.load_state_dict(torch.load(os.path.join(generator_path, str(((n_pseudo_batches//50)-1)*50) + ".pt")))
        continue
    else:
        zero_count = 0

    ## Take n_generator_iter steps on generator
    if idx_pseudo % n_repeat_batch < n_generator_iter:
        student_logits, *student_activations = student(x_pseudo)
        teacher_logits, *teacher_activations = model(x_pseudo)
        generator_total_loss = KT_loss_generator(student_logits, teacher_logits, KL_temperature=KL_temperature)

        optimizer_generator.zero_grad()
        generator_total_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 5)
        optimizer_generator.step()
        running_gen_loss.append(generator_total_loss.cpu().detach())


    elif idx_pseudo % n_repeat_batch < (n_generator_iter + n_student_iter):


        with torch.no_grad():
            teacher_logits, *teacher_activations = model(x_pseudo)

        student_logits, *student_activations = student(x_pseudo)
        student_total_loss = KT_loss_student(student_logits, student_activations,
                                             teacher_logits, teacher_activations,
                                             KL_temperature=KL_temperature, AT_beta = AT_beta)

        optimizer_student.zero_grad()
        student_total_loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), 5)
        optimizer_student.step()
        running_stu_loss.append(student_total_loss.cpu().detach())

    if (idx_pseudo + 1) % n_repeat_batch == 0:
        if((n_pseudo_batches)% 50 == 0):
            MeanGLoss = np.mean(running_gen_loss)
            running_gen_loss = []
            MeanSLoss = np.mean(running_stu_loss)
            running_stu_loss = []

            history_forget = [evaluate(student, forget_valid_dl, device = device)]
            AccForget = history_forget[0]["Acc"]*100
            ErrForget = history_forget[0]["Loss"]

            history_retain = [evaluate(student, retain_valid_dl, device = device)]
            AccRetain = history_retain[0]["Acc"]*100
            ErrRetain = history_retain[0]["Loss"]
            new_row = {
                "Epochs": n_pseudo_batches,
                "AccForget": AccForget,
                "AccRetain": AccRetain,
                "ErrForget": ErrForget,
                "ErrRetain": ErrRetain,
                "MeanGeneratorLoss": MeanGLoss,
                "MeanStudentLoss": MeanSLoss,
            }
            df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
            print(df.iloc[-1:])
            scheduler_student.step(history_retain[0]["Loss"])
            scheduler_generator.step(history[0]["Loss"])

            # saving the generator
            torch.save(generator.state_dict(), os.path.join(generator_path, str(n_pseudo_batches) + ".pt"))

            # saving the student
            torch.save(student.state_dict(), os.path.join(student_path, str(n_pseudo_batches) + ".pt"))


        n_pseudo_batches += 1

    idx_pseudo += 1

In [None]:
df.iloc[10:20]

In [None]:
df.to_csv("cifar_allcnn_class0.csv", index = False)
