In [None]:
%matplotlib inline
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 200
import matplotlib.colors as colors
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import numpy.linalg as la
from sklearn.metrics import r2_score, roc_auc_score
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import StepLR
from pathlib import Path

In [None]:
class TimeseriesDataset(Dataset):

    # X is basically torch.tensor(y_session), U is basically torch.tensor(u_session)
    def __init__(self, X, U, num_lags):
        self.X = X
        self.U = U 
        self.num_lags = num_lags

    def __len__(self):
        # The dataset length is reduced by num_lags due to the dependency on previous data points
        return len(self.X) - self.num_lags

# NOTE: This function is returning 4 index history of X and U as well as the prediction point X
    def __getitem__(self, index):
        # return slices of X and U of [index:index + num_lags]
        X_history = [self.X[index + i] for i in range(self.num_lags)]
        U_history = [self.U[index + i] for i in range(self.num_lags)]

        # X_next is the prediction point
        X_next = self.X[index + self.num_lags]
        return X_history, U_history, X_next

# NOTE: This function normalizes matrix such that the largest singular value becomes 2
def singular_value_norm(matrix):
    norm_val = torch.linalg.norm(matrix, 2)  # norm_val = largest singular value of matrix
    if norm_val > 2: 
        matrix = 2 * matrix / norm_val  # normalized/scaled such that the largest singular values becomes 2
    return matrix


class LinearDynamicModel(nn.Module):  # inherit from nn.Module which is base class for all neural network in pytorch
    def __init__(self, state_dim, input_dim, num_lags, init_value = None):
        
        super(LinearDynamicModel, self).__init__()  # constructor of nn.Module, gives functionalities of nn.Module to LinearDynamicModel
        
        # if doesn't exist yet, create a new parameter list
        if init_value is None:
            # Create diagonal matrices for alpha and beta, one for each lag

            # nn.ParameterList is like Python list, but tensors that are nn.Parameter are visible by all Module methods and autograd will work
            # torch.randn(state_dim) creates tensor of len(state_dim) of random numbers from normal distribution (0, 1)
            # alpha and beta are lists of parameters with num_lags=4 tensors of size state_dim=502
            self.alpha = nn.ParameterList([nn.Parameter(torch.randn(state_dim)) for _ in range(num_lags)])
            self.beta = nn.ParameterList([nn.Parameter(torch.randn(state_dim)) for _ in range(num_lags)])
            
            # W is matrix A and B is matrix B in the paper
            # W and B are lists of parameters with num_lags=4 tensors of size (state_dim, state_dim) and (state_dim, input_dim)
            self.W = nn.ParameterList([nn.Parameter(torch.randn(state_dim, state_dim)) for _ in range(num_lags)])
            self.B = nn.ParameterList([nn.Parameter(torch.randn(state_dim, input_dim)) for _ in range(num_lags)]) # this is full-rank, linear model so i think state_dim=input_dim

            # V is a parameter tensor of size state_dim
            self.V = nn.Parameter(torch.randn(state_dim))

        # if already exists, create the parameter list by pulling from the existing dictionary init_value
        else:
            # init_value is a {} dictionary of [] lists
            self.alpha = nn.ParameterList([nn.Parameter(init_value['alpha'][i]) for i in range(num_lags)])
            self.beta = nn.ParameterList([nn.Parameter(init_value['beta'][i]) for i in range(num_lags)])

            self.W = nn.ParameterList([nn.Parameter(init_value['W'][i]) for i in range(num_lags)])
            self.B = nn.ParameterList([nn.Parameter(init_value['B'][i]) for _ in range(num_lags)])

            self.V = nn.Parameter(init_value['V'])
    

    def forward(self, X_history, U_history):
        # initialize tensor of 0s of same size as first X_history tensor so we would expect X_next to be size (502)
        X_next = torch.zeros_like(X_history[0])

        #  self.W is ParameterList of 4 tensors of size (502, 502)
        #  self.alpha is ParameterList of 4 tensors of size (502)
        #  X_history is python list of size (4, 502)
        for W_k, alpha_k, X_k in zip(self.W, self.alpha, X_history):
            # W_k (502, 502)
            # alpha_k (502)
            # X_k (502)

            X_k = X_k.unsqueeze(-1)
            alpha_diag_k = torch.diag(alpha_k)

            # unsqueeze X_k so it now has shape (502, 1), add extra dimension
            # torch.diag(alpha_k) so alpha_diag_k has shape (502, 502)

            # compute contribution of state X_k to state X_next
            # matrices A and B correspond to W + diag(alpha) and B + diag(beta)
            # (502, 502) @ (502, 1) = (502,1).squeeze(-1) = (502)
            contribution = torch.matmul(singular_value_norm(W_k + alpha_diag_k), X_k).squeeze(-1)

            # X_next is (502) of 0's so add contribution to it
            # X_next now has the contribution of num_lags previous states
            X_next += contribution

        #  self.B is ParameterList of 4 tensors of size (502, input_dim)
        #  self.beta is ParameterList of 4 tensors of size (502)
        #  U_history is python list of size (4, 502)
        for B_k, beta_k, U_k in zip(self.B, self.beta, U_history):
            U_k = U_k.unsqueeze(-1)
            beta_diag_k = torch.diag(beta_k)
            # B_k (502, input_dim=502), full rank model so state_dim = input_dim
            # U_k (502, 1)
            # beta_diag_k (502, 502)

            # compute contribution of input U_k to next state X-next
            # (502,502) @ (502,1) = (502,1).squeeze(-1) = (502)
            contribution = torch.matmul(singular_value_norm(B_k + beta_diag_k), U_k).squeeze(-1)

            # X_next now has contribution of num_lags previous states AND num_lags previous inputs
            # X_next is still (502)
            X_next += contribution

        # X_next is still (502)
        X_next += self.V[None, :]
        return X_next
    
