# MetaTrain
Trains the lstm optimizer model

### Imports and Setup

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from models import simple_training_NN
from models import l2o_lstm_model

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
# Hyperparameters

batch_size = 128    # number of data used for each iteration step
iterations = 50     # number of steps for one meta-training sample
epochs = 30         # total number of training optimizees

In [None]:
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

print(f"Number of batches: {len(train_dataloader)}")

In [None]:
optimizer = l2o_lstm_model().to(device)

### Meta-Training Process

In [None]:
loss_fn = nn.CrossEntropyLoss()
meta_optimizer = torch.optim.SGD(optimizer.parameters(), lr=1e-3)

In [None]:
def meta_train(optimizer, optimizee, train_dataloader, iterations=50):
    loss = 0

    for i, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)

        if i == iterations:
            break

        #First Pass
        pred = optimizee(X)
        current_loss = loss_fn(pred, y)
        print(f"Current loss: {current_loss}")

        loss += current_loss

        # backprop for optimizee gradient calculations
        current_loss.backward()

        for m_key in optimizee._modules:
            module = optimizee._modules[m_key]
            for p_key in module._parameters:
                # Input shape to LSTM: (seq_len, batch_size, input_dim)

                # update = optimizer(module._parameters[p_key].grad.reshape(1, -1, 1))
                # new_val = p - update.reshape(p.grad.shape)

                module._parameters[p_key] = module._parameters[p_key] - module._parameters[p_key].grad * 0

    pred = optimizee(X)
    current_loss = loss_fn(pred, y)
    print(f"Current loss: {current_loss}")

    loss += current_loss
    
    print(f"Total Loss: {loss}")
        

In [None]:
for i in range(epochs):
    print(f"=== Epoch: {i} ===")
    optimizee = simple_training_NN().to(device)

    meta_train(optimizer, optimizee, train_dataloader, 1)