In [39]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [40]:
import torch
import numpy as np
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import Dataset
import copy
import random
import time
import pickle
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

In [41]:
to_forget = 81
num_classes = 100
max_count = -1
in_size = 3

torch.manual_seed(42)

<torch._C.Generator at 0x79802124b2f0>

In [42]:
class IndexingDataset(Dataset):
    def __init__(self, internal_dataset):
        self.dataset = internal_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, sample_index):
        r = self.dataset[sample_index]
        if not isinstance(r, tuple):
            r = (r,)
        return *r, sample_index
    
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

# Using MNIST
data = datasets.CIFAR100('/data', download=True, train=True, transform=transform)
traindata = IndexingDataset(data)
testdata = datasets.CIFAR100('/data', download=True, train=False, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:10<00:00, 15687434.33it/s]


Extracting /data/cifar-100-python.tar.gz to /data
Files already downloaded and verified


In [43]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=False)

target_index = []
nontarget_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == to_forget:
        target_index.append(i)
    else:
        nontarget_index.append(i)
target_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))

target_index = []
nontarget_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != to_forget:
        target_index.append(i)
        nontarget_index.append(i)
    if traindata[i][1] == to_forget and (count < max_count or max_count < 1):
        count += 1
        target_index.append(i)
target_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))


unlearningdata = copy.deepcopy(data)
unlearninglabels = list(range(num_classes))
unlearninglabels.remove(to_forget)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == to_forget:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(IndexingDataset(unlearningdata), batch_size=64, shuffle=True)

In [44]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [45]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
P=.1
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [46]:
# Training method
def train(model, epoch, loader, returnable=False, keep_p=.1):
    model.train()
    if returnable:
        batches = []
    for batch_idx, (data, target, samples_idx) in enumerate(loader):
        optimizer.zero_grad()
        if to_forget in target:
            before = {}
            for key, param in model.named_parameters():
                before[key] = param.clone()
        data = data.to(device)
        output = model(data)
        loss = criterion(output, target.to(device))
        loss.backward()
        
        optimizer.step()
        
        with torch.no_grad():
            if to_forget in target:
                batches.append(batch_idx)
                step = {}
                for key, param in model.named_parameters():
                    step[key] = (param - before[key]).cpu()
                f = open(f"steps/e{epoch}b{batches[-1]:04}.pkl", "wb")
                pickle.dump(step, f)
                f.close()
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
    if returnable:
        return batches

In [47]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

In [48]:
trainingepochs = 4
forgetfulepochs = 4

In [49]:
# load resnet 18 and change to fit problem dimensionality
#model = models.resnet18()
#model.bn1 = nn.GroupNorm(1, model.bn1.weight.shape[0])
#model.layer1[0].bn1 = nn.GroupNorm(1, model.layer1[0].bn1.weight.shape[0])
#model.layer1[0].bn2 = nn.GroupNorm(1, model.layer1[0].bn2.weight.shape[0])
#model.layer1[1].bn1 = nn.GroupNorm(1, model.layer1[1].bn1.weight.shape[0])
#model.layer1[1].bn2 = nn.GroupNorm(1, model.layer1[1].bn2.weight.shape[0])

#model.layer2[0].bn1 = nn.GroupNorm(1, model.layer2[0].bn1.weight.shape[0])
#model.layer2[0].bn2 = nn.GroupNorm(1, model.layer2[0].bn2.weight.shape[0])
#model.layer2[0].downsample[1] = nn.GroupNorm(1, model.layer2[0].downsample[1].weight.shape[0])
#model.layer2[1].bn1 = nn.GroupNorm(1, model.layer2[1].bn1.weight.shape[0])
#model.layer2[1].bn2 = nn.GroupNorm(1, model.layer2[1].bn2.weight.shape[0])

#model.layer3[0].bn1 = nn.GroupNorm(1, model.layer3[0].bn1.weight.shape[0])
#model.layer3[0].bn2 = nn.GroupNorm(1, model.layer3[0].bn2.weight.shape[0])
#odel.layer3[0].downsample[1] = nn.GroupNorm(1, model.layer3[0].downsample[1].weight.shape[0])
#model.layer3[1].bn1 = nn.GroupNorm(1, model.layer3[1].bn1.weight.shape[0])
#model.layer3[1].bn2 = nn.GroupNorm(1, model.layer3[1].bn2.weight.shape[0])

#model.layer4[0].bn1 = nn.GroupNorm(1, model.layer4[0].bn1.weight.shape[0])
#model.layer4[0].bn2 = nn.GroupNorm(1, model.layer4[0].bn2.weight.shape[0])
#model.layer4[0].downsample[1] = nn.GroupNorm(1, model.layer4[0].downsample[1].weight.shape[0])
#model.layer4[1].bn1 = nn.GroupNorm(1, model.layer4[1].bn1.weight.shape[0])
#model.layer4[1].bn2 = nn.GroupNorm(1, model.layer4[1].bn2.weight.shape[0])

