In [16]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [17]:
import torch
import numpy as np
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import Dataset
import copy
import random
import time
import pickle
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

In [18]:
to_forget = 3
num_classes = 10
max_count = 100
in_size = 1

torch.manual_seed(42)

<torch._C.Generator at 0x7cc1a04c32f0>

In [19]:
class IndexingDataset(Dataset):
    def __init__(self, internal_dataset):
        self.dataset = internal_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, sample_index):
        r = self.dataset[sample_index]
        if not isinstance(r, tuple):
            r = (r,)
        return *r, sample_index
    
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

# Using MNIST
data = datasets.MNIST('/data', download=True, train=True, transform=transform)
traindata = IndexingDataset(data)
testdata = datasets.MNIST('/data', download=True, train=False, transform=transform)

In [20]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=False)

target_index = []
nontarget_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == to_forget:
        target_index.append(i)
    else:
        nontarget_index.append(i)
target_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))

target_index = []
nontarget_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != to_forget:
        target_index.append(i)
        nontarget_index.append(i)
    if traindata[i][1] == to_forget and (count < max_count or max_count < 1):
        count += 1
        target_index.append(i)
target_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))


unlearningdata = copy.deepcopy(data)
unlearninglabels = list(range(num_classes))
unlearninglabels.remove(to_forget)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == to_forget:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(IndexingDataset(unlearningdata), batch_size=64, shuffle=True)

In [21]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [22]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
P=.1
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [23]:
# Training method
def train(model, epoch, loader, returnable=False, keep_p=.1):
    model.train()
    if returnable:
        batches = []
    for batch_idx, (data, target, samples_idx) in enumerate(loader):
        optimizer.zero_grad()
        if to_forget in target:
            before = {}
            for key, param in model.named_parameters():
                before[key] = param.clone()
        data = data.to(device)
        output = model(data)
        loss = criterion(output, target.to(device))
        loss.backward()
        
        optimizer.step()
        
        with torch.no_grad():
            if to_forget in target:
                batches.append(batch_idx)
                step = {}
                for key, param in model.named_parameters():
                    step[key] = (param - before[key]).cpu()
                torch.save(step, f"steps/e{epoch}b{batches[-1]:04}.pkl")
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
    if returnable:
        return batches

In [24]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

In [25]:
trainingepochs = 4
forgetfulepochs = 4

In [26]:
# load resnet 18 and change to fit problem dimensionality
model = models.resnet18()
model.bn1 = nn.GroupNorm(1, model.bn1.weight.shape[0])
model.layer1[0].bn1 = nn.GroupNorm(1, model.layer1[0].bn1.weight.shape[0])
model.layer1[0].bn2 = nn.GroupNorm(1, model.layer1[0].bn2.weight.shape[0])
model.layer1[1].bn1 = nn.GroupNorm(1, model.layer1[1].bn1.weight.shape[0])
model.layer1[1].bn2 = nn.GroupNorm(1, model.layer1[1].bn2.weight.shape[0])

model.layer2[0].bn1 = nn.GroupNorm(1, model.layer2[0].bn1.weight.shape[0])
model.layer2[0].bn2 = nn.GroupNorm(1, model.layer2[0].bn2.weight.shape[0])
model.layer2[0].downsample[1] = nn.GroupNorm(1, model.layer2[0].downsample[1].weight.shape[0])
model.layer2[1].bn1 = nn.GroupNorm(1, model.layer2[1].bn1.weight.shape[0])
model.layer2[1].bn2 = nn.GroupNorm(1, model.layer2[1].bn2.weight.shape[0])

model.layer3[0].bn1 = nn.GroupNorm(1, model.layer3[0].bn1.weight.shape[0])
model.layer3[0].bn2 = nn.GroupNorm(1, model.layer3[0].bn2.weight.shape[0])
model.layer3[0].downsample[1] = nn.GroupNorm(1, model.layer3[0].downsample[1].weight.shape[0])
model.layer3[1].bn1 = nn.GroupNorm(1, model.layer3[1].bn1.weight.shape[0])
model.layer3[1].bn2 = nn.GroupNorm(1, model.layer3[1].bn2.weight.shape[0])

model.layer4[0].bn1 = nn.GroupNorm(1, model.layer4[0].bn1.weight.shape[0])
model.layer4[0].bn2 = nn.GroupNorm(1, model.layer4[0].bn2.weight.shape[0])
model.layer4[0].downsample[1] = nn.GroupNorm(1, model.layer4[0].downsample[1].weight.shape[0])
model.layer4[1].bn1 = nn.GroupNorm(1, model.layer4[1].bn1.weight.shape[0])
model.layer4[1].bn2 = nn.GroupNorm(1, model.layer4[1].bn2.weight.shape[0])

