# Purpose
This code trains a classification network using a variable learning rate determined by L2O

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision.transforms import ToTensor

from models.class_network import class_net

In [2]:
## Hyperparameters
batch_size = 128
iters = 50

In [3]:
## Learnable Parameters
lr = torch.tensor([0.001], requires_grad=True)

In [4]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [5]:
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

print(f"Number of batches: {int(len(training_data) / batch_size)}")

Shape of X [N, C, H, W]: torch.Size([128, 1, 28, 28])
Shape of y: torch.Size([128]) torch.int64
Number of batches: 468


In [6]:
# Defines Model Architecture used
model_arch = class_net
def new_model(inputNum, outputNum):
    return model_arch(inputNum, outputNum)

In [7]:
models_t = [None for _ in range(iters)]
models_t[0] = new_model(28*28, 10)

In [8]:
loss_fn = nn.CrossEntropyLoss()

In [9]:
total_loss = 0

for i, (X, y) in enumerate(train_dataloader):
    # Preprocessing
    X = X.reshape(batch_size, -1)
    
    # Forward Pass
    pred = models_t[i](X)
    loss = loss_fn(pred, y)
    total_loss = total_loss + loss

    if i % 10 == 0:
        print(f"Batch {i:2}, MSE Loss: {loss:.5f}")
    if i == iters - 1:
        break

    # Backprop
    loss.backward(retain_graph=True)

    # Initialize a new model with previous weights
    models_t[i+1] = new_model(28*28, 10)

    for m_key in models_t[i]._modules:
        m1, m2 = models_t[i]._modules[m_key], models_t[i+1]._modules[m_key]
        for p_key in m1._parameters:
            m2._parameters[p_key] = m1._parameters[p_key].detach() - m1._parameters[p_key].grad * lr
            m2._parameters[p_key].requires_grad_()
            m2._parameters[p_key].retain_grad()


print(f"Total loss in {iters} iterations: {total_loss:.3f}")

Batch  0, MSE Loss: 2.33190
Batch 10, MSE Loss: 2.30209
Batch 20, MSE Loss: 2.30005
Batch 30, MSE Loss: 2.27335
Batch 40, MSE Loss: 2.29328
Total loss in 50 iterations: 115.081


In [10]:
# Resets model sequence after training iterations
# Assigns the starting model to be the prev ending model

model_end = models_t[iters - 1]
model_new_start = new_model(28*28, 10)

for m_key in model_end._modules:
    m1, m2 = model_end._modules[m_key], model_new_start._modules[m_key]
    for p_key in m1._parameters:
        m2._parameters[p_key] = m1._parameters[p_key].detach()
        m2.requires_grad_()

models_t[0] = model_new_start

In [11]:
# Outer Loop Backprop
total_loss.backward()

In [12]:
print(lr.grad)

tensor([-40.5704])
