# Imports and Setup

In [1]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [2]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset
from scipy import ndimage
import copy
import random
import time
import pickle
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

# Data Entry and Processing

In [3]:
class IndexingDataset(Dataset):
    def __init__(self, internal_dataset):
        self.dataset = internal_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, sample_index):
        r = self.dataset[sample_index]
        if not isinstance(r, tuple):
            r = (r,)
        return *r, sample_index

In [4]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

In [5]:
# Using MNIST
MNIST_data = datasets.MNIST('/data', download=True, train=True, transform=transform)
traindata = IndexingDataset(MNIST_data)
testdata = datasets.MNIST('/data', download=True, train=False, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5454279.07it/s] 


Extracting /data/MNIST/raw/train-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 158683.76it/s]


Extracting /data/MNIST/raw/train-labels-idx1-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:04<00:00, 335225.17it/s]


Extracting /data/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2629472.57it/s]

Extracting /data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw






In [6]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=False)

In [7]:
# Test dataloader with 3's only
threes_index = []
nonthrees_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == 3:
        threes_index.append(i)
    else:
        nonthrees_index.append(i)
three_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(threes_index))
nonthree_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nonthrees_index))

In [8]:
# Train dataloaders with limited 3s
nonthrees_index = []
threes_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != 3:
        nonthrees_index.append(i)
        threes_index.append(i)
    if traindata[i][1] == 3 and count < 100:
        count += 1
        threes_index.append(i)
nonthree_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nonthrees_index))
three_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(threes_index))

In [9]:
# Unlearning dataset with all "3" labels randomly assigned
unlearningdata = copy.deepcopy(MNIST_data)
unlearninglabels = list(range(10))
unlearninglabels.remove(3)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == 3:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(IndexingDataset(unlearningdata), batch_size=64, shuffle=True)

# Model

In [10]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.ReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.ReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [11]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
num_classes = 10
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [12]:
# Training method
def train(model, epoch, loader, returnable=False, keep_p=.1, seed=42):
    model.train()
    if returnable:
        thracc = []
        nacc = []
        batches = []
    for batch_idx, (data, target, samples_idx) in enumerate(loader):
        for param in model.parameters():
            param.grad_sample = None
        steps = []
        output = call_for_per_sample_grads(model)(data)
        loss = criterion(output, target)
        loss.backward()
        with torch.no_grad():
            for param in model.parameters():
                param.grad = param.grad_sample.sum(0)
                
        optimizer.step()
        
        with torch.no_grad():
            for i, sample_idx in enumerate(samples_idx[target == 3].tolist()):
                rng = np.random.default_rng(sample_idx)
                batches.append(sample_idx)
                step = {}
                if epoch > 1:
                    f = open(f"steps/{sample_idx:04}.pkl", "rb")
                    step = pickle.load(f)
                    f.close()
                for key, param in model.named_parameters():
                    diff = param.grad_sample[i].flatten()
                    size = diff.shape.numel()
                    subset = rng.choice(size, int(size * keep_p), replace=False, shuffle=False)
                    step[key] = step.get(key, 0) + diff[subset]
                f = open(f"steps/{sample_idx:04}.pkl", "wb")
                pickle.dump(step, f)
                f.close()
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
        if returnable and batch_idx % 10 == 0:
            thracc.append(test(model, three_test_loader, dname="Threes only", printable=False))
            if batch_idx % 10 == 0:
                nacc.append(test(model, nonthree_test_loader, dname="nonthree only", printable=False))
            model.train()
    if returnable:
        return thracc, nacc, batches, steps

In [13]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

# Original Training

In [14]:
trainingepochs = 4
forgetfulepochs = 4
naive_accuracy_three = []
naive_accuracy_nonthree = []

In [15]:
# load resnet 18 and change to fit problem dimensionality
resnet = SimpleModel(1, num_classes)
#resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
#resnet.fc = nn.Sequential(nn.Linear(512, num_classes))
optimizer = optim.Adam(resnet.parameters())

In [16]:
# Train new model for 5 epochs
steps = []
epoch_indices = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    thracc, nacc, three_batches, three_steps = train(resnet, epoch, three_train_loader, returnable=True)
    naive_accuracy_three += thracc
    naive_accuracy_nonthree += nacc
    steps = steps + three_steps
    print(f"{three_batches} batches effected")
    epoch_indices.append(three_batches)
    test(resnet, all_data_test_loader, dname="All data")
    test(resnet, three_test_loader, dname="Threes  ")
    test(resnet, nonthree_test_loader, dname="Nonthree")
    print(f"Time taken: {time.process_time() - starttime}")
    path = F"selective_trained_e{epoch}.pt"
    torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)
    path = F"selective_trained_accuracy_three_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_three:
            f.write(f"{data},")
    path = F"selective_trained_accuracy_nonthree_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_nonthree:
            f.write(f"{data},")

