Unlearning

In [None]:
#getting the necessary imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse


In [None]:
#basic resnet architecture
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock,[2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.nn.Parameter((torch.randn(1, 3, 32, 32))))
    print(y.size())



import argparse
import sys

# Check if the code is running in a Jupyter Notebook environment
try:
    import ipykernel
    in_notebook = True
except ImportError:
    in_notebook = False

# Only parse arguments if not running in a Jupyter Notebook
if not in_notebook:
    try:
        parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
        parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
        parser.add_argument('--resume', '-r', action='store_true',
                            help='resume from checkpoint')
        args = parser.parse_args()
    except SystemExit:
        # The exception is caught so that the kernel does not exit, allowing you to continue executing cells
        pass

# Rest of your code here


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=1)                              #CHANGE BATCH SIZE,ORIGNAL 128

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=1)                             #change batch size for test dataset,orignal=100

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

#net = lora_model


cuda
==> Preparing data..
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29967729.63it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
#resnet defining
net = ResNet18()
#net = net.to(device)
#if device == 'cuda':
   #net = torch.nn.DataParallel(net)
   #cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init





def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)




try:
    term_size = os.popen('stty size', 'r').read().split()
    if len(term_size) == 2:
        _, term_width = term_size
    else:
        term_width = 80  # Default terminal width
except ValueError:
    term_width = 80  # Default terminal width

print(f"Terminal Width: {term_width}")



import tqdm
from tqdm import tqdm




Terminal Width: 80


In [None]:
#training the resnet
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader) , unit = "batch" , total = len(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                    # % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    print("TRAIN  ACCURACY=", 100.*correct/total )
    print("loss" , loss)

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    TOTAL_BAR_LENGTH = 40

    with torch.no_grad():
         for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader) , unit = "batch" , total = len(trainloader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

           # progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                       #  % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
         print("batch index=",batch_idx,"test_loss/(batch index+1)=",test_loss/(batch_idx+1),"accuracy=", 100.*correct/total,"correct=", correct,"total=", total)

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'optimizer': optimizer.state_dict()             ###################################
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, '/ckpt.pth')
        best_acc = acc
    print("TEST ACCURACY=",acc)

#for epoch in range(0, 51):
#    train(epoch)
#    test(epoch)
#    scheduler.step()
# Set the starting epoch here
start_epoch = 0

# Load checkpoint if available
if os.path.exists('./ckpt.pth'):
    checkpoint = torch.load('./ckpt.pth' , map_location = torch.device('cpu') )
    net.load_state_dict(checkpoint['net'], strict = False)
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
    print("start_epoch",start_epoch)
for epoch in range(start_epoch, 104):  # Set the desired number of epochs
    train(epoch)
    test(epoch)
    scheduler.step()
save_path = '/content/drive/MyDrive/model.pth'

# Save the model's state dictionary to the specified file
torch.save(net.state_dict(), save_path)
#torch.save(net.state_dict(), '/content/drive/MyDrive')


start_epoch 104


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.chdir('/content/drive/MyDrive')

In [None]:
state_dict = torch.load("/content/drive/MyDrive/model.pth")
net.load_state_dict(state_dict , strict=False)

<All keys matched successfully>

In [None]:
import torch.nn as nn


# Iterate through the model's layers and print detailed information about Conv2d layers
for name, layer in net.named_modules():
    if isinstance(layer, nn.Conv2d):
        print(f"Conv2d Layer Name: {name}")
        print(f"Input Channels: {layer.in_channels}")
        print(f"Output Channels: {layer.out_channels}")
        print(f"Kernel Size: {layer.kernel_size}")
        print(f"Stride: {layer.stride}")
        print(f"Padding: {layer.padding}")
        print(f"Dilation: {layer.dilation}")
        print(f"Groups: {layer.groups}")
        print()



In [None]:
target_config = []
for layer in net.modules():
  if isinstance(layer, nn.Conv2d):
   target_config.append(layer.weight.shape)


strustured pruning


In [None]:
import torch.nn.utils.prune as prune

# Define the target Conv2d layer configuration


# Iterate through the model's layers and apply pruning to the similar Conv2d layer
for layer in net.modules():
    if isinstance(layer, nn.Conv2d) and layer.weight.shape in target_config:

        prune.ln_structured(layer, name="weight", amount=0.8, n=float('-inf'), dim=0)


RUN EITHER ONE OF THE ABOVE AND BELOW ONE

l1 pruning


In [None]:
import torch.nn.utils.prune as prune


