In [21]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18
import torch.nn.functional as F
import os
import numpy as np
import requests

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE)

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
RNG = torch.Generator().manual_seed(42)

Running on device: cuda


In [2]:
# check if cuda is installed
!nvidia-smi
!nvcc --version

Tue Dec  5 21:24:30 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   61C    P8    11W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
# some of this portion is from (i.e starting point): https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb
# loading the normalization applied during training

normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

# download train set
train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=normalize
)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)

# download held out data into test set
held_out = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=normalize
)

test_loader = DataLoader(held_out, batch_size=128, shuffle=False, num_workers=2)


# download the forget and retain index split, will be replaced with the specific class
local_path = "forget_idx.npy"
if not os.path.exists(local_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/" + local_path
    )
    open(local_path, "wb").write(response.content)
forget_idx = np.load(local_path)

# construct indices to retain based on the forget set
forget_mask = np.zeros(len(train_set.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

# split train set into a forget and a retain set
forget_set = torch.utils.data.Subset(train_set, forget_idx)
retain_set = torch.utils.data.Subset(train_set, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=128, shuffle=True, num_workers=2
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=128, shuffle=True, num_workers=2, generator=RNG
)




Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43388491.39it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
# download pre-trained weights, TODO: plug in trained weights from Nur
# using unlearning-challenge weights for now

local_path = "weights_resnet18_cifar10.pth"
if not os.path.exists(local_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/weights_resnet18_cifar10.pth"
    )
    open(local_path, "wb").write(response.content)

pretrained_state_dict = torch.load(local_path, map_location=DEVICE)

# load model with pre-trained weights
model = resnet18(weights=None, num_classes=10)
model.load_state_dict(pretrained_state_dict)
model.to(DEVICE)
model.eval()


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
def accuracy(nn, dataLoader):
    num_correct = 0
    total = 0
    for inputs, targets in dataLoader:
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = nn(inputs) # get logits
        _, predicted = outputs.max(1) # select max index
        total += targets.size(0)
        num_correct += predicted.eq(targets).sum().item() # sum correct instances
    return num_correct / total


print(f"Train accuracy: {100.0 * accuracy(model, train_loader)}%")
print(f"Test accuracy: {100.0 * accuracy(model, test_loader)}%")

Train accuracy: 99.46199999999999%
Test accuracy: 88.34%


In [26]:
# Noise data
class Noise(nn.Module):
    def __init__(self, *dim):
        super().__init__()
        self.noise = torch.nn.Parameter(torch.randn(*dim), requires_grad = True)

    def forward(self):
        return self.noise

In [14]:
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

forget_classes = [1, 3 ,9]

num_classes = len(classes)

train_set_classes = {} # define dict: class -> class data
for i in range(num_classes):
    train_set_classes[i] = []
for img, label in train_set:
    train_set_classes[label].append((img, label))

test_set_classes = {}
for i in range(num_classes):
    test_set_classes[i] = []
for img, label in held_out:
    test_set_classes[label].append((img, label))

# number of retain samples from each class, needed for repair and impair step
numRetainSamples = 1000
retainedSamples = []
for i in range(num_classes):
    if classes[i] not in forget_classes:
        # get first numRetainSamples from each class not in the forget set
        retainedSamples.append(train_set_classes[i][:numRetainSamples])

In [34]:
def fast_effective_unlearning(net, classesToForget, retainSamples):
    '''
    net: NN.module (i.e the neural network)
    classesToForget: List (i.e the list of classes to unlearn)
    retainSamples: List (i.e the images sampled from D_retain)
    '''
    # Learn noise
    BATCH_SIZE = 256 # same as paper
    noises = {} # noise dict --> maps foget class to learnt noises (i.e Noise nn.module)
    IMG_SIZE = (3, 32, 32)
    L2_REG = 0.1 # same as in the paper

    print("Phase 1: Learning Noise for forget classes")

    for classF in classesToForget:
        print(f"Learning noise matrices for class= {classF}")
        noises[classF] = Noise(BATCH_SIZE, *IMG_SIZE).to(DEVICE)
        opt = torch.optim.Adam(noises[classF].parameters(), lr = 0.1, weight_decay=L2_REG) # same learning rate in the paper

        numEpochs = 5
        stepPerEpoch = 5
        for epoch in range(numEpochs):
            total_loss = []
            for batch in range(stepPerEpoch):
                inputs = noises[classF]() # input set as noise matrix
                labels = torch.zeros(BATCH_SIZE).to(DEVICE) + classF # set all labels as class to forget
                labels = labels.long()
                outputs = net(inputs) # get outputs from trained nn
                loss = -F.cross_entropy(outputs, labels) # L2 regularization performed by optimizer opt
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss.append(loss.cpu().detach().numpy())
        print("Loss: {}".format(np.mean(total_loss)))

    print("Phase 1 complete.")

    print("Phase 2: Impair")

    noiseData = []
    numBatches = 20 # number of times the noisy data is replicated, same as in paper
    class_num = 0

    for classF in classesToForget:
        for i in range(numBatches):
            batch = noises[classF]().cpu().detach()
            for i in range(batch[0].size(0)): # for each noise matrix in batch
                noiseData.append((batch[i], torch.tensor(classF))) # (noise matrix, class_num)
    # TODO: what if we randomized the labels for classF as well?

    retainSampleImpair = []
    for i in range(len(retainSamples)):
        retainSampleImpair.append((retainSamples[i][0].cpu(), torch.tensor(retainSamples[i][1]))) # data, label

    impairData = []
    impairData.append(noiseData, retainSamretainSampleImpairpleCpu)
    impairLoader = torch.utils.data.DataLoader(impairData, batch_size=256, shuffle = True)


    optimizer = torch.optim.Adam(model.parameters(), lr = 0.02)

    NUM_EPOCH_IMPAIR = 1 # same as in the paper
    for epoch in range(NUM_EPOCH_IMPAIR):
        net.train(True) # set to training mode
        totalAcc = 0.0
        totalLoss = 0.0
        for impairData in impairLoader:
            inputs, labels = impairData
            inputs = inputs.cuda()
            labels = torch.tensor(labels).cuda()
            outputs = model(inputs)
            optimizer.zero_grad()
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item() * inputs.size(0)
            out = torch.argmax(outputs.detach(),dim=1)
            assert out.shape==labels.shape
            running_acc += (labels==out).sum().item()
        print(f"Train loss {epoch+1}: {running_loss/len(train_ds)},Train Acc:{running_acc*100/len(train_ds)}%")


%time
fast_effective_unlearning(model, forget_classes, retainedSamples)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.68 µs
Phase 1: Learning Noise for forget classes
Learning noise matrices for class= 1
Loss: -9.631937980651855
Learning noise matrices for class= 3
Loss: -4.250012397766113
Learning noise matrices for class= 9
Loss: -8.33691692352295
