In [None]:
import copy
import matplotlib.pyplot as plt
import torch
import torchvision
import helpers

In [None]:
# Loads the MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data', train=True, download=False, 
                                   transform=torchvision.transforms.ToTensor)
X = mnist.data.flatten(1)/255.0
y = mnist.targets
N = X.shape[0]
d_in = 28*28
d_out = 10

In [None]:
# defines the model
torch.manual_seed(0)
class Softmax_regression(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(d_in, d_out, bias=True)
        
        self.layers = [self.lin]

        self.input_1 = None
        self.output_1 = None
        self.input_2 = None

        self.inputs = list()
        self.outputs_grad = list()

    def hook(self, grad):
        self.outputs_grad += [grad]

    def forward(self, X):
        self.outputs_grad.clear()
        self.inputs.clear()

        self.input_1 = X
        self.inputs.append(self.input_1)

        self.output_1 = self.lin(self.input_1)
        if self.output_1.requires_grad:
            self.output_1.register_hook(self.hook)

        return torch.nn.functional.log_softmax(
            torch.cat([self.output_1, torch.ones_like(self.output_1[:,0:1])], dim=1), dim=1)
    
# defines the loss and initializes the model
loss_func = torch.nn.NLLLoss(reduction='sum')
nlps_func = loss_func
model = Softmax_regression()

# copies the starting point
start = copy.deepcopy(model.state_dict())

In [None]:
# optimization parameters
step_size = 1e-05
batch_size = 128

# runs the non-accelerate algorithms
losses_SGD = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=False, var_reduce=None)
#losses_SAG = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=False, var_reduce='SAG')
losses_SAGA = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=False, var_reduce='SAGA')

# runs the accelerated algorithms
losses_acc_SGD = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=True, var_reduce=None)
#losses_acc_SAG = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=True, var_reduce='SAG')
losses_acc_SAGA = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=True, var_reduce='SAGA')

In [None]:
# plots the progress of each of SGD, SAGA, and SAG
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(6.6*2,4))

axes[0].plot(losses_SGD[50:], label='SGD')
#axes[0].plot(losses_SAG, label='SAG')
axes[0].plot(losses_SAGA[50:], label='SAGA')
axes[0].legend(loc="upper right");

axes[1].plot(losses_acc_SGD[50:], label='acc_SGD')
#axes[1].plot(losses_acc_SAG, label='acc_SAG')
axes[1].plot(losses_acc_SAGA[50:], label='acc_SAGA')
axes[1].legend(loc="upper right");
plt.legend(loc="upper right");

In [None]:
# optimization parameters
step_size = 1e-05
batch_size = 128

# runs the non-accelerate algorithms
nlps_SGD = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=False, var_reduce=None)
#nlps_SAG = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=False, var_reduce='SAG')
nlps_SAGA = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=False, var_reduce='SAGA')

# runs the accelerated algorithms
nlps_acc_SGD = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=True, var_reduce=None)
#nlps_acc_SAG = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=True, var_reduce='SAG')
nlps_acc_SAGA = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=True, var_reduce='SAGA')

In [None]:
# plots the progress of each of SGD, SAGA, and SAG
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(6.6*2,4))

axes[0].plot(nlps_SGD[50:], label='SGD')
#axes[0].plot(nlps_SAG, label='SAG')
axes[0].plot(nlps_SAGA[50:], label='SAGA')
axes[0].legend(loc="upper right");

axes[1].plot(nlps_acc_SGD[50:], label='acc_SGD')
#axes[1].plot(nlps_acc_SAG, label='acc_SAG')
axes[1].plot(nlps_acc_SAGA[50:], label='acc_SAGA')
axes[1].legend(loc="upper right");
plt.legend(loc="upper right");