In [45]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import FashionMNIST, MNIST, CIFAR10, SVHN
import torchvision
from torchvision import transforms
import torchvision.utils as vision_utils
import matplotlib.pyplot as plt
import random
import os
import time
import math

DEVICE = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [46]:
def switch_to_device(dataset,device=None):
    final_X, final_Y = [], []
    for x, y in dataset:
        final_X.append(x)
        final_Y.append(y)
    X = torch.stack(final_X)
    Y = torch.tensor(final_Y)
    if device is not None:
        X = X.to(device)
        Y = Y.to(device)
    return torch.utils.data.TensorDataset(X, Y)


def get_mnist_dl(batch_size_train=256, batch_size_eval=256, device=torch.device('cpu')):
    transform = transforms.Compose([transforms.ToTensor()])
    
    data_train = MNIST('./datasets', train=True, download=True, transform=transform)
    data_train = switch_to_device(data_train, device=device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [55000,5000])
    
    data_test = MNIST('./datasets', train=False, download=True, transform=transform)
    data_test = switch_to_device(data_test, device=device)
    
    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_eval, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_eval, shuffle=False)
    
    return train_dl, valid_dl, test_dl

In [47]:
class MLP_Net(nn.Module):

  def __init__(self, num_classes=10) -> None:
    super().__init__()
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(28*28, 1024)
    self.Relu1 = nn.ReLU()
    self.fc2 = nn.Linear(1024, 1024)
    self.Relu2 = nn.ReLU()
    self.fc3 = nn.Linear(1024, 1024)
    self.Relu3 = nn.ReLU()
    self.fc4 = nn.Linear(1024, num_classes)
    #self.softmax = nn.Softmax()


  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.flatten(x)
    x = self.Relu1(self.fc1(x))
    x = self.Relu2(self.fc2(x))
    x = self.Relu3(self.fc3(x))
    x = self.fc4(x)
    
    return x

