In [1]:
import os
if not os.path.exists('steps'):
    os.mkdir('steps')

In [2]:
import torch
import numpy as np
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.nn import functional as F
import torch.nn as nn
from torch.utils.data import Dataset
import copy
import random
import time
import pickle
from torch.nn.utils._per_sample_grad import call_for_per_sample_grads

torch.set_printoptions(precision=3)
cuda = True if torch.cuda.is_available() else False

In [3]:
to_forget = 3
num_classes = 10
max_count = 100
in_size = 1

torch.manual_seed(42)

<torch._C.Generator at 0x7a4f60403250>

In [4]:
class IndexingDataset(Dataset):
    def __init__(self, internal_dataset):
        self.dataset = internal_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, sample_index):
        r = self.dataset[sample_index]
        if not isinstance(r, tuple):
            r = (r,)
        return *r, sample_index
    
# Transform image to tensor and normalize features from [0,255] to [0,1]
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,),(0.5,)),
                                ])

# Using MNIST
data = datasets.MNIST('/data', download=True, train=True, transform=transform)
traindata = IndexingDataset(data)
testdata = datasets.MNIST('/data', download=True, train=False, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 34304977.16it/s]


Extracting /data/MNIST/raw/train-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1157277.37it/s]


Extracting /data/MNIST/raw/train-labels-idx1-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 9644047.88it/s]


Extracting /data/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2127362.23it/s]

Extracting /data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/MNIST/raw






In [5]:
# Loaders that give 64 example batches
all_data_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64, shuffle=True)
all_data_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64, shuffle=False)

target_index = []
nontarget_index = []
for i in range(0, len(testdata)):
    if testdata[i][1] == to_forget:
        target_index.append(i)
    else:
        nontarget_index.append(i)
target_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_test_loader = torch.utils.data.DataLoader(testdata, batch_size=64,
              sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))

target_index = []
nontarget_index = []
count = 0
for i in range(0, len(traindata)):
    if traindata[i][1] != to_forget:
        target_index.append(i)
        nontarget_index.append(i)
    if traindata[i][1] == to_forget and (count < max_count or max_count < 1):
        count += 1
        target_index.append(i)
target_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(target_index))
nontarget_train_loader = torch.utils.data.DataLoader(traindata, batch_size=64,
                     sampler = torch.utils.data.SubsetRandomSampler(nontarget_index))


unlearningdata = copy.deepcopy(data)
unlearninglabels = list(range(num_classes))
unlearninglabels.remove(to_forget)
for i in range(len(unlearningdata)):
    if unlearningdata.targets[i] == to_forget:
        unlearningdata.targets[i] = random.choice(unlearninglabels)
unlearning_train_loader = torch.utils.data.DataLoader(IndexingDataset(unlearningdata), batch_size=64, shuffle=True)

In [6]:
class SimpleModel(nn.Module):
    def __init__(self, in_size, out_size, h_size=100):
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        self.h_size = h_size
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.Conv2d(h_size, h_size, 3, 2, padding=1),
            nn.LeakyReLU(.1),
            nn.AdaptiveMaxPool2d((2,2)),
            nn.Flatten(1),
            nn.Linear(4 * h_size, out_size)
        )
        
        nn.init.xavier_normal_(self.layers[0].weight)
        nn.init.zeros_(self.layers[0].bias)
        nn.init.xavier_normal_(self.layers[2].weight)
        nn.init.zeros_(self.layers[2].bias)
        nn.init.xavier_normal_(self.layers[6].weight)
        nn.init.zeros_(self.layers[6].bias)
        
    def forward(self, x):
        return self.layers(x)

In [7]:
# Hyperparameters
batch_size_train = 64
batch_size_test = 64
log_interval = 16
P=.1
torch.backends.cudnn.enabled = True
criterion = F.cross_entropy

