In [1]:
import pandas as pd
import numpy as np
import os
import sys
from datetime import datetime
import matplotlib.pyplot as plt
import sqlite3
from tqdm import tqdm
import pickle
import time 

from datetime import datetime
from datetime import timedelta

from create_financial_database import get_credentials 
from SQLite_tools import query_stock_data, check_if_close_price_exists
from ticker_loader import load_SPY_components

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch_geometric.utils import add_remaining_self_loops
from torch.nn import Linear
from torch.nn import LayerNorm

try:
    from torch_geometric_temporal.nn.recurrent import A3TGCN2
except ModuleNotFoundError:
    from torch_geometric_temporal.nn.recurrent import A3TGCN2
from sklearn.model_selection import train_test_split

In [2]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [3]:
print("Expected Output :")
print("2.2.0+cu121")
print("CUDA available: True")
print("CUDA version: 12.1 \n")
print("Actual Output :")
print(torch.__version__)  # To check the PyTorch version
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)

USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda")
print("CUDA:", USE_CUDA, DEVICE)
print("")
print("PyTorch Devices: ", torch.cuda.get_device_name(0))
print("Using GPU: ", USE_CUDA)
print("DEVICE: ", DEVICE)

Expected Output :
2.2.0+cu121
CUDA available: True
CUDA version: 12.1 

Actual Output :
2.2.0+cu121
CUDA available: True
CUDA version: 12.1
CUDA: True cuda

PyTorch Devices:  NVIDIA GeForce RTX 3060
Using GPU:  True
DEVICE:  cuda


In [4]:
# Define the directory containing the chunks
# CHUNK_DATA_DIR = "../Data/Networks_chunks/window_size_10"
# CHUNK_DATA_DIR = "../Data/Networks_chunks/window_size_10"
# CHUNK_DATA_DIR = "../Data/Networks_chunks/window_size_3"
CHUNK_DATA_DIR = "../Data/Networks_chunks/window_size_5"

# Helper function to load a chunk
def load_chunk(file_path):
    with open(file_path, "rb") as f:
        return pickle.load(f)

# List all chunk files
chunk_files = [os.path.join(CHUNK_DATA_DIR, f) for f in os.listdir(CHUNK_DATA_DIR) if f.endswith(".pkl")]

# Load an example chunk
example_chunk_path = chunk_files[150]
print(f"Loading example chunk: {example_chunk_path}")
example_chunk = load_chunk(example_chunk_path)

# Inspect the first graph in the chunk
print(f"Number of graphs in the chunk: {len(example_chunk)}")
print(f"Example Graph - index 0: {example_chunk[0]}")


Loading example chunk: ../Data/Networks_chunks/window_size_5\chunk_1996-11-07__5.pkl
Number of graphs in the chunk: 5
Example Graph - index 0: Data(x=[734, 71], edge_index=[2, 1011], edge_attr=[1011], y=[734])


In [5]:
print(example_chunk[0])
print(example_chunk[0].x.size())
print(example_chunk[0].x.size()[0])
print(example_chunk[0].x.size()[1])
print(len(example_chunk))



Data(x=[734, 71], edge_index=[2, 1011], edge_attr=[1011], y=[734])
torch.Size([734, 71])
734
71
5


In [None]:
def load_chunk(file_path):
    with open(file_path, "rb") as f:
        return pickle.load(f)


def load_chunks(file_list):
    chunks = []
    for file in file_list:
        with open(file, "rb") as f:
            chunk_graphs = pickle.load(f)
        primary_target = chunk_graphs[-1].y  
        chunks.append((chunk_graphs, primary_target))
    return chunks

        
def prepare_ordered_splits(CHUNK_DATA_DIR, train_ratio=0.7, val_ratio=0.1):
    """
    Prepares train, validation, and test datasets from chunked graph data in chronological order.

    Args:
        CHUNK_DATA_DIR (str): Directory containing graph chunk files.
        train_ratio (float): Proportion of data to use for training.
        val_ratio (float): Proportion of data to use for validation.

    Returns:
        tuple: Lists of train, validation, and test chunk files.
    """
    chunk_files = [os.path.join(CHUNK_DATA_DIR, f) for f in os.listdir(CHUNK_DATA_DIR) if f.endswith(".pkl")]
    chunk_files = sorted(chunk_files, key=lambda x: x.split("_")[1])

    num_files = len(chunk_files)
    train_end = int(num_files * train_ratio)
    val_end = train_end + int(num_files * val_ratio)

    train_files = chunk_files[:train_end]
    val_files = chunk_files[train_end:val_end]
    test_files = chunk_files[val_end:]

    return train_files, val_files, test_files


