In [1]:
import matplotlib.pyplot as plt
import torch
import helpers
import os
os.chdir('../')
import torch_vr

In [2]:
N = 1000
d_in = 10
d_out = 1

In [3]:
# generates synthetic data
U, _, V = torch.svd(torch.randn(N, d_in))
Lambda = (10 * torch.ones(d_in)).pow(torch.arange(d_in)/4.5)
X = U @ (Lambda * V).T #X.T @ X has eigenvalues Lambda.pow(2)
beta = torch.randn(d_in)
y = beta @ X.T

L = 2*torch.tensor(1e04) # Lipschitz constant of the gradients
mu = 2*torch.tensor(1.0) # strong convexity constant

In [4]:
# Defines the true optimum
true_optimum = torch.solve((X.T @ y).unsqueeze(1), X.T @ X)[0].t()

# Defines the true mean and covariance of the gaussian posterior
true_mean = true_optimum
true_cov = (X.T @ X).inverse()

In [5]:
# defines and initializes the model
torch.manual_seed(0)

layers = list()
layers.append(torch.nn.Linear(d_in, d_out, bias=False))
activations = list()
def squeeze(tensor):
    tensor.squeeze(1)
activations.append(squeeze) 
model = torch_vr.Sequential(layers, activations)
    
# defines the loss
loss_func = torch.nn.MSELoss(reduction='sum')

# copies the starting point
start = model.state_dict()

In [None]:
# runs the non-accelerate algorithms
acc_GD = helpers.optimize(X, y, model, loss_func, start, step_size, N, accelerated=False, var_reduce=None)
acc_SGD = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=False, var_reduce=None)
acc_SAG = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=False, var_reduce='SAG')
acc_SAGA = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=False, var_reduce='SAGA')

# runs the accelerated algorithms
acc_SGD = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=True, var_reduce=None)
acc_SAG = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=True, var_reduce='SAG')
acc_SAGA = helpers.optimize(X, y, model, loss_func, start, step_size, batch_size, accelerated=True, var_reduce='SAGA')

In [None]:
(L.sqrt() - 1)/(L.sqrt() + 1)

In [None]:
# plots the progress of each of SGD, SAGA, and SAG
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(6.6*2,4))

axes[0].plot(losses_GD, label='GD')
#axes[0].plot(losses_SGD, label='SGD')
#axes[0].plot(losses_SAG, label='SAG')
#axes[0].plot(losses_SAGA, label='SAGA')
#axes[0].legend(loc="upper right");

#axes[1].plot(losses_acc_SGD, label='acc_SGD')
#axes[1].plot(losses_acc_SAG, label='acc_SAG')
#axes[1].plot(losses_acc_SAGA, label='acc_SAGA')
#axes[1].legend(loc="upper right");

In [None]:
# optimization parameters
step_size = 0.1*2.0/(L + mu)
batch_size = 16

# runs the non-accelerate algorithms
nlps_LD = helpers.sample(X, y, model, nlp_func, start, step_size, N, accelerated=False, var_reduce=None, iterations=1000)
#nlps_SGD = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=False, var_reduce=None)
#nlps_SAG = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=False, var_reduce='SAG')
#nlps_SAGA = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=False, var_reduce='SAGA')

# runs the accelerated algorithms
#gam = torch.sqrt(L + mu)
#step_size = mu/(4*gam)
gam = 20
step_size = 1e-03
nlps_acc_LD = helpers.sample(X, y, model, nlp_func, start, step_size, N, accelerated=True, var_reduce=None, gam=gam, iterations=1000)
#nlps_acc_SGD = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=True, var_reduce=None, gam=gam)
#nlps_acc_SAG = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=True, var_reduce='SAG')
#nlps_acc_SAGA = helpers.sample(X, y, model, nlp_func, start, step_size, batch_size, accelerated=True, var_reduce='SAGA', gam=gam)

In [None]:
# plots the progress of each of SGD, SAGA, and SAG
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(6.6*2,4))

axes[0].plot(nlps_LD, label='LD')
#axes[0].plot(nlps_SGD[:], label='SGD')
#axes[0].plot(nlps_SAG, label='SAG')
#axes[0].plot(nlps_SAGA[:], label='SAGA')
#axes[0].legend(loc="upper right");

axes[1].plot(nlps_acc_LD[:], label='acc_GD')
#axes[1].plot(nlps_acc_SGD[:], label='acc_SGD')
#axes[1].plot(nlps_acc_SAG, label='acc_SAG')
#axes[1].plot(nlps_acc_SAGA[:], label='acc_SAGA')
#axes[1].legend(loc="upper right");