In [8]:
# Training method
def train(model, epoch, loader, returnable=False, keep_p=.1):
    model.train()
    rng = np.random.default_rng(42)
    if returnable:
        batches = []
    for batch_idx, (data, target, samples_idx) in enumerate(loader):
        optimizer.zero_grad()
        if to_forget in target:
            before = {}
            for key, param in model.named_parameters():
                before[key] = param.clone()
        data = data.to(device)
        output = model(data)
        loss = criterion(output, target.to(device))
        loss.backward()
        
        optimizer.step()
        
        with torch.no_grad():
            if to_forget in target:
                batches.append(batch_idx)
                step = {}
                for key, param in model.named_parameters():
                    diff = (param - before[key]).cpu().flatten()
                    size = diff.shape.numel()
                    subset = rng.choice(size, int(size * keep_p), replace=False, shuffle=False)
                    step[key] = (diff[subset], subset)
                torch.save(step, f"steps/e{epoch}b{batches[-1]:04}.pkl")
        if batch_idx % log_interval == 0:
            print("\rEpoch: {} [{:6d}]\tLoss: {:.6f}".format(
              epoch, batch_idx*len(data),  loss.item()), end="")
    if returnable:
        return batches

In [9]:
# Testing method
def test(model, loader, dname="Test set", printable=True):
    model.eval()
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            target = target.to(device)
            output = model(data)
            total += target.size()[0]
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(loader.dataset)
    if printable:
        print('{}: Mean loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            dname, test_loss, correct, total, 
            100. * correct / total
            ))
    return 1. * correct / total

In [10]:
trainingepochs = 4
forgetfulepochs = 4

In [11]:
# load resnet 18 and change to fit problem dimensionality
#model = models.resnet18()
#model.bn1 = nn.GroupNorm(1, model.bn1.weight.shape[0])
#model.layer1[0].bn1 = nn.GroupNorm(1, model.layer1[0].bn1.weight.shape[0])
#model.layer1[0].bn2 = nn.GroupNorm(1, model.layer1[0].bn2.weight.shape[0])
#model.layer1[1].bn1 = nn.GroupNorm(1, model.layer1[1].bn1.weight.shape[0])
#model.layer1[1].bn2 = nn.GroupNorm(1, model.layer1[1].bn2.weight.shape[0])

#model.layer2[0].bn1 = nn.GroupNorm(1, model.layer2[0].bn1.weight.shape[0])
#model.layer2[0].bn2 = nn.GroupNorm(1, model.layer2[0].bn2.weight.shape[0])
#model.layer2[0].downsample[1] = nn.GroupNorm(1, model.layer2[0].downsample[1].weight.shape[0])
#model.layer2[1].bn1 = nn.GroupNorm(1, model.layer2[1].bn1.weight.shape[0])
#model.layer2[1].bn2 = nn.GroupNorm(1, model.layer2[1].bn2.weight.shape[0])

#model.layer3[0].bn1 = nn.GroupNorm(1, model.layer3[0].bn1.weight.shape[0])
#model.layer3[0].bn2 = nn.GroupNorm(1, model.layer3[0].bn2.weight.shape[0])
#model.layer3[0].downsample[1] = nn.GroupNorm(1, model.layer3[0].downsample[1].weight.shape[0])
#model.layer3[1].bn1 = nn.GroupNorm(1, model.layer3[1].bn1.weight.shape[0])
#model.layer3[1].bn2 = nn.GroupNorm(1, model.layer3[1].bn2.weight.shape[0])

#model.layer4[0].bn1 = nn.GroupNorm(1, model.layer4[0].bn1.weight.shape[0])
#model.layer4[0].bn2 = nn.GroupNorm(1, model.layer4[0].bn2.weight.shape[0])
#model.layer4[0].downsample[1] = nn.GroupNorm(1, model.layer4[0].downsample[1].weight.shape[0])
#model.layer4[1].bn1 = nn.GroupNorm(1, model.layer4[1].bn1.weight.shape[0])
#model.layer4[1].bn2 = nn.GroupNorm(1, model.layer4[1].bn2.weight.shape[0])

#model.conv1 = nn.Conv2d(in_size, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
#model.fc = nn.Sequential(nn.Linear(512, num_classes))

device = "cuda" if torch.cuda.is_available() else 'cpu'

model = SimpleModel(in_size, num_classes)

model = model.to(device)

optimizer = optim.Adam(model.parameters())

In [12]:
# Train new model for 5 epochs
epoch_indices = []
for epoch in range(1, trainingepochs+1):
    starttime = time.process_time()
    # train(resnet, epoch, all_data_train_loader, returnable=False)
    batches = train(model, epoch, target_train_loader, returnable=True, keep_p=P)
    print(f"{batches} batches effected")
    epoch_indices.append(batches)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")
    print(f"Time taken: {time.process_time() - starttime}")