train_files, val_files, test_files = prepare_ordered_splits(CHUNK_DATA_DIR)

train_files[0:6]

['../Data/Networks_chunks/window_size_5\\chunk_1996-04-08__5.pkl',
 '../Data/Networks_chunks/window_size_5\\chunk_1996-04-09__5.pkl',
 '../Data/Networks_chunks/window_size_5\\chunk_1996-04-10__5.pkl',
 '../Data/Networks_chunks/window_size_5\\chunk_1996-04-11__5.pkl',
 '../Data/Networks_chunks/window_size_5\\chunk_1996-04-12__5.pkl',
 '../Data/Networks_chunks/window_size_5\\chunk_1996-04-15__5.pkl']

In [7]:

train_start_date = train_files[0].split("_")[4]
train_end_date   = train_files[-1].split("_")[4]
val_start_date   = val_files[0].split("_")[4]
val_end_date     = val_files[-1].split("_")[4]
test_start_date  = test_files[0].split("_")[4]
test_end_date    = test_files[-1].split("_")[4]

print("Train Start Date:", train_start_date)
print("Train End Date:", train_end_date)
print("Validation Start Date:", val_start_date)
print("Validation End Date:", val_end_date)
print("Test Start Date:", test_start_date)
print("Test End Date:", test_end_date)

Train Start Date: 1996-04-08
Train End Date: 2016-01-07
Validation Start Date: 2016-01-08
Validation End Date: 2018-10-31
Test Start Date: 2018-11-01
Test End Date: 2024-06-28


In [None]:
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data


class A3TGCN2ChunkDataset(Dataset):
    def __init__(self, file_paths):
        """
        Dataset for A3TGCN2 with chunked graph data.

        Args:
            file_paths (list): List of paths to chunk files.
        """
        self.file_paths = file_paths

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        """
        Returns processed data for a single chunk.

        Args:
            idx (int): Index of the chunk.

        Returns:
            tuple: (X, edge_index, edge_weight, y) for the chunk.
        """
        chunk_path = self.file_paths[idx]
        chunk = load_chunk(chunk_path) 

        X = torch.stack([graph.x for graph in chunk], dim=1)  
        A = chunk[-1].edge_index 
        edge_index = A
        edge_weight = chunk[-1].edge_attr  
        y = chunk[-1].y  

        return X, edge_index, edge_weight, y


In [None]:
from torch.utils.data import DataLoader

def a3tgcn2_collate_fn(batch):
    """
    Collate function for batching A3TGCN2 data.

    Args:
        batch (list): List of (X, edge_index, edge_weight, y) tuples.

    Returns:
        tuple: Batched X, edge_index, edge_weight, and y.
    """
    X_batch = torch.cat([item[0].unsqueeze(0) for item in batch], dim=0)  # [B, N, T, d]
    
    edge_indices = []
    edge_weights = []
    y_batch = []

    node_offset = 0  
    for i, (X, edge_index, edge_weight, y) in enumerate(batch):

        # NOTE: Edge data will be empty for this test, allowing the model only to use the node featur
        if len(edge_indices) == 0:
            edge_index = edge_index + node_offset
            # print("edge_index: ", edge_index)
            # print("edge_weight: ", edge_weight)
            # Make a dummy edge_index and edge_weight
            edge_index = torch.tensor([[1, 1],
                                        [1,1]], dtype=torch.long)
            edge_weight = torch.tensor([1.0, 1.0], dtype=torch.float)


            edge_indices.append(edge_index)
            edge_weights.append(edge_weight)

        y_batch.append(y)
        # node_offset += X.size(0)  # Update node offset (number of nodes in the graph)

    edge_index_batch = torch.cat(edge_indices, dim=1)  # [2, num_edges]
    edge_weight_batch = torch.cat(edge_weights)  # [num_edges]
    y_batch = torch.cat(y_batch)  # [B * N]

    # re-adjust shape such that it is [B, N, d, T]
    X_batch = X_batch.permute(0, 1, 3, 2)

    return X_batch, edge_index_batch, edge_weight_batch, y_batch



