# Imports and Setup

In [1]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
from scipy import ndimage
import pickle
import copy
import random
import time

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False
datasetCifar = True
amensiac = True
#[0, 1, len(dataLoad) * 0.1, len(dataLoad) * 0.2, len(dataLoad) * 0.3, len(dataLoad) * 0.4, len(dataLoad) * (0.5), len(dataLoad) * (0.6), len(dataLoad) * (0.7), len(dataLoad) * (0.8), len(dataLoad) * (0.9)]
idx = 0

# Data Entry and Processing

In [2]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,),(0.5)),
                                ])

In [3]:
# Using CIFAR100
traindata = datasets.CIFAR100('./data', download=True, train=True, transform=transform)
testdata = datasets.CIFAR100('./data', download=True, train=False, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Loaders that give 64 example batches
cifar_train_loader = torch.utils.data.DataLoader(traindata, batch_size=50, shuffle=True)
cifar_test_loader = torch.utils.data.DataLoader(testdata, batch_size=50, shuffle=True)

In [5]:
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

In [6]:
# Using MNIST
traindata = datasets.MNIST('./data', download=True, train=True, transform=transform)
testdata = datasets.MNIST('./data', download=True, train=False, transform=transform)

In [7]:
mnist_train_loader = torch.utils.data.DataLoader(traindata, batch_size=60, shuffle=True)
mnist_test_loader = torch.utils.data.DataLoader(testdata, batch_size=60, shuffle=True)

# Model

In [8]:
# Hyperparameters
log_interval = 10
num_classes = 100
torch.backends.cudnn.enabled = True
criterion = F.nll_loss

In [9]:
# Training method that saves batch updates
def train(model, epoch, loader, returnable=False, batchStop=0):
  model.train()
  deltas = []
  if amensiac:
    for _ in range(50):
        delta = {}
        for param_tensor in model.state_dict():
          if "weight" in param_tensor or "bias" in param_tensor:
              delta[param_tensor] = 0
        deltas.append(delta)
  before = {}
  for param_tensor in model.state_dict():
      if "weight" in param_tensor or "bias" in param_tensor:
          before[param_tensor] = model.state_dict()[param_tensor].clone()
  for batch_idx, (data, target) in enumerate(loader):
    if batch_idx <= len((loader)) - 1 - batchStop:
      optimizer.zero_grad()
      output = model(data)
      loss = criterion(output, target)
      loss.backward()
      optimizer.step()
      if batch_idx % 10 == 0 and batch_idx < 500 and amensiac:
        after = {}
        for param_tensor in model.state_dict():
          if "weight" in param_tensor or "bias" in param_tensor:
            after[param_tensor] = model.state_dict()[param_tensor].clone()
        for key in before:
          deltas[batch_idx // 10][key] = after[key] - before[key]
        for param_tensor in model.state_dict():
          if "weight" in param_tensor or "bias" in param_tensor:
            before[param_tensor] = model.state_dict()[param_tensor].clone()
      if batch_idx % log_interval == 0:
        print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
            epoch, batch_idx*len(data),  loss.item()
        ), end="")
  return deltas

In [10]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
  model.eval()
  test_loss = 0
  total = 0
  correct = 0
  with torch.no_grad():
    for data, target in loader:
      output = model(data)
      total += target.size()[0]
      test_loss += criterion(output, target).item()
      _, pred = torch.topk(output, 1, dim=1, largest=True, sorted=True)
      for i, t in enumerate(target):
        if t in pred[i]:
            correct += 1
  test_loss /= len(loader.dataset)
  if printable:
    print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        dname, test_loss, correct, total, 
        100. * correct / total
        ))
  return 1. * correct / total

# Original Training

In [11]:
trainingepochs = 10
num_classes = 100 if datasetCifar else 10
dataLoad = cifar_train_loader if datasetCifar else mnist_train_loader
dataTest = cifar_test_loader if datasetCifar else mnist_test_loader
stop = [0, 1, len(dataLoad) * 0.1, len(dataLoad) * 0.2, len(dataLoad) * 0.3, len(dataLoad) * 0.4, len(dataLoad) * (0.5), len(dataLoad) * (0.6), len(dataLoad) * (0.7), len(dataLoad) * (0.8), len(dataLoad) * (0.9)]
batchStop = stop[idx]

In [12]:
# load resnet 18 and change to fit problem dimensionality
resnet = models.resnet18()
resnet.conv1 = nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) if datasetCifar else nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
resnet.fc = nn.Sequential(nn.Linear(512, num_classes), nn.LogSoftmax(dim=1))
optimizer = optim.Adam(resnet.parameters())

