In [1]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [2]:
import torch
import numpy as np
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import Dataset
import copy
import random
import time
import pickle
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

In [3]:
to_forget = 81
num_classes = 100
max_count = -1
in_size = 3

torch.manual_seed(42)

<torch._C.Generator at 0x7ae59580f310>

In [4]:
class IndexingDataset(Dataset):
    def __init__(self, internal_dataset):
        self.dataset = internal_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, sample_index):
        r = self.dataset[sample_index]
        if not isinstance(r, tuple):
            r = (r,)
        return *r, sample_index
    
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

# Using MNIST
data = datasets.CIFAR100('/data', download=True, train=True, transform=transform)
traindata = IndexingDataset(data)
testdata = datasets.CIFAR100('/data', download=True, train=False, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:01<00:00, 105732000.59it/s]


Extracting /data/cifar-100-python.tar.gz to /data
Files already downloaded and verified


In [5]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=False)

target_index = []
nontarget_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == to_forget:
        target_index.append(i)
    else:
        nontarget_index.append(i)
target_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))

target_index = []
nontarget_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != to_forget:
        target_index.append(i)
        nontarget_index.append(i)
    if traindata[i][1] == to_forget and (count < max_count or max_count < 1):
        count += 1
        target_index.append(i)
target_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))


unlearningdata = copy.deepcopy(data)
unlearninglabels = list(range(num_classes))
unlearninglabels.remove(to_forget)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == to_forget:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(IndexingDataset(unlearningdata), batch_size=64, shuffle=True)

In [6]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [7]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
P=.1
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [8]:
# Training method
stuff = {}
def train(model, epoch, loader, returnable=False, keep_p=.1):
    model.train()
    rng = np.random.default_rng(42)
    if returnable:
        batches = []
    for batch_idx, (data, target, samples_idx) in enumerate(loader):
        optimizer.zero_grad()
        if to_forget in target:
            before = {}
            for key, param in model.named_parameters():
                before[key] = param.clone()
        data = data.to(device)
        output = model(data)
        loss = criterion(output, target.to(device))
        loss.backward()
        
        optimizer.step()
        
        with torch.no_grad():
            if to_forget in target:
                batches.append(batch_idx)
                step = {}
                for key, param in model.named_parameters():
                    if key not in stuff:
                        stuff[key] = torch.zeros(param.shape, device='cpu')
                    diff = (param - before[key]).cpu().flatten()
                    size = diff.shape.numel()
                    subset = rng.choice(size, int(size * keep_p), replace=False, shuffle=False)
                    stuff[key].view(-1)[subset] += diff[subset]
                #torch.save(step, f"steps/e{epoch}b{batches[-1]:04}.pkl")
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
    if returnable:
        return batches

In [9]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

In [10]:
trainingepochs = 4
forgetfulepochs = 4

In [11]:
# load resnet 18 and change to fit problem dimensionality
model = models.resnet18()
model.bn1 = nn.GroupNorm(1, model.bn1.weight.shape[0])
model.layer1[0].bn1 = nn.GroupNorm(1, model.layer1[0].bn1.weight.shape[0])
model.layer1[0].bn2 = nn.GroupNorm(1, model.layer1[0].bn2.weight.shape[0])
model.layer1[1].bn1 = nn.GroupNorm(1, model.layer1[1].bn1.weight.shape[0])
model.layer1[1].bn2 = nn.GroupNorm(1, model.layer1[1].bn2.weight.shape[0])

model.layer2[0].bn1 = nn.GroupNorm(1, model.layer2[0].bn1.weight.shape[0])
model.layer2[0].bn2 = nn.GroupNorm(1, model.layer2[0].bn2.weight.shape[0])
model.layer2[0].downsample[1] = nn.GroupNorm(1, model.layer2[0].downsample[1].weight.shape[0])
model.layer2[1].bn1 = nn.GroupNorm(1, model.layer2[1].bn1.weight.shape[0])
model.layer2[1].bn2 = nn.GroupNorm(1, model.layer2[1].bn2.weight.shape[0])

model.layer3[0].bn1 = nn.GroupNorm(1, model.layer3[0].bn1.weight.shape[0])
model.layer3[0].bn2 = nn.GroupNorm(1, model.layer3[0].bn2.weight.shape[0])
model.layer3[0].downsample[1] = nn.GroupNorm(1, model.layer3[0].downsample[1].weight.shape[0])
model.layer3[1].bn1 = nn.GroupNorm(1, model.layer3[1].bn1.weight.shape[0])
model.layer3[1].bn2 = nn.GroupNorm(1, model.layer3[1].bn2.weight.shape[0])