# Model parameters
# num_nodes = 114     # Number of stocks in the graph
# in_channels = 64    # Number of features per node
# periods = 5        # Number of historical time steps (T)
# out_channels = 1    # Predicting a single scalar per node (e.g., regression)
# hidden_dim = 32     # Size of the GRU hidden state
# batch_size = 1      # Batch size


num_nodes    = example_chunk[0].x.size()[0]     # Number of stocks in the graph
in_channels  = example_chunk[0].x.size()[1]     # Number of features per node
periods      = len(example_chunk)               # Number of historical time steps (T)
out_channels = 1                      # Predicting a single scalar per node (e.g., regression)
hidden_dim   = 512                    # Size of the GRU hidden state
batch_size   = 1                      # Batch size


# Create DataLoader
train_dataset = A3TGCN2ChunkDataset(train_files)
val_dataset = A3TGCN2ChunkDataset(val_files)
test_dataset = A3TGCN2ChunkDataset(test_files)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  collate_fn=a3tgcn2_collate_fn)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, collate_fn=a3tgcn2_collate_fn)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, collate_fn=a3tgcn2_collate_fn)



nan_vals_detected = False

for idx, (X_batch, edge_index_batch, edge_weight_batch, y_batch) in enumerate(train_loader):
    if idx == 0:
        print("")
        print(f"X_batch dtype: {X_batch.dtype}, X_batch shape: {X_batch.shape}")
        print(f"edge_index_batch dtype: {edge_index_batch.dtype}, edge_index_batch shape: {edge_index_batch.shape}")
        print(f"edge_weight_batch dtype: {edge_weight_batch.dtype}, edge_weight_batch shape: {edge_weight_batch.shape}")
        print(f"y_batch dtype: {y_batch.dtype}, y_batch shape: {y_batch.shape}")
        print(f"y_batch mean: {y_batch.mean()},  unique: {y_batch.unique()}")
        print("")

        print("X_batch: ", X_batch.mean(axis=3).shape)

    if torch.isnan(X_batch).any():
        print("Nan value X_batch")
        nan_vals_detected = True
    if torch.isnan(edge_index_batch).any():
        print("Nan value detected in edge_index_batch")
        nan_vals_detected = True
    if torch.isnan(edge_weight_batch).any():
        print("Nan value detected in edge_weight_batch")
        nan_vals_detected = True
    if torch.isnan(y_batch).any():
        print("Nan value detected in y_batch")
        nan_vals_detected = True

print("")
if not nan_vals_detected:
    print("No Nan values detected in the data.")
else:
    print("Nan values detected in the data.")

    


X_batch dtype: torch.float32, X_batch shape: torch.Size([1, 734, 71, 5])
edge_index_batch dtype: torch.int64, edge_index_batch shape: torch.Size([2, 2])
edge_weight_batch dtype: torch.float32, edge_weight_batch shape: torch.Size([2])
y_batch dtype: torch.float32, y_batch shape: torch.Size([734])
y_batch mean: 1.3746594190597534,  unique: tensor([-2., -1.,  0.,  1.,  2.])

X_batch:  torch.Size([1, 734, 71])

No Nan values detected in the data.


In [None]:
class A3TGCN2WithCustomOutput(nn.Module):
    def __init__(self, in_channels, hidden_dim, out_channels, num_nodes, periods, batch_size, device):
        super(A3TGCN2WithCustomOutput, self).__init__()
        self.device = device
        self.base_model = A3TGCN2(
            in_channels=in_channels,
            out_channels=hidden_dim,
            periods=periods,
            batch_size=1,
            add_self_loops = True
        ).to(device)
        # self.dropout = nn.Dropout(p=0.1)  # Drop 10% of activations
        # self.activation = nn.Tanh()
        self.output_layer = nn.Linear(hidden_dim, out_channels).to(device)

    def forward(self, X, edge_index, edge_weight, H=None):
        # Forward pass through A3TGCN2
        # X = self.dropout(X)
        H_output = self.base_model(X, edge_index, edge_weight, H)
        out = self.output_layer(H_output).squeeze(-1)  # [B * N]

        return out.view(-1)


def initialize_weights(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)


model = A3TGCN2WithCustomOutput(
    in_channels=in_channels,
    hidden_dim=hidden_dim,
    out_channels=out_channels,
    num_nodes=num_nodes,
    periods=periods,
    batch_size = batch_size,
    device=DEVICE
).to(DEVICE)

model.apply(initialize_weights)


