# Purpose
This code trains a classification network using the LSTM optimizer

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision.transforms import ToTensor

from models.l2o_optimizer import lstm_l2o_optimizer

from models.class_network import class_net

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# device = "cpu"
# print(f"Using {device} device")

In [None]:
## Hyperparameters
batch_size = 128
iters = 50

In [None]:
## Optimizer Model
update_fn = lstm_l2o_optimizer().to(device)
meta_optimizer = torch.optim.SGD(update_fn.parameters(), lr=0.01)

In [None]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

print(f"Number of batches: {int(len(training_data) / batch_size)}")

In [None]:
# Defines Model Architecture used
model_arch = class_net
def new_model(inputNum, outputNum):
    return model_arch(inputNum, outputNum)

In [None]:
models_t = [None for _ in range(iters)]
models_t[0] = new_model(28*28, 10).to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()

### Main Loop
Repeat executing the cells below to train the L2O optimizer for multiple epochs

In [None]:
# Initializes a hidden state dictionary for every parameter value in the model.

h = {}

for m_key in models_t[0]._modules:
    modules = models_t[0]._modules[m_key]
    h_module = {}
    for p_key in modules._parameters:
        h_module[p_key] = None
    h[m_key] = h_module

In [None]:
# Outer Forward Pass
total_loss = 0

for i, (X, y) in enumerate(train_dataloader):
    # Preprocessing
    X = X.reshape(batch_size, -1)
    X, y = X.to(device), y.to(device)
    
    # Forward Pass
    pred = models_t[i](X)
    loss = loss_fn(pred, y)
    total_loss = total_loss + loss

    if i % 10 == 0:
        print(f"Batch {i:2}, MSE Loss: {loss:.5f}")
    if i == iters - 1:
        break

    # Backprop
    loss.backward(retain_graph=True)

    # Initialize a new model with previous weights
    models_t[i+1] = new_model(28*28, 10).to(device)

    for m_key in models_t[i]._modules:
        m1, m2 = models_t[i]._modules[m_key], models_t[i+1]._modules[m_key]
        h_module = h[m_key]
        for p_key in m1._parameters:
            # Shape for Batch input: (1, Num, 1)
            # Shape for Hidden State: (1, Num, 24)
            
            grad_in = m1._parameters[p_key].grad.reshape(1, -1, 1)

            update, h_module[p_key] = update_fn(grad_in, h_module[p_key])        
            update = update.reshape(m1._parameters[p_key].shape)
            
            m2._parameters[p_key] = m1._parameters[p_key].detach() - update  
            m2._parameters[p_key].requires_grad_()
            m2._parameters[p_key].retain_grad()


print(f"Total loss in {iters} iterations: {total_loss:.3f}")

In [None]:
# Resets model sequence after training iterations
# Assigns the starting model to be the prev ending model

model_end = models_t[iters - 1]
model_new_start = new_model(28*28, 10)

for m_key in model_end._modules:
    m1, m2 = model_end._modules[m_key], model_new_start._modules[m_key]
    for p_key in m1._parameters:
        m2._parameters[p_key] = m1._parameters[p_key].detach()
        m2.requires_grad_()

models_t[0] = model_new_start

In [None]:
# Outer Loop Backprop
meta_optimizer.zero_grad()
total_loss.backward()

In [None]:
meta_optimizer.step()