In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from models.mnist_nets import deep_net

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using {device}")

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
# Constants

history_step = 15
output_step = 100

In [None]:
# Hyperparameters

batch_size = 128
epochs = 5

In [None]:
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

print(f"Number of batches: {len(train_dataloader)}")

In [None]:
optimizers = [
    torch.optim.SGD,
    torch.optim.Adam,
    torch.optim.RMSprop
]

lrs = [0.01, 0.001, 0.01]

num_opts = len(optimizers)

In [None]:
from copy import deepcopy
model = deep_net(28*28, 10).to(device)
models = [deepcopy(model) for _ in range(num_opts)]
model_for_l2o = deepcopy(model)

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    history = []
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss = loss.item()

        if batch % output_step == 0:
            current = batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            

        if batch % history_step == 0:
            history.append(loss)
            
    return history

In [None]:
his = [[] for _ in range(num_opts)]
for i in range(num_opts):
    print(f"Using {optimizers[i].__name__}")
    for t in range(epochs):
        print(f"Epoch {t+1}: ")
        
        loss = train(
            train_dataloader,
            models[i],
            loss_fn,
            optimizers[i](models[i].parameters(), lr=lrs[i]))

        his[i].append(loss)

    print("Done!\n")

In [None]:
for i in range(num_opts):
    his[i] = torch.tensor(his[i]).reshape(-1)

## Using L2O Optimizer

In [None]:
from models.optim_nets import lstm_l2o_optimizer   
from trainUtil import init_hidden, zero_gradients

In [None]:
l2o_optimizer = lstm_l2o_optimizer().to(device)
l2o_optimizer.load_state_dict(torch.load("trained_model/l2o_optimizer.pth"))
l2o_optimizer.eval()

In [None]:
def update_weights(model, update_fn, hidden):
    with torch.no_grad():
        for m_key in model._modules:
            m1 = model._modules[m_key]
            h_module = hidden[m_key]
            for p_key in m1._parameters:
                
                grad_in = m1._parameters[p_key].grad.reshape(1, -1, 1)

                update, h_module[p_key] = update_fn(grad_in, h_module[p_key])        
                update = update.reshape(m1._parameters[p_key].shape)
                
                m1._parameters[p_key] -= update 

In [None]:
def train_with_l2o(dataloader, model, loss_fn, l2o_optimizer):
    h = init_hidden(model)
    size = len(dataloader.dataset)
    history = []
    model.train()
    for batch, (X, y) in enumerate(train_dataloader):
        # Preprocessing
        X, y = X.to(device), y.to(device)

        # Forward Pass
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backprop
        zero_gradients(model)
        loss.backward()
        update_weights(model, l2o_optimizer, h)

        loss = loss.item()

        if batch % output_step == 0:
            current = batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            

        if batch % history_step == 0:
            history.append(loss)

    return history

In [None]:
l2o_hist = []
for i in range(epochs):
    print(f"Epoch: {i + 1}")
    loss = train_with_l2o(
        train_dataloader,
        model_for_l2o,
        loss_fn,
        l2o_optimizer)

    l2o_hist.append(loss)


### Plotting Results

In [None]:
from matplotlib import pyplot as plt

In [None]:
l2o_hist = torch.tensor(l2o_hist).reshape(-1)
his.append(l2o_hist)

In [None]:
names = [optimizers[i].__name__ for i in range(num_opts)]
names.append("L2O-LSTM")

In [None]:
plt.xlabel('Training Steps')
plt.ylabel('Training Loss')
plt.title('Comparing Training using Different Optimizers')
for i in range(num_opts + 1):
    plt.plot(his[i], label=names[i])

plt.legend(loc="upper left")
plt.show()