class WeightedMSELoss(nn.Module):
    def __init__(self, weights=None):
        """
        Weighted MSE Loss function.
        
        Args:
            weights (dict): A dictionary mapping target values to weights. This is done as some labels are more important to predict accurately than others.
                            For this use case, we want to accurately predict the outliers (e.g., -2 and 2) more than stock with no movement (e.g., 0).
                                      Example: {-2: 4.0, -1: 2.0, 0: 1.0, 1: 2.0, 2: 4.0}.
                            The Default is currently all 1's, as i want to the option to test the model without the weights to see if it improves the model.
        """
        super(WeightedMSELoss, self).__init__()
        self.weights = weights if weights else {-2: 1, -1: 1.0, 0: 1.0, 1: 1.0, 2: 1.0}
        self.base_loss = nn.MSELoss(reduction='none')

    def forward(self, predictions, targets):
        """
        Compute the loss.
        
        Args:
            predictions (torch.Tensor): Predicted values from the model. Shape: [batch_size, ...].
            targets (torch.Tensor): Ground truth values. Shape: [batch_size, ...].
            
        Returns:
            torch.Tensor: Computed loss value.
        """
        # Get the weights for each target
        device = predictions.device
        target_weights = torch.tensor([self.weights[int(target.item())] for target in targets], device=device)

        # Compute the base MSE loss
        mse_loss = self.base_loss(predictions, targets)

        # Encourage higher predictions:
        # I had a problem with the model predicting 0's, followed by only positive values between 0.2 and 0.6.
        # This was due to the model being penalized equally for false positives and false negatives.
        # This penalty system encourages the model to predict higher values for targets `1` and `2`, and lower values for target `-2`.

        too_low_prediction_penalty = torch.zeros_like(predictions)
        too_high_prediction_penalty = torch.zeros_like(predictions)

        # Apply false-negative penalty for targets `1` or `2` (encouraging high predictions)
        high_target_mask = (targets >= 1) 
        too_low_prediction_penalty[high_target_mask] = (predictions[high_target_mask] < targets[high_target_mask]).float() * 1.5
        too_high_prediction_penalty[high_target_mask] = (predictions[high_target_mask] > targets[high_target_mask]).float() * 0.75

        # Apply false-positive penalty for targets `-2` (encouraging low predictions)
        low_target_mask = (targets == -2)
        too_low_prediction_penalty[low_target_mask] = (predictions[low_target_mask] > targets[low_target_mask]).float() * 1.5
        too_high_prediction_penalty[low_target_mask] = (predictions[low_target_mask] < targets[low_target_mask]).float() * 0.75

        #The Combined penalties are subtracted from eachother to balance the loss, ensuring that the model is not encouraged to predict high values for all targets.
        penalties = 1 + too_low_prediction_penalty - too_high_prediction_penalty

        weighted_loss = mse_loss * target_weights

        return weighted_loss.mean()


# # Optimizer and loss function
optimizer = torch.optim.Adam(list(model.parameters()), lr=0.00001) 
# weights = {-2: 4.0, -1: 2.0, 0: 1.0, 1: 2.0, 2: 4.0}
weights = {-2: 3.0, -1: 2.0, 0: 0.5, 1: 2.0, 2: 5.0}
criterion = WeightedMSELoss(weights=weights)
criterion_unweighted = WeightedMSELoss(weights=None)

print(model)

A3TGCN2WithCustomOutput(
  (base_model): A3TGCN2(
    (_base_tgcn): TGCN2(
      (conv_z): GCNConv(71, 512)
      (linear_z): Linear(in_features=1024, out_features=512, bias=True)
      (conv_r): GCNConv(71, 512)
      (linear_r): Linear(in_features=1024, out_features=512, bias=True)
      (conv_h): GCNConv(71, 512)
      (linear_h): Linear(in_features=1024, out_features=512, bias=True)
    )
  )
  (output_layer): Linear(in_features=512, out_features=1, bias=True)
)


