In [37]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import Subset
import math

In [38]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    "Defines a simple convolutional neural net"
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

        self.initialize_weights()
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            if isinstance(m, nn.Linear):
                nn.init.uniform_(m.weight, -1/math.sqrt(5), 1/math.sqrt(5))

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

In [39]:
# model = SimpleCNN()
# dataiter = iter(data_loader)
# input, target = dataiter.next()
# y = model(input)

In [40]:
def loss_fn(predictions, targets):
    return F.cross_entropy(predictions, targets)

In [41]:
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split

data = datasets.MNIST(
    root = '~/Documents/Code/Gradient_Matching/data',
    train = True,
    transform = ToTensor(), 
    download = True,            
)

test_data = datasets.MNIST(
    root = '~/Documents/Code/Gradient_Matching/data',
    train = False,
    transform = ToTensor(), 
    download = True,            
)

batch_size = 7000
# data = ConcatDataset([train_data, test_data])
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False)

# Split the indices in a stratified way
indices = np.arange(len(data))
train_indices, test_indices = train_test_split(indices, train_size=100*10, stratify=data.targets, random_state=42)

# Warp into Subsets and DataLoaders
train_subset = Subset(data, train_indices)
subset_loader = DataLoader(data, batch_size=1, shuffle=False)
subset_loader.dataset.targets.size()

torch.Size([60000])

In [42]:
def subset_forward_pass_grads(model, dataloader):
    "Returns the gradient of the loss wrt the final layer parameters after one forward pass"
    sample_grads = []
    for _, (input, target) in enumerate(dataloader):
        output = model(input)
        loss = loss_fn(output, target)

        grad = torch.autograd.grad(loss, model.fc2.weight, retain_graph=True)[0]
        sample_grads.append(grad)
#         print(len(sample_grads))

    for i, tensor in enumerate(sample_grads):
        sample_grads[i] = torch.flatten(tensor)

    return sample_grads


model = SimpleCNN()
sample_grads = subset_forward_pass_grads(model, subset_loader)
sample_grads[0].shape
len(sample_grads)
model

SimpleCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [43]:
gradients = []
for _, (input, target) in enumerate(data_loader):
    output = model(input)
    loss = loss_fn(output, target)
    grad_ = torch.autograd.grad(loss, model.fc2.weight, retain_graph=True)[0]
    gradients.append(grad_)

In [44]:
grad = torch.flatten(torch.mean(torch.stack(gradients), dim=0))

In [45]:
def gradient_closure(full_grad, subset_grad):
    "Returns the cosine distance between two gradient vectors calculated from different passes"
    cosine_distances = []
    cosine = nn.CosineSimilarity(dim=0)
    for gradient in subset_grad:
        dist = cosine(full_grad, gradient)
        cosine_distances.append(float(dist))
    return cosine_distances
list_g = gradient_closure(grad, sample_grads)

In [46]:
# list_g_indices = [(int(data.targets[i]), i) for i, index in enumerate(data.targets)]
list_g_indices = [int(data.targets[i]) for i, index in enumerate(data.targets)]

In [47]:
zipped_list = zip(list_g, list_g_indices, list(range(60000)))
zipped_list = sorted(zipped_list, reverse=True)
zipped_list