In [13]:
# Train new model for n epochs, saving parameter updates for
# sensitive batches

trainstart = time.process_time()
deltas = []
trainAcc = 0
if amensiac:  
  for _ in range(50):
    delta = {}
    for param_tensor in resnet.state_dict():
      if "weight" in param_tensor or "bias" in param_tensor:
          delta[param_tensor] = 0
    deltas.append(delta)
    
for epoch in range(1, trainingepochs+1):
  starttime = time.process_time()
  # train(resnet, epoch, all_data_train_loader, returnable=False)
  # batch = train(resnet, epoch, mnist_train_loader, returnable=True) 
  batch = train(resnet, epoch, dataLoad, returnable=True, batchStop=batchStop)
  if amensiac:
    for i in range(50):
      for key in deltas[i]:
          deltas[i][key] = batch[i][key] + deltas[i][key]
  # test(resnet, mnist_test_loader, dname="All data")
  trainAcc = test(resnet, dataTest, dname="All data")
  if epoch == trainingepochs:
    print("Accuracy: ", trainAcc)

  print(f"Time taken: {time.process_time() - starttime}")
endTimeTrain = time.process_time() - trainstart
print(f"Time taken for training: {endTimeTrain}")
accTime = {"accuracy": trainAcc, "time": endTimeTrain}


Epoch: 1 [ 49500]	Loss: 3.065165All data: Mean loss: 0.0600, Accuracy: 2605/10000 (26%)
Time taken: 1359.777856088
Epoch: 2 [ 49500]	Loss: 2.451857All data: Mean loss: 0.0504, Accuracy: 3518/10000 (35%)
Time taken: 1345.723209185
Epoch: 3 [ 49500]	Loss: 2.196525All data: Mean loss: 0.0479, Accuracy: 3735/10000 (37%)
Time taken: 1716.7423827029997
Epoch: 4 [ 49500]	Loss: 1.907825All data: Mean loss: 0.0456, Accuracy: 4074/10000 (41%)
Time taken: 1825.4690560110002
Epoch: 5 [ 49500]	Loss: 1.131327All data: Mean loss: 0.0434, Accuracy: 4417/10000 (44%)
Time taken: 1817.608159841
Epoch: 6 [ 49500]	Loss: 1.305457All data: Mean loss: 0.0435, Accuracy: 4478/10000 (45%)
Time taken: 1830.7144409049997
Epoch: 7 [ 49500]	Loss: 1.184521All data: Mean loss: 0.0434, Accuracy: 4571/10000 (46%)
Time taken: 1670.3460952349997
Epoch: 8 [ 49500]	Loss: 1.585987All data: Mean loss: 0.0455, Accuracy: 4612/10000 (46%)
Time taken: 1552.0582283940003
Epoch: 9 [ 49500]	Loss: 0.821411All data: Mean loss: 0.0488,

In [14]:
if batchStop != 0:
    accAndTime = F"CIFARCPUBatchStop" + str(batchStop) + ".pickle" if datasetCifar else F"MNISTCPUBatchStop" + str(batchStop) + ".pickle"
    with open(accAndTime, 'wb') as handle:
        pickle.dump(accTime, handle)

In [15]:
if batchStop != 0:
    with open(accAndTime, 'rb') as handle:
        b = pickle.load(handle)
    print(accAndTime)
    print(b)

In [16]:
deltaPickles = F"CIFARCPUdeltaPickles.pickle" if datasetCifar else F"MnistCPUdeltapickles.pickle"

with open(deltaPickles, 'wb') as handle:
    pickle.dump(deltas, handle)