class LowRankLinearDynamicModel(nn.Module):  # inherit from nn.Module which is base class for all neural network in pytorch
    def __init__(self, state_dim, input_dim, rank_dim, num_lags, init_value = None):
        super(LowRankLinearDynamicModel, self).__init__()  # constructor of nn.Module, gives functionalities of nn.Module to LinearDynamicModel
        
        if init_value is None:
            # self.alpha and self.beta are both ParameterList of 4 Parameters each a tensor of size (502)
            self.alpha = nn.ParameterList([nn.Parameter(torch.randn(state_dim)) for _ in range(num_lags)])
            self.beta = nn.ParameterList([nn.Parameter(torch.randn(state_dim)) for _ in range(num_lags)])

            # self.W_u and self.W_v are both ParameterList of 4 Parameters each a tensor of size (502, 5)
            self.W_u = nn.ParameterList([nn.Parameter(torch.randn(state_dim, rank_dim)) for _ in range(num_lags)])
            self.W_v = nn.ParameterList([nn.Parameter(torch.randn(state_dim, rank_dim)) for _ in range(num_lags)])

            # self.B_u and self.B_v are both ParameterList of 4 Parameters each a tensor of size (502, 5)
            self.B_u = nn.ParameterList([nn.Parameter(torch.randn(state_dim, rank_dim)) for _ in range(num_lags)])
            self.B_v = nn.ParameterList([nn.Parameter(torch.randn(state_dim, rank_dim)) for _ in range(num_lags)])

            # self.V is a Parameter tensor of size (502)
            self.V = nn.Parameter(torch.randn(state_dim))
        else:
            self.alpha = nn.ParameterList([nn.Parameter(init_value['alpha'][i]) for i in range(num_lags)])
            self.beta = nn.ParameterList([nn.Parameter(init_value['beta'][i]) for i in range(num_lags)])

            self.W_u = nn.ParameterList([nn.Parameter(init_value['W_u'][i]) for i in range(num_lags)])
            self.W_v = nn.ParameterList([nn.Parameter(init_value['W_v'][i]) for i in range(num_lags)])

            self.B_u = nn.ParameterList([nn.Parameter(init_value['B_u'][i]) for _ in range(num_lags)])
            self.B_v = nn.ParameterList([nn.Parameter(init_value['B_v'][i]) for _ in range(num_lags)])

            self.V = nn.Parameter(init_value['V'])
        
    def forward(self, X_history, U_history):
        X_next = torch.zeros_like(X_history[0])  # (502)
        for W_u_k, W_v_k, alpha_k, X_k in zip(self.W_u, self.W_v, self.alpha, X_history):
            X_k = X_k.unsqueeze(-1)  
            alpha_diag_k = torch.diag(alpha_k)
            # W_u_k and W_v_k(502, 35) low rank approx. of matrix A
            # alpha_diag_k (502, 502) original diagonals of each of the 4 (502,502) in Ahat ~ y_session[t:t+4]
            # X_k (502, 1) each of the 4 previous states

            # U_A @ V_A.T (502, 502)
            W_k = torch.mm(W_u_k, W_v_k.T)  # reconstruct each of 4 A matrix 
            
            # A_s = U_A @ V_A.T + D_A
            # (502, 502) @ (502, 1) = (502, 1).squeeze(-1) = (502)
            contribution = torch.matmul(singular_value_norm(W_k + alpha_diag_k), X_k).squeeze(-1)  # Shape returns to (batch_size, state_dim)
            X_next += contribution  # add each of the 4 contributions from the previous states to x_next (502,)

        for B_u_k, B_v_k, beta_k, U_k in zip(self.B_u, self.B_v, self.beta, U_history):
            U_k = U_k.unsqueeze(-1)
            beta_diag_k = torch.diag(beta_k)
            # B_u_k and B_v_k (502, 35) low rank approx. of matrix B
            # beta_diag_k (502, 502) (502, 502) original diagonals of each of the 4 (502,502) in Ahat ~ u_session[t:t+4]
            # U_k (502, 1) each of the 4 previous inputs

            # U_B @ V_B.T (502, 502)
            B_k = torch.mm(B_u_k, B_v_k.T)  # reconstruct each of 4 B matrix 
            
            # B_s = U_B @ V_B.T + D_B
            # (502, 502) @ (502, 1) = (502, 1).squeeze(-1) = (502)
            contribution = torch.matmul(singular_value_norm(B_k + beta_diag_k), U_k).squeeze(-1)
            X_next += contribution  # add each of the 4 contributions from the previous inputs to x_next (502,)

        # X_next is (502,) of all contributions from each of the previous 4 states and inputs    
        X_next += self.V[None, :]  # account for bias term
        return X_next