In [48]:
class CNN_Net(nn.Module):
    def __init__(self, num_classes=10) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 26, kernel_size=5, stride=1, padding = 0)
        self.maxpooling1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(26, 52, kernel_size=3, stride=1, padding = 0)

        self.conv3 = nn.Conv2d(52, 10, kernel_size=1, stride=1, padding=0)
        self.maxpooling3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc_1 = nn.Linear(5*5*10, 1000)
        self.fc_2 = nn.Linear(1000, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.maxpooling1(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.maxpooling3(F.relu(self.conv3(x)))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

In [49]:
def print_stats(stats):

  fig, (ax1, ax2) = plt.subplots(1,2,figsize=(7,3), dpi=110)
  ax1.grid()
  ax2.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Valid Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")

  itrs = [x[0] for x in stats['train-loss']]
  loss = [x[1] for x in stats['train-loss']]
  ax1.plot(itrs, loss)

  itrs = [x[0] for x in stats['valid-acc']]
  acc = [x[1] for x in stats['valid-acc']]
  ax2.plot(itrs, acc)

  ax1.set_ylim(0.0, max(loss))
  ax2.set_ylim(0.0, 1.05)


@torch.no_grad()
def get_acc(model, dl):
  model.eval()
  acc = []
  for X, y in dl:
    #acc.append((torch.sigmoid(model(X)) > 0.5) == y)
    acc.append(torch.argmax(model(X), dim=1) == y)
  acc = torch.cat(acc)
  acc = torch.sum(acc)/len(acc)
  model.train()
  return acc.item()

In [50]:
def buffer_decay(gradient_buffer, decay_rate = 0.9):
    new_buffer = {}
    for key in gradient_buffer:
        new_buffer[key*decay_rate] = gradient_buffer[key]
    del gradient_buffer
    return new_buffer

In [51]:
def mean_aggregation(gradient_buffer, cur_gradient):
    res = {}
    num = len(gradient_buffer.keys())
    if num == 0:
        return cur_gradient
    for key in gradient_buffer.keys():
        for i in gradient_buffer[key].keys():
            if i not in res.keys():
                res[i] = gradient_buffer[key][i]
            else:
                res[i] += gradient_buffer[key][i]
    
    for i in cur_gradient.keys():
        res[i] += cur_gradient[i]
        res[i] /= (num + 1)
    
    return res

In [52]:
def sum_aggregation(gradient_buffer, cur_gradient):
    res = {}
    num = len(gradient_buffer.keys())
    if num == 0:
        return cur_gradient
    for key in gradient_buffer.keys():
        for i in gradient_buffer[key].keys():
            if i not in res.keys():
                res[i] = gradient_buffer[key][i]
            else:
                res[i] += gradient_buffer[key][i]
    
    for i in cur_gradient.keys():
        res[i] /= num
        res[i] += cur_gradient[i]
    
    return res

In [53]:
def run_experiment(model, opt, train_dl, valid_dl, test_dl, criterion, max_epochs=20, use_forward_grad=False, num_forward_grad=1, use_memory_augmented=False, buffer_capacity=5, decay_rate = 0.9, aggregate_method = "mean"):
    itr = -1
    stats = {'train-loss' : [], 'valid-acc' : []}
    if use_memory_augmented:
        gradient_buffer = {}
    random_dir = {}
    for i, p in enumerate(model.parameters()):
        random_dir[i] = 0
    
    for epoch in range(max_epochs):
        for x, y in train_dl:
            itr += 1
            '''
            if itr != 0 and itr % 2000 ==0 :
                if buffer_capacity < 400:
                    buffer_capacity += 20
            '''
            opt.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            if use_memory_augmented:
                norm = 0
            if use_forward_grad:
                temp_grad = {}
                with torch.no_grad():
                    da = torch.zeros((num_forward_grad, 1), device = DEVICE)

                    for i, p in enumerate(model.parameters()):
                        g = p.grad.view(-1)
                        v = torch.randn(num_forward_grad, len(g), device=DEVICE)
                        random_dir[i] = v
                        da = da + (v @ g).view(num_forward_grad, 1)

                    for i, p in enumerate(model.parameters()):
                        g = (da * random_dir[i]).mean(dim = 0)
                        if use_memory_augmented:
                            norm += torch.norm(g)**2
                        temp_grad[i] = g.view(p.grad.shape)

                    if use_memory_augmented:
                        aggregate_gradient = {}
                        if aggregate_method == "mean":
                            estimated_gradient = mean_aggregation(gradient_buffer, temp_grad)
                        elif aggregate_method == "sum":
                            estimated_gradient = sum_aggregation(gradient_buffer, temp_grad)
                        gradient_buffer = buffer_decay(gradient_buffer, decay_rate)
                        if len(gradient_buffer.keys()) < buffer_capacity:
                            gradient_buffer[norm] = temp_grad
                        else:
                            min_norm = min(gradient_buffer.keys())
                            if min_norm <= norm:
                                del gradient_buffer[min_norm]
                                gradient_buffer[norm] = temp_grad
                    
                    else:
                        estimated_gradient = temp_grad
                    
                    for i, p in enumerate(model.parameters()):
                        p.grad = estimated_gradient[i].view(p.grad.shape)
            
            opt.step()
            stats['train-loss'].append((itr, loss.item()))

            if itr % 100 == 0:
                valid_acc = get_acc(model, valid_dl)
                stats['valid-acc'].append((itr, valid_acc))
                s = f"{epoch}:{itr} [train] loss:{loss.item():.3f}, [valid] acc:{valid_acc:.3f}"
                print(s)
    
    test_acc = get_acc(model, test_dl)
    print(f"[test] acc:{test_acc:.3f}")

    return stats
            

In [54]:
model = MLP_Net().to(DEVICE)
opt = torch.optim.SGD(model.parameters(), lr = 1e-4)
criterion =nn.CrossEntropyLoss()
max_epochs = 300
use_forward_grad = False
num_forward_grad = 1
use_memory_augmented = False
buffer_capacity = 5
decay_rate = 0.9
aggregate_method = 'mean'

In [55]:
train_dl, valid_dl, test_dl = get_mnist_dl(device=DEVICE)

for p in model.parameters():
    g = p.view(-1)
    v = torch.normal(mean = torch.full((1, len(g)), 0.), std = torch.full((1, len(g)), 0.1)).to(DEVICE)
    p.data = v.view(p.shape)

stats = run_experiment(model, opt, train_dl, valid_dl, test_dl, criterion, max_epochs, use_forward_grad, num_forward_grad, use_memory_augmented, buffer_capacity, decay_rate, aggregate_method)
print_stats(stats)


0:0 [train] loss:17.984, [valid] acc:0.065
0:100 [train] loss:7.241, [valid] acc:0.127
0:200 [train] loss:4.751, [valid] acc:0.249
1:300 [train] loss:4.002, [valid] acc:0.378
1:400 [train] loss:3.374, [valid] acc:0.475
2:500 [train] loss:2.364, [valid] acc:0.537
2:600 [train] loss:2.389, [valid] acc:0.579
3:700 [train] loss:1.888, [valid] acc:0.614
3:800 [train] loss:2.011, [valid] acc:0.638
4:900 [train] loss:1.675, [valid] acc:0.655


KeyboardInterrupt: 