In [15]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18
import torch.nn.functional as F
import os
import numpy as np
import requests

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE)

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
RNG = torch.Generator().manual_seed(42)
print(torch.__version__)

Running on device: cuda
2.1.0+cu118


In [16]:
# check if cuda is installed
!nvidia-smi
!nvcc --version

Wed Dec  6 06:29:33 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P0    29W /  70W |   1835MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [17]:
# some of this portion is from (i.e starting point): https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb
# loading the normalization applied during training

normalize = transforms.Compose(
    [
        transforms.ToTensor(), # mean and sd from pytorch
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

# download train set
train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=normalize
)

train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=2)

# download held out data into test set
held_out = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=normalize
)


test_loader = DataLoader(held_out, batch_size=256, shuffle=False, num_workers=2)



Files already downloaded and verified
Files already downloaded and verified


In [18]:
def accuracy(nn, dataLoader):
    nn.eval()
    num_correct = 0
    total = 0
    for inputs, targets in dataLoader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = nn(inputs) # get logits
        _, predicted = outputs.max(1) # select max index
        total += targets.size(0)
        num_correct += predicted.eq(targets).sum().item() # sum correct instances
    return num_correct / total

def accuracy_from_output(logits, truthLabels):
    out = torch.argmax(logits.detach(), dim=1)
    return (truthLabels==out).sum().item()

def evaluate(nn, dataLoader):
    nn.eval()
    total_correct = 0.0
    total_loss = 0.0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataLoader:
              inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
              outputs = nn(inputs) # get logits
              loss = F.cross_entropy(outputs, targets)
              total += targets.size(0)
              total_correct += accuracy_from_output(outputs, targets)
              total_loss += loss.detach()
    return total_correct/total, total_loss/total



In [19]:
# download pre-trained weights, TODO: plug in trained weights from Nur
# using unlearning-challenge weights for now

local_path = "weights_resnet18_cifar10.pth"
if not os.path.exists(local_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/weights_resnet18_cifar10.pth"
    )
    open(local_path, "wb").write(response.content)

pretrained_state_dict = torch.load(local_path, map_location=DEVICE)

# load model with pre-trained weights
model = resnet18(weights=None, num_classes=10)
model.load_state_dict(pretrained_state_dict)
model.to(DEVICE)
model.eval()

print(f"Train accuracy: {100.0 * accuracy(model, train_loader)}%")
print(f"Test accuracy: {100.0 * accuracy(model, test_loader)}%")


Train accuracy: 99.46199999999999%
Test accuracy: 88.64%


In [20]:
# Noise data
class Noise(nn.Module):
    def __init__(self, *dim):
        super().__init__()
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad = True)

    def forward(self):
        return self.noise

In [21]:
def fast_effective_unlearning(net, classesToForget, retainSamples):
    '''
    net: NN.module (i.e the neural network)
    classesToForget: List (i.e the list of classes to unlearn)
    retainSamples: List (i.e the (images, label_idx) sampled from D_retain)
    '''
    # Learn noise
    BATCH_SIZE = 256 # same as paper
    noises = {} # noise dict --> maps foget class to learnt noises (i.e Noise nn.module)
    IMG_SIZE = (3, 32, 32)
    L2_REG = 0.1 # same as in the paper

    print("Phase 1: Learning Noise for forget classes")

    net.eval() # freeze weights when generating noise
    for classF in classesToForget:
        print(f"Learning noise matrices for class = {classF}")
        noises[classF] = Noise(BATCH_SIZE, *IMG_SIZE).to(DEVICE)
        opt = torch.optim.Adam(noises[classF].parameters(), lr = 0.1, weight_decay=L2_REG) # same learning rate in the paper
        noises[classF].train(True)
        numEpochs = 5 # same as paper
        stepPerEpoch = 20 # same as in paper
        for epoch in range(numEpochs):
            total_loss = []
            for batch in range(stepPerEpoch):
                inputs = noises[classF]() # input set as noise matrix
                labels = torch.zeros(BATCH_SIZE).to(DEVICE) + classF # set all labels as class to forget
                labels = labels.long()
                outputs = net(inputs) # get outputs from trained nn
                loss = -F.cross_entropy(outputs, labels)
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss.append(loss.cpu().detach().numpy())
        print("Loss: {}".format(np.mean(total_loss)))

    print("Phase 1 complete.")
    print("Forget Set Performance:", evaluate(net, forgetSetTestLoader))
    print("Retain Set Performance:", evaluate(net, retainSetTestLoader))
    print("+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-")


    print("Phase 2: Impair")
    noiseData = []
    numBatches = 20 # number of times the noisy data is replicated, same as in paper

    for classF in classesToForget:
        for i in range(numBatches):
            batch = noises[classF]().cpu().detach()
            for i in range(batch[0].size(0)): # for each noise matrix in batch
                noiseData.append((batch[i], torch.tensor(classF))) # (noise matrix, class_num)
    # TODO: what if we randomized the labels for classF as well?

    retainSampleImpair = []
    for i in range(len(retainSamples)):
        retainSampleImpair.append((retainSamples[i][0].cpu(), torch.tensor(retainSamples[i][1]))) # data, label

    impairData = []
    impairData.extend(noiseData)
    impairData.extend(retainSampleImpair)
    impairLoader = torch.utils.data.DataLoader(impairData, batch_size=256, shuffle = True)


    optimizer = torch.optim.Adam(net.parameters(), lr = 0.02)

    net.train() # set to training mode
    NUM_EPOCH_IMPAIR = 1 # same as in the paper
    for epoch in range(NUM_EPOCH_IMPAIR):
        totalAcc = 0.0
        totalLoss = 0.0
        for impairData in impairLoader:
            inputs, labels = impairData
            inputs = inputs.to(DEVICE)
            labels = labels.clone().detach().to(DEVICE)

            outputs = net(inputs)
            loss = F.cross_entropy(outputs, labels) # cross entropy loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            totalLoss += loss.item()
            totalAcc += accuracy_from_output(outputs, labels)
        print(f"Epoch {epoch+1} | Train loss: {totalLoss/len(impairLoader.dataset)}, Train Acc:{totalAcc*100/len(impairLoader.dataset)}%")


    print("Phase 2 complete.")
    print("Forget Set Performance:", evaluate(net, forgetSetTestLoader))
    print("Retain Set Performance:", evaluate(net, retainSetTestLoader))
    print("+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-")


    print("Phase 3: Repair")

    repairLoader = torch.utils.data.DataLoader(retainSampleImpair, batch_size=256, shuffle = True)
    repairOtpm = torch.optim.Adam(net.parameters(), lr = 0.01)

    NUM_EPOCH_REPAIR = 20
    net.train()
    for epoch in range(NUM_EPOCH_REPAIR):
        total_loss = 0.0
        total_acc = 0.0
        for data in repairLoader:
            inputs, labels = data
            inputs = inputs.to(DEVICE)
            labels = labels.clone().detach().to(DEVICE)
            outputs = net(inputs)
            loss = F.cross_entropy(outputs, labels)
            repairOtpm.zero_grad()
            loss.backward()
            repairOtpm.step()
            total_loss += loss.item()
            total_acc += accuracy_from_output(outputs, labels)
        print(f"Epoch {epoch+1} | Train loss: {total_loss/len(repairLoader.dataset)}, Training Acc:{total_acc*100/len(repairLoader.dataset)}%")

    print("Phase 3 complete.")
    print("+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-")
    net.eval()
    return net


