In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import Subset
import math

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    "Defines a simple convolutional neural net"
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

        self.initialize_weights()
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            if isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, -1/math.sqrt(5), 1/math.sqrt(5))

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

In [3]:
# model = SimpleCNN()
# dataiter = iter(data_loader)
# input, target = dataiter.next()
# y = model(input)

In [4]:
def loss_fn(predictions, targets):
    return F.cross_entropy(predictions, targets)

In [5]:
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split

data = datasets.MNIST(
    root = '~/Documents/Code/Gradient_Matching/data',
    train = True,
    transform = ToTensor(), 
    download = True,            
)

test_data = datasets.MNIST(
    root = '~/Documents/Code/Gradient_Matching/data',
    train = False,
    transform = ToTensor(), 
    download = True,            
)

batch_size = 7000
# data = ConcatDataset([train_data, test_data])
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False)

# Split the indices in a stratified way
indices = np.arange(len(data))
train_indices, test_indices = train_test_split(indices, train_size=100*10, stratify=data.targets, random_state=42)

# Warp into Subsets and DataLoaders
train_subset = Subset(data, train_indices)
subset_loader = DataLoader(data, batch_size=1, shuffle=False)

In [6]:
def subset_forward_pass_grads(model, dataloader):
    "Returns the gradient of the loss wrt the final layer parameters after one forward pass"
    sample_grads = []
    for _, (input, target) in enumerate(dataloader):
        output = model(input)
        loss = loss_fn(output, target)

        grad = torch.autograd.grad(loss, model.fc2.weight, retain_graph=True)[0]
        sample_grads.append(grad)
#         print(len(sample_grads))

    for i, tensor in enumerate(sample_grads):
        sample_grads[i] = torch.flatten(tensor)

    return sample_grads


model = SimpleCNN()
sample_grads = subset_forward_pass_grads(model, subset_loader)
sample_grads[0].shape
len(sample_grads)

60000

In [7]:
gradients = []
for _, (input, target) in enumerate(data_loader):
    output = model(input)
    loss = loss_fn(output, target)
    grad_ = torch.autograd.grad(loss, model.fc2.weight, retain_graph=True)[0]
    gradients.append(grad_)

In [8]:
grad = torch.flatten(torch.mean(torch.stack(gradients), dim=0))


In [9]:
def gradient_closure(full_grad, subset_grad):
    "Returns the cosine distance between two gradient vectors calculated from different passes"
    cosine_distances = []
    cosine = nn.CosineSimilarity(dim=0)
    for gradient in subset_grad:
        dist = cosine(full_grad, gradient)
        cosine_distances.append(dist)
    return cosine_distances
list_g = gradient_closure(grad, sample_grads)

In [10]:
list_g = [float(element) for element in list_g]


In [11]:
# list_g_indices = [(int(data.targets[i]), i) for i, index in enumerate(data.targets)]
list_g_indices = [int(data.targets[i]) for i, index in enumerate(data.targets)]

In [12]:
zipped_list = zip(list_g, list_g_indices, list(range(60000)))
zipped_list = sorted(zipped_list, reverse=True)
zipped_list