In [11]:
def calculate_test_loss(model, test_loader, criterion, device):
    """
    Calculates the test loss for the given model.

    Args:
        model (nn.Module): The trained model to evaluate.
        test_loader (DataLoader): DataLoader for the test dataset.
        criterion (nn.Module): Loss function to use.
        device (torch.device): Device to perform computations on.

    Returns:
        float: The average test loss.
    """
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():  # Disable gradient computation for testing
        for X_batch, edge_index_batch, edge_weight_batch, y_batch in test_loader:
            # Move data to the specified device
            X_batch, edge_index_batch, edge_weight_batch, y_batch = (
                X_batch.to(device),
                edge_index_batch.to(device),
                edge_weight_batch.to(device),
                y_batch.to(device),
            )

            # Forward pass through the model
            predictions = model(X_batch, edge_index_batch, edge_weight_batch)
            # predictions = torch.nan_to_num(predictions, nan=0.0)

            # Compute loss
            loss = criterion(predictions, y_batch)
            total_loss += loss.item()
            num_batches += 1

    # Compute average loss
    avg_loss = total_loss / num_batches 
    return avg_loss

In [None]:
# Reload the model at the last state from the checkpoint IF the notebook crashed or was stopped for some reason.

# model_name = "A3TGCN2_2"
# checkpoint = torch.load(f"../Data/Models/{model_name}.pt")
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# best_val_loss = checkpoint['best_val_loss']
# model = model.to(DEVICE)

In [None]:
# Training Loop
epochs = 500
force_epochs = 10
stop_training = False
patience = 5
cur_patience = 0
best_val_loss = float("inf")
model_name = "A3TGCN2_7"

train_losses = []
train_unweighted_losses = []
val_losses = []
val_unweighted_losses = []

if not os.path.exists(f"../Data/Models/"):
    os.makedirs(f"../Data/Models/")
# elif os.path.exists(f"../Data/Models/{model_name}.pt"):
#     print(f"Model {model_name} already exists. Do you want to overwrite it? (y/n)")
#     response = input()
#     if response.lower() != "y":
#         print("Exiting training loop...")
#         sys.exit()

for epoch in range(0,epochs):
    total_loss = 0
    total_unweighted_loss = 0
    model.train()
    for X_batch, edge_index_batch, edge_weight_batch, y_batch in tqdm(train_loader, total = len(train_loader), desc=f"Epoch {epoch + 1}/{epochs} in progress..."):
        # Move to DEVICE
        # X_batch = clip_outliers(X_batch)
        if stop_training:
            break

        X_batch, edge_index_batch, edge_weight_batch, y_batch = (
            X_batch.to(DEVICE), 
            edge_index_batch.to(DEVICE), 
            edge_weight_batch.to(DEVICE),  
            y_batch.to(DEVICE)
        )

        #[B, N, d, T]
        batch_size, num_nodes, num_features, time_steps = X_batch.size()
        if epoch == 0:
            assert not torch.isnan(edge_weight_batch).any(), "Edge weights contain NaN values!"
            assert not torch.isinf(edge_weight_batch).any(), "Edge weights contain Inf values!"
            assert edge_index_batch.dim() == 2 and edge_index_batch.size(0) == 2, "edge_index must have shape [2, num_edges]"
            assert edge_index_batch.max() < num_nodes, "edge_index contains indices outside the valid node range"
            assert edge_index_batch.min() >= 0, "edge_index contains negative indices"
            assert edge_weight_batch.size(0) == edge_index_batch.size(1), "edge_weight size must match the number of edges"

        # Add self-loops to edge indices
        # edge_index_batch, edge_weight_batch = add_remaining_self_loops(
        #     edge_index_batch, edge_attr=edge_weight_batch, fill_value=1.0
        # )

        if epoch == 0:
            assert not torch.isnan(X_batch).any(), "Input features contain NaN values!"
            assert not torch.isinf(X_batch).any(), "Input features contain Inf values!"
            assert not torch.isnan(edge_weight_batch).any(), "Edge weights contain NaN values!"
            assert not torch.isinf(edge_weight_batch).any(), "Edge weights contain Inf values!"

        # Forward pass through A3TGCN2
        predictions = model(X_batch, edge_index_batch, edge_weight_batch)#, H)
        # predictions = torch.nan_to_num(predictions, nan=0.0)

        if torch.isnan(predictions).any():
            print("predictions", predictions)
            # replace NaN values with 0

        # if all predictions are nan, break
        if torch.isnan(predictions).all():
            print("All predictions are NaN")
            print("predictions: ", predictions)
            print("y_batch: ", y_batch)
            stop_training = True

        # Compute loss
        loss = criterion(predictions, y_batch)
        optimizer.zero_grad()
        loss.backward()

        # Clipping gradients in an attempt to avoid Nan/Inf values
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()

        # print("loss: ", loss)

        total_loss += loss.item()

    # Compute validation loss
    model.eval()
    val_loss = calculate_test_loss(model, val_loader, criterion, DEVICE)
    val_unweighted_loss = calculate_test_loss(model, val_loader, criterion_unweighted, DEVICE)
    # train_unweighted_loss = calculate_test_loss(model, train_loader, criterion_unweighted, DEVICE)
    total_loss = total_loss / len(train_loader)

    train_losses.append(total_loss)
    # train_unweighted_losses.append(train_unweighted_loss)
    val_losses.append(val_loss)
    val_unweighted_losses.append(val_unweighted_loss)

    # Early stopping
    force_epochs-=1
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        cur_patience = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
        }, f"../Data/Models/{model_name}.pt")
    else:
        if force_epochs <= 0:
            cur_patience += 1
            if cur_patience == patience:
                print(f"Early stopping: validation loss did not improve for {patience} epochs")
                break

    # print(f"Epoch {epoch + 1}/{epochs}, Training Loss Weighted MSE: {total_loss:.4f} | Unweighted MSE: {train_unweighted_loss:.4f}    Validation Loss Weighted MSE: {val_loss:.4f} | Unweighted MSE: {val_unweighted_loss:.4f}")
    print(f"Epoch {epoch + 1}/{epochs}, Training Loss Weighted MSE: {total_loss:.4f}    Validation Loss Weighted MSE: {val_loss:.4f} | Unweighted MSE: {val_unweighted_loss:.4f}")
    if stop_training:
            break


