# Dataset Construction

In [1]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignal, DynamicGraphTemporalSignal
import torch
from typing import Union
import glob
from natsort import natsorted
import random

random.seed(42)

class SP500CorrelationsDatasetLoader(object):
    def __init__(self, corr_name, corr_scope):
        self._read_csv(corr_name, corr_scope)

    def _load_global_corr(self, corr_name):
        return np.loadtxt(f'{corr_name}/global_corr.csv', delimiter=',')

    def _load_local_corrs(self, corr_name):
        _correlation_matrices = []
        corr_files = natsorted(glob.glob(f'{corr_name}/local_corr_*.csv'))
        for corr_file in corr_files:
            matrix = np.loadtxt(corr_file, delimiter=',')
            _correlation_matrices.append(matrix)
        return _correlation_matrices

    def _read_csv(self, corr_name, corr_scope):
        match corr_scope:
            case 'global':
                self._correlation_matrices = [self._load_global_corr(corr_name)]
            case 'local':
                self._correlation_matrices = self._load_local_corrs(corr_name)
            case 'dual':
                global_corr = self._load_global_corr(corr_name)
                self._correlation_matrices = [np.stack((global_corr, local_corr), axis=-1) for local_corr in self._load_local_corrs(corr_name)]
            case 'none':
                # None uses identity matrix as correlation
                global_corr = self._load_global_corr(corr_name)
                self._correlation_matrices = [np.eye(global_corr.shape[0], global_corr.shape[1])]
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)

        # Round data size to nearest multiple of batch_size
        self.days_in_quarter = 64
        num_quarters = data.size(0) // self.days_in_quarter
        num_days = num_quarters * self.days_in_quarter
        data = data[:num_days]
        
        # z-score normalization with training data following GERU
        train_days = int(0.8 * num_quarters) * self.days_in_quarter
        data = (data - data[:train_days].mean(dim=0)) / data[:train_days].std(dim=0)
        data = data.numpy()

        data = data[..., np.newaxis]

        # # Add percent change features
        # p_chg = data / np.roll(data, 1, axis=0) - 1
        # p_chg[0] = 0.0
        # p_chg_3 = data / np.roll(data, 3, axis=0) - 1
        # p_chg_3[0:3] = 0.0
        # p_chg_6 = data / np.roll(data, 6, axis=0) - 1
        # p_chg_6[0:6] = 0.0

        # data = np.stack([data, p_chg, p_chg_3, p_chg_6], axis=-1)
        # print('data.shape', data.shape)

        assert(not np.any(np.isnan(data)))
        self._dataset = data

    def _get_edges(self, times, overlap):
        if len(self._correlation_matrices) == 1:
            _edges = np.array(np.ones(self._correlation_matrices[0].shape[:2]).nonzero())
        else:
            _edges = []
            for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):
                if not time in times:
                    continue
                corr_index = max(0, time // self.days_in_quarter - 1)
                _edges.append(
                    np.array(np.ones(self._correlation_matrices[corr_index].shape[:2]).nonzero())
                )
        return _edges

    def _get_edge_weights(self, times, overlap):
        if len(self._correlation_matrices) == 1:
            # Flatten the first two dimensions
            w = self._correlation_matrices[0]
            _edge_weights = w.reshape((w.shape[0] * w.shape[1],) + w.shape[2:])
        else:
            _edge_weights = []
            for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):
                if not time in times:
                    continue
                corr_index = max(0, time // self.days_in_quarter - 1)
                # Flatten the first two dimensions
                w = self._correlation_matrices[corr_index]
                _edge_weights.append(
                    np.array(w.reshape((w.shape[0] * w.shape[1],) + w.shape[2:]))
                )
        return _edge_weights

    def _get_targets_and_features(self, times, overlap, predict_all):
        features = [
            self._dataset[i : i + self.batch_size, :]
            for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)
            if i in times
        ]
        # predict next-day stock prices
        targets = [
            (self._dataset[i+1 : i + self.batch_size+1, :, 0]).T if predict_all else (self._dataset[i + self.batch_size, :, 0]).T
            for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)
            if i in times
        ]
        return features, targets

    def get_dataset(self, batch_size, split) -> Union[StaticGraphTemporalSignal, DynamicGraphTemporalSignal]:
        """Returning the data iterator.
        """
        self.batch_size = batch_size

        total_times = list(range(0, self._dataset.shape[0] - self.batch_size, self.batch_size))

        if split == 'train':
            times = list(range(total_times[int(len(total_times) * 0)], total_times[int(len(total_times) * 0.8)]))
            overlap = self.batch_size
            predict_all = True
        elif split == 'val':
            times = list(range(total_times[int(len(total_times) * 0.8)], total_times[int(len(total_times) * 0.9)]))
            overlap = 1
            predict_all = False
        elif split == 'test':
            times = list(range(total_times[int(len(total_times) * 0.9)], total_times[-1] + self.batch_size))
            overlap = 1
            predict_all = False
        else:
            raise ValueError(f'Invalid split name: {split}')

        _edges = self._get_edges(times, overlap)
        _edge_weights = self._get_edge_weights(times, overlap)
        features, targets = self._get_targets_and_features(times, overlap, predict_all)
        dataset = (DynamicGraphTemporalSignal if type(_edges) == list else StaticGraphTemporalSignal)(
            _edges, _edge_weights, features, targets
        )
        return dataset