def train_model(model, train_loader, val_loader, epochs=100, lr=0.01, clip_value=1.0, l2_lambda=0.01, step_size=50, gamma=0.5, checkpoint_name = 'linear_35'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    # Decays the lr of each parameter by a factor of gamma at every step_size
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    criterion = nn.MSELoss()
    best_val_loss = float('inf')
    
    # Lists to track loss
    train_losses = []
    val_losses = []
    l2_penalty_losses = []
    
    current_lr = lr

    for epoch in range(epochs):
        # Training Phase
        model.train()  # set model to training mode
        total_train_loss = 0
        total_l2_penaty = 0

        for X_history, U_history, X_next in train_loader:  # loop through 12 train batches in train_loader
            optimizer.zero_grad()  # zero out gradients for each batch
            predictions = model(X_history, U_history)  # this is X_next (502) from forward method
            loss = criterion(predictions, X_next)  # take loss betwen X_next (502) predicted and X_next true
            # print(f"pre_L2_loss: {loss}")
            
            # total_train_loss += loss.item()
            
            # Compute the L2 penalty for each parameter
            l2_penalty = torch.tensor(0.).to(device)
            
            for param in model.W_u:
                l2_penalty += torch.norm(param,p=2)
            # print(f"l2_penalty is {l2_penalty}")
            for param in model.W_v:
                l2_penalty += torch.norm(param,p=2)
            for param in model.B_u:
                l2_penalty += torch.norm(param,p=2)
            for param in model.B_v:
                l2_penalty += torch.norm(param,p=2)
            for param in model.alpha:
                l2_penalty += torch.norm(param,p=2)
            for param in model.beta:
                l2_penalty += torch.norm(param,p=2)
            
            # Add the L2 penalty to the original loss
            loss += l2_lambda * l2_penalty
            
            total_train_loss += loss.item()  # ??? Maybe this is why its higher than val loss.

            # print(f"post_L2_loss: {loss}")

            total_l2_penaty += l2_penalty.item()

            # accumulates dloss/dx for every parameter x into x.grad for every parameter x
            # x.grad += dloss/dx
            loss.backward()

            # x += -lr * x.grad, update parameters with gradients
            optimizer.step()
        

        # Validation Phase
        model.eval()
        total_val_loss = 0

        with torch.no_grad():
            for X_history, U_history, X_next in val_loader:  # loop through 4 val batches of 2000 in val_loader
                predictions = model(X_history, U_history)
                loss = criterion(predictions, X_next)
                total_val_loss += loss.item()

        # Logging average training and validation loss, and L2 penalty
        # TODO: need some clarity if this is computing the loss from 1 sample from the batch of 2000 or somehow doing it for all points, i think the gap is in how __getitem__ works
        train_loss = total_train_loss / len(train_loader)
        val_loss = total_val_loss / len(val_loader)
        l2_penaty_loss = total_l2_penaty / len(train_loader)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        l2_penalty_losses.append(l2_penaty_loss)
        
        print(f'Epoch {epoch}: Train Loss = {train_loss}, Val Loss = {val_loss}, L2 Penalty Loss = {l2_penaty_loss}, LR = {current_lr}')

        # Checkpointing based on minimal validation loss across all epochs
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            Path("checkpoints").mkdir(exist_ok=True)
            checkpoint_path = f'checkpoints/model_best_' + checkpoint_name + '.pt'
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch} with Val Loss: {val_loss:.4f}")
            
    # Plotting the training and validation losses
    plt.figure(figsize=(12, 6), dpi=80)
    plt.plot(train_losses, label='Training Loss')  #if no x provided, plots against indices
    plt.plot(val_losses, label='Validation Loss')
    plt.title(f"Training and Validation Loss w/ L2 Reg of lambda {l2_lambda}")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    # Plotting the L2 penalties
    plt.figure(figsize=(12, 6), dpi=80)
    plt.plot(l2_penalty_losses, label='L2 Penalty', color='red')
    plt.title(f"L2 Penalty Over Epochs of lambda {l2_lambda}")
    plt.xlabel('Epochs')
    plt.ylabel('L2 Penalty')
    plt.legend()
    plt.grid(True)
    plt.show()