Epoch 1/500 in progress...: 100%|██████████| 4973/4973 [10:30<00:00,  7.89it/s]
Epoch 2/500 in progress...:   0%|          | 1/4973 [00:00<11:23,  7.27it/s]

Epoch 1/500, Training Loss Weighted MSE: 4.3542    Validation Loss Weighted MSE: 3.7209 | Unweighted MSE: 1.2097


Epoch 2/500 in progress...: 100%|██████████| 4973/4973 [10:49<00:00,  7.66it/s]
Epoch 3/500 in progress...:   0%|          | 1/4973 [00:00<11:36,  7.13it/s]

Epoch 2/500, Training Loss Weighted MSE: 4.3346    Validation Loss Weighted MSE: 3.7261 | Unweighted MSE: 1.2144


Epoch 3/500 in progress...: 100%|██████████| 4973/4973 [10:44<00:00,  7.71it/s]
Epoch 4/500 in progress...:   0%|          | 1/4973 [00:00<11:01,  7.52it/s]

Epoch 3/500, Training Loss Weighted MSE: 4.3306    Validation Loss Weighted MSE: 3.7310 | Unweighted MSE: 1.2416


Epoch 4/500 in progress...: 100%|██████████| 4973/4973 [10:48<00:00,  7.67it/s]
Epoch 5/500 in progress...:   0%|          | 1/4973 [00:00<11:08,  7.44it/s]

Epoch 4/500, Training Loss Weighted MSE: 4.3295    Validation Loss Weighted MSE: 3.7275 | Unweighted MSE: 1.2201


Epoch 5/500 in progress...: 100%|██████████| 4973/4973 [10:44<00:00,  7.72it/s]
Epoch 6/500 in progress...:   0%|          | 1/4973 [00:00<11:14,  7.37it/s]

Epoch 5/500, Training Loss Weighted MSE: 4.3274    Validation Loss Weighted MSE: 3.7255 | Unweighted MSE: 1.2239


Epoch 6/500 in progress...: 100%|██████████| 4973/4973 [10:19<00:00,  8.02it/s]
Epoch 7/500 in progress...:   0%|          | 1/4973 [00:00<10:10,  8.14it/s]

Epoch 6/500, Training Loss Weighted MSE: 4.3272    Validation Loss Weighted MSE: 3.7275 | Unweighted MSE: 1.2130


Epoch 7/500 in progress...: 100%|██████████| 4973/4973 [10:19<00:00,  8.03it/s]
Epoch 8/500 in progress...:   0%|          | 1/4973 [00:00<10:47,  7.68it/s]

Epoch 7/500, Training Loss Weighted MSE: 4.3239    Validation Loss Weighted MSE: 3.7316 | Unweighted MSE: 1.2181


Epoch 8/500 in progress...: 100%|██████████| 4973/4973 [10:26<00:00,  7.93it/s]
Epoch 9/500 in progress...:   0%|          | 1/4973 [00:00<11:00,  7.53it/s]