#model.conv1 = nn.Conv2d(in_size, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
#model.fc = nn.Sequential(nn.Linear(512, num_classes))

device = "cuda" if torch.cuda.is_available() else 'cpu'

model = SimpleModel(in_size, num_classes)

model = model.to(device)

optimizer = optim.Adam(model.parameters())

In [50]:
# Train new model for 5 epochs
epoch_indices = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    batches = train(model, epoch, target_train_loader, returnable=True, keep_p=P)
    print(f"{batches} batches effected")
    epoch_indices.append(batches)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")
    print(f"Time taken: {time.process_time() - starttime}")

Epoch: 1 [ 49152]	Loss: 3.184155[1, 2, 4, 8, 9, 10, 17, 19, 20, 22, 23, 25, 26, 27, 28, 33, 34, 36, 39, 45, 47, 51, 53, 54, 56, 57, 60, 62, 64, 65, 68, 70, 74, 76, 77, 84, 85, 86, 88, 89, 90, 91, 92, 93, 98, 100, 101, 102, 106, 110, 111, 116, 118, 119, 120, 121, 124, 125, 128, 129, 132, 134, 137, 138, 140, 141, 144, 148, 153, 154, 159, 160, 162, 165, 166, 169, 170, 172, 173, 177, 180, 182, 183, 184, 186, 188, 189, 192, 193, 196, 198, 200, 201, 203, 204, 205, 206, 210, 211, 213, 215, 217, 218, 219, 220, 221, 222, 224, 228, 232, 239, 242, 251, 252, 254, 256, 258, 259, 261, 267, 270, 274, 275, 276, 277, 281, 282, 284, 286, 287, 289, 290, 291, 293, 294, 295, 297, 299, 303, 304, 306, 307, 308, 310, 312, 314, 315, 316, 317, 318, 319, 322, 324, 325, 332, 333, 335, 336, 337, 338, 340, 342, 347, 349, 350, 352, 354, 357, 358, 360, 362, 363, 364, 366, 367, 370, 371, 372, 373, 374, 376, 377, 379, 380, 381, 382, 384, 385, 386, 390, 391, 394, 395, 398, 399, 400, 401, 404, 407, 410, 411, 412, 415, 41

In [51]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [52]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [53]:
for i in range(1, trainingepochs+1):
    for j in range(1600):
        path = f"steps/e{i}b{j:04}.pkl"
        try:
            f = open(path, "rb")
            steps = pickle.load(f)
            f.close()
            print(f"\rLoading steps/e{i}b{j:04}.pkl", end="")
            const = 1
            with torch.no_grad():
                for key, param in model.named_parameters():
                    param -= const*steps[key].to(device)
        except:
#             print(f"\r{i},{j}", end="")
            pass
print(steps.keys())

Loading steps/e4b0842.pkldict_keys(['layers.0.weight', 'layers.0.bias', 'layers.2.weight', 'layers.2.bias', 'layers.6.weight', 'layers.6.bias'])


In [54]:
test(model, all_data_test_loader, dname="All data")
test(model, target_test_loader, dname="Forget  ")
test(model, nontarget_test_loader, dname="Retain  ")

All data: Mean loss: 0.0819, Accuracy: 506/10000 (5%)
Forget  : Mean loss: 0.0055, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0773, Accuracy: 506/9900 (5%)


tensor(0.051, device='cuda:0')

In [55]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [56]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [57]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    _ = train(model, epoch, nontarget_train_loader, returnable=True)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")

Epoch: 5 [ 49152]	Loss: 2.372996All data: Mean loss: 0.0481, Accuracy: 3129/10000 (31%)
Forget  : Mean loss: 0.0056, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0435, Accuracy: 3129/9900 (32%)
Epoch: 6 [ 49152]	Loss: 2.968864All data: Mean loss: 0.0467, Accuracy: 3300/10000 (33%)
Forget  : Mean loss: 0.0058, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0421, Accuracy: 3300/9900 (33%)
Epoch: 7 [ 49152]	Loss: 2.371677All data: Mean loss: 0.0457, Accuracy: 3424/10000 (34%)
Forget  : Mean loss: 0.0057, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0411, Accuracy: 3424/9900 (35%)
Epoch: 8 [ 49152]	Loss: 2.189278All data: Mean loss: 0.0449, Accuracy: 3532/10000 (35%)
Forget  : Mean loss: 0.0055, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0404, Accuracy: 3532/9900 (36%)
