# Imports and Setup

In [1]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [2]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset
from scipy import ndimage
import copy
import random
import time
import pickle
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

# Data Entry and Processing

In [3]:
class IndexingDataset(Dataset):
    def __init__(self, internal_dataset):
        self.dataset = internal_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, sample_index):
        r = self.dataset[sample_index]
        if not isinstance(r, tuple):
            r = (r,)
        return *r, sample_index

In [4]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

In [5]:
# Using MNIST
MNIST_data = datasets.MNIST('/data', download=True, train=True, transform=transform)
traindata = IndexingDataset(MNIST_data)
testdata = datasets.MNIST('/data', download=True, train=False, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15697649.47it/s]


Extracting /data/MNIST/raw/train-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 438219.47it/s]


Extracting /data/MNIST/raw/train-labels-idx1-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4300823.17it/s]


Extracting /data/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2896097.41it/s]

Extracting /data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw






In [6]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=False)

In [7]:
# Test dataloader with 3's only
threes_index = []
nonthrees_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == 3:
        threes_index.append(i)
    else:
        nonthrees_index.append(i)
three_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(threes_index))
nonthree_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nonthrees_index))

In [8]:
# Train dataloaders with limited 3s
nonthrees_index = []
threes_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != 3:
        nonthrees_index.append(i)
        threes_index.append(i)
    if traindata[i][1] == 3 and count < 100:
        count += 1
        threes_index.append(i)
nonthree_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nonthrees_index))
three_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(threes_index))

In [9]:
# Unlearning dataset with all "3" labels randomly assigned
unlearningdata = copy.deepcopy(MNIST_data)
unlearninglabels = list(range(10))
unlearninglabels.remove(3)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == 3:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(IndexingDataset(unlearningdata), batch_size=64, shuffle=True)

# Model

In [10]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.ReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.ReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [11]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
num_classes = 10
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [12]:
# Training method
def train(model, epoch, loader, returnable=False, keep_p=.1, seed=42):
    model.train()
    if returnable:
        thracc = []
        nacc = []
        batches = []
    for batch_idx, (data, target, samples_idx) in enumerate(loader):
        for param in model.parameters():
            param.grad_sample = None
        steps = []
        output = call_for_per_sample_grads(model)(data)
        loss = criterion(output, target)
        loss.backward()
        with torch.no_grad():
            for param in model.parameters():
                param.grad = param.grad_sample.sum(0)
                
        optimizer.step()
        
        with torch.no_grad():
            for i, sample_idx in enumerate(samples_idx[target == 3].tolist()):
                rng = np.random.default_rng(sample_idx)
                batches.append(sample_idx)
                step = {}
                if epoch > 1:
                    f = open(f"steps/{sample_idx:04}.pkl", "rb")
                    step = pickle.load(f)
                    f.close()
                for key, param in model.named_parameters():
                    diff = param.grad_sample[i].flatten()
                    size = diff.shape.numel()
                    subset = rng.choice(size, int(size * keep_p), replace=False, shuffle=False)
                    step[key] = step.get(key, 0) + diff[subset]
                f = open(f"steps/{sample_idx:04}.pkl", "wb")
                pickle.dump(step, f)
                f.close()
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
        if returnable and batch_idx % 10 == 0:
            thracc.append(test(model, three_test_loader, dname="Threes only", printable=False))
            if batch_idx % 10 == 0:
                nacc.append(test(model, nonthree_test_loader, dname="nonthree only", printable=False))
            model.train()
    if returnable:
        return thracc, nacc, batches, steps

In [13]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

# Original Training

In [14]:
trainingepochs = 4
forgetfulepochs = 4
naive_accuracy_three = []
naive_accuracy_nonthree = []

In [15]:
# load resnet 18 and change to fit problem dimensionality
resnet = SimpleModel(1, num_classes)
#resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
#resnet.fc = nn.Sequential(nn.Linear(512, num_classes))
optimizer = optim.Adam(resnet.parameters())

In [16]:
# Train new model for 5 epochs
steps = []
epoch_indices = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    thracc, nacc, three_batches, three_steps = train(resnet, epoch, three_train_loader, returnable=True)
    naive_accuracy_three += thracc
    naive_accuracy_nonthree += nacc
    steps = steps + three_steps
    print(f"{three_batches} batches effected")
    epoch_indices.append(three_batches)
    test(resnet, all_data_test_loader, dname="All data")
    test(resnet, three_test_loader, dname="Threes  ")
    test(resnet, nonthree_test_loader, dname="Nonthree")
    print(f"Time taken: {time.process_time() - starttime}")
    path = F"selective_trained_e{epoch}.pt"
    torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)
    path = F"selective_trained_accuracy_three_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_three:
            f.write(f"{data},")
    path = F"selective_trained_accuracy_nonthree_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_nonthree:
            f.write(f"{data},")

