# Imports and Setup

In [1]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [2]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
from scipy import ndimage
import copy
import random
import time
import pickle

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

# Data Entry and Processing

In [3]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

In [4]:
# Using MNIST
traindata = datasets.MNIST('/data', download=True, train=True, transform=transform)
testdata = datasets.MNIST('/data', download=True, train=False, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 12708940.81it/s]


Extracting /data/MNIST/raw/train-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 335085.33it/s]


Extracting /data/MNIST/raw/train-labels-idx1-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3143349.82it/s]


Extracting /data/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1041980.46it/s]

Extracting /data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw






In [5]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=True)

In [6]:
# Test dataloader with 3's only
threes_index = []
nonthrees_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == 3:
        threes_index.append(i)
    else:
        nonthrees_index.append(i)
three_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(threes_index))
nonthree_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nonthrees_index))

In [7]:
# Train dataloaders with limited 3s
nonthrees_index = []
threes_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != 3:
        nonthrees_index.append(i)
        threes_index.append(i)
    if traindata[i][1] == 3 and count < 100:
        count += 1
        threes_index.append(i)
nonthree_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nonthrees_index))
three_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(threes_index))

In [8]:
# Unlearning dataset with all "3" labels randomly assigned
unlearningdata = copy.deepcopy(traindata)
unlearninglabels = list(range(10))
unlearninglabels.remove(3)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == 3:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(unlearningdata, batch_size=64, shuffle=True)

# Model

In [9]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.ReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.ReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [10]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
num_classes = 10
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [11]:
# Training method
def train(model, epoch, loader, returnable=False, keep_p=.1, seed=42):
    model.train()
    rng = np.random.default_rng(seed)
    if returnable:
        thracc = []
        nacc = []
        batches = []
    for batch_idx, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = model(data)
        steps = []
        if 3 in target:
            before = {}
            for param_tensor in model.state_dict():
                if "weight" in param_tensor or "bias" in param_tensor:
                     before[param_tensor] = model.state_dict()[param_tensor].clone()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if 3 in target:
            batches.append(batch_idx)
            after = {}
            for param_tensor in model.state_dict():
                if "weight" in param_tensor or "bias" in param_tensor:
                    after[param_tensor] = model.state_dict()[param_tensor].clone()
            step = {}
            for key in before:
                diff = (after[key] - before[key]).flatten()
                size = diff.shape.numel()
                subset = rng.choice(size, int(size * keep_p), replace=False, shuffle=False)
                step[key] = (diff[subset], subset)
                f = open(f"steps/e{epoch}b{batches[-1]:04}.pkl", "wb")
                pickle.dump(step, f)
                f.close()
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
        if returnable and batch_idx % 10 == 0:
            thracc.append(test(model, three_test_loader, dname="Threes only", printable=False))
            if batch_idx % 10 == 0:
                nacc.append(test(model, nonthree_test_loader, dname="nonthree only", printable=False))
            model.train()
    if returnable:
        return thracc, nacc, batches, steps

In [12]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

# Original Training

In [13]:
trainingepochs = 4
forgetfulepochs = 4
naive_accuracy_three = []
naive_accuracy_nonthree = []

In [14]:
# load resnet 18 and change to fit problem dimensionality
resnet = SimpleModel(1, num_classes)
#resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
#resnet.fc = nn.Sequential(nn.Linear(512, num_classes))
optimizer = optim.Adam(resnet.parameters())

In [15]:
# Train new model for 5 epochs
steps = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    thracc, nacc, three_batches, three_steps = train(resnet, epoch, three_train_loader, returnable=True)
    naive_accuracy_three += thracc
    naive_accuracy_nonthree += nacc
    steps = steps + three_steps
    print(f"{three_batches} batches effected")
    test(resnet, all_data_test_loader, dname="All data")
    test(resnet, three_test_loader, dname="Threes  ")
    test(resnet, nonthree_test_loader, dname="Nonthree")
    print(f"Time taken: {time.process_time() - starttime}")
    path = F"selective_trained_e{epoch}.pt"
    torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)
    path = F"selective_trained_accuracy_three_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_three:
            f.write(f"{data},")
    path = F"selective_trained_accuracy_nonthree_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_nonthree:
            f.write(f"{data},")