Epoch: 1 [ 53248]	Loss: 0.018577[1, 6, 26, 30, 37, 40, 57, 92, 93, 94, 100, 107, 108, 137, 138, 140, 162, 168, 170, 174, 194, 198, 199, 201, 203, 206, 218, 227, 229, 239, 242, 243, 244, 247, 265, 293, 336, 344, 346, 361, 384, 385, 392, 394, 395, 403, 404, 406, 415, 442, 456, 461, 484, 491, 496, 497, 498, 511, 523, 537, 540, 558, 570, 571, 576, 584, 594, 595, 608, 611, 614, 622, 634, 663, 672, 675, 681, 687, 690, 711, 715, 717, 719, 723, 732, 739, 750, 756, 771, 782, 791, 803, 816, 817, 826, 830, 838, 841] batches effected
All data: Mean loss: 0.0057, Accuracy: 8959/10000 (90%)
Forget  : Mean loss: 0.0041, Accuracy: 285/1010 (28%)
Retain  : Mean loss: 0.0016, Accuracy: 8674/8990 (96%)
Time taken: 23.541826562000004
Epoch: 2 [ 53248]	Loss: 0.127367[6, 8, 21, 37, 57, 59, 65, 66, 78, 81, 83, 89, 94, 100, 119, 128, 132, 151, 156, 160, 161, 162, 167, 173, 175, 176, 180, 185, 192, 195, 207, 226, 269, 272, 274, 280, 314, 315, 321, 331, 363, 367, 383, 399, 405, 407, 410, 415, 427, 436, 437, 442

In [13]:
path = F"selective_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [14]:
path = F"selective_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [15]:
for i in range(1, trainingepochs+1):
    for j in epoch_indices[i-1]:
        path = f"steps/e{i}b{j:04}.pkl"
        steps = torch.load(path)
        print(f"\rLoading steps/e{i}b{j:04}.pkl", end="")
        const = 1
        with torch.no_grad():
            for key, param in model.named_parameters():
                param.view(-1)[steps[key][1]] -= const * steps[key][0].to(device)

Loading steps/e4b0842.pkl

In [16]:
test(model, all_data_test_loader, dname="All data")
test(model, target_test_loader, dname="Forget  ")
test(model, nontarget_test_loader, dname="Retain  ")

All data: Mean loss: 0.0105, Accuracy: 8851/10000 (89%)
Forget  : Mean loss: 0.0097, Accuracy: 67/1010 (7%)
Retain  : Mean loss: 0.0010, Accuracy: 8784/8990 (98%)


tensor(0.977, device='cuda:0')

In [17]:
path = F"selective_post_trained.pt"
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path)

In [18]:
path = F"selective_post_trained.pt"
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [19]:
# Train model for 10 epochs
for epoch in range(trainingepochs+1,trainingepochs+forgetfulepochs+1):
  # train(resnet, epoch, nonthree_train_loader, returnable=False)
    _ = train(model, epoch, nontarget_train_loader, returnable=True)
    test(model, all_data_test_loader, dname="All data")
    test(model, target_test_loader, dname="Forget  ")
    test(model, nontarget_test_loader, dname="Retain  ")

Epoch: 5 [ 53248]	Loss: 0.014603All data: Mean loss: 0.0082, Accuracy: 8953/10000 (90%)
Forget  : Mean loss: 0.0074, Accuracy: 122/1010 (12%)
Retain  : Mean loss: 0.0008, Accuracy: 8831/8990 (98%)
Epoch: 6 [ 53248]	Loss: 0.002206All data: Mean loss: 0.0077, Accuracy: 9036/10000 (90%)
Forget  : Mean loss: 0.0071, Accuracy: 200/1010 (20%)
Retain  : Mean loss: 0.0007, Accuracy: 8836/8990 (98%)
Epoch: 7 [ 53248]	Loss: 0.032163All data: Mean loss: 0.0092, Accuracy: 8971/10000 (90%)
Forget  : Mean loss: 0.0087, Accuracy: 96/1010 (10%)
Retain  : Mean loss: 0.0006, Accuracy: 8875/8990 (99%)
Epoch: 8 [ 53248]	Loss: 0.017488All data: Mean loss: 0.0132, Accuracy: 8865/10000 (89%)
Forget  : Mean loss: 0.0126, Accuracy: 24/1010 (2%)
Retain  : Mean loss: 0.0007, Accuracy: 8841/8990 (98%)
