# Purpose
This code trains a classification network using the LSTM optimizer

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision.transforms import ToTensor

from models.l2o_optimizer import lstm_l2o_optimizer

from models.class_network import class_net

In [2]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
# Constants
in_size = 28 * 28
out_size = 10

In [4]:
## Meta-Hyperparameters
num_optimizee = 10
unroll_len = 10

In [5]:
## Hyperparameters
batch_size = 128
num_epochs = 5

In [6]:
## Optimizer Model
update_fn = lstm_l2o_optimizer().to(device)
meta_optimizer = torch.optim.Adam(update_fn.parameters(), lr=0.01)

In [7]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [8]:
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

from math import ceil

num_batches = ceil(len(training_data) / batch_size)
print(f"Number of batches: {num_batches}")

Shape of X [N, C, H, W]: torch.Size([128, 1, 28, 28])
Shape of y: torch.Size([128]) torch.int64
Number of batches: 469


In [9]:
models_t = [None for _ in range(unroll_len)]
models_t[0] = class_net(in_size, out_size).to(device)

In [10]:
loss_fn = nn.CrossEntropyLoss()

### Useful Functions
Defines useful functions for the main loop

In [11]:
def init_hidden(model):
    # Initializes a hidden state dictionary for every parameter value in the model.
    h = {}

    for m_key in model._modules:
        modules = model._modules[m_key]
        h_module = {}
        for p_key in modules._parameters:
            h_module[p_key] = None
        h[m_key] = h_module
    
    return h

In [12]:
def init_sequence(model_arch):
    models_t = [None for _ in range(unroll_len)]
    models_t[0] = model_arch(in_size, out_size).to(device)
    return models_t

In [13]:
def reset_computational_graph(models_t, h_dict):
    # Resets model sequence after training iterations
    # Assigns the starting model to be the prev ending model

    model_end = models_t[-1]
    model_new_start = class_net(in_size, out_size)

    for m_key in model_end._modules:
        m1, m2 = model_end._modules[m_key], model_new_start._modules[m_key]
        for p_key in m1._parameters:
            m2._parameters[p_key] = m1._parameters[p_key].detach()
            m2.requires_grad_()

    models_t[0] = model_new_start

    # Resets computational graph of hidden state
    for m_key in h_dict:
        h_mod = h_dict[m_key]
        for p_key in h_mod:
            # Every h has two values, short term and long term memory
            h_mod[p_key] = (h_mod[p_key][0].detach(), h_mod[p_key][1].detach())
            
            h_mod[p_key][0].requires_grad_()
            h_mod[p_key][1].requires_grad_()

In [14]:
def zero_gradients(model):
    for m_key in model._modules:
        m1 = model._modules[m_key]
        for p_key in m1._parameters:
            # Shape for Batch input: (1, Num, 1)
            # Shape for Hidden State: (1, Num, 24)
            
            if m1._parameters[p_key].grad is not None:
                m1._parameters[p_key].grad.zero_()

In [15]:
def update_optimizee_and_copy(old_model, new_model, hidden):
    for m_key in old_model._modules:
        m1, m2 = old_model._modules[m_key], new_model._modules[m_key]
        h_module = hidden[m_key]
        for p_key in m1._parameters:
            # Shape for Batch input: (1, Num, 1)
            # Shape for Hidden State: (1, Num, 24)
            
            grad_in = m1._parameters[p_key].grad.reshape(1, -1, 1)

            update, h_module[p_key] = update_fn(grad_in, h_module[p_key])        
            update = update.reshape(m1._parameters[p_key].shape)
            
            m2._parameters[p_key] = m1._parameters[p_key].detach() - update  
            m2._parameters[p_key].requires_grad_()
            m2._parameters[p_key].retain_grad()

In [16]:
def backprop_on_optimizer(total_loss):
    # Outer Loop Backprop
    meta_optimizer.zero_grad()
    total_loss.backward()
    meta_optimizer.step()

In [17]:
def inner_pass(h, models_t):
    total_loss = 0
    for i, (X, y) in enumerate(train_dataloader):
        iter = i % unroll_len

        if (num_batches - i) < unroll_len:
            # End prematurely if remaining batches not enough for unroll length
            reset_computational_graph(models_t, h)
            break

        # Preprocessing
        X = X.reshape(batch_size, -1)
        X, y = X.to(device), y.to(device)

        # Forward Pass
        pred = models_t[iter](X)
        loss = loss_fn(pred, y)
        total_loss = total_loss + loss

        if (i % 125) == 0:
            print(f"Batch {i} / {num_batches}, Model loss: {loss}")

        if iter == unroll_len - 1:
            backprop_on_optimizer(total_loss)
            reset_computational_graph(models_t, h)
            total_loss = 0

        else:
            # Backprop
            zero_gradients(models_t[iter])
            loss.backward(retain_graph=True)
            models_t[iter+1] = class_net(28*28, 10).to(device)  # Initialize a new model
            update_optimizee_and_copy(
                old_model=models_t[iter], new_model=models_t[iter+1], hidden=h)

In [18]:
def outer_pass():
    models_t = init_sequence(class_net)
    h = init_hidden(models_t[0])
    for epoch in range(num_epochs):
        print(f"Epoch {epoch}")
        inner_pass(h, models_t)

### Main Loop
Executing the cells below to train the L2O optimizer for multiple epochs

In [19]:
for count in range(num_optimizee):
    print(f"{count}-th optimizee")
    outer_pass()

0-th optimizee
Epoch 0
Batch 0 / 469, Model loss: 2.308332920074463
Batch 125 / 469, Model loss: 0.48283201456069946
Batch 250 / 469, Model loss: 0.4312162399291992
Batch 375 / 469, Model loss: 0.2685301899909973
Epoch 1
Batch 0 / 469, Model loss: 0.29593631625175476
Batch 125 / 469, Model loss: 0.37319380044937134
Batch 250 / 469, Model loss: 0.36491602659225464
Batch 375 / 469, Model loss: 0.22021491825580597
Epoch 2
Batch 0 / 469, Model loss: 0.2292969673871994
Batch 125 / 469, Model loss: 0.3217734098434448
Batch 250 / 469, Model loss: 0.31891757249832153
Batch 375 / 469, Model loss: 0.200679749250412
Epoch 3
Batch 0 / 469, Model loss: 0.1947544515132904
Batch 125 / 469, Model loss: 0.29188230633735657
Batch 250 / 469, Model loss: 0.27890434861183167
Batch 375 / 469, Model loss: 0.18150165677070618
Epoch 4
Batch 0 / 469, Model loss: 0.16952216625213623
Batch 125 / 469, Model loss: 0.2691188454627991
Batch 250 / 469, Model loss: 0.25245538353919983
Batch 375 / 469, Model loss: 0.164

In [24]:
torch.save(update_fn.state_dict(), 'trained_model\l2o_optimizer.pth')