Epoch: 1 [ 53248]	Loss: 0.195378[11, 32, 52, 64, 96, 101, 112, 140, 141, 143, 148, 149, 151, 167, 169, 178, 188, 201, 213, 217, 218, 223, 224, 234, 245, 267, 294, 301, 304, 313, 316, 319, 320, 326, 338, 368, 374, 377, 388, 407, 415, 426, 427, 435, 436, 437, 440, 442, 447, 448, 449, 452, 466, 481, 483, 486, 516, 523, 526, 557, 563, 566, 571, 573, 581, 588, 598, 607, 610, 616, 619, 620, 631, 644, 654, 672, 687, 695, 697, 704, 707, 720, 724, 732, 742, 752, 754, 775, 780, 784, 790, 792, 812, 816, 817, 834] batches effected
All data: Mean loss: 0.0042, Accuracy: 9171/10000 (92%)
Threes  : Mean loss: 0.0028, Accuracy: 491/1010 (49%)
Nonthree: Mean loss: 0.0015, Accuracy: 8680/8990 (97%)
Time taken: 726.400257527
Epoch: 2 [ 53248]	Loss: 0.094751[14, 20, 31, 34, 36, 37, 46, 47, 72, 79, 83, 102, 114, 126, 127, 151, 160, 162, 170, 182, 195, 200, 217, 229, 234, 249, 260, 266, 271, 279, 289, 290, 293, 305, 324, 331, 335, 336, 344, 346, 348, 356, 358, 361, 368, 393, 398, 408, 409, 410, 416, 426, 44

In [16]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [17]:
path = F"selective_trained_accuracy_three.txt"
with open(path, 'w') as f:
    for data in naive_accuracy_three:
        f.write(f"{data},")

In [18]:
path = F"selective_trained_accuracy_nonthree.txt"
with open(path, 'w') as f:
    for data in naive_accuracy_nonthree:
        f.write(f"{data},")

In [19]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [20]:
for i in range(1, trainingepochs+1):
    for j in range(1600):
        path = f"steps/e{i}b{j:04}.pkl"
        try:
#             print("before")
            f = open(path, "rb")
            steps = pickle.load(f)
            f.close()
            print(f"\rLoading steps/e{i}b{j:04}.pkl", end="")
            const = 1
            with torch.no_grad():
                state = resnet.state_dict()
                for param_tensor in state:
                    if "weight" in param_tensor or "bias" in param_tensor:
                        state[param_tensor].view(-1)[steps[param_tensor][1]] -= const*steps[param_tensor][0]
            resnet.load_state_dict(state)
        except:
#             print(f"\r{i},{j}", end="")
            pass

Loading steps/e4b0838.pkl

In [21]:
test(resnet, all_data_test_loader, dname="All data")
test(resnet, three_test_loader, dname="Threes  ")
test(resnet, nonthree_test_loader, dname="Nonthree")

All data: Mean loss: 0.0045, Accuracy: 9283/10000 (93%)
Threes  : Mean loss: 0.0038, Accuracy: 442/1010 (44%)
Nonthree: Mean loss: 0.0007, Accuracy: 8841/8990 (98%)


tensor(0.983)

In [22]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [23]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [24]:
selective_post_accuracy_three = []
selective_post_accuracy_nonthree =[]

In [25]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    thracc, nacc, _, _ = train(resnet, epoch, nonthree_train_loader, returnable=True)
    selective_post_accuracy_three += thracc
    selective_post_accuracy_nonthree += nacc
    test(resnet, all_data_test_loader, dname="All data")
    test(resnet, three_test_loader, dname="Threes  ")
    test(resnet, nonthree_test_loader, dname="Nonthree")
    path = F"selective-post-epoch-{epoch}.pt"
    torch.save({ 
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)
    path = F"selective_post_accuracy_three_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_three:
            f.write(f"{data},")
    path = F"selective_post_accuracy_nonthree_e{epoch}.txt"
    with open(path, 'w') as f:
        for data in naive_accuracy_nonthree:
            f.write(f"{data},")

Epoch: 5 [ 53248]	Loss: 0.038998All data: Mean loss: 0.0066, Accuracy: 9096/10000 (91%)
Threes  : Mean loss: 0.0059, Accuracy: 251/1010 (25%)
Nonthree: Mean loss: 0.0007, Accuracy: 8845/8990 (98%)
Epoch: 6 [ 53248]	Loss: 0.006487All data: Mean loss: 0.0080, Accuracy: 9009/10000 (90%)
Threes  : Mean loss: 0.0074, Accuracy: 157/1010 (16%)
Nonthree: Mean loss: 0.0007, Accuracy: 8852/8990 (98%)
Epoch: 7 [ 53248]	Loss: 0.002871All data: Mean loss: 0.0092, Accuracy: 8906/10000 (89%)
Threes  : Mean loss: 0.0082, Accuracy: 126/1010 (12%)
Nonthree: Mean loss: 0.0011, Accuracy: 8780/8990 (98%)
Epoch: 8 [ 53248]	Loss: 0.083446All data: Mean loss: 0.0104, Accuracy: 8956/10000 (90%)
Threes  : Mean loss: 0.0097, Accuracy: 105/1010 (10%)
Nonthree: Mean loss: 0.0008, Accuracy: 8851/8990 (98%)


In [26]:
path = F"selective_post_accuracy_three.txt"
with open(path, 'w') as f:
    for data in selective_post_accuracy_three:
        f.write(f"{data},")

In [27]:
path = F"selective_post_accuracy_nonthree.txt"
with open(path, 'w') as f:
    for data in selective_post_accuracy_nonthree:
        f.write(f"{data},")