[(0.7002543210983276, 8, 37198),
 (0.6954518556594849, 8, 13498),
 (0.6942363381385803, 8, 19291),
 (0.6932784914970398, 8, 17244),
 (0.6899743676185608, 8, 47136),
 (0.6894584894180298, 8, 28381),
 (0.6887336373329163, 8, 1207),
 (0.6887250542640686, 8, 14173),
 (0.6882762312889099, 8, 8465),
 (0.6864116787910461, 8, 22001),
 (0.6862618923187256, 8, 52983),
 (0.6862281560897827, 8, 12963),
 (0.6862051486968994, 8, 36299),
 (0.6857589483261108, 8, 35764),
 (0.6854323148727417, 8, 59402),
 (0.6849002838134766, 8, 14253),
 (0.6848194599151611, 8, 33089),
 (0.6846453547477722, 8, 5409),
 (0.684518039226532, 8, 44458),
 (0.6840699911117554, 2, 38195),
 (0.6837885975837708, 8, 6111),
 (0.6836216449737549, 7, 17044),
 (0.683563768863678, 8, 33205),
 (0.6832146048545837, 8, 48092),
 (0.6829837560653687, 8, 5977),
 (0.6825869679450989, 8, 36199),
 (0.6824616193771362, 0, 54711),
 (0.6820456385612488, 8, 8285),
 (0.6819006204605103, 8, 15861),
 (0.6817883849143982, 9, 8669),
 (0.681686580181121

In [13]:
top_1000 = []
for label in range(10):
    dummy = 100
    for tup in zipped_list:
        if dummy == 0:
            break
        if tup[1] == label:
            top_1000.append(tup)
            dummy-=1
            
len(top_1000)

1000

In [23]:
top_1000 = sorted(top_1000, reverse=True)
top_1000

[(0.7002543210983276, 8, 37198),
 (0.6954518556594849, 8, 13498),
 (0.6942363381385803, 8, 19291),
 (0.6932784914970398, 8, 17244),
 (0.6899743676185608, 8, 47136),
 (0.6894584894180298, 8, 28381),
 (0.6887336373329163, 8, 1207),
 (0.6887250542640686, 8, 14173),
 (0.6882762312889099, 8, 8465),
 (0.6864116787910461, 8, 22001),
 (0.6862618923187256, 8, 52983),
 (0.6862281560897827, 8, 12963),
 (0.6862051486968994, 8, 36299),
 (0.6857589483261108, 8, 35764),
 (0.6854323148727417, 8, 59402),
 (0.6849002838134766, 8, 14253),
 (0.6848194599151611, 8, 33089),
 (0.6846453547477722, 8, 5409),
 (0.684518039226532, 8, 44458),
 (0.6840699911117554, 2, 38195),
 (0.6837885975837708, 8, 6111),
 (0.6836216449737549, 7, 17044),
 (0.683563768863678, 8, 33205),
 (0.6832146048545837, 8, 48092),
 (0.6829837560653687, 8, 5977),
 (0.6825869679450989, 8, 36199),
 (0.6824616193771362, 0, 54711),
 (0.6820456385612488, 8, 8285),
 (0.6819006204605103, 8, 15861),
 (0.6817883849143982, 9, 8669),
 (0.681686580181121

In [15]:
top_1000_indices = [i[2] for i in top_1000]
t1_subset = Subset(data, top_1000_indices)
t1_dataloader = DataLoader(t1_subset, batch_size=1000, shuffle=False)
t1_model = SimpleCNN()
t1_optimizer = torch.optim.Adam(t1_model.parameters(), lr=0.01)

full_trained_model = SimpleCNN()
full_trained_model_optimizer = torch.optim.Adam(full_trained_model.parameters(), lr=0.01)

random_model = SimpleCNN()
random_optimizer = torch.optim.Adam(random_model.parameters(), lr=0.01)

In [24]:
t1_model.train()
for epoch in range(50):
    for x, y in t1_dataloader:
        output = t1_model(x)
        loss = loss_fn(output, y)
        t1_optimizer.zero_grad()
        loss.backward()
        t1_optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 0.0025904024951159954
Epoch: 2, Loss: 0.0022600123193114996
Epoch: 3, Loss: 0.001979546621441841
Epoch: 4, Loss: 0.0017336038872599602
Epoch: 5, Loss: 0.0015178925823420286
Epoch: 6, Loss: 0.0013332723174244165
Epoch: 7, Loss: 0.001176523743197322
Epoch: 8, Loss: 0.0010444113286212087
Epoch: 9, Loss: 0.0009335807408206165
Epoch: 10, Loss: 0.0008380103972740471
Epoch: 11, Loss: 0.000754927983507514
Epoch: 12, Loss: 0.000682217301800847
Epoch: 13, Loss: 0.0006180554046295583
Epoch: 14, Loss: 0.0005619860603474081
Epoch: 15, Loss: 0.0005132054211571813
Epoch: 16, Loss: 0.0004715305403806269
Epoch: 17, Loss: 0.0004363200278021395
Epoch: 18, Loss: 0.0004065672983415425
Epoch: 19, Loss: 0.00038129647145979106
Epoch: 20, Loss: 0.0003595245652832091
Epoch: 21, Loss: 0.00034055387368425727
Epoch: 22, Loss: 0.0003237771161366254
Epoch: 23, Loss: 0.00030868645990267396
Epoch: 24, Loss: 0.0002948575420305133
Epoch: 25, Loss: 0.00028204568661749363
Epoch: 26, Loss: 0.000270146731054

In [26]:
# Split the indices in a stratified way
indices = np.arange(len(data))
train_indices, test_indices = train_test_split(indices, train_size=100*10, stratify=data.targets, random_state=72)

# Warp into Subsets and DataLoaders
train_subset = Subset(data, train_indices)
subset_loader = DataLoader(train_subset, batch_size=1000, shuffle=False)


random_model.train()
for epoch in range(50):
    for x, y in subset_loader:
        output = random_model(x)
        loss = loss_fn(output, y)
        random_optimizer.zero_grad()
        loss.backward()
        random_optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 0.5406907796859741
Epoch: 2, Loss: 0.430744469165802
Epoch: 3, Loss: 0.33654481172561646
Epoch: 4, Loss: 0.2560882568359375
Epoch: 5, Loss: 0.1900797337293625
Epoch: 6, Loss: 0.14614783227443695
Epoch: 7, Loss: 0.11149246245622635
Epoch: 8, Loss: 0.08146645128726959
Epoch: 9, Loss: 0.05801115557551384
Epoch: 10, Loss: 0.0426979586482048
Epoch: 11, Loss: 0.03329777345061302
Epoch: 12, Loss: 0.026777150109410286
Epoch: 13, Loss: 0.02104392647743225
Epoch: 14, Loss: 0.015658466145396233
Epoch: 15, Loss: 0.011505262926220894
Epoch: 16, Loss: 0.009242282249033451
Epoch: 17, Loss: 0.008057740516960621
Epoch: 18, Loss: 0.0071452646516263485
Epoch: 19, Loss: 0.006221673917025328
Epoch: 20, Loss: 0.005316763650625944
Epoch: 21, Loss: 0.004465391859412193
Epoch: 22, Loss: 0.003771273884922266
Epoch: 23, Loss: 0.003265250939875841
Epoch: 24, Loss: 0.002923576394096017
Epoch: 25, Loss: 0.0026820520870387554
Epoch: 26, Loss: 0.002493408275768161
Epoch: 27, Loss: 0.002326343208551407

In [18]:
full_trained_model.train()
for epoch in range(20):
    for x, y in data_loader:
        output = full_trained_model(x)
        loss = loss_fn(output, y)
        full_trained_model_optimizer.zero_grad()
        loss.backward()
        full_trained_model_optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 1.6675498485565186
Epoch: 2, Loss: 0.3777286410331726
Epoch: 3, Loss: 0.18923455476760864
Epoch: 4, Loss: 0.12296568602323532
Epoch: 5, Loss: 0.09009825438261032
Epoch: 6, Loss: 0.06869860738515854
Epoch: 7, Loss: 0.05477095767855644
Epoch: 8, Loss: 0.044096823781728745
Epoch: 9, Loss: 0.03593887761235237
Epoch: 10, Loss: 0.028850944712758064
Epoch: 11, Loss: 0.022880660369992256
Epoch: 12, Loss: 0.01755237579345703
Epoch: 13, Loss: 0.013548647053539753
Epoch: 14, Loss: 0.010129846632480621
Epoch: 15, Loss: 0.007499981671571732
Epoch: 16, Loss: 0.005657300818711519
Epoch: 17, Loss: 0.004198737908154726
Epoch: 18, Loss: 0.0033829226158559322
Epoch: 19, Loss: 0.003674417734146118
Epoch: 20, Loss: 0.006961305160075426


In [27]:
test_dataloader = DataLoader(test_data, batch_size=1000, shuffle=False)
t1_correct = 0
full_correct = 0
random_correct = 0
total = 0

t1_correct_pred = {classname: 0 for classname in range(10)}
random_correct_pred = {classname: 0 for classname in range(10)}
full_correct_pred = {classname: 0 for classname in range(10)}
total_pred = {classname: 0 for classname in range(10)}


with torch.no_grad():
    for x, y in test_dataloader:
        t1_output = t1_model(x)
        full_model_output = full_trained_model(x)
        random_output = random_model(x)
        
        _, t1_predicted = torch.max(t1_output.data, 1)
        total += y.size(0)
        t1_correct += (t1_predicted == y).sum().item()
        
        _, random_predicted = torch.max(random_output.data, 1)
        random_correct += (random_predicted == y).sum().item()
        
        _, full_predicted = torch.max(full_model_output.data, 1)
        full_correct += (full_predicted == y).sum().item()
        
        for label, prediction in zip(y, t1_predicted):
            if label == prediction:
                t1_correct_pred[int(label)] += 1
            
        for label, prediction in zip(y, random_predicted):
            if label == prediction:
                random_correct_pred[int(label)] += 1
            
        for label, prediction in zip(y, full_predicted):
            if label == prediction:
                full_correct_pred[int(label)] += 1
                
            total_pred[int(label)] += 1
        
        

print(100*t1_correct/total)
print(100*random_correct/total)
print(100*full_correct/total)

74.56
93.8
97.9