In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = 'cpu'

corr_name = 'mi'
corr_scope = 'global'
loader = SP500CorrelationsDatasetLoader(corr_name=corr_name, corr_scope=corr_scope)

lag_size = 64
train_dataset = loader.get_dataset(batch_size=lag_size * 2, split='train')
val_dataset = loader.get_dataset(batch_size=lag_size, split='val')
test_dataset = loader.get_dataset(batch_size=lag_size, split='test')

In [3]:
print(len(train_dataset.features))
print(len(train_dataset.targets))
print(len(val_dataset.features))
print(len(val_dataset.targets))
print(len(test_dataset.features))
print(len(test_dataset.targets))

15
15
256
256
256
256


# Differential Graph Transformer

In [4]:
import math
import torch
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from torch_geometric.nn.conv import MessagePassing

class DConv(MessagePassing):
    r"""An implementation of the Diffusion Convolution Layer.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Filter size :math:`K`.
        bias (bool, optional): If set to :obj:`False`, the layer
            will not learn an additive bias (default :obj:`True`).

    """

    def __init__(self, in_channels, out_channels, K, bias=True):
        super(DConv, self).__init__(aggr="add", flow="source_to_target")
        assert K > 0
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weight = torch.nn.Parameter(torch.Tensor(2, K, in_channels, out_channels))

        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)

        self.__reset_parameters()

    def __reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)
        torch.nn.init.zeros_(self.bias)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        adj_mat = to_dense_adj(edge_index, edge_attr=edge_weight)
        adj_mat = adj_mat.reshape(adj_mat.size(1), adj_mat.size(2))
        deg_out = torch.matmul(
            adj_mat, torch.ones(size=(adj_mat.size(0), 1)).to(X.device)
        )
        deg_out = deg_out.flatten()
        deg_in = torch.matmul(
            torch.ones(size=(1, adj_mat.size(0))).to(X.device), adj_mat
        )
        deg_in = deg_in.flatten()

        deg_out_inv = torch.reciprocal(deg_out)
        deg_in_inv = torch.reciprocal(deg_in)
        row, col = edge_index
        norm_out = deg_out_inv[row]
        norm_in = deg_in_inv[row]

        reverse_edge_index = adj_mat.transpose(0, 1)
        reverse_edge_index, vv = dense_to_sparse(reverse_edge_index)

        Tx_0 = X
        Tx_1 = X
        H = torch.matmul(Tx_0, (self.weight[0])[0]) + torch.matmul(
            Tx_0, (self.weight[1])[0]
        )

        if self.weight.size(1) > 1:
            Tx_1_o = self.propagate(edge_index, x=X, norm=norm_out, size=None)
            Tx_1_i = self.propagate(reverse_edge_index, x=X, norm=norm_in, size=None)
            H = (
                H
                + torch.matmul(Tx_1_o, (self.weight[0])[1])
                + torch.matmul(Tx_1_i, (self.weight[1])[1])
            )

        for k in range(2, self.weight.size(1)):
            Tx_2_o = self.propagate(edge_index, x=Tx_1_o, norm=norm_out, size=None)
            Tx_2_o = 2.0 * Tx_2_o - Tx_0
            Tx_2_i = self.propagate(
                reverse_edge_index, x=Tx_1_i, norm=norm_in, size=None
            )
            Tx_2_i = 2.0 * Tx_2_i - Tx_0
            H = (
                H
                + torch.matmul(Tx_2_o, (self.weight[0])[k])
                + torch.matmul(Tx_2_i, (self.weight[1])[k])
            )
            Tx_0, Tx_1_o, Tx_1_i = Tx_1, Tx_2_o, Tx_2_i

        if self.bias is not None:
            H += self.bias

        return H


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from multihead_diffattn import MultiheadDiffAttn