Epoch 8/500, Training Loss Weighted MSE: 4.3239    Validation Loss Weighted MSE: 3.7450 | Unweighted MSE: 1.2250


Epoch 9/500 in progress...: 100%|██████████| 4973/4973 [11:17<00:00,  7.34it/s]
Epoch 10/500 in progress...:   0%|          | 1/4973 [00:00<12:16,  6.75it/s]

Epoch 9/500, Training Loss Weighted MSE: 4.3227    Validation Loss Weighted MSE: 3.7400 | Unweighted MSE: 1.2214


Epoch 10/500 in progress...: 100%|██████████| 4973/4973 [11:21<00:00,  7.30it/s]
Epoch 11/500 in progress...:   0%|          | 1/4973 [00:00<11:45,  7.05it/s]

Epoch 10/500, Training Loss Weighted MSE: 4.3205    Validation Loss Weighted MSE: 3.7468 | Unweighted MSE: 1.2164


Epoch 11/500 in progress...: 100%|██████████| 4973/4973 [10:49<00:00,  7.66it/s]
Epoch 12/500 in progress...:   0%|          | 1/4973 [00:00<12:59,  6.38it/s]

Epoch 11/500, Training Loss Weighted MSE: 4.3217    Validation Loss Weighted MSE: 3.7264 | Unweighted MSE: 1.2189


Epoch 12/500 in progress...: 100%|██████████| 4973/4973 [10:36<00:00,  7.82it/s]
Epoch 13/500 in progress...:   0%|          | 1/4973 [00:00<11:05,  7.47it/s]

Epoch 12/500, Training Loss Weighted MSE: 4.3179    Validation Loss Weighted MSE: 3.7359 | Unweighted MSE: 1.2195


Epoch 13/500 in progress...: 100%|██████████| 4973/4973 [11:11<00:00,  7.41it/s]
Epoch 14/500 in progress...:   0%|          | 1/4973 [00:00<11:17,  7.34it/s]

Epoch 13/500, Training Loss Weighted MSE: 4.3194    Validation Loss Weighted MSE: 3.7321 | Unweighted MSE: 1.2392


Epoch 14/500 in progress...: 100%|██████████| 4973/4973 [11:26<00:00,  7.24it/s]


Early stopping: validation loss did not improve for 5 epochs


In [None]:
# model_name = "A3TGCN2_5"
# checkpoint = torch.load(f"../Data/Models/{model_name}.pt")
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# best_val_loss = checkpoint['best_val_loss']
# model = model.to(DEVICE)
# model

A3TGCN2WithCustomOutput(
  (base_model): A3TGCN2(
    (_base_tgcn): TGCN2(
      (conv_z): GCNConv(71, 512)
      (linear_z): Linear(in_features=1024, out_features=512, bias=True)
      (conv_r): GCNConv(71, 512)
      (linear_r): Linear(in_features=1024, out_features=512, bias=True)
      (conv_h): GCNConv(71, 512)
      (linear_h): Linear(in_features=1024, out_features=512, bias=True)
    )
  )
  (output_layer): Linear(in_features=512, out_features=1, bias=True)
)

In [248]:
# Once the model is done training, we can evaluate it on the test set, save the model weights, and save the predictions together with the ground truth values for further analysis.

