In [1]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [2]:
import torch
import numpy as np
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import Dataset
import copy
import random
import time
import pickle
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

In [3]:
to_forget = 81
num_classes = 100
max_count = -1

torch.manual_seed(42)

<torch._C.Generator at 0x788f8d567230>

In [4]:
class IndexingDataset(Dataset):
    def __init__(self, internal_dataset):
        self.dataset = internal_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, sample_index):
        r = self.dataset[sample_index]
        if not isinstance(r, tuple):
            r = (r,)
        return *r, sample_index
    
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

# Using MNIST
data = datasets.CIFAR100('/data', download=True, train=True, transform=transform)
traindata = IndexingDataset(data)
testdata = datasets.CIFAR100('/data', download=True, train=False, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:01<00:00, 94060149.85it/s] 


Extracting /data/cifar-100-python.tar.gz to /data
Files already downloaded and verified


In [5]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=False)

target_index = []
nontarget_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == to_forget:
        target_index.append(i)
    else:
        nontarget_index.append(i)
target_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))

target_index = []
nontarget_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != to_forget:
        target_index.append(i)
        nontarget_index.append(i)
    if traindata[i][1] == to_forget and (count < max_count or max_count < 1):
        count += 1
        target_index.append(i)
target_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))


unlearningdata = copy.deepcopy(data)
unlearninglabels = list(range(num_classes))
unlearninglabels.remove(to_forget)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == to_forget:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(IndexingDataset(unlearningdata), batch_size=64, shuffle=True)

In [6]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [7]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
P=.01
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [8]:
# Training method
def train(model, epoch, loader, returnable=False, keep_p=.1):
    model.train()
    if returnable:
        batches = []
    for batch_idx, (data, target, samples_idx) in enumerate(loader):
        for param in model.parameters():
            param.grad_sample = None
        output = call_for_per_sample_grads(model)(data)
        loss = criterion(output, target)
        loss.backward()
        with torch.no_grad():
            for param in model.parameters():
                param.grad = param.grad_sample.sum(0)
                
        optimizer.step()
        
        with torch.no_grad():
            for i, sample_idx in enumerate(samples_idx[target == to_forget].tolist()):
                rng = np.random.default_rng(sample_idx)
                batches.append(sample_idx)
                step = {}
                if epoch > 1:
                    f = open(f"steps/{sample_idx:04}.pkl", "rb")
                    step = pickle.load(f)
                    f.close()
                for key, param in model.named_parameters():
                    diff = param.grad_sample[i].flatten()
                    size = diff.shape.numel()
                    subset = rng.choice(size, int(size * keep_p), replace=False, shuffle=False)
                    step[key] = step.get(key, 0) + diff[subset]
                f = open(f"steps/{sample_idx:04}.pkl", "wb")
                pickle.dump(step, f)
                f.close()
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
    if returnable:
        return batches

In [9]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

In [10]:
trainingepochs = 4
forgetfulepochs = 4

In [11]:
# load resnet 18 and change to fit problem dimensionality
model = SimpleModel(3, num_classes)
#resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
#resnet.fc = nn.Sequential(nn.Linear(512, num_classes))
optimizer = optim.Adam(model.parameters())

In [12]:
# Train new model for 5 epochs
epoch_indices = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    batches = train(model, epoch, target_train_loader, returnable=True, keep_p=P)
    print(f"{batches} batches effected")
    epoch_indices.append(batches)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")
    print(f"Time taken: {time.process_time() - starttime}")

Epoch: 1 [ 49152]	Loss: 3.275319[20226, 16857, 9956, 48461, 27114, 40790, 6258, 14648, 36, 18265, 43537, 26308, 37252, 13258, 45585, 18330, 45157, 29483, 503, 32625, 39498, 41444, 46697, 40161, 37599, 49479, 37465, 7554, 25500, 20967, 35119, 8699, 42555, 36593, 16778, 15933, 31364, 30059, 38528, 37260, 15624, 6337, 34261, 1241, 2097, 3009, 13060, 23331, 13014, 11546, 6214, 11575, 489, 1641, 23159, 10180, 24444, 25909, 36497, 14377, 32141, 41549, 39219, 11885, 47568, 12480, 12683, 1943, 49309, 14235, 19383, 48199, 38982, 25308, 36296, 6998, 29408, 654, 12539, 33414, 975, 39868, 24756, 7023, 37917, 27139, 29499, 31180, 34026, 27129, 12246, 6172, 31393, 196, 36724, 39398, 15521, 40957, 10222, 37364, 7208, 20470, 17913, 22076, 9789, 18708, 12077, 48259, 21463, 16877, 2132, 29252, 12395, 21151, 7370, 7612, 18735, 31655, 12346, 21829, 12195, 33201, 3871, 21025, 27845, 43371, 49742, 4954, 36433, 31727, 24481, 24669, 22504, 45865, 31466, 38668, 28774, 16318, 19430, 4904, 24965, 13603, 14809, 1

In [13]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [14]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [15]:
for j in epoch_indices[0]:
    path = f"steps/{j:04}.pkl"
    f = open(path, "rb")
    steps = pickle.load(f)
    f.close()
    print(f"\rLoading steps/{j:04}.pkl", end="")
    const = 1
    rng = np.random.default_rng(j)
    with torch.no_grad():
        for key, param in model.named_parameters():
            size = param.shape.numel()
            subset = rng.choice(size, int(size * P), replace=False, shuffle=False)
            param.view(-1)[subset] += const*steps[key]

Loading steps/23507.pkl

In [16]:
test(model, all_data_test_loader, dname="All data")
test(model, target_test_loader, dname="Forget  ")
test(model, nontarget_test_loader, dname="Retain  ")

All data: Mean loss: 0.0436, Accuracy: 3209/10000 (32%)
Forget  : Mean loss: 0.0007, Accuracy: 5/100 (5%)
Retain  : Mean loss: 0.0429, Accuracy: 3204/9900 (32%)


tensor(0.324)

In [17]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [18]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [19]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    _ = train(model, epoch, nontarget_train_loader, returnable=True)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")

Epoch: 5 [ 49152]	Loss: 2.144419All data: Mean loss: 0.0429, Accuracy: 3327/10000 (33%)
Forget  : Mean loss: 0.0015, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0417, Accuracy: 3327/9900 (34%)
Epoch: 6 [ 49152]	Loss: 2.739810All data: Mean loss: 0.0428, Accuracy: 3404/10000 (34%)
Forget  : Mean loss: 0.0017, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0413, Accuracy: 3404/9900 (34%)
Epoch: 7 [ 49152]	Loss: 2.298378All data: Mean loss: 0.0423, Accuracy: 3530/10000 (35%)
Forget  : Mean loss: 0.0019, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0406, Accuracy: 3530/9900 (36%)
Epoch: 8 [ 49152]	Loss: 2.028638All data: Mean loss: 0.0419, Accuracy: 3505/10000 (35%)
Forget  : Mean loss: 0.0019, Accuracy: 0/100 (0%)
Retain  : Mean loss: 0.0403, Accuracy: 3505/9900 (35%)
