In [None]:
import numpy as np
import torch as torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from saga import SAGA
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import StepLR,LambdaLR
from helpers import *
import time
%matplotlib inline

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
MOMENTUM_SAGA = 0.1
MOMENTUM_SGD = 0.9
betas = (0.9,0.999)
N_SAMPLES = 0 #set to 0 to use all dataset
n_channels = 1 # 1 to flatten cifar10
sched_step = 10000 #for step scheduler
gamma = 0.1 #for step scheduler
weight_decay = 0.001
data = "mnist" #"mnist" or "cifar10"
X,y,X_test,y_test,IN_DIM,OUT_DIM = get_data(data,n_channels,N_SAMPLES)
model_str = "LR"
opti = "Adam"
lr = 0.00001
n_epochs = 1000
centered = "uncentered"
if (centered == "centered"):
    X = (X/127.5) - 1
else:
    X = X/255

if (N_SAMPLES == 0):
    N_SAMPLES = X.shape[0]

In [None]:
print(X.mean())
print(X.min())
print(X.max())
print(X.shape)

In [None]:
import pandas as pd
print(pd.Series(y).value_counts())

In [None]:
class_proba = get_class_proba(y)

In [None]:
class_proba

In [None]:
class LR(nn.Module):
    def __init__(self,IN_DIM,OUT_DIM):
        super(LR, self).__init__()
        self.linear = torch.nn.Linear(IN_DIM, OUT_DIM)
        
    def forward(self, x):
        return self.linear(x)

In [None]:
class NN(nn.Module):
    def __init__(self,in_dim,out_dim):
        super(NN, self).__init__()
        self.mlp = nn.Sequential(
        nn.Linear(in_dim,100),
        nn.ReLU(),
        nn.Linear(100,out_dim)
        )
    
    def forward(self,x):
        return self.mlp(x)

In [None]:
if (model_str == 'LR'):
    model = LR
else:
    model = NN

model_GD = model(IN_DIM,OUT_DIM).to(device)
model_SGD = model(IN_DIM,OUT_DIM).to(device)
model_SAGA_pc1 = model(IN_DIM,OUT_DIM).to(device)

model_GD_losses = []
model_SGD_losses = []
model_SAGA_pc1_losses = []

criterion = nn.CrossEntropyLoss()

if (opti == "Adam"):
    optimizer_GD = torch.optim.Adam(model_GD.parameters(), lr = lr,betas = betas,weight_decay = weight_decay)
    optimizer_SGD = torch.optim.Adam(model_SGD.parameters(), lr = lr,betas = betas,weight_decay = weight_decay)
    optimizer_SAGA_pc1 = SAGA(model_SAGA_pc1.parameters(),
                               n_classes=OUT_DIM, lr = lr,
                          class_proba = class_proba,momentum=MOMENTUM_SAGA,compute_var = True, betas = betas,
                             use_adam = True,weight_decay = weight_decay)       
else:
    optimizer_GD = torch.optim.SGD(model_GD.parameters(), lr = lr,weight_decay = weight_decay)
    optimizer_SGD = torch.optim.SGD(model_SGD.parameters(), lr = lr,weight_decay = weight_decay)
    optimizer_SAGA_pc1 = SAGA(model_SAGA_pc1.parameters(),
                              n_classes=OUT_DIM, lr = lr,
                              class_proba = None,momentum=MOMENTUM_SAGA,compute_var = True,
                            weight_decay = weight_decay)

lr_lambda = lambda epoch : 1/np.sqrt(epoch+1)
scheduler_SGD = LambdaLR(optimizer_SGD, lr_lambda = lr_lambda)
scheduler_SAGA_pc1 = LambdaLR(optimizer_SAGA_pc1, lr_lambda = lr_lambda)
        
# different scheduling        
# scheduler_GD = StepLR(optimizer_GD, step_size=sched_step, gamma=gamma)
# scheduler_SGD_MB = StepLR(optimizer_SGD_MB, step_size=sched_step, gamma=gamma)
# scheduler_SGD = StepLR(optimizer_SGD, step_size=sched_step, gamma=gamma)
# scheduler_SAGA_pc1 = StepLR(optimizer_SAGA_pc1, step_size=sched_step, gamma=gamma)

