In [2]:
import torch
from model import ResNet18
from utils import *
from torch.utils.data import DataLoader

In [3]:
from torchvision.datasets import CIFAR10
train_ds = CIFAR10(root='.', train=True,download=True, transform=transform_train)
valid_ds = CIFAR10(root='.', train=False,download=True, transform=transform_train)

batch_size = 64
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=32, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size, num_workers=32, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
num_classes = 10
classwise_train = {}
for i in range(num_classes):
    classwise_train[i] = []

for img, label in train_ds:
    classwise_train[label].append((img, label))
    
classwise_test = {}
for i in range(num_classes):
    classwise_test[i] = []

for img, label in valid_ds:
    classwise_test[label].append((img, label))

In [5]:
# train the model
device = 'cuda'
model = ResNet18(num_classes = 10, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, model, train_dl, valid_dl, device = device)
torch.save(model.state_dict(), "ResNET18_CIFAR10_Pretrained_ALL_CLASSES_5_Epochs.pt")



Epoch [0], last_lr: 0.00100, train_loss: 0.8286, val_loss: 1.0791, val_acc: 69.1481
Epoch [1], last_lr: 0.00100, train_loss: 0.5047, val_loss: 0.4167, val_acc: 86.4948
Epoch [2], last_lr: 0.00100, train_loss: 0.4152, val_loss: 0.3620, val_acc: 88.1568
Epoch [3], last_lr: 0.00100, train_loss: 0.3131, val_loss: 0.3767, val_acc: 88.4256
Epoch [4], last_lr: 0.00100, train_loss: 0.3418, val_loss: 0.2974, val_acc: 90.7842


In [6]:
# load the trained model
device = 'cuda'
model = ResNet18(num_classes = 10, pretrained = True).to(device)
model.load_state_dict(torch.load("ResNET18_CIFAR10_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location='cuda'))

<All keys matched successfully>

In [7]:
forget_valid = []
forget_classes = [0]
for cls in range(num_classes):
    if cls in forget_classes:
        for img, clabel in classwise_test[cls]:
            forget_valid.append((img, clabel))

retain_valid = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, clabel in classwise_test[cls]:
            retain_valid.append((img, clabel))
            
forget_train = []
for cls in range(num_classes):
    if cls in forget_classes:
        for img, clabel in classwise_train[cls]:
            forget_train.append((img, clabel))

retain_train = []
for cls in range(num_classes):
    if cls not in forget_classes:
        for img, clabel in classwise_train[cls]:
            retain_train.append((img, clabel))

forget_valid_dl = DataLoader(forget_valid, batch_size, num_workers=32, pin_memory=True)

retain_valid_dl = DataLoader(retain_valid, batch_size, num_workers=32, pin_memory=True)

forget_train_dl = DataLoader(forget_train, batch_size, num_workers=32, pin_memory=True)
retain_train_dl = DataLoader(retain_train, batch_size, num_workers=32, pin_memory=True, shuffle = True)
import random
retain_train_subset = random.sample(retain_train, int(0.3*len(retain_train)))
retain_train_subset_dl = DataLoader(retain_train_subset, batch_size, num_workers=32, pin_memory=True, shuffle = True)

In [8]:
evaluate(model, retain_valid_dl, device)
evaluate(model, forget_valid_dl, device)

{'Loss': 0.3983558118343353, 'Acc': 86.54296875}

In [10]:
device = 'cuda'
retrain_model = ResNet18(num_classes = 10, pretrained = True).to(device)
epochs = 5
history = fit_one_cycle(epochs, retrain_model, retain_train_dl, retain_valid_dl, device = device)
torch.save(retrain_model.state_dict(), "ResNET18_CIFAR10_Pretrained_retrain_Class0_5_Epochs.pt")



Epoch [0], last_lr: 0.00100, train_loss: 0.7912, val_loss: 0.7193, val_acc: 77.5488
Epoch [1], last_lr: 0.00100, train_loss: 0.4643, val_loss: 0.5097, val_acc: 83.8586
Epoch [2], last_lr: 0.00100, train_loss: 0.3529, val_loss: 0.4857, val_acc: 84.0869
Epoch [3], last_lr: 0.00100, train_loss: 0.4092, val_loss: 0.3867, val_acc: 88.4131
Epoch [4], last_lr: 0.00100, train_loss: 0.3746, val_loss: 0.3028, val_acc: 90.7203


In [12]:
device = 'cuda'
retrained_model = ResNet18(num_classes = 10, pretrained = True).to(device)
retrained_model.load_state_dict(torch.load("ResNET18_CIFAR10_Pretrained_retrain_Class0_5_Epochs.pt", map_location=device))

<All keys matched successfully>

In [13]:
evaluate(retrained_model, forget_valid_dl, device)
evaluate(retrained_model, retain_valid_dl, device)

{'Loss': 0.3027856647968292, 'Acc': 90.72029876708984}

In [None]:
device = 'cuda'
unlearning_teacher = ResNet18(num_classes = 10, pretrained = False).to(device).eval()
student_model = ResNet18(num_classes = 10, pretrained = False).to(device)
student_model.load_state_dict(torch.load("ResNET18_CIFAR10_Pretrained_ALL_CLASSES_5_Epochs.pt", map_location = device))
model = model.eval()

KL_temperature = 1

optimizer = torch.optim.Adam(student_model.parameters(), lr = 0.0001)

blindspot_unlearner(model = student_model, unlearning_teacher = unlearning_teacher, full_trained_teacher = model, 
          retain_data = retain_train_subset, forget_data = forget_train, epochs = 1, optimizer = optimizer, lr = 0.0001, 
          batch_size = 64, num_workers = 32, device = device, KL_temperature = KL_temperature)

In [None]:
evaluate(student_model, forget_valid_dl, device)
evaluate(student_model, retain_valid_dl, device)