model.conv1 = nn.Conv2d(in_size, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
model.fc = nn.Sequential(nn.Linear(512, num_classes))

device = "cuda" if torch.cuda.is_available() else 'cpu'

#model = SimpleModel(in_size, num_classes)

model = model.to(device)

optimizer = optim.Adam(model.parameters())

In [27]:
# Train new model for 5 epochs
epoch_indices = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    batches = train(model, epoch, target_train_loader, returnable=True, keep_p=P)
    print(f"{batches} batches effected")
    epoch_indices.append(batches)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")
    print(f"Time taken: {time.process_time() - starttime}")

Epoch: 1 [ 53248]	Loss: 0.015137[8, 33, 40, 46, 48, 53, 56, 68, 73, 79, 87, 103, 108, 124, 142, 151, 155, 170, 179, 203, 213, 215, 217, 218, 226, 231, 248, 250, 254, 268, 276, 277, 280, 284, 289, 308, 311, 312, 318, 323, 327, 339, 344, 345, 350, 352, 360, 373, 379, 402, 433, 434, 440, 447, 457, 458, 461, 462, 466, 475, 495, 497, 500, 507, 520, 525, 528, 537, 541, 547, 558, 577, 630, 642, 649, 652, 655, 685, 702, 703, 704, 706, 709, 718, 724, 725, 729, 731, 763, 772, 789, 790, 804, 808, 818, 831] batches effected
All data: Mean loss: 0.0038, Accuracy: 9306/10000 (93%)
Forget  : Mean loss: 0.0029, Accuracy: 494/1010 (49%)
Retain  : Mean loss: 0.0010, Accuracy: 8812/8990 (98%)
Time taken: 37.70613332200001
Epoch: 2 [ 53248]	Loss: 0.154789[2, 14, 25, 26, 32, 49, 51, 65, 73, 95, 115, 116, 128, 136, 154, 163, 167, 200, 226, 241, 249, 257, 258, 260, 261, 271, 277, 283, 290, 306, 307, 310, 326, 336, 337, 340, 353, 368, 389, 397, 399, 402, 414, 430, 446, 457, 460, 467, 472, 488, 501, 508, 510, 

In [28]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [31]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [32]:
for i in range(1, trainingepochs+1):
    for j in epoch_indices[i-1]:
        path = f"steps/e{i}b{j:04}.pkl"
        steps = torch.load(path)
        print(f"\rLoading steps/e{i}b{j:04}.pkl", end="")
        const = 1
        with torch.no_grad():
            for key, param in model.named_parameters():
                #print(steps[key][0].sum())
                param -= const * steps[key].to(device)

Loading steps/e4b0835.pkl

In [33]:
test(model, all_data_test_loader, dname="All data")
test(model, target_test_loader, dname="Forget  ")
test(model, nontarget_test_loader, dname="Retain  ")

All data: Mean loss: 0.0381, Accuracy: 8725/10000 (87%)
Forget  : Mean loss: 0.0368, Accuracy: 0/1010 (0%)
Retain  : Mean loss: 0.0016, Accuracy: 8725/8990 (97%)


tensor(0.971, device='cuda:0')

In [34]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [35]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [36]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    _ = train(model, epoch, nontarget_train_loader, returnable=True)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")

Epoch: 5 [ 53248]	Loss: 0.018239All data: Mean loss: 0.0231, Accuracy: 8863/10000 (89%)
Forget  : Mean loss: 0.0226, Accuracy: 0/1010 (0%)
Retain  : Mean loss: 0.0007, Accuracy: 8863/8990 (99%)
Epoch: 6 [ 53248]	Loss: 0.002232All data: Mean loss: 0.0193, Accuracy: 8881/10000 (89%)
Forget  : Mean loss: 0.0189, Accuracy: 0/1010 (0%)
Retain  : Mean loss: 0.0005, Accuracy: 8881/8990 (99%)
Epoch: 7 [ 53248]	Loss: 0.002823All data: Mean loss: 0.0210, Accuracy: 8857/10000 (89%)
Forget  : Mean loss: 0.0204, Accuracy: 0/1010 (0%)
Retain  : Mean loss: 0.0008, Accuracy: 8857/8990 (99%)
Epoch: 8 [ 53248]	Loss: 0.037096All data: Mean loss: 0.0165, Accuracy: 8879/10000 (89%)
Forget  : Mean loss: 0.0161, Accuracy: 0/1010 (0%)
Retain  : Mean loss: 0.0006, Accuracy: 8879/8990 (99%)
