### This notebook is another version of the Memory augmented optimizer, instead of store the estimate gradient, we can store the random key that generate the tangent vector to save the memory consumption

In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import FashionMNIST, MNIST, CIFAR10, SVHN
import torchvision
from torchvision import transforms
import torchvision.utils as vision_utils
import matplotlib.pyplot as plt
import random
import os
import time
import math

DEVICE = torch.device('cuda')
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [13]:
def switch_to_device(dataset,device=None):
    final_X, final_Y = [], []
    for x, y in dataset:
        final_X.append(x)
        final_Y.append(y)
    X = torch.stack(final_X)
    Y = torch.tensor(final_Y)
    if device is not None:
        X = X.to(device)
        Y = Y.to(device)
    return torch.utils.data.TensorDataset(X, Y)


def get_mnist_dl(batch_size_train=256, batch_size_eval=256, device=torch.device('cpu')):
    transform = transforms.Compose([transforms.ToTensor()])
    
    data_train = MNIST('./datasets', train=True, download=True, transform=transform)
    data_train = switch_to_device(data_train, device=device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [55000,5000])
    
    data_test = MNIST('./datasets', train=False, download=True, transform=transform)
    data_test = switch_to_device(data_test, device=device)
    
    train_dl = DataLoader(data_train, batch_size=batch_size_train, shuffle=True)
    valid_dl = DataLoader(data_valid, batch_size=batch_size_eval, shuffle=False)
    test_dl = DataLoader(data_test, batch_size=batch_size_eval, shuffle=False)
    
    return train_dl, valid_dl, test_dl

In [14]:
class MLP_Net(nn.Module):

  def __init__(self, num_classes=10) -> None:
    super().__init__()
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(28*28, 1024)
    #self.Relu1 = nn.ReLU()
    self.Relu1 = nn.Hardtanh()
    self.fc2 = nn.Linear(1024, 1024)
    #self.Relu2 = nn.ReLU()
    self.Relu2 = nn.Hardtanh()
    self.fc3 = nn.Linear(1024, 1024)
    #self.Relu3 = nn.ReLU()
    self.Relu3 = nn.Hardtanh()
    self.fc4 = nn.Linear(1024, num_classes)
    self.softmax = nn.Softmax(dim = -1)


  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.flatten(x)
    x = self.Relu1(self.fc1(x))
    x.data = x.data.sign()
    x = self.Relu2(self.fc2(x))
    x.data = x.data.sign()
    x = self.Relu3(self.fc3(x))
    x.data = x.data.sign()
    x = self.fc4(x)
    
    return x

In [15]:
class CNN_Net(nn.Module):
    def __init__(self, num_classes=10) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 26, kernel_size=5, stride=1, padding = 0)
        self.maxpooling1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(26, 52, kernel_size=3, stride=1, padding = 0)

        self.conv3 = nn.Conv2d(52, 10, kernel_size=1, stride=1, padding=0)
        self.maxpooling3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc_1 = nn.Linear(5*5*10, 1024)
        self.fc_2 = nn.Linear(1024, num_classes)
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.maxpooling1(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.maxpooling3(F.relu(self.conv3(x)))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc_1(x))
        x = self.fc_2(x)
        return x

In [16]:
def print_stats(stats):

  fig, (ax1, ax2) = plt.subplots(1,2,figsize=(7,3), dpi=110)
  ax1.grid()
  ax2.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Valid Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")

  itrs = [x[0] for x in stats['train-loss']]
  loss = [x[1] for x in stats['train-loss']]
  ax1.plot(itrs, loss)

  itrs = [x[0] for x in stats['valid-acc']]
  acc = [x[1] for x in stats['valid-acc']]
  ax2.plot(itrs, acc)

  ax1.set_ylim(0.0, max(loss))
  ax2.set_ylim(0.0, 1.05)