In [22]:
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
BATCH_SIZE = 256
forget_classes = [0, 6]
num_classes = len(classes)

train_set_classes = {} # define dict: class -> class imgs
for i in classes:
    train_set_classes[i] = []
for img, label in train_set:
    train_set_classes[label].append((img, label))

test_set_classes = {}
for i in classes:
    test_set_classes[i] = []
for img, label in held_out:
    test_set_classes[label].append((img, label))

# number of retain samples from each class, needed for repair and impair step
# subset of D_retain
numRetainSamples = 1000
retainedSamples = []
for i in classes:
    if classes[i] not in forget_classes:
        # get first numRetainSamples from each class not in the forget set
        retainedSamples.extend(train_set_classes[i][:numRetainSamples])

# retain test set
retainTestSet = []
for classR in classes:
    if classR not in forget_classes:
        for img, label in test_set_classes[classR]:
            retainTestSet.append((img, label))

# forget test set
forgetTestSet = []
for classF in classes:
    if classF in forget_classes:
        for img, label in test_set_classes[classF]:
            forgetTestSet.append((img, label))

forgetSetTestLoader = DataLoader(forgetTestSet, BATCH_SIZE, num_workers=2)
retainSetTestLoader = DataLoader(retainTestSet, BATCH_SIZE, num_workers=2)




In [23]:
%time
# load model with pre-trained weights for unlearning
model_for_unlearning = resnet18(weights=None, num_classes=10)
model_for_unlearning.load_state_dict(pretrained_state_dict)
model_for_unlearning.to(DEVICE)

# perform unlearning
unlearned_model = fast_effective_unlearning(model_for_unlearning, forget_classes, retainedSamples)

print("Unlearned model performance metrics on Forget Class:")
acc, loss = evaluate(unlearned_model, forgetSetTestLoader)
print(f"Accuracy: {acc}")
print(f"Loss: {loss}")


print("Unlearned model performance metrics on Retain Class:")
acc2, loss2 = evaluate(unlearned_model, retainSetTestLoader)
print(f"Accuracy: {acc2}")
print(f"Loss: {loss2}")


CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 5.72 µs
Phase 1: Learning Noise for forget classes
Learning noise matrices for class = 0
Loss: -4.171452045440674
Learning noise matrices for class = 6
Loss: -6.774816989898682
Phase 1 complete.
Forget Set Performance: (0.916, tensor(0.0013, device='cuda:0'))
Retain Set Performance: (0.879, tensor(0.0019, device='cuda:0'))
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
Phase 2: Impair
Epoch 1 | Train loss: 0.00753837710824506, Train Acc:29.704433497536947%
Phase 2 complete.
Forget Set Performance: (0.003, tensor(0.0908, device='cuda:0'))
Retain Set Performance: (0.256875, tensor(0.0138, device='cuda:0'))
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-
Phase 3: Repair
Epoch 1 | Train loss: 0.005908283486962319, Training Acc:41.95%
Epoch 2 | Train loss: 0.0051711806654930114, Training Acc:49.375%
Epoch 3 | Train loss: 0.004602989196777344, Training Acc:55.475%
Epoch 4 | Train loss: 0.004178617484867573, Training Acc:60.05%
Epoch 5 | Train loss: 0.003742990635335445

In [24]:
torch.save(model.state_dict(), "weights_resnet18_cifar10_unlearned_sid.pth")