for name, module in net.named_modules():
   if isinstance(module, nn.Conv2d):   #pruning conv2d layers and final linear layers
       prune.l1_unstructured(module, name="weight", amount=0.3)   #0.3 is a hyperparameter
   if isinstance(module, nn.Linear):   #pruning conv2d layers and final linear layers
        prune.l1_unstructured(module, name="weight", amount=0.4)   #0.3 is a hyperparameter


prune.l1_unstructured(module, name="bias", amount=3)           #Prune 3 smallest entries in bias by L1 norm

Linear(in_features=512, out_features=10, bias=True)

In [None]:
parameter_names = [name for name, _ in net.named_parameters()]
print(parameter_names)

In [None]:
net2 = ResNet18()
model_state_dict = net2.state_dict()
for key in net.state_dict().keys():
    if 'orig' in key:

        raw_key = key.split('_')[0]

        orig_w_key = raw_key + '_orig'
        mask_w_key = raw_key + '_mask'

        # Check if orig and mask keys exist in the checkpoint
        if orig_w_key not in net.state_dict() or mask_w_key not in net.state_dict():
         raise KeyError(f"Missing orig/mask keys for {raw_key}")

                    # Extract original weight (A) and mask (B)
        A = net.state_dict()[orig_w_key]
        B = net.state_dict()[ mask_w_key]

                    # Check if A and B have compatible shapes
        if A.shape != B.shape:
          raise ValueError(f"Shapes of {orig_w_key} and {mask_w_key} do not match")

                  # Perform pointwise multiplication and assign to the original key in the model's state_dict
        model_state_dict[raw_key] = A.mul(B)



    else:
         model_state_dict[key] = net.state_dict()[key]



In [None]:
net2.load_state_dict(model_state_dict,strict = False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['conv1.weight_mask', 'layer1.0.conv1.weight_mask', 'layer1.0.conv2.weight_mask', 'layer1.1.conv1.weight_mask', 'layer1.1.conv2.weight_mask', 'layer2.0.conv1.weight_mask', 'layer2.0.conv2.weight_mask', 'layer2.0.shortcut.0.weight_mask', 'layer2.1.conv1.weight_mask', 'layer2.1.conv2.weight_mask', 'layer3.0.conv1.weight_mask', 'layer3.0.conv2.weight_mask', 'layer3.0.shortcut.0.weight_mask', 'layer3.1.conv1.weight_mask', 'layer3.1.conv2.weight_mask', 'layer4.0.conv1.weight_mask', 'layer4.0.conv2.weight_mask', 'layer4.0.shortcut.0.weight_mask', 'layer4.1.conv1.weight_mask', 'layer4.1.conv2.weight_mask', 'linear.weight_mask', 'linear.bias_mask'])

In [None]:
print(model_state_dict['conv1.weight'])
print(net2.state_dict()['conv1.weight'])