In [None]:
#train GD
tstart = time.process_time()
GD_avg_var = []
GD_var = 0
# for epoch in range(n_epochs):
#     inputs = torch.from_numpy(X).to(device)
#     labels = torch.tensor(y, dtype=torch.long).to(device)
#     outputs = model_GD.forward(inputs)
#     loss = criterion(outputs, labels)
#     loss.backward()
#     for param_group in list(model_GD.parameters()):
#         GD_var += (param_group.grad.data**2).sum()
#     optimizer_GD.step()
#     GD_avg_var.append(GD_var.cpu().numpy()/(epoch+1))
#     model_GD_losses.append(loss.data.item()) #true loss, not running average
#     optimizer_GD.zero_grad()
print('GD Elapsed time: {:.2f}s'.format(time.process_time() - tstart))

In [None]:
#train SGD
tstart = time.process_time()
SGD_avg_var = []
SGD_var = 0
total_loss = 0
for epoch in range(n_epochs):
    idx = np.random.randint(X.shape[0])
    inputs = torch.from_numpy(X)[idx].to(device)
    labels = torch.tensor(y, dtype=torch.long)[idx].view(1).to(device)
    outputs = model_SGD.forward(inputs).view(1,-1)
    loss = criterion(outputs, labels)
    loss.backward()
    total_loss += loss.data.item()
    model_SGD_losses.append(total_loss/(epoch+1))
    for param_group in list(model_SGD.parameters()):
        SGD_var += (param_group.grad.data**2).sum()
    SGD_avg_var.append(SGD_var/(epoch+1))
    optimizer_SGD.step()
    optimizer_SGD.zero_grad()
    if (epoch != 0 and (epoch % N_SAMPLES) == 0):
        scheduler_SGD.step()
print('SGD Elapsed time: {:.2f}s'.format(time.process_time() - tstart))

In [None]:
#train SAGApc1
tstart = time.process_time()
SAGA_pc1_avg_var = []
SAGA_pc1_var = 0
total_loss = 0
for epoch in range(n_epochs):
    idx = np.random.randint(X.shape[0])
    inputs = torch.from_numpy(X)[idx].to(device)
    labels = torch.tensor(y, dtype=torch.long)[idx].view(1).to(device)
    label = int(labels.item())
    outputs = model_SAGA_pc1.forward(inputs).view(1,-1)
    loss = criterion(outputs, labels)
    loss.backward()
    total_loss += loss.data.item()
    model_SAGA_pc1_losses.append(total_loss/(epoch+1))
    _, var = optimizer_SAGA_pc1.step(idx = label)
    SAGA_pc1_var += var
    SAGA_pc1_avg_var.append(SAGA_pc1_var/(epoch+1))
    optimizer_SAGA_pc1.zero_grad()
    if (epoch != 0 and (epoch % N_SAMPLES) == 0):
          scheduler_SAGA_pc1.step()
print('SAGApc1 Elapsed time: {:.2f}s'.format(time.process_time() - tstart))

In [None]:
#plot with hyperparam information
def plot_val(losses, labels, value = "Loss"):
    plt.figure(figsize=(15,5))
    for loss, label in zip(losses, labels):
        if len(loss) != 0:
            print(label,loss[-1])
        plt.plot(loss, label = label)
    plt.legend(loc='upper right')
    if (N_SAMPLES == 0):
        n_s = "all"
    else:
        n_s = N_SAMPLES
    plt.title('{} {} {} training {} SAGA ({} {}) (lr: {}, weight decay: {}, n_samples = {})'.format(model_str,
                    data,centered,value,opti,betas,lr,weight_decay, n_s))
    plt.xlabel('iteration')
    plt.ylabel(value)
    #plt.yscale('log')

In [None]:
plot_val([model_SGD_losses,
          model_SAGA_pc1_losses],
         ['SGD', 
         'SAGApc1 (gradient momentum : {})'.format(MOMENTUM_SAGA)])

In [None]:
#plot variance without hyperparam info
def plot_val(losses, labels,val = "Variance"):
    plt.figure(figsize=(6,6))
    for loss, label in zip(losses, labels):
        plt.plot(loss, label = label)
    plt.legend(loc='upper right')
    plt.xlabel('Iteration',fontsize = 25)
    plt.yticks(fontsize=13)
    plt.ylabel(val,fontsize = 25)
    plt.yscale('log')

In [None]:
plot_val([SGD_avg_var,SAGA_pc1_avg_var],
         ['SGD', 
         'SAGApc(1)'])