@torch.no_grad()
def get_acc(model, dl):
  model.eval()
  acc = []
  for X, y in dl:
    #acc.append((torch.sigmoid(model(X)) > 0.5) == y)
    acc.append(torch.argmax(model(X), dim=1) == y)
  acc = torch.cat(acc)
  acc = torch.sum(acc)/len(acc)
  model.train()
  return acc.item()

In [17]:
def buffer_decay(vector_buffer, decay_rate = 0.9):
    new_buffer = {}
    for key in vector_buffer:
        new_buffer[key*decay_rate] = vector_buffer[key]
    del vector_buffer
    return new_buffer

In [18]:
def mean_aggregation(vector_buffer, cur_vector, dir_align):
    res = {}
    num = len(vector_buffer.keys())
    if num == 0:
        return cur_vector
    
    if dir_align:
        dir_dict = torch.ones(len(vector_buffer), device = DEVICE)
        
        for j, key in enumerate(vector_buffer.keys()):
            similarity = 0
            for i in vector_buffer[key].keys():
                similarity += torch.sum(vector_buffer[key][i] * cur_vector[i])

            if similarity >= 0 :
                dir_dict[j] = 1
            else:
                dir_dict[j] = -1


        for j, key in enumerate(vector_buffer.keys()):
            for i in vector_buffer[key].keys():
                if i not in res.keys():
                    res[i] = dir_dict[j]*vector_buffer[key][i]
                else:
                    res[i] += (dir_dict[j]*vector_buffer[key][i])


    else:
        for j, key in enumerate(vector_buffer.keys()):
            for i in vector_buffer[key].keys():
                if i not in res.keys():
                    res[i] = vector_buffer[key][i]
                else:
                    res[i] += 1*vector_buffer[key][i]
                    #print(vector_buffer[key][i] - 1*vector_buffer[key][i])
    
    for i in cur_vector.keys():
        res[i] += cur_vector[i]
        res[i] /= (num + 1)
    
    return res

In [19]:
def sum_aggregation(vector_buffer, cur_vector, dir_align):
    res = {}
    num = len(vector_buffer.keys())
    if num == 0:
        return cur_vector
        
    if dir_align:
        dir = torch.zeros(len(vector_buffer), device = DEVICE)
        for j, key in enumerate(vector_buffer.keys()):
            similarity = 0
            for i in vector_buffer[key].keys():
                similarity += torch.sum(vector_buffer[key][i] * cur_vector[i])
            if similarity >= 0 :
                dir[j] = 1
            else:
                dir[j] = -1

        for j, key in enumerate(vector_buffer.keys()):
            for i in vector_buffer[key].keys():
                if i not in res.keys():
                    res[i] = dir[j]*vector_buffer[key][i]
                else:
                    res[i] += (dir[j]*vector_buffer[key][i])
    
    else:
        for key in vector_buffer.keys():
            for i in vector_buffer[key].keys():
                if i not in res.keys():
                    res[i] = vector_buffer[key][i]
                else:
                    res[i] += vector_buffer[key][i]
    
    for i in cur_vector.keys():
        '''
        res[i] += cur_vector[i]
        res[i] /= (num + 1)
        '''
        res[i] /= num
        res[i] += cur_vector[i]
    
    return res