In [None]:
!pip install transformers accelerate evaluate datasets peft -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/7.9 MB[0m [31m3.2 MB/s[0m eta [36m0:00:03[0m[2K     [91m━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/7.9 MB[0m [31m44.1 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m7.9/7.9 MB[0m [31m90.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m66.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━

In [None]:
from peft import LoraConfig, get_peft_model

In [None]:
import re

# Define a regex pattern to match module names containing "conv1" or "conv2"
pattern = re.compile(r'.*(\.(conv1|conv2))(?!.*dropout).*')

# Get all modules in the model that match the pattern
target_modules = [name for name, _ in net.named_modules() if pattern.match(name)]

# Use the target_modules list in your LoraConfig
config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=target_modules,
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["linear","classifier"]
)



In [None]:
target_modules

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for name, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(

          f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )


In [None]:
import peft
peft_model = peft.get_peft_model(net2, config)

In [None]:
#peft_model.load_state_dict(net.state_dict(), strict=False)
# Transfer the pre-trained weights to the PEFT model
# Print the number of trainable parameters in both models
print_trainable_parameters(peft_model)
print_trainable_parameters(net)

trainable params: 555018 || all params: 11728980 || trainable%: 4.73
trainable params: 11173962 || all params: 11173962 || trainable%: 100.00


# **UNLEANRING**

In [None]:
import os
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18

In [None]:
# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
RNG = torch.Generator().manual_seed(42)

In [None]:
# download the forget and retain index split
local_path = "forget_idx.npy"
if not os.path.exists(local_path):
    response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/" + local_path
    )
    open(local_path, "wb").write(response.content)
forget_idx = np.load(local_path)

# construct indices of retain from those of the forget set
forget_mask = np.zeros(len(trainset.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

# split train set into a forget and a retain set
forget_set = torch.utils.data.Subset(trainset, forget_idx)
retain_set = torch.utils.data.Subset(trainset, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=128, shuffle=True, num_workers=2
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=128, shuffle=True, num_workers=2, generator=RNG
)

In [None]:
def accuracy(net, loader):
    """Return accuracy on a dataset given by the data loader."""
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return correct / total

In [None]:
def unlearning(net, retain, forget, validation):
    """Unlearning by fine-tuning.

    Fine-tuning is a very simple algorithm that trains using only
    the retain set.

    Args:
      net : nn.Module.
        pre-trained model to use as base of unlearning.
      retain : torch.utils.data.DataLoader.
        Dataset loader for access to the retain set. This is the subset
        of the training set that we don't want to forget.
      forget : torch.utils.data.DataLoader.
        Dataset loader for access to the forget set. This is the subset
        of the training set that we want to forget. This method doesn't
        make use of the forget set.
      validation : torch.utils.data.DataLoader.
        Dataset loader for access to the validation set. This method doesn't
        make use of the validation set.
    Returns:
      net : updated model
    """
    epochs = 5

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    net.train()

    for epoch in range(epochs):
        for i, (inputs, targets) in enumerate(tqdm(retain, desc=f'Epoch {epoch + 1}/{epochs}')):

            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

    net.eval()
    return net

In [None]:
peft_model.to("cuda:0")
unlearned_prune_lora_model=unlearning(peft_model, retain_loader, forget_loader, testloader)

Epoch 1/5: 100%|██████████| 352/352 [00:52<00:00,  6.69it/s]
Epoch 2/5: 100%|██████████| 352/352 [00:43<00:00,  8.05it/s]
Epoch 3/5: 100%|██████████| 352/352 [00:43<00:00,  8.18it/s]
Epoch 4/5: 100%|██████████| 352/352 [00:44<00:00,  7.93it/s]
Epoch 5/5: 100%|██████████| 352/352 [00:43<00:00,  8.13it/s]


In [None]:
unlearned_prune_lora_model2 = ResNet18()
state_dict_unlearn_prune = torch.load("/content/drive/MyDrive/Unlearning_project/unlearned_pruned_lora_model(30%)")
state_dict_unlearn_prune = {key.replace("module.", ""): value for key, value in state_dict_unlearn_prune.items()}
unlearned_prune_lora_model2.load_state_dict(state_dict_unlearn_prune,strict = False)

In [None]:
import os
file_path = "/content/drive/MyDrive/Unlearning_project/unlearned_pruned_lora_model(30%)"
#torch.save(unlearnedresnet.state_dict(), file_path)
torch.save(unlearned_prune_lora_model.state_dict(), file_path)

In [None]:
print(f"Retain set accuracy: {100.0 * accuracy(unlearned_prune_lora_model, retain_loader):0.1f}%")
print(f"Test set accuracy: {100.0 * accuracy(unlearned_prune_lora_model, testloader):0.1f}%")
print(f"Forget set accuracy: {100.0 * accuracy(unlearned_prune_lora_model, forget_loader):0.1f}%")

Retain set accuracy: 62.4%
Test set accuracy: 63.5%
Forget set accuracy: 61.0%


In [None]:
net.to("cuda")
#print(f"Retain set accuracy: {100.0 * accuracy(net, retain_loader):0.1f}%")
#print(f"Test set accuracy: {100.0 * accuracy(net, testloader):0.1f}%")

# **Evaluation**

In [None]:
def simple_mia(sample_loss, members, n_splits=10, random_state=0):
    """Computes cross-validation score of a membership inference attack.

    Args:
      sample_loss : array_like of shape (n,).
        objective function evaluated on n samples.
      members : array_like of shape (n,),
        whether a sample was used for training.
      n_splits: int
        number of splits to use in the cross-validation.
    Returns:
      scores : array_like of size (n_splits,)
    """

    unique_members = np.unique(members)
    if not np.all(unique_members == np.array([0, 1])):
        raise ValueError("members should only have 0 and 1s")

    attack_model = linear_model.LogisticRegression()
    cv = model_selection.StratifiedShuffleSplit(
        n_splits=n_splits, random_state=random_state
    )
    return model_selection.cross_val_score(
        attack_model, sample_loss, members, cv=cv, scoring="accuracy"
    )

In [None]:
def compute_losses(net, loader):
    """Auxiliary function to compute per-sample losses"""

    criterion = nn.CrossEntropyLoss(reduction="none")
    all_losses = []

    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)

        logits = net(inputs)
        losses = criterion(logits, targets).numpy(force=True)
        for l in losses:
            all_losses.append(l)

    return np.array(all_losses)


train_losses = compute_losses(unlearned_prune_lora_model, trainloader)
test_losses = compute_losses(unlearned_prune_lora_model, testloader)

In [None]:
plt.title("Losses on train and test set (pre-trained model)")
plt.hist(test_losses, density=True, alpha=0.5, bins=50, label="Test set")
plt.hist(train_losses, density=True, alpha=0.5, bins=50, label="Train set")
plt.xlabel("Loss", fontsize=14)
plt.ylabel("Frequency", fontsize=14)
plt.xlim((0, np.max(test_losses)))
plt.yscale("log")
plt.legend(frameon=False, fontsize=14)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.show()

In [None]:
ft_forget_losses = compute_losses(unlearned_prune_lora_model, forget_loader)
ft_test_losses = compute_losses(unlearned_prune_lora_model, testloader)

# make sure we have a balanced dataset for the MIA
#assert len(ft_test_losses) == len(ft_forget_losses)

ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

In [None]:
ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

print(
    f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

ax1.set_title(f"Re-trained model.")
ax1.hist(rt_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")
ax1.hist(rt_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")

ax2.set_title(
    f"Unlearned by fine-tuning our LORA model(r=1) and linear layer as target module"
)
ax2.hist(ft_retain_losses, density=True, alpha=0.5, bins=50, label="Retain set")
ax2.hist(ft_forget_losses, density=True, alpha=0.5, bins=50, label="Forget set")

ax1.set_xlabel("Loss")
ax2.set_xlabel("Loss")
ax1.set_ylabel("Frequency")
ax1.set_yscale("log")
ax2.set_yscale("log")
ax1.set_xscale("log")
ax2.set_xscale("log")
ax1.set_xlim((0, np.max(test_losses)))
ax2.set_xlim((0, np.max(test_losses)))
for ax in (ax1, ax2):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
ax1.legend(frameon=False, fontsize=14)
plt.show()

In [None]:
import evaluate
from peft import LoraConfig, get_peft_model
import transformers
import accelerate
import peft

In [None]:
from torchvision.models import resnet18
import torch.nn.utils.prune as prune

In [None]:
net2 = resnet18(weights=None, num_classes=10)
config = LoraConfig(
        r=1,
        lora_alpha=16,
        target_modules=['conv1','conv2','linear'],
        lora_dropout=0.1,
        bias="none",
        modules_to_save=["classifier"], )
lora_model = get_peft_model(net2, config)

In [None]:
epochs = 5
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(lora_model.parameters(), lr=0.001,
                          momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs)
lora_model.train()
for ep in range(epochs):
        lora_model.train()
        for sample in retain_loader:
            inputs = sample["image"]
            targets = sample["age_group"]
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = lora_model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

lora_model.eval()

In [None]:
!pip install /content/datasets-2.14.6-py3-none-any.whl

Processing /content/datasets-2.14.6-py3-none-any.whl
Collecting dill<0.3.8,>=0.3.0 (from datasets==2.14.6)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets==2.14.6)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets==2.14.6)
  Downloading huggingface_hub-0.19.0-py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dill, multiprocess, huggingface-hub, datasets
Successfully installed datasets-2.14.6 dill-0.3.7 huggingface-hub-0.19.0 multiprocess-0.70.15


In [None]:
!pip install /content/evaluate-0.4.1-py3-none-any.whl

Processing /content/evaluate-0.4.1-py3-none-any.whl
Collecting responses<0.19 (from evaluate==0.4.1)
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: responses, evaluate
Successfully installed evaluate-0.4.1 responses-0.18.0


In [None]:
!pip install /content/peft-0.5.0-py3-none-any.whl

Processing /content/peft-0.5.0-py3-none-any.whl
Collecting transformers (from peft==0.5.0)
  Downloading transformers-4.35.0-py3-none-any.whl (7.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate (from peft==0.5.0)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from peft==0.5.0)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers->peft==0.5.0)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB

In [None]:
import evaluate
from peft import LoraConfig, get_peft_model
import transformers
import accelerate
import peft
import torch.nn.utils.prune as prune
import re

In [None]:
# You can replace the below simple unlearning with your own unlearning function.
def unlearning(net, retain_loader, forget_loader, val_loader):
    """Simple unlearning by finetuning."""
    target_config = []
    for layer in net.modules():
      if isinstance(layer, nn.Conv2d):
        target_config.append(layer.weight.shape)

    for layer in net.modules():
        if isinstance(layer, nn.Conv2d) and layer.weight.shape in target_config:
            prune.ln_structured(layer, name="weight", amount=0.8, n=float('-inf'), dim=0)
    net2 = resnet18(weights=None, num_classes=10)
    model_state_dict = net2.state_dict()

    for key in net.state_dict().keys():
        if 'orig' in key:

            raw_key = key.split('_')[0]

            orig_w_key = raw_key + '_orig'
            mask_w_key = raw_key + '_mask'

            # Check if orig and mask keys exist in the checkpoint
            if orig_w_key not in net.state_dict() or mask_w_key not in net.state_dict():
                raise KeyError(f"Missing orig/mask keys for {raw_key}")

                # Extract original weight (A) and mask (B)
            A = net.state_dict()[orig_w_key]
            B = net.state_dict()[mask_w_key]

            # Check if A and B have compatible shapes
            if A.shape != B.shape:
                raise ValueError(f"Shapes of {orig_w_key} and {mask_w_key} do not match")

                # Perform pointwise multiplication and assign to the original key in the model's state_dict
            model_state_dict[raw_key] = A.mul(B)


        else:
           model_state_dict[key] = net.state_dict()[key]


    net2.load_state_dict(model_state_dict, strict=False)

    #pattern = re.compile(r'.(\.(conv1|conv2))(?!.*dropout).')
    # Get all modules in the model that match the pattern
    #target_all_modules = [name for name, _ in net.named_modules() if pattern.match(name)]
    #print(target_all_modules)
    config = LoraConfig(
        r=1,
        lora_alpha=16,
        target_modules=['conv1','conv2','linear'],
        lora_dropout=0.1,
        bias="none",
        modules_to_save=["classifier"], )
    lora_model = get_peft_model(net2, config)
    epochs = 10
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001,
                          momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs)
    lora_model.train()
    for ep in range(epochs):
        lora_model.train()
        for sample in retain_loader:
            inputs = sample["image"]
            targets = sample["age_group"]
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

    lora_model.eval()

In [None]:
net2 = resnet18(weights=None, num_classes=10)
net2.to(device)
net2=unlearning(net,)

In [None]:
local_path = "retrain_weights_resnet18_cifar10.pth"
response = requests.get(
        "https://storage.googleapis.com/unlearning-challenge/" + local_path
    )
open(local_path, "wb").write(response.content)

weights_pretrained = torch.load(local_path, map_location=device)

# load model with pre-trained weights
rt_model = resnet18(weights=None, num_classes=10)
rt_model.load_state_dict(weights_pretrained)
rt_model.to(device)
rt_model.eval()

In [None]:
os.makedirs('/kaggle/tmp', exist_ok=True)
    #retain_loader, forget_loader, validation_loader = get_dataset(64)

#retain_loader, forget_loader, validation_loader = get_dataset(64)
rt_model = resnet18(weights=None, num_classes=10)
rt_model.to(device)
for i in range(1):
              local_path = "retrain_weights_resnet18_cifar10.pth"
              response = requests.get(
                     "https://storage.googleapis.com/unlearning-challenge/" + local_path)
              open(local_path, "wb").write(response.content)

              weights_pretrained = torch.load(local_path, map_location=device)

              # load model with pre-trained weights
              rt_model = resnet18(weights=None, num_classes=10)
              rt_model.load_state_dict(weights_pretrained)
              unlearning(rt_model,retain_loader, forget_loader, testloader)
              state = rt_model.state_dict()
              torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')
              #net1.load_state_dict(torch.load('/kaggle/input/neurips-2023-machine-unlearning/original_model.pth'))
              #torch.save(state, f'/kaggle/tmp/unlearned_checkpoint_{i}.pth')

            # Ensure that submission.zip will contain exactly 512 checkpoints
            # (if this is not the case, an exception will be thrown).
#unlearned_ckpts = os.listdir('/kaggle/tmp')
#if len(unlearned_ckpts) != 512:
          #raise RuntimeError('Expected exactly 512 checkpoints. The submission will throw an exception otherwise.')

#subprocess.run('zip submission.zip /kaggle/tmp/*.pth', shell=True)