Epoch: 1 [ 53248]	Loss: 0.105125[581, 675, 179, 629, 645, 10, 459, 1077, 546, 279, 731, 574, 181, 149, 752, 857, 255, 659, 878, 875, 479, 130, 561, 111, 953, 235, 327, 298, 811, 975, 983, 486, 549, 392, 1068, 998, 215, 207, 242, 425, 540, 107, 867, 281, 715, 500, 433, 861, 490, 198, 1097, 495, 509, 136, 909, 890, 50, 330, 30, 874, 613, 49, 86, 250, 356, 808, 840, 789, 254, 760, 203, 1035, 7, 12, 44, 992, 670, 643, 895, 695, 361, 1055, 966, 291, 228, 74, 321, 452, 341, 843, 157, 557, 1021, 98, 767, 856, 135, 1007, 405, 27] batches effected
All data: Mean loss: 0.0041, Accuracy: 9143/10000 (91%)
Threes  : Mean loss: 0.0028, Accuracy: 432/1010 (43%)
Nonthree: Mean loss: 0.0013, Accuracy: 8711/8990 (97%)
Time taken: 620.074776798
Epoch: 2 [ 53248]	Loss: 0.068070[861, 321, 760, 136, 30, 645, 789, 74, 574, 840, 242, 49, 675, 1021, 405, 953, 546, 843, 27, 878, 808, 557, 250, 425, 975, 811, 874, 643, 433, 486, 149, 181, 715, 495, 207, 356, 992, 752, 291, 1097, 361, 1055, 228, 111, 998, 890, 13

In [17]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [18]:
path = F"selective_trained_accuracy_three.txt"
with open(path, 'w') as f:
    for data in naive_accuracy_three:
        f.write(f"{data},")

In [19]:
path = F"selective_trained_accuracy_nonthree.txt"
with open(path, 'w') as f:
    for data in naive_accuracy_nonthree:
        f.write(f"{data},")

In [67]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [68]:
for j in epoch_indices[0]:
    path = f"steps/{j:04}.pkl"
    f = open(path, "rb")
    steps = pickle.load(f)
    f.close()
    print(f"\rLoading steps/{j:04}.pkl", end="")
    const = 1
    rng = np.random.default_rng(j)
    with torch.no_grad():
        for key, param in resnet.named_parameters():
            size = param.shape.numel()
            subset = rng.choice(size, int(size * .1), replace=False, shuffle=False)
            param.view(-1)[subset] += const*steps[key]

Loading steps/0027.pkl

In [69]:
test(resnet, all_data_test_loader, dname="All data")
test(resnet, three_test_loader, dname="Threes  ")
test(resnet, nonthree_test_loader, dname="Nonthree")

All data: Mean loss: 0.0157, Accuracy: 8122/10000 (81%)
Threes  : Mean loss: 0.0108, Accuracy: 37/1010 (4%)
Nonthree: Mean loss: 0.0050, Accuracy: 8085/8990 (90%)


tensor(0.899)

In [23]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [24]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [25]:
selective_post_accuracy_three = []
selective_post_accuracy_nonthree =[]

In [26]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    thracc, nacc, _, _ = train(resnet, epoch, nonthree_train_loader, returnable=True)
    selective_post_accuracy_three += thracc
    selective_post_accuracy_nonthree += nacc
    test(resnet, all_data_test_loader, dname="All data")
    test(resnet, three_test_loader, dname="Threes  ")
    test(resnet, nonthree_test_loader, dname="Nonthree")
    path = F"selective-post-epoch-{epoch}.pt"
    torch.save({ 
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)
    path = F"selective_post_accuracy_three_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_three:
            f.write(f"{data},")
    path = F"selective_post_accuracy_nonthree_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_nonthree:
            f.write(f"{data},")

Epoch: 5 [  1024]	Loss: 0.004545

KeyboardInterrupt: 

In [None]:
path = F"selective_post_accuracy_three.txt"
with open(path, 'w') as f:
    for data in selective_post_accuracy_three:
        f.write(f"{data},")

In [None]:
path = F"selective_post_accuracy_nonthree.txt"
with open(path, 'w') as f:
    for data in selective_post_accuracy_nonthree:
        f.write(f"{data},")