In [20]:
def run_experiment(model, opt, criterion, train_dl, valid_dl, test_dl, max_epochs=20, use_forward_grad=False, num_forward_grad=1, use_memory_augmented=False, buffer_capacity=5, decay_rate = 0.9, aggregate_method = "mean", dir_align = False):
    itr = -1
    stats = {'train-loss' : [], 'valid-acc' : []}
    if use_memory_augmented:
        vector_buffer = {}
    random_dir = {}
    for i, p in enumerate(model.parameters()):
        random_dir[i] = 0
    
    for epoch in range(max_epochs):
        for x, y in train_dl:
            itr += 1
            opt.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            
            if use_memory_augmented:
                norm = 0
            
            if use_forward_grad:
                with torch.no_grad():
                    da = torch.zeros((num_forward_grad, 1), device = DEVICE)

                    for i, p in enumerate(model.parameters()):
                        g = p.grad.view(-1)
                        v = torch.randn(num_forward_grad, len(g), device = DEVICE)
                        random_dir[i] = v
                        da = da + (v @ g).view(num_forward_grad, 1)
                    
                    if use_memory_augmented:
                        for i, p in enumerate(model.parameters()):
                            g = (da * random_dir[i]).mean(dim = 0)
                            norm += torch.norm(g)**2
                        aggregate_vector = {} 
                        if aggregate_method == "mean":
                            estimated_dir = mean_aggregation(vector_buffer, random_dir, dir_align)
                        elif aggregate_method == "sum":
                            estimated_dir = sum_aggregation(vector_buffer, random_dir, dir_align)
                        vector_buffer = buffer_decay(vector_buffer, decay_rate)
                        if len(vector_buffer.keys()) < buffer_capacity:
                            vector_buffer[norm] = random_dir
                        else:
                            min_norm = min(vector_buffer.keys())
                            if min_norm <= norm:
                                del vector_buffer[min_norm]
                                vector_buffer[norm] = random_dir

                    else:
                        estimated_dir = random_dir
                    
                    vector_norm = 0
                    dim = 0
                    for i, p in enumerate(model.parameters()):
                        vector_norm += torch.norm(estimated_dir[i])**2
                        dim += len(estimated_dir[i][0])

                    vector_norm = vector_norm**0.5

                    da = torch.zeros((num_forward_grad, 1), device = DEVICE)
                    for i, p in enumerate(model.parameters()):
                        g = p.grad.view(-1)
                        v = estimated_dir[i]#/vector_norm
                        da = da + (v @ g).view(num_forward_grad, 1)

                    for i, p in enumerate(model.parameters()):
                        #g = (da * (estimated_dir[i]/vector_norm)).mean(dim = 0)
                        #g = g*dim
                        g = (da * (estimated_dir[i])).mean(dim = 0)
                        p.grad = g.view(p.grad.shape)
            opt.step()

            stats['train-loss'].append((itr, loss.item()))

            if itr % 100 == 0:
                valid_acc = get_acc(model, valid_dl)
                stats['valid-acc'].append((itr, valid_acc))
                s = f"{epoch}:{itr} [train] loss:{loss.item():.3f}, [valid] acc:{valid_acc:.3f}"
                print(s)
    
    test_acc = get_acc(model, test_dl)
    print(f"[test] acc:{test_acc:.3f}")

    return stats
                    


In [21]:
model = CNN_Net().to(DEVICE)
opt = torch.optim.SGD(model.parameters(), lr=5e-4)
max_epochs = 100
criterian = nn.CrossEntropyLoss()
use_forward_grad = True
num_forward_grad = 1
use_memory_augmented = True
buffer_capacity = 5
decay_rate = 0.9
aggregate_method = 'mean'
dir_align = False

In [22]:
train_dl, valid_dl, test_dl = get_mnist_dl(device=DEVICE)

for p in model.parameters():
    g = p.view(-1)
    v = torch.normal(mean = torch.full((1, len(g)), 0.), std = torch.full((1, len(g)), 0.1)).to(DEVICE)
    p.data = v.view(p.shape)

stats = run_experiment(model, opt, criterian, train_dl, valid_dl, test_dl, max_epochs, use_forward_grad, num_forward_grad, use_memory_augmented, buffer_capacity, decay_rate, aggregate_method, dir_align)
print_stats(stats)


0:0 [train] loss:2.380, [valid] acc:0.127
0:100 [train] loss:nan, [valid] acc:0.097
0:200 [train] loss:nan, [valid] acc:0.097
1:300 [train] loss:nan, [valid] acc:0.097
1:400 [train] loss:nan, [valid] acc:0.097
2:500 [train] loss:nan, [valid] acc:0.097
2:600 [train] loss:nan, [valid] acc:0.097
3:700 [train] loss:nan, [valid] acc:0.097
3:800 [train] loss:nan, [valid] acc:0.097


KeyboardInterrupt: 