def model_evaluation(model, test_loader, test_files, model_name):
    """
    Evaluates the model on the test set.  

    Args:
        model (nn.Module): The trained model to evaluate.
        test_loader (DataLoader): DataLoader for the test dataset.

    Returns:
        tuple: Predictions and ground truth values.
    """
    model.eval()  # Set the model to evaluation mode
    predictions = []
    val_predictions = []
    ground_truth = []
    val_ground_truth = []
    total_test_loss = 0
    total_val_loss = 0
    total_unweighted_test_loss = 0
    total_unweighted_val_loss = 0

    # ['../Data/Networks_chunks/window_size_5\\chunk_2014-04-09__5.pkl',
    date_of_predictions = [x.split("chunk_")[1].split("__")[0] for x in test_files]

    with torch.no_grad():  # Disable gradient computation for testing
        for X_batch, edge_index_batch, edge_weight_batch, y_batch in test_loader:
            # Move data to the specified device
            X_batch, edge_index_batch, edge_weight_batch, y_batch = (
                X_batch.to(DEVICE),
                edge_index_batch.to(DEVICE),
                edge_weight_batch.to(DEVICE),
                y_batch.to(DEVICE),
            )

            # Forward pass through the model
            batch_predictions = model(X_batch, edge_index_batch, edge_weight_batch)
            predictions.append(batch_predictions.cpu().numpy())
            ground_truth.append(y_batch.cpu().numpy())
            total_test_loss += criterion(batch_predictions, y_batch).item()
            total_unweighted_test_loss += criterion_unweighted(batch_predictions, y_batch).item()

        for X_batch, edge_index_batch, edge_weight_batch, y_batch in val_loader:
            # Move data to the specified device
            X_batch, edge_index_batch, edge_weight_batch, y_batch = (
                X_batch.to(DEVICE),
                edge_index_batch.to(DEVICE),
                edge_weight_batch.to(DEVICE),
                y_batch.to(DEVICE),
            )

            # Forward pass through the model
            val_batch_predictions = model(X_batch, edge_index_batch, edge_weight_batch)
            val_predictions.append(val_batch_predictions.cpu().numpy())
            val_ground_truth.append(y_batch.cpu().numpy())
            total_val_loss += criterion(val_batch_predictions, y_batch).item()
            total_unweighted_val_loss += criterion_unweighted(val_batch_predictions, y_batch).item()

    # Compute average test loss
    test_loss = total_test_loss / len(test_loader)
    test_unweighted_loss = total_unweighted_test_loss / len(test_loader)

    print("len(date_of_predictions): ", len(date_of_predictions))
    print("date_of_predictions[0]: ", date_of_predictions[0])
    print("")
    print("len(predictions): ", len(predictions))
    print("predictions[0]: ", predictions[0])
    print("")
    print("len(ground_truth): ", len(ground_truth))
    print("ground_truth[0]: ", ground_truth[0])


    # Merge results into a single dataframe
    results_dict ={
            "Date": date_of_predictions,
            "Prediction": predictions,
            "Ground Truth": ground_truth,
            "Val_Prediction": val_predictions,
            "Val_Ground Truth": val_ground_truth,
            "Model_Name": model_name,
            "Hidden_Dim": model.base_model.out_channels,
            "Periods": model.base_model.periods,
            "Batch_Size": model.base_model.batch_size,
            "Optimizer": optimizer.__class__.__name__,
            "Learning_Rate": optimizer.param_groups[0]['lr'],
            "Weight_Decay": optimizer.param_groups[0]['weight_decay'],
            "Loss_Function": "MSELoss",
            "Epochs": epoch,
            "Patience": patience,
            "Train_Loss": train_losses,
            "Train_Unweighted_Loss": train_unweighted_losses,
            "Val_Loss": val_losses,
            "Val_Unweighted_Loss": val_unweighted_losses,
            "Test_Loss": test_loss,
            "Test_Unweighted_Loss": test_unweighted_loss            
        }
    
    with open (f"../Data/Results/{model_name}_results.pkl", 'wb') as f:
        pickle.dump(results_dict, f)


# Evaluate the model

model_evaluation(model, test_loader, test_files, model_name)



len(date_of_predictions):  1422
date_of_predictions[0]:  2018-11-01

len(predictions):  1422
predictions[0]:  [ 0.18713504  0.26728505  0.33683348  0.23380937 -0.02067674  0.43942797
  0.14771695  0.1693549   0.29873222  0.07086572  0.26728505  0.25873852
  0.17155436  0.40968335  0.04240787  0.04857202  0.15908572  0.32416096
  0.19801128  0.06003696  0.05724679  0.01846218 -0.00455145  0.18727724
  0.1731407   0.08695132  0.27933028  0.06702836  0.26795694  0.10675919
  0.22364448  0.26728505  0.10188337 -0.06394638  0.4403374   0.19050533
  0.23495775  0.31654358  0.20709878  0.29873222 -0.42290804  0.21678388
  0.3579213   0.18909512  0.2052049   0.07013295  0.27693275  0.29873222
  0.26728505  0.30708897  0.29873222  0.27942348  0.1914568   0.12074582
  0.32461566  0.29680824  0.22394052  0.20393027  0.21460499  0.2557029
  0.08885728  0.08922023  0.27503622  0.23647043  0.27933028  0.19178775
  0.16422468  0.16215119  0.03340145  0.26728505  0.10050316  0.02050108
  0.27933028  0