In [1]:
import os
import requests
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model, model_selection

import time
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.models import resnet18
import numpy as np
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# manual random seed is used for dataset partitioning
# to ensure reproducible results across runs
RNG = torch.Generator().manual_seed(42)

Running on device: CUDA


In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

In [3]:
# CIFAR 10 dataset
normalize = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

# Train data
train_set = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=normalize
)
# Train loader
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=2)

# Test data
# we split held out data into test and validation set
test_set = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=normalize
)
test_loader = DataLoader(test_set, batch_size=256, shuffle=True, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13188017.09it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
np.where(np.array(train_set.targets) == 1)[0].shape

(5000,)

In [5]:
# Choose random forget indecies from cars
# Index of class
class_index = 1
class_set = np.where(np.array(train_set.targets) == 1)[0]
# Percantage of whole data ( from class )
amount = 0.1 # 10 %
amount_int = class_set.shape[0] * amount

# Get indeces
forget_idx = np.random.choice(class_set, int(amount_int))

# construct indices of retain from those of the forget set
forget_mask = np.zeros(len(train_set.targets), dtype=bool)
forget_mask[forget_idx] = True
retain_idx = np.arange(forget_mask.size)[~forget_mask]

# split train set into a forget and a retain set
forget_set = torch.utils.data.Subset(train_set, forget_idx)
retain_set = torch.utils.data.Subset(train_set, retain_idx)

forget_loader = torch.utils.data.DataLoader(
    forget_set, batch_size=256, shuffle=True, num_workers=2
)
retain_loader = torch.utils.data.DataLoader(
    retain_set, batch_size=256, shuffle=True, num_workers=2, generator=RNG
)


In [6]:
dataloaders = {
    "train": train_loader,
    "val": test_loader
}

dataset_sizes = {"train": len(train_set), "val": len(test_set)}
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(DEVICE)
                    labels = labels.to(DEVICE)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path))
    return model

In [17]:
model_ft = resnet18(weights=None, num_classes=10)
# num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 10
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
# model_ft.fc = nn.Linear(num_ftrs, 10)

model_ft = model_ft.to(DEVICE)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [18]:

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=15)

Epoch 0/14
----------
train Loss: 1.3797 Acc: 0.5043
val Loss: 1.1638 Acc: 0.5866

Epoch 1/14
----------
train Loss: 0.9742 Acc: 0.6572
val Loss: 0.9782 Acc: 0.6557

Epoch 2/14
----------
train Loss: 0.7955 Acc: 0.7217
val Loss: 0.9502 Acc: 0.6741

Epoch 3/14
----------
train Loss: 0.6601 Acc: 0.7668
val Loss: 0.8870 Acc: 0.6997

Epoch 4/14
----------
train Loss: 0.5535 Acc: 0.8049
val Loss: 0.9101 Acc: 0.7031

Epoch 5/14
----------
train Loss: 0.4575 Acc: 0.8372
val Loss: 0.8126 Acc: 0.7381

Epoch 6/14
----------
train Loss: 0.3671 Acc: 0.8693
val Loss: 0.8807 Acc: 0.7344

Epoch 7/14
----------
train Loss: 0.1551 Acc: 0.9525
val Loss: 0.7870 Acc: 0.7758

Epoch 8/14
----------
train Loss: 0.0804 Acc: 0.9780
val Loss: 0.8682 Acc: 0.7738

Epoch 9/14
----------
train Loss: 0.0464 Acc: 0.9896
val Loss: 0.9669 Acc: 0.7718

Epoch 10/14
----------
train Loss: 0.0252 Acc: 0.9953
val Loss: 1.0843 Acc: 0.7668

Epoch 11/14
----------
train Loss: 0.0133 Acc: 0.9980
val Loss: 1.2156 Acc: 0.7684

Ep

In [24]:
sch = 'linear'
init_rate = 0.3
init_method = 'snip_little_grad'
lr = 0.001
epoch = 5
weight_decay = 5e-4

from torch.optim.lr_scheduler import _LRScheduler

class LinearAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, num_annealing_steps, num_total_steps):
        self.num_annealing_steps = num_annealing_steps
        self.num_total_steps = num_total_steps

        super().__init__(optimizer)

    def get_lr(self):
        if self._step_count <= self.num_annealing_steps:
            return [base_lr * self._step_count / self.num_annealing_steps for base_lr in self.base_lrs]
        else:
            return [base_lr * (self.num_total_steps - self._step_count) / (self.num_total_steps - self.num_annealing_steps) for base_lr in self.base_lrs]