class FeedForward(nn.Module):
    def __init__(self, hidden_size, expand_ratio, dropout):
        super(FeedForward, self).__init__()
        self.linear = nn.Linear(hidden_size, hidden_size * expand_ratio)
        self.linear2 = nn.Linear(hidden_size * expand_ratio, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

class Attention(nn.Module):
    def __init__(self, d_model, num_heads, expand_ratio, dropout, attn_variant='standard'):
        super().__init__()
        self.attn_variant = attn_variant
        if attn_variant == 'standard':
            self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
            self.ln1 = nn.LayerNorm(d_model)
        elif attn_variant == 'diff':
            self.mha = MultiheadDiffAttn(embed_dim=d_model, num_heads=num_heads, depth=0)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(hidden_size=d_model, expand_ratio=expand_ratio, dropout=dropout)

    def forward(self, x, A=None, attn_mask=None, need_weights=False):
        if self.attn_variant == 'standard':
            x1, attn_weights = self.mha(x, x, x, attn_mask=attn_mask, need_weights=need_weights, average_attn_weights=False)
            x2 = self.ln1(x + x1)
        elif self.attn_variant == 'diff':
            x2, attn_weights = self.mha(x, A, attn_mask=attn_mask)
        x = self.ln2(self.ffn(x2) + x2)
        if need_weights:
            return (x, attn_weights)
        else:
            return x


class GraphTransformer(nn.Module):
    def __init__(self, in_channels=1, out_channels=32, attn_variant='standard', num_heads=2, expand_ratio=1, dropout=0.1, T = 128, N=472):
        super().__init__()
        self.T = T
        self.N = N
        self.attn_variant = attn_variant
        self.d_model = out_channels
        self.num_heads = num_heads
        self.input_proj = nn.Linear(1, out_channels)
        self.time_embedding = nn.Embedding(T, out_channels)
        self.stock_embedding = nn.Embedding(N, out_channels)
        self.spatial_attn = Attention(out_channels, num_heads, expand_ratio, dropout, attn_variant=attn_variant)
        self.temporal_attn = Attention(out_channels, num_heads, expand_ratio, dropout)
    
    def forward(self, x, edge_index, edge_weight, hidden, need_weights=False):
        N, T, D = x.size()
        assert(D == 1)
        # print(T, self.T, N, self.N)
        assert(T <= self.T and N == self.N)

        x = x.reshape(T, N, D)

        # print(x.size())
        # print(x.view(T, N, 1).size())
        x = self.input_proj(x)
        # print('after input_proj', x.size())
        stock_embs = self.stock_embedding(torch.arange(N).unsqueeze(0).expand(T, N).to(x.device))
        # print('stock_embs.size', stock_embs.size())
        x += stock_embs
        time_embs = self.time_embedding(torch.arange(T).unsqueeze(0).expand(N, T).to(x.device))
        # print('time_embs.size', time_embs.size())
        x += time_embs.view(T, N, self.d_model)

        # x = self.input_proj(x.view(T, N, 1))
        # IDEA: Each spatial head takes in a different type of correlation matrix.
        # Like one takes in positive pearson's coefficnet and the other takes in negative

        x = x.view(N, T, self.d_model)
        temporal_causal_mask = torch.triu(torch.ones((T, T), dtype=torch.bool), diagonal=1).expand(N * self.num_heads, T, T).to(x.device)
        x = self.temporal_attn(x, attn_mask=temporal_causal_mask, need_weights=need_weights) + x

        x = x.view(T, N, self.d_model)
        
        if self.attn_variant == 'diff':
            A = to_dense_adj(edge_index, edge_attr=edge_weight)
            # Encountered more than one adjacency matrices, e.g. dual correlations
            if len(A.size()) == 4:
                A = A.reshape(A.size(-1), A.size(1), A.size(2))
            x = self.spatial_attn(x, A, need_weights=need_weights) + x
        else:
            x = self.spatial_attn(x, need_weights=need_weights) + x

        x = x.view(N, T, self.d_model)

        return x

# Plain RNN

In [6]:
class GRU(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(GRU, self).__init__()
        self.rnn = nn.GRU(input_size=in_channels, hidden_size=out_channels, num_layers=2, batch_first=True)

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        outputs, _ = self.rnn(X, H)
        return outputs

# T-GCN with GAT

In [7]:
import torch
from torch_geometric.nn import GATv2Conv

# https://pytorch-geometric-temporal.readthedocs.io/en/latest/_modules/torch_geometric_temporal/nn/recurrent/temporalgcn.html#TGCN
class TGCN(torch.nn.Module):
    r"""An implementation of the Temporal Graph Convolutional Gated Recurrent Cell.
    For details see this paper: `"T-GCN: A Temporal Graph ConvolutionalNetwork for
    Traffic Prediction." <https://arxiv.org/abs/1811.05320>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        num_heads (int): Number of attention heads
        add_self_loops (bool): Adding self-loops for smoothing. Default is True.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_heads: int = 2,
        add_self_loops: bool = True,
    ):
        super(TGCN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.add_self_loops = add_self_loops

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):

        self.conv_z = GATv2Conv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            heads=self.num_heads,
            add_self_loops=self.add_self_loops,
            edge_dim=1,
        )

        self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_reset_gate_parameters_and_layers(self):

        self.conv_r = GATv2Conv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            heads=self.num_heads,
            add_self_loops=self.add_self_loops,
            edge_dim=1,
        )

        self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_candidate_state_parameters_and_layers(self):

        self.conv_h = GATv2Conv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            heads=self.num_heads,
            add_self_loops=self.add_self_loops,
            edge_dim=1,
        )

        self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=1)
        Z = self.linear_z(Z)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([self.conv_r(X, edge_index, edge_weight), H], axis=1)
        R = self.linear_r(R)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([self.conv_h(X, edge_index, edge_weight), H * R], axis=1)
        H_tilde = self.linear_h(H_tilde)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
            self,
            X: torch.FloatTensor,
            edge_index: torch.LongTensor,
            edge_weight: torch.FloatTensor = None,
            H: torch.FloatTensor = None,
        ) -> torch.FloatTensor:
            """
            Making a forward pass. If edge weights are not present the forward pass
            defaults to an unweighted graph. If the hidden state matrix is not present
            when the forward pass is called it is initialized with zeros.

            Arg types:
                * **X** *(PyTorch Float Tensor)* - Node features.
                * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
                * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
                * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.

            Return types:
                * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
            """
            H = self._set_hidden_state(X, H)
            Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
            R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
            H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
            H = self._calculate_hidden_state(Z, H, H_tilde)
            return H

# RGCN

In [8]:
import torch
import torch.nn.functional as F

class RecurrentGNN(torch.nn.Module):
    def __init__(self, gnn, node_features, hidden_size=32, **kwargs):
        super(RecurrentGNN, self).__init__()
        self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        self.linear = torch.nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, hidden):
        outputs = self.recurrent(x, edge_index, edge_weight, hidden)
        return self.linear(F.relu(outputs)), outputs

In [9]:
import math

def rmse(y_hat, y):
    return math.sqrt(F.mse_loss(y_hat, y).item())

def mae(y_hat, y):
    return F.l1_loss(y_hat, y).item()

In [10]:
import wandb
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

def eval(epoch, model, eval_dataset, eval_name):
    model.eval()
    with torch.no_grad():
        def process(snapshot):
            X = snapshot.x.to(device)
            if model.recurrent.__class__.__name__ in ['GRU', 'GraphTransformer']:
                batch_y_hats, _ = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=None)
                return batch_y_hats[:, -1]
            else:
                H = None
                for x in X:
                    y_hat, H = model(x, snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=H)
                return y_hat
        # with ThreadPoolExecutor(max_workers=1) as executor:
        y_hats = list(tqdm(map(process, eval_dataset), total=len(eval_dataset), desc=eval_name))
        ys = [snapshot.y.to(device) for snapshot in eval_dataset]
        y_hats, ys = torch.cat(y_hats, dim=0).squeeze().to(device), torch.cat(ys, dim=0).to(device)
        eval_rmse = rmse(y_hats, ys)
        eval_mae = mae(y_hats, ys)
        wandb.log({"epoch": epoch,
                f"{eval_name}/rmse": eval_rmse,
                f"{eval_name}/mae": eval_mae })
        print(f'Epoch {epoch}, {eval_name}/rmse: {eval_rmse}, {eval_name}/mae: {eval_mae}')
        return (eval_rmse, eval_mae)

In [11]:
from tqdm import tqdm
import wandb
from torch_geometric_temporal import DCRNN

gnn = TGCN

node_features = 1

model = RecurrentGNN(gnn = gnn, node_features = node_features, num_heads=1).to(device)

if gnn.__name__ == 'GRU':
    model_name = f'{gnn.__name__}'
elif gnn.__name__ == 'DCRNN':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}'
elif gnn.__name__ == 'TGCN':
    model_name = f'{gnn.__name__}_gat_{corr_name}_{corr_scope}'
elif gnn.__name__ == 'GraphTransformer':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}{f"_{model.recurrent.attn_variant}"}'

lr = 1e-2
num_epochs = 100
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
track_with_wandb = True

if track_with_wandb:
    wandb.init(project="cs224w-stock-market-prediction", config={
        "dataset": "S&P500",
        "corr_name": corr_name,
        "corr_scope": corr_scope,
        "learning_rate": lr,
        "epochs": num_epochs,
        "architecture": gnn.__name__,
    })

best_val_rmse = float('inf')

batch_size = 64
eval_per_epoch = 10

train_samples = list(train_dataset)
val_samples = list(val_dataset)
test_samples = list(test_dataset)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for step, snapshot in tqdm(enumerate(train_samples), total=len(train_samples), desc=f'Epoch {epoch}'):
        X = snapshot.x.to(device)
        if model.recurrent.__class__.__name__ in ['GRU', 'GraphTransformer']:
            y_hats, _ = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=None)
        else:
            H = None
            y_hats = []
            for x in X:
                y_hat, H = model(x, snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=H)
                y_hats.append(y_hat.squeeze())
            y_hats = torch.stack(y_hats, dim=1)
        loss = F.mse_loss(y_hats.squeeze(), snapshot.y.to(device))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        wandb.log({"epoch": epoch,
                   "step": step,
                    "train/loss": loss.item() })
    train_loss /= len(train_samples)

    if track_with_wandb and epoch % eval_per_epoch == 0:
        val_rmse, val_mae = eval(epoch, model, val_samples, 'val')
        if val_rmse < best_val_rmse:
            best_rmse = val_rmse
            torch.save(model.state_dict(), f'{model_name}.pth')


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: 100%|██████████| 15/15 [14:24<00:00, 57.64s/it]
val: 100%|██████████| 256/256 [43:11<00:00, 10.12s/it]


Epoch 0, val/rmse: 1.5127118616448685, val/mae: 1.061437964439392


Epoch 1: 100%|██████████| 15/15 [13:50<00:00, 55.39s/it]
Epoch 2: 100%|██████████| 15/15 [17:08<00:00, 68.58s/it]
Epoch 3: 100%|██████████| 15/15 [17:31<00:00, 70.09s/it]
Epoch 4: 100%|██████████| 15/15 [16:58<00:00, 67.90s/it]
Epoch 5: 100%|██████████| 15/15 [17:04<00:00, 68.27s/it]
Epoch 6: 100%|██████████| 15/15 [17:20<00:00, 69.40s/it]
Epoch 7: 100%|██████████| 15/15 [16:49<00:00, 67.32s/it]
Epoch 8: 100%|██████████| 15/15 [18:19<00:00, 73.29s/it]
Epoch 9: 100%|██████████| 15/15 [16:57<00:00, 67.81s/it]
Epoch 10: 100%|██████████| 15/15 [16:30<00:00, 66.04s/it]
val: 100%|██████████| 256/256 [44:08<00:00, 10.35s/it]


Epoch 10, val/rmse: 1.080539101429334, val/mae: 0.6324087977409363


Epoch 11: 100%|██████████| 15/15 [15:33<00:00, 62.23s/it]
Epoch 12: 100%|██████████| 15/15 [16:58<00:00, 67.88s/it]
Epoch 13: 100%|██████████| 15/15 [17:17<00:00, 69.19s/it]
Epoch 14: 100%|██████████| 15/15 [17:46<00:00, 71.13s/it]
Epoch 15: 100%|██████████| 15/15 [16:51<00:00, 67.41s/it]
Epoch 16: 100%|██████████| 15/15 [18:08<00:00, 72.56s/it]
Epoch 17: 100%|██████████| 15/15 [18:03<00:00, 72.21s/it]
Epoch 18: 100%|██████████| 15/15 [19:46<00:00, 79.08s/it]
Epoch 19: 100%|██████████| 15/15 [36:48<00:00, 147.23s/it]
Epoch 20:   7%|▋         | 1/15 [01:17<18:02, 77.29s/it]wandb-core(88410) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(88457) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(88477) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(88492) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(88

In [13]:
if track_with_wandb:
    best_model = RecurrentGNN(gnn = gnn, node_features = node_features, num_heads=1).to(device)
    best_model.load_state_dict(torch.load(f'{model_name}.pth', weights_only=True))
    test_rmse, test_mae = eval(epoch, best_model, test_samples, 'test')


test:   0%|          | 0/256 [00:00<?, ?it/s]

test: 100%|██████████| 256/256 [00:07<00:00, 33.98it/s]


Epoch 99, test/rmse: 0.3686766559592353, test/mae: 0.13227011263370514