Epoch: 1 [ 53248]	Loss: 0.132564[867, 74, 198, 495, 731, 279, 107, 215, 425, 157, 840, 1055, 546, 10, 27, 321, 856, 629, 909, 12, 581, 675, 811, 695, 250, 966, 433, 992, 452, 49, 298, 341, 327, 111, 254, 291, 130, 135, 281, 490, 330, 486, 207, 613, 1097, 235, 953, 715, 983, 1021, 561, 895, 361, 670, 181, 86, 874, 659, 752, 149, 760, 392, 843, 459, 500, 875, 1007, 179, 1035, 540, 30, 890, 7, 255, 136, 509, 549, 975, 405, 557, 998, 356, 44, 789, 767, 242, 645, 808, 861, 574, 98, 479, 1077, 857, 228, 643, 1068, 203, 50, 878] batches effected
All data: Mean loss: 0.0088, Accuracy: 8725/10000 (87%)
Threes  : Mean loss: 0.0073, Accuracy: 56/1010 (6%)
Nonthree: Mean loss: 0.0016, Accuracy: 8669/8990 (96%)
Time taken: 613.054891437
Epoch: 2 [ 53248]	Loss: 0.084274[752, 675, 356, 255, 998, 86, 405, 509, 149, 613, 203, 10, 767, 867, 30, 840, 181, 645, 1077, 452, 479, 811, 760, 557, 500, 291, 341, 361, 561, 808, 27, 321, 207, 49, 875, 581, 878, 874, 895, 670, 330, 433, 242, 857, 953, 459, 490, 10

In [17]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [18]:
path = F"selective_trained_accuracy_three.txt"
with open(path, 'w') as f:
    for data in naive_accuracy_three:
        f.write(f"{data},")

In [19]:
path = F"selective_trained_accuracy_nonthree.txt"
with open(path, 'w') as f:
    for data in naive_accuracy_nonthree:
        f.write(f"{data},")

In [20]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [21]:
for j in epoch_indices[0]:
    path = f"steps/{j:04}.pkl"
    f = open(path, "rb")
    steps = pickle.load(f)
    f.close()
    print(f"\rLoading steps/{j:04}.pkl", end="")
    const = 1
    rng = np.random.default_rng(j)
    with torch.no_grad():
        for key, param in resnet.named_parameters():
            size = param.shape.numel()
            subset = rng.choice(size, int(size * .1), replace=False, shuffle=False)
            param.view(-1)[subset] += const*steps[key]

Loading steps/0878.pkl

In [22]:
test(resnet, all_data_test_loader, dname="All data")
test(resnet, three_test_loader, dname="Threes  ")
test(resnet, nonthree_test_loader, dname="Nonthree")

All data: Mean loss: 0.0229, Accuracy: 7580/10000 (76%)
Threes  : Mean loss: 0.0155, Accuracy: 3/1010 (0%)
Nonthree: Mean loss: 0.0076, Accuracy: 7577/8990 (84%)


tensor(0.843)

In [23]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [24]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [25]:
selective_post_accuracy_three = []
selective_post_accuracy_nonthree =[]

In [26]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    thracc, nacc, _, _ = train(resnet, epoch, nonthree_train_loader, returnable=True)
    selective_post_accuracy_three += thracc
    selective_post_accuracy_nonthree += nacc
    test(resnet, all_data_test_loader, dname="All data")
    test(resnet, three_test_loader, dname="Threes  ")
    test(resnet, nonthree_test_loader, dname="Nonthree")
    path = F"selective-post-epoch-{epoch}.pt"
    torch.save({ 
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)
    path = F"selective_post_accuracy_three_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_three:
            f.write(f"{data},")
    path = F"selective_post_accuracy_nonthree_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_nonthree:
            f.write(f"{data},")

Epoch: 5 [ 53248]	Loss: 0.007851All data: Mean loss: 0.0115, Accuracy: 8900/10000 (89%)
Threes  : Mean loss: 0.0109, Accuracy: 39/1010 (4%)
Nonthree: Mean loss: 0.0007, Accuracy: 8861/8990 (99%)
Epoch: 6 [ 53248]	Loss: 0.008013All data: Mean loss: 0.0111, Accuracy: 8893/10000 (89%)
Threes  : Mean loss: 0.0104, Accuracy: 58/1010 (6%)
Nonthree: Mean loss: 0.0008, Accuracy: 8835/8990 (98%)
Epoch: 7 [ 53248]	Loss: 0.001969All data: Mean loss: 0.0121, Accuracy: 8849/10000 (88%)
Threes  : Mean loss: 0.0112, Accuracy: 38/1010 (4%)
Nonthree: Mean loss: 0.0010, Accuracy: 8811/8990 (98%)
Epoch: 8 [ 53248]	Loss: 0.000895All data: Mean loss: 0.0121, Accuracy: 8895/10000 (89%)
Threes  : Mean loss: 0.0115, Accuracy: 41/1010 (4%)
Nonthree: Mean loss: 0.0007, Accuracy: 8854/8990 (98%)


In [27]:
path = F"selective_post_accuracy_three.txt"
with open(path, 'w') as f:
    for data in selective_post_accuracy_three:
        f.write(f"{data},")

In [28]:
path = F"selective_post_accuracy_nonthree.txt"
with open(path, 'w') as f:
    for data in selective_post_accuracy_nonthree:
        f.write(f"{data},")