## Purpose
This code trains a simple regression model using a variable learning rate learnt by L2O

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader

from datasets import LinearDataset, BostonDataset

In [None]:
batch_size = 1
unfold_length = 10
epochs = 500

In [None]:
#learnable parameters
lr = torch.tensor([0.001], requires_grad=True)

In [None]:
dataset = LinearDataset()
train_loader = DataLoader(dataset, batch_size=batch_size)

In [None]:
def loss_fn(y, y_pred):
    return (y - y_pred) ** 2

In [None]:
X, y = next(iter(train_loader))
X, y = X.to(dtype=torch.float32), y.to(dtype=torch.float32)
print(X.shape)
print(y.shape)

### Shape of Parameters
Parameters in time are a (python vanilla) list of paramters. The length of the paramter list is `unfold_length`
At the n-th step, we take `weight[n]` for feedforward

In [None]:
weights = [None for i in range(unfold_length)]
weights[0] = nn.Parameter(torch.rand(1, dtype=torch.float32))

bias = [None for i in range(unfold_length)]
bias[0] = nn.Parameter(torch.rand(1, dtype=torch.float32))

In [None]:
total_loss = 0

for i, (X, y) in enumerate(train_loader):
    print(f"Batch: {i}")
    X, y = X.to(dtype=torch.float32), y.to(dtype=torch.float32)

    pred = torch.matmul(X, weights[i]) + bias[i]
    loss = loss_fn(y, pred).mean()

    print(f"MSE: {loss.item()}")

    loss.backward(retain_graph=True)

    total_loss = total_loss + loss

    if (i == len(train_loader) - 1):
        break

    weights[i+1] = weights[i].detach() - weights[i].grad * lr
    bias[i+1] = bias[i].detach() - bias[i].grad * lr

    # no need to zero gradients, as weights[i+1] and weights[i] are not the same variable
    
    weights[i+1].requires_grad_()   # require_grad, as assignment is done in no_grad context
    weights[i+1].retain_grad()      # retain_grad since weights[i+1] is not a leaf tensor in the larger grape

    bias[i+1].requires_grad_()
    bias[i+1].retain_grad()

print(f"Total loss: {total_loss}")

In [None]:
with torch.no_grad():
    weights[0] = weights[9]
    bias[0] = bias[9]

last_weights = weights[9].detach()
last_bias = bias[9].detach()

weights = [None for i in range(unfold_length)]
weights[0] = nn.Parameter(last_weights)

bias = [None for i in range(unfold_length)]
bias[0] = nn.Parameter(last_bias)

In [None]:
total_loss.backward()

In [None]:
with torch.no_grad():
    # clip gradients
    if lr.grad >= 10:
        lr.grad = torch.tensor([10.0])
    if lr.grad <= -10:
        lr.grad = torch.tensor([-10.0])

    lr -= lr.grad * 0.001

print(lr)

In [None]:
lr = lr.detach()
lr.requires_grad_()