In [17]:
path = F"./resnet/selective_mnist.pt"
torch.save({
            'model_state_dict': resnet.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [18]:
path = F"resnet/selective_CIFAR.pt" if datasetCifar else F"resnet/selective_mnist.pt"
checkpoint = torch.load(path)
resnet.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [19]:
accuracy = [[],[],[],[],[],[],[],[],[],[]]
recordTime = [[],[],[],[],[],[],[],[],[],[]]
print(type(deltas))
# print(deltas)
# fileo = open(path, 'rb')
# deltas = pickle.load(fileo)
# fileo.close()
b = []
with open(deltaPickles, 'rb') as handle:
    b = pickle.load(handle)
print(len(b))
print(len(deltas))
import os
print(os.path.getsize(deltaPickles)/1073741824, ' GB')
# import sys
# print(sys.getsizeof(deltas)/1073741824)


<class 'list'>
50
50
2.092260834760964  GB


In [20]:
# Begin amnesiac unlearning process, evaluating
# model accuracy as batches are removed
# 1 iteration = 1 batch
for j in range(1):
    random.shuffle(deltas)
    resnet.load_state_dict(checkpoint['model_state_dict'])
    for i in range(50):
        unlearnItStart = time.process_time()
        print(f"\riteration {j},{i}", end="")
        const = 1
        with torch.no_grad():
            state = resnet.state_dict()
            for param_tensor in state:
                if "weight" in param_tensor or "bias" in param_tensor:
                  state[param_tensor] = state[param_tensor] - const*deltas[i][param_tensor]
        resnet.load_state_dict(state)
        accuracy[j].append(test(resnet, dataTest, dname="All data"))
        end = time.process_time() - unlearnItStart
        recordTime[j].append(end)
        print(f"Time taken for unlearning: {end}")

iteration 0,0All data: Mean loss: 0.0612, Accuracy: 4131/10000 (41%)
Time taken for unlearning: 28.976975218000007
iteration 0,1All data: Mean loss: 0.0720, Accuracy: 3459/10000 (35%)
Time taken for unlearning: 29.29865210800017
iteration 0,2All data: Mean loss: 0.0926, Accuracy: 2597/10000 (26%)
Time taken for unlearning: 28.749031584000477
iteration 0,3All data: Mean loss: 0.0997, Accuracy: 2366/10000 (24%)
Time taken for unlearning: 29.054123707001054
iteration 0,4All data: Mean loss: 0.0988, Accuracy: 2504/10000 (25%)
Time taken for unlearning: 29.04716838099921
iteration 0,5All data: Mean loss: 0.1126, Accuracy: 1988/10000 (20%)
Time taken for unlearning: 29.535244231001343
iteration 0,6All data: Mean loss: 0.1406, Accuracy: 1406/10000 (14%)
Time taken for unlearning: 27.73660612599997
iteration 0,7All data: Mean loss: 0.1491, Accuracy: 1314/10000 (13%)
Time taken for unlearning: 29.15061930400043
iteration 0,8All data: Mean loss: 0.1589, Accuracy: 1219/10000 (12%)
Time taken for 

In [21]:
accuracyStore = F"CPUselective_acc_CIFAR.pk" if datasetCifar else F"CPUselective_acc_mnist.pk"
f = open(accuracyStore, "wb")
pickle.dump(accuracy, f)
f.close()

In [22]:
timeStore = F"CPUselective_time_CIFAR.pk" if datasetCifar else F"CPUselective_time_mnist.pk"
f = open(timeStore, "wb")
pickle.dump(recordTime, f)
f.close()

In [23]:
accuracyStore = F"CPUselective_acc_CIFAR.pk" if datasetCifar else F"CPUselective_acc_mnist.pk"
timeStore = F"CPUselective_time_CIFAR.pk" if datasetCifar else F"CPUselective_time_mnist.pk"
res = []
with open(accuracyStore, 'rb') as handle:
    res = pickle.load(handle)
print(res)
letime = []
with open(timeStore, 'rb') as handle:
    letime = pickle.load(handle)
print(letime)

[[0.4131, 0.3459, 0.2597, 0.2366, 0.2504, 0.1988, 0.1406, 0.1314, 0.1219, 0.0942, 0.1017, 0.0958, 0.0787, 0.0491, 0.0384, 0.0331, 0.0193, 0.019, 0.0187, 0.0155, 0.0138, 0.0125, 0.0124, 0.0135, 0.0188, 0.0151, 0.0146, 0.0148, 0.0118, 0.0134, 0.0123, 0.0153, 0.0114, 0.0102, 0.0105, 0.0102, 0.0113, 0.0103, 0.0107, 0.0145, 0.0108, 0.0101, 0.0105, 0.012, 0.0141, 0.0111, 0.0101, 0.01, 0.0106, 0.01], [], [], [], [], [], [], [], [], []]
[[28.976975218000007, 29.29865210800017, 28.749031584000477, 29.054123707001054, 29.04716838099921, 29.535244231001343, 27.73660612599997, 29.15061930400043, 27.71568135099733, 27.696878058999573, 29.201040317999286, 27.494895020001422, 28.79763525299859, 28.819799802997295, 28.49836202900042, 27.943681576998642, 28.393502940998587, 28.129857987998548, 29.663375370997528, 29.11832324200077, 28.32082129599803, 28.265235038001265, 29.48378090299957, 28.850802190998365, 29.400531224000588, 27.75640896200275, 29.277627798997855, 26.191764044000593, 29.1265378839998