class IncreasingLR(_LRScheduler):
    def __init__(self, optimizer, num_total_steps):
        self.num_total_steps = num_total_steps

        super().__init__(optimizer)

    def get_lr(self):
        return [base_lr * self._step_count / self.num_total_steps for base_lr in self.base_lrs]

def mp_importance_score(model): #weight maginute
    score_dict = {}
    for m in model.modules():
        if isinstance(m, (Conv2d,)):
            score_dict[(m, "weight")] = m.weight.data.abs()
    return score_dict

def mp_prune(model, ratio):
    score_dict = mp_importance_score(model)

    prune.global_unstructured(
        parameters=score_dict.keys(),
        pruning_method=prune.L1Unstructured,
        amount=ratio,
        importance_scores=score_dict,
    )

def remove_prune(model):
    print("Remove hooks for multiplying masks (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.remove(m, "weight")

def pruning_model_random(model, px):
    print("Apply Unstructured Random Pruning Globally (all conv layers)")
    parameters_to_prune = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            parameters_to_prune.append((m, "weight"))
    parameters_to_prune = tuple(parameters_to_prune)
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.RandomUnstructured,
        amount=px,
    )

@torch.no_grad()
def re_init_model_random(model, px):
    print("Apply Unstructured random re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]

def re_init_model_random_diffrent_init(model, px):
    print("Apply Unstructured random re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            temp_weight = torch.empty_like(mask, dtype=torch.float)
            nn.init.kaiming_normal_(temp_weight, mode="fan_out", nonlinearity="relu")
            m.weight.data[mask] = temp_weight[mask]

@torch.no_grad()
def re_init_model_mp(model, px): # re init smallest weight
    print("Apply Unstructured mp re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = -m.weight.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]

@torch.no_grad()
def re_init_model_mp_ver2(model, px): # re init Biggest weight
    print("Apply Unstructured mp re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = m.weight.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]

@torch.no_grad()
def re_init_model_mp(model, px): # re init smallest weight
    print("Apply Unstructured mp re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = -m.weight.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]

@torch.no_grad()
def re_init_model_snip(model, px): # re init Biggest gradients
    print("Apply Unstructured mp re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = m.weight.grad.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]

@torch.no_grad()
def re_init_model_snip_ver2(model, px): # re init smallest gradients
    print("Apply Unstructured snip ve2 re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = -m.weight.grad.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]
def get_grads_for_snip(model, retain_loader, forget_loader):
    indices = torch.randperm(len(retain_loader.dataset), dtype=torch.int32, device='cpu')[:len(forget_loader.dataset)]
    retain_dataset = Subset(retain_loader.dataset, indices)
    retain_loader = DataLoader(retain_dataset, batch_size=64, shuffle=True)

    model.zero_grad()
    for sample in retain_loader:
        inputs = sample[0]
        targets = sample[1]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        outputs = model(inputs)
        loss = F.cross_entropy(outputs, targets)
        loss.backward()

    for sample in forget_loader:
        inputs = sample[0]
        targets = sample[1]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        outputs = model(inputs)
        loss = -F.cross_entropy(outputs, targets)
        loss.backward()

def get_all_grads_for_snip(model, retain_loader, forget_loader):
    retain_weight = len(forget_loader.dataset) / len(retain_loader.dataset)
    model.zero_grad()
    for sample in retain_loader:
        inputs = sample[0]
        targets = sample[1]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        outputs = model(inputs)
        loss = retain_weight * F.cross_entropy(outputs, targets)
        loss.backward()

    for sample in forget_loader:
        inputs = sample[0]
        targets = sample[1]
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        outputs = model(inputs)
        loss = -F.cross_entropy(outputs, targets)
        loss.backward()
from functools import partial
from types import MethodType

def zero_grads(mask, g):
    return g*mask

def zero_grads_ver2(mask, g):
    out = g.clone()
    out[~mask] = 0
    return out

def zero_grads_ver3(self, g):
    out = g.clone()
    out[~self.mask] = 0
    return out

@torch.no_grad()
def re_init_model_random_zero_grad(model, px):
    print('re_init_model_random_zero_grad')
    print("Apply Unstructured random re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]
            m.weight.register_hook(partial(zero_grads, mask))

@torch.no_grad()
def re_init_model_random_zero_grad_ver2(model, px):
    print('re_init_model_random_zero_grad_ver2')
    print("Apply Unstructured random re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]
            m.weight.register_hook(partial(zero_grads_ver2, mask))

@torch.no_grad()
def re_init_model_random_zero_grad_ver3(model, px):
    print('re_init_model_random_zero_grad_ver3')
    print("Apply Unstructured random re init no grads Globally (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True
            m.weight.data[mask] = nn.Conv2d(in_c, out_c, ke, device=DEVICE).weight[mask]

            m.mask = mask
            m.backward_fn = MethodType(zero_grads_ver3, m)

            m.weight.register_hook(m.backward_fn)
class Masker(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, mask):
        ctx.save_for_backward(mask)
        return x

    @staticmethod
    def backward(ctx, grad):
        mask, = ctx.saved_tensors
        return grad * mask, None


class MaskConv2d(nn.Conv2d):
    def __init__(self, mask, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device='cpu'):
        super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                     padding, dilation, groups, bias, padding_mode, device=device)
        self.mask = mask

    def forward(self, input):
        masked_weight = Masker.apply(self.weight, self.mask)
        return super(MaskConv2d, self)._conv_forward(input, masked_weight, self.bias)

@torch.no_grad()
def re_init_model_random_zero_grad_Maskconv(model, px):
    print("Apply Unstructured random re init no grads Globally (all conv layers)")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            new_conv = MaskConv2d(mask, m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            #nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
            new_conv.weight.data[~mask] = m.weight[~mask]
            setattr(model, name, new_conv)

@torch.no_grad()
def re_init_model_random_little_grad_Maskconv_fanout(model, px):
    print("Apply Unstructured re_init_model_random_little_grad_Maskconv_fanout (all conv layers)")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            prob = torch.rand_like(mask.float(), device=DEVICE)
            topk = torch.topk(prob.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True
            grad_mask = mask.clone().float()
            grad_mask[grad_mask==0] += 0.1

            new_conv = MaskConv2d(grad_mask, m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")
            new_conv.weight.data[~mask] = m.weight[~mask]
            setattr(model, name, new_conv)
@torch.no_grad()
def re_init_model_snip_ver2_zero_grad(model, px): # re init smallest gradients
    print("Apply Unstructured re_init_model_snip_ver2_zero_grad Globally (all conv layers)")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = -m.weight.grad.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True

            new_conv = MaskConv2d(mask, m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")

            new_conv.weight.data[~mask] = m.weight[~mask]
            set_layer(model, name, new_conv)

@torch.no_grad()
def re_init_model_snip_ver2_little_grad(model, px): # re init smallest gradients
    print("Apply Unstructured re_init_model_snip_ver2_little_grad Globally (all conv layers)")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.Conv2d):
            mask = torch.zeros_like(m.weight, device=DEVICE).bool()
            nparams_toprune = round(px*mask.nelement())

            out_c, in_c, ke, _ = mask.shape
            value = -m.weight.grad.abs()
            topk = torch.topk(value.view(-1), k=nparams_toprune)
            mask.view(-1)[topk.indices] = True
            grad_mask = mask.clone().float()
            grad_mask[grad_mask==0] += 0.1

            new_conv = MaskConv2d(grad_mask, m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            nn.init.kaiming_normal_(new_conv.weight, mode="fan_out", nonlinearity="relu")

            new_conv.weight.data[~mask] = m.weight[~mask]

            set_layer(model, name, new_conv)

def set_layer(model, layer_name, layer):
    splited = layer_name.split('.')
    if len(splited) == 1:
        setattr(model, splited[0], layer)
    elif len(splited) == 3:
        setattr(getattr(model, splited[0])[int(splited[1])], splited[2], layer)
    elif len(splited) == 4:
        getattr(getattr(model, splited[0])[int(splited[1])], splited[2])[int(splited[3])] = layer

@torch.no_grad()
def replace_maskconv(model):
    print("Remove Maskconv")
    for name, m in list(model.named_modules()):
        if isinstance(m, MaskConv2d):
            conv = nn.Conv2d(m.in_channels, m.out_channels, m.kernel_size, m.stride,
                 m.padding, m.dilation, m.groups, m.bias!=None, m.padding_mode, device=DEVICE)
            conv.weight.data = m.weight
            conv.bias = m.bias
            set_layer(model, name, conv)
@torch.no_grad()
def reinit_topn(model, n):
    print(f"Reinit top {n} conv2d(only 5 layer available)")
    target_layer = ['layer4.0.conv1', 'layer4.0.conv2', 'layer4.0.downsample.0','layer4.1.conv1', 'layer4.1.conv2']
    target_layer = target_layer[-n:]
    for name, m in list(model.named_modules()):
        if name in target_layer:
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
@torch.no_grad()
def reinit_batchnorm(model):
    print("Reinit BatchNorm2d")
    for name, m in list(model.named_modules()):
        if isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

def snip_importance_score(model):
    score_dict = {}
    for m in model.modules():
        if isinstance(m, (nn.Conv2d,)):
            score_dict[(m, "weight")] = -m.weight.grad.abs()
    return score_dict

@torch.no_grad()
def prune_model_snip_ver2(model, px): # re init smallest gradients
    print("Apply Unstructured snip ver2 prune grads Globally (all conv layers)")

    score_dict = snip_importance_score(model)
    prune.global_unstructured(
        parameters=score_dict.keys(),
        pruning_method=prune.L1Unstructured,
        amount=px,
        importance_scores=score_dict,
    )

def remove_prune(model):
    print("Remove hooks for multiplying masks (all conv layers)")
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.remove(m, "weight")


In [27]:
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, Subset
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.models import resnet18
from torch.utils.data import DataLoader, Dataset
def unlearning(
    net,
    retain_loader,
    forget_loader,
    val_loader):
    if init_method=='random':
        re_init_model_random(net, init_rate)
    elif init_method=='random_init2':
        re_init_model_random_diffrent_init(net, init_rate)
    elif init_method=='mp':
        re_init_model_mp(net, init_rate)
    elif init_method=='reversed_mp':
        re_init_model_mp_ver2(net, init_rate)
    elif init_method=='snip':
        get_grads_for_snip(net, retain_loader, forget_loader)
        re_init_model_snip_ver2(net, init_rate)
    elif init_method=='random_zero_grad':
        re_init_model_random_zero_grad(net, init_rate)
    elif init_method=='random_zero_grad_ver2':
        re_init_model_random_zero_grad_ver2(net, init_rate)
    elif init_method=='random_zero_grad_ver3':
        re_init_model_random_zero_grad_ver3(net, init_rate)
    elif init_method=='maskconv':
        re_init_model_random_zero_grad_Maskconv(net, init_rate)
    elif init_method=='little_grad':
        re_init_model_random_little_grad_Maskconv_fanout(net, init_rate)
    elif init_method=='snip_zero_grad':
        replace_maskconv(net)
        get_grads_for_snip(net, retain_loader, forget_loader)
        re_init_model_snip_ver2_zero_grad(net, init_rate)
    elif init_method=='snip_little_grad':
        replace_maskconv(net)
        get_grads_for_snip(net, retain_loader, forget_loader)
        re_init_model_snip_ver2_little_grad(net, init_rate)
    elif init_method=='snip_and_reinit_bn':
        get_grads_for_snip(net, retain_loader, forget_loader)
        re_init_model_snip_ver2(net, init_rate)
        reinit_batchnorm(net)
    elif init_method=='reinit_bn':
        reinit_batchnorm(net)
    elif init_method=='snip_and_topn_reinit':
        get_grads_for_snip(net, retain_loader, forget_loader)
        re_init_model_snip_ver2(net, init_rate)
        reinit_topn(net, 1)
    elif init_method=='prune_snip':
        get_grads_for_snip(net, retain_loader, forget_loader)
        prune_model_snip_ver2(net, init_rate)
    elif init_method=='all_snip_reinit':
        get_all_grads_for_snip(net, retain_loader, forget_loader)
        re_init_model_snip_ver2(net, init_rate)
    else:
        raise "not implemented"

    """Simple unlearning by finetuning."""
    epochs = epoch
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=0.9, weight_decay=weight_decay)
    if sch=='cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    elif sch=='linear':
        scheduler = LinearAnnealingLR(optimizer, num_annealing_steps=(epochs+1)//2, num_total_steps=epochs+1)
    elif sch=='increase':
        scheduler = IncreasingLR(optimizer, num_total_steps=epochs)
    elif sch=='decrease':
        print('decrease')
        scheduler = LinearAnnealingLR(optimizer, num_annealing_steps=1, num_total_steps=epochs+1)

    net.train()

    for ep in range(epochs):
        for sample in retain_loader:
            inputs = sample[0]
            targets = sample[1]
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

        scheduler.step()
    #remove_prune(net)
    net.eval()

In [28]:
%%time
model_ft_forget = unlearning(model_ft, retain_loader, forget_loader, test_loader)

Remove Maskconv
Apply Unstructured re_init_model_snip_ver2_little_grad Globally (all conv layers)
CPU times: user 58.4 s, sys: 1.96 s, total: 1min
Wall time: 1min 27s


In [12]:
type(forget_loader)

torch.utils.data.dataloader.DataLoader