model.layer4[0].bn1 = nn.GroupNorm(1, model.layer4[0].bn1.weight.shape[0])
model.layer4[0].bn2 = nn.GroupNorm(1, model.layer4[0].bn2.weight.shape[0])
model.layer4[0].downsample[1] = nn.GroupNorm(1, model.layer4[0].downsample[1].weight.shape[0])
model.layer4[1].bn1 = nn.GroupNorm(1, model.layer4[1].bn1.weight.shape[0])
model.layer4[1].bn2 = nn.GroupNorm(1, model.layer4[1].bn2.weight.shape[0])

model.conv1 = nn.Conv2d(in_size, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
model.fc = nn.Sequential(nn.Linear(512, num_classes))

device = "cuda" if torch.cuda.is_available() else 'cpu'

#model = SimpleModel(in_size, num_classes)

model = model.to(device)

optimizer = optim.Adam(model.parameters())

In [12]:
# Train new model for 5 epochs
epoch_indices = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    batches = train(model, epoch, target_train_loader, returnable=True, keep_p=P)
    print(f"{batches} batches effected")
    epoch_indices.append(batches)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")
    print(f"Time taken: {time.process_time() - starttime}")

Epoch: 1 [ 49152]	Loss: 3.566571[0, 1, 3, 4, 7, 8, 9, 12, 13, 14, 17, 24, 25, 26, 27, 31, 32, 33, 37, 39, 40, 41, 45, 47, 48, 51, 52, 54, 56, 58, 60, 65, 66, 68, 69, 71, 77, 78, 79, 80, 85, 87, 94, 100, 101, 102, 103, 105, 106, 107, 108, 111, 112, 119, 120, 123, 124, 125, 126, 127, 130, 131, 132, 136, 137, 140, 142, 143, 147, 151, 153, 155, 156, 160, 161, 162, 163, 169, 170, 172, 173, 174, 175, 176, 177, 178, 180, 183, 184, 185, 186, 188, 190, 191, 193, 198, 199, 200, 201, 204, 210, 211, 212, 213, 215, 217, 220, 221, 222, 224, 225, 226, 227, 229, 230, 232, 235, 236, 237, 243, 248, 250, 251, 252, 253, 254, 255, 256, 257, 259, 261, 264, 265, 267, 272, 275, 281, 285, 287, 289, 290, 292, 293, 296, 300, 301, 303, 305, 308, 310, 312, 321, 322, 323, 324, 327, 328, 329, 333, 336, 338, 340, 344, 346, 347, 348, 350, 351, 352, 353, 354, 355, 356, 357, 358, 360, 361, 362, 365, 367, 370, 376, 377, 382, 384, 385, 386, 389, 390, 392, 393, 394, 396, 397, 399, 403, 406, 407, 411, 412, 421, 422, 425, 42

In [13]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [14]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [15]:
#for i in range(1, trainingepochs+1):
#    for j in epoch_indices[i-1]:
#        path = f"steps/e{i}b{j:04}.pkl"
#        steps = torch.load(path)
#        print(f"\rLoading steps/e{i}b{j:04}.pkl", end="")
#        const = 1
with torch.no_grad():
    for key, param in model.named_parameters():
        param -= stuff[key].to(device)

In [16]:
test(model, all_data_test_loader, dname="All data")
test(model, target_test_loader, dname="Forget  ")
test(model, nontarget_test_loader, dname="Retain  ")

All data: Mean loss: 0.0447, Accuracy: 3019/10000 (30%)
Forget  : Mean loss: 0.0012, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0437, Accuracy: 3019/9900 (30%)


tensor(0.305, device='cuda:0')

In [17]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [18]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [19]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    _ = train(model, epoch, nontarget_train_loader, returnable=True)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")

Epoch: 5 [ 49152]	Loss: 2.096139All data: Mean loss: 0.0403, Accuracy: 3507/10000 (35%)
Forget  : Mean loss: 0.0013, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0392, Accuracy: 3507/9900 (35%)
Epoch: 6 [ 49152]	Loss: 1.976110All data: Mean loss: 0.0388, Accuracy: 3833/10000 (38%)
Forget  : Mean loss: 0.0016, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0374, Accuracy: 3833/9900 (39%)
Epoch: 7 [ 49152]	Loss: 1.570183All data: Mean loss: 0.0391, Accuracy: 3920/10000 (39%)
Forget  : Mean loss: 0.0017, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0377, Accuracy: 3920/9900 (40%)
Epoch: 8 [ 49152]	Loss: 1.563240All data: Mean loss: 0.0406, Accuracy: 3873/10000 (39%)
Forget  : Mean loss: 0.0019, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0390, Accuracy: 3873/9900 (39%)