[(0.6034554243087769, 2, 18783),
 (0.6017210483551025, 2, 36715),
 (0.59885573387146, 2, 30355),
 (0.5978372097015381, 2, 32430),
 (0.5976244211196899, 2, 54837),
 (0.5898824334144592, 2, 34326),
 (0.5897219777107239, 2, 55157),
 (0.5890105366706848, 2, 12079),
 (0.5877698063850403, 2, 54831),
 (0.5877556800842285, 2, 48748),
 (0.587235152721405, 2, 12938),
 (0.5868067145347595, 2, 55203),
 (0.5863257050514221, 2, 33982),
 (0.5862870216369629, 2, 25138),
 (0.5860787034034729, 2, 42092),
 (0.5837225317955017, 2, 15016),
 (0.5825902223587036, 2, 23335),
 (0.5822906494140625, 6, 5629),
 (0.5820091366767883, 6, 8481),
 (0.5807057619094849, 2, 15603),
 (0.5806631445884705, 6, 1269),
 (0.5801284909248352, 6, 2631),
 (0.5800151824951172, 2, 15731),
 (0.5799592733383179, 2, 44092),
 (0.5799319744110107, 2, 59253),
 (0.5795027613639832, 2, 56083),
 (0.5794111490249634, 6, 6266),
 (0.5792816281318665, 6, 8146),
 (0.579276442527771, 6, 29048),
 (0.5790455937385559, 2, 54765),
 (0.5787364840507507

In [48]:
top_1000 = []
for label in range(10):
    dummy = 100
    for tup in zipped_list:
        if dummy == 0:
            break
        if tup[1] == label:
            top_1000.append(tup)
            dummy-=1
            
len(top_1000)

1000

In [49]:
top_1000 = sorted(top_1000, reverse=False)
top_1000

[(-7.134053885238245e-05, 3, 48732),
 (-5.458311352413148e-05, 3, 23796),
 (-5.110523125040345e-05, 3, 4238),
 (-4.797607471118681e-05, 3, 46908),
 (-4.709909626399167e-05, 3, 36337),
 (-4.441680357558653e-05, 3, 10155),
 (-4.2793886677827686e-05, 3, 42707),
 (-4.172565240878612e-05, 3, 29249),
 (-4.138366784900427e-05, 3, 26691),
 (-3.779450344154611e-05, 3, 2707),
 (-1.9501972928992473e-05, 3, 3815),
 (-1.908394187921658e-05, 3, 33694),
 (-1.8289138097316027e-05, 3, 17491),
 (-1.7579779523657635e-05, 3, 16502),
 (-1.6577705537201837e-05, 3, 55292),
 (-1.4900831956765614e-05, 3, 19793),
 (-1.0910996934399009e-05, 3, 57300),
 (-9.524201232125051e-06, 3, 58219),
 (-7.451016699633328e-06, 3, 2735),
 (-7.112022558430908e-06, 3, 55732),
 (-6.57173177387449e-06, 3, 34407),
 (-6.358486643875949e-06, 3, 32940),
 (-6.094162927183788e-06, 3, 46588),
 (-6.027310519129969e-06, 3, 11039),
 (-4.2071842472068965e-06, 3, 46729),
 (-4.184943918517092e-06, 3, 54377),
 (-4.181859367236029e-06, 3, 50565)

In [55]:
top_1000_indices = [i[2] for i in top_1000]
t1_subset = Subset(data, top_1000_indices)
t1_dataloader = DataLoader(t1_subset, batch_size=1000, shuffle=False)
t1_model = SimpleCNN()
t1_optimizer = torch.optim.Adam(t1_model.parameters(), lr=0.01)

full_trained_model = SimpleCNN()
full_trained_model_optimizer = torch.optim.Adam(full_trained_model.parameters(), lr=0.01)

random_model = SimpleCNN()
random_optimizer = torch.optim.Adam(random_model.parameters(), lr=0.01)

In [56]:
t1_model.train()
for epoch in range(50):
    for x, y in t1_dataloader:
        output = t1_model(x)
        loss = loss_fn(output, y)
        t1_optimizer.zero_grad()
        loss.backward()
    t1_optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 47.96760177612305
Epoch: 2, Loss: 61.86268615722656
Epoch: 3, Loss: 53.41510009765625
Epoch: 4, Loss: 34.731597900390625
Epoch: 5, Loss: 18.73115348815918
Epoch: 6, Loss: 8.452287673950195
Epoch: 7, Loss: 4.2869744300842285
Epoch: 8, Loss: 2.4124882221221924
Epoch: 9, Loss: 1.4991490840911865
Epoch: 10, Loss: 1.2039600610733032
Epoch: 11, Loss: 1.0722095966339111
Epoch: 12, Loss: 0.9388383030891418
Epoch: 13, Loss: 0.7879299521446228
Epoch: 14, Loss: 0.6380698084831238
Epoch: 15, Loss: 0.504909873008728
Epoch: 16, Loss: 0.39124754071235657
Epoch: 17, Loss: 0.29604047536849976
Epoch: 18, Loss: 0.21848855912685394
Epoch: 19, Loss: 0.1633404940366745
Epoch: 20, Loss: 0.12995800375938416
Epoch: 21, Loss: 0.1053372249007225
Epoch: 22, Loss: 0.08379978686571121
Epoch: 23, Loss: 0.06652709096670151
Epoch: 24, Loss: 0.05335800722241402
Epoch: 25, Loss: 0.04311821982264519
Epoch: 26, Loss: 0.034420620650053024
Epoch: 27, Loss: 0.025486920028924942
Epoch: 28, Loss: 0.016320912167

In [57]:
# Split the indices in a stratified way
indices = np.arange(len(data))
train_indices, test_indices = train_test_split(indices, train_size=100*10, stratify=data.targets, random_state=72)

# Warp into Subsets and DataLoaders
train_subset = Subset(data, train_indices)
subset_loader = DataLoader(train_subset, batch_size=1000, shuffle=False)


random_model.train()
for epoch in range(50):
    for x, y in subset_loader:
        output = random_model(x)
        loss = loss_fn(output, y)
        random_optimizer.zero_grad()
        loss.backward()
        random_optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 32.84511947631836
Epoch: 2, Loss: 74.23544311523438
Epoch: 3, Loss: 45.376441955566406
Epoch: 4, Loss: 14.558917045593262
Epoch: 5, Loss: 6.532426834106445
Epoch: 6, Loss: 3.12656831741333
Epoch: 7, Loss: 2.127545118331909
Epoch: 8, Loss: 1.9246206283569336
Epoch: 9, Loss: 1.8332191705703735
Epoch: 10, Loss: 1.7123725414276123
Epoch: 11, Loss: 1.547063946723938
Epoch: 12, Loss: 1.3582741022109985
Epoch: 13, Loss: 1.1908224821090698
Epoch: 14, Loss: 1.0171470642089844
Epoch: 15, Loss: 0.872646689414978
Epoch: 16, Loss: 0.733734667301178
Epoch: 17, Loss: 0.5950322151184082
Epoch: 18, Loss: 0.48759546875953674
Epoch: 19, Loss: 0.4251521825790405
Epoch: 20, Loss: 0.38743114471435547
Epoch: 21, Loss: 0.34929296374320984
Epoch: 22, Loss: 0.306177020072937
Epoch: 23, Loss: 0.26482847332954407
Epoch: 24, Loss: 0.2237488180398941
Epoch: 25, Loss: 0.18525531888008118
Epoch: 26, Loss: 0.15430320799350739
Epoch: 27, Loss: 0.1318640559911728
Epoch: 28, Loss: 0.11300668865442276
Epoc

In [53]:
full_trained_model.train()
for epoch in range(20):
    for x, y in data_loader:
        output = full_trained_model(x)
        loss = loss_fn(output, y)
        full_trained_model_optimizer.zero_grad()
        loss.backward()
        full_trained_model_optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 3.2623860836029053
Epoch: 2, Loss: 1.1148806810379028
Epoch: 3, Loss: 0.4795956611633301
Epoch: 4, Loss: 0.29435062408447266
Epoch: 5, Loss: 0.21368089318275452
Epoch: 6, Loss: 0.16567398607730865
Epoch: 7, Loss: 0.12689635157585144
Epoch: 8, Loss: 0.09860623627901077
Epoch: 9, Loss: 0.0794445350766182
Epoch: 10, Loss: 0.06694278120994568
Epoch: 11, Loss: 0.05684063211083412
Epoch: 12, Loss: 0.04848324880003929
Epoch: 13, Loss: 0.042625147849321365
Epoch: 14, Loss: 0.038683872669935226
Epoch: 15, Loss: 0.034155186265707016
Epoch: 16, Loss: 0.029371730983257294
Epoch: 17, Loss: 0.02551121823489666
Epoch: 18, Loss: 0.022589072585105896
Epoch: 19, Loss: 0.021332040429115295
Epoch: 20, Loss: 0.020627425983548164


In [58]:
test_dataloader = DataLoader(test_data, batch_size=1000, shuffle=False)
t1_correct = 0
full_correct = 0
random_correct = 0
total = 0

t1_correct_pred = {classname: 0 for classname in range(10)}
random_correct_pred = {classname: 0 for classname in range(10)}
full_correct_pred = {classname: 0 for classname in range(10)}
total_pred = {classname: 0 for classname in range(10)}


with torch.no_grad():
    for x, y in test_dataloader:
        t1_output = t1_model(x)
        full_model_output = full_trained_model(x)
        random_output = random_model(x)
        
        _, t1_predicted = torch.max(t1_output.data, 1)
        total += y.size(0)
        t1_correct += (t1_predicted == y).sum().item()
        
        _, random_predicted = torch.max(random_output.data, 1)
        random_correct += (random_predicted == y).sum().item()
        
        _, full_predicted = torch.max(full_model_output.data, 1)
        full_correct += (full_predicted == y).sum().item()
        
        for label, prediction in zip(y, t1_predicted):
            if label == prediction:
                t1_correct_pred[int(label)] += 1
            
        for label, prediction in zip(y, random_predicted):
            if label == prediction:
                random_correct_pred[int(label)] += 1
            
        for label, prediction in zip(y, full_predicted):
            if label == prediction:
                full_correct_pred[int(label)] += 1
                
            total_pred[int(label)] += 1
        
        

print(100*t1_correct/total)
print(100*random_correct/total)
print(100*full_correct/total)

77.32
90.65
10.13
