# Dataset Construction

In [1]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignal, DynamicGraphTemporalSignal
import torch
from typing import Union
import glob
from natsort import natsorted
import random

random.seed(42)

class SP500CorrelationsDatasetLoader(object):
    def __init__(self, corr_name, corr_scope):
        self._read_csv(corr_name, corr_scope)

    def _read_csv(self, corr_name, corr_scope):
        match corr_scope:
            case 'global':
                self._correlation_matrices = [np.loadtxt(f'{corr_name}/{corr_scope}_corr.csv', delimiter=',')]
            case 'local':
                self._correlation_matrices = []
                corr_files = natsorted(glob.glob(f'{corr_name}/local_corr_*.csv'))
                for corr_file in corr_files:
                    matrix = np.loadtxt(corr_file, delimiter=',')
                    self._correlation_matrices.append(matrix)
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)

        # Round data size to nearest multiple of batch_size
        self.days_in_quarter = 64
        num_quarters = data.size(0) // self.days_in_quarter
        num_days = num_quarters * self.days_in_quarter
        data = data[:num_days]
        
        # z-score normalization with training data following GERU
        train_days = int(0.8 * num_quarters) * self.days_in_quarter
        data = (data - data[:train_days].mean(dim=0)) / data[:train_days].std(dim=0)
        data = data.numpy()

        # Add percent change features
        p_chg = data / np.roll(data, 1, axis=0) - 1
        p_chg[0] = 0.0
        p_chg_3 = data / np.roll(data, 3, axis=0) - 1
        p_chg_3[0:3] = 0.0
        p_chg_6 = data / np.roll(data, 6, axis=0) - 1
        p_chg_6[0:6] = 0.0

        data = np.stack([data, p_chg, p_chg_3, p_chg_6], axis=-1)
        print('data.shape', data.shape)

        assert(not np.any(np.isnan(data)))
        self._dataset = data

    def _get_edges(self):
        if len(self._correlation_matrices) == 1:
            self._edges = np.array(np.ones_like(self._correlation_matrices[0]).nonzero())
        else:
            self._edges = []
            for time in range(self._dataset.shape[0] - self.lags):
                corr_index = max(0, time // self.days_in_quarter - 1)
                self._edges.append(
                    np.array(np.ones_like(self._correlation_matrices[corr_index]).nonzero())
                )

    def _get_edge_weights(self):
        if len(self._correlation_matrices) == 1:
            self._edge_weights = self._correlation_matrices[0].flatten()
        else:
            self._edge_weights = []
            for time in range(self._dataset.shape[0] - self.lags):
                corr_index = max(0, time // self.days_in_quarter - 1)
                self._edge_weights.append(
                    np.array(self._correlation_matrices[corr_index]).flatten()
                )

    def _get_targets_and_features(self):
        stacked_target = self._dataset
        # print(stacked_target.shape)
        self.features = [
            stacked_target[i : i + self.lags, :]
            for i in range(stacked_target.shape[0] - self.lags)
        ]
        # predict next-day stock movement
        self.targets = [
            (stacked_target[i + self.lags, :, 0]).T
            for i in range(stacked_target.shape[0] - self.lags)
        ]

    def get_dataset(self, lags) -> Union[StaticGraphTemporalSignal, DynamicGraphTemporalSignal]:
        """Returning the data iterator.
        """
        self.lags = lags
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        dataset = (DynamicGraphTemporalSignal if type(self._edges) == list else StaticGraphTemporalSignal)(
            self._edges, self._edge_weights, self.features, self.targets
        )
        return dataset

In [2]:
from torch_geometric_temporal.signal import temporal_signal_split

# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = 'cpu'

corr_name = 'mi'
corr_scope = 'global'
loader = SP500CorrelationsDatasetLoader(corr_name=corr_name, corr_scope=corr_scope)

lags = 6
dataset = loader.get_dataset(lags)

train_dataset, test_val_dataset = temporal_signal_split(dataset, train_ratio=0.8)
val_dataset, test_dataset = temporal_signal_split(test_val_dataset, train_ratio=0.5)

data.shape (2496, 472, 4)


In [3]:
print(train_dataset.features[0].shape)
print(train_dataset.targets[0].shape)

(6, 472, 4)
(472,)


# Differential Graph Transformer

In [4]:
import torch
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn.conv import MessagePassing
import torch.nn as nn
import torch.nn.functional as F


class DGAttn(MessagePassing):
    r"""An implementation of the Diffusion Convolution Layer.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Filter size :math:`K`.
        bias (bool, optional): If set to :obj:`False`, the layer
            will not learn an additive bias (default :obj:`True`).

    """

    def __init__(self, in_channels, out_channels, num_heads=1, diff_attn=True):
        super(DGAttn, self).__init__(aggr="add", flow="source_to_target")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.head_dim = out_channels // self.num_heads
        self.diff_attn = diff_attn
        if self.diff_attn:
            self.Q = nn.Linear(in_channels, out_channels)
            self.K = nn.Linear(in_channels, out_channels)
        self.V = nn.Linear(in_channels, out_channels)
        self.ffn = nn.Linear(out_channels, out_channels)
        self.ln = nn.RMSNorm(self.head_dim, eps=1e-5, elementwise_affine=True)

        if self.diff_attn:
            self.lambda_init = 0.2
            self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

        # self.__reset_parameters()

    # def __reset_parameters(self):
        # torch.nn.init.xavier_uniform_(self.V)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        A = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
        a1 = F.softmax(A, dim=-1)
        if self.diff_attn:
            q = self.Q(X)
            k = self.K(X)
            a2 = F.softmax(torch.matmul(q, k.transpose(-1, -2)) * (self.head_dim ** -0.5), dim=-1)
            lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
            lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
            lambda_full = lambda_1 - lambda_2 + self.lambda_init
            a = a1 - lambda_full * a2
        else:
            a = a1
        v = self.V(X)
        attn = torch.matmul(a, v)
        attn = self.ln(attn)
        if self.diff_attn:
            attn = attn * (1 - self.lambda_init)
        attn = self.ffn(attn)
        return attn


class DGRNN(torch.nn.Module):
    r"""An implementation of the Diffusion Convolutional Gated Recurrent Unit.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        num_heads (int): Number of attention heads.
    """

    def __init__(self, in_channels: int, out_channels: int, num_heads: int, diff_attn: bool):
        super(DGRNN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.diff_attn = diff_attn,

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):
        self.conv_x_z = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_reset_gate_parameters_and_layers(self):
        self.conv_x_r = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_candidate_state_parameters_and_layers(self):
        self.conv_x_h = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([X, H], dim=1)
        Z = self.conv_x_z(Z, edge_index, edge_weight)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([X, H], dim=1)
        R = self.conv_x_r(R, edge_index, edge_weight)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([X, H * R], dim=1)
        H_tilde = self.conv_x_h(H_tilde, edge_index, edge_weight)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.
            * **H** (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        # print('X.shape', X.shape)
        # print('H.shape', H.shape)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H


# Plain RNN

In [5]:
class GRU(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(GRU, self).__init__()
        self.rnn = nn.GRU(input_size=in_channels, hidden_size=out_channels, num_layers=2, batch_first=True)

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        outputs, H = self.rnn(X, H)
        return H[-1]

# RGCN

In [6]:
import torch
import torch.nn.functional as F

class RecurrentGNN(torch.nn.Module):
    def __init__(self, gnn, node_features, hidden_size=32, **kwargs):
        super(RecurrentGNN, self).__init__()
        if gnn.__name__ == "DCRNN":
            self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        elif gnn.__name__ == "DGRNN":
            self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        elif gnn.__name__ == "GRU":
            self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        else:
            raise ValueError(f"Invalid GNN type {gnn.__name__}")
        self.linear = torch.nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, hidden):
        h0 = self.recurrent(x, edge_index, edge_weight, hidden)
        h = F.relu(h0)
        h = self.linear(h)
        return (h, h0)

In [7]:
import math

def rmse(y_hat, y):
    return math.sqrt(F.mse_loss(y_hat, y).item())

def mae(y_hat, y):
    return F.l1_loss(y_hat, y).item()

In [8]:
import wandb

def eval(epoch, model, eval_dataset, eval_name):
    model.eval()
    with torch.no_grad():
        y_hats = []
        ys = []
        for snapshot in eval_dataset:
            X = snapshot.x.to(device)
            h = None
            if model.recurrent.__class__.__name__ == "GRU":
                y_hat, _ = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), h)
            else:
                for x in X:
                    y_hat, h = model(x, snapshot.edge_index.to(device), snapshot.edge_attr.to(device), h)
            y_hats.append(y_hat)
            ys.append(snapshot.y.to(device))
        y_hats, ys = torch.stack(list(y_hats)).squeeze().to(device), torch.stack(list(ys)).to(device)
        eval_rmse = rmse(y_hats, ys)
        eval_mae = mae(y_hats, ys)
        wandb.log({"epoch": epoch,
                f"{eval_name}/rmse": eval_rmse,
                f"{eval_name}/mae": eval_mae })
        print(f'Epoch {epoch}, {eval_name}/rmse: {eval_rmse}, {eval_name}/mae: {eval_mae}')
        return (eval_rmse, eval_mae)

In [9]:
from tqdm import tqdm
import wandb
from torch_geometric_temporal.nn.recurrent import DCRNN

gnn = GRU

model = RecurrentGNN(gnn = gnn, node_features = 4).to(device)

if gnn.__name__ == 'GRU':
    model_name = f'{gnn.__name__}'
elif gnn.__name__ == 'DCRNN':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}'
elif gnn.__name__ == 'DGRNN':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}{"_diff_attn" if model.recurrent.diff_attn else ""}'

lr = 1e-3
num_epochs = 50
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
track_with_wandb = True

if track_with_wandb:
    wandb.init(project="cs224w-stock-market-prediction", config={
        "dataset": "S&P500",
        "corr_name": corr_name,
        "corr_scope": corr_scope,
        "learning_rate": lr,
        "epochs": num_epochs,
        "architecture": gnn.__name__,
    })

best_val_rmse = float('inf')

batch_size = 64

def split_into_batches(lst, batch_size):
    return [lst[i:i + batch_size] for i in range(0, len(lst), batch_size)]

train_samples = list(train_dataset)
train_batches = split_into_batches(train_samples, batch_size)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    random.shuffle(train_batches)
    for step, batch in tqdm(enumerate(train_batches), total=len(train_batches), desc=f'Epoch {epoch}'):
        y_hats = []
        ys = []
        for snapshot in batch:
            X = snapshot.x.to(device)
            h = None
            if model.recurrent.__class__.__name__ == "GRU":
                y_hat, _ = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), h)
            else:
                for x in X:
                    y_hat, h = model(x, snapshot.edge_index.to(device), snapshot.edge_attr.to(device), h)
            y_hats.append(y_hat)
            ys.append(snapshot.y.to(device))
        y_hats, ys = torch.stack(y_hats).squeeze(), torch.stack(ys)
        loss = F.mse_loss(y_hats, ys)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        wandb.log({"epoch": epoch,
                   "step": epoch * len(train_samples) + step,
                    "train/loss": loss.item() })
    train_loss /= len(train_samples)

    if track_with_wandb:
        val_rmse, val_mae = eval(epoch, model, val_dataset, 'val')
        if val_rmse < best_val_rmse:
            best_rmse = val_rmse
            torch.save(model.state_dict(), f'{model_name}.pth')


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: 100%|██████████| 32/32 [00:15<00:00,  2.09it/s]


Epoch 0, val/rmse: 1.5263910078165706, val/mae: 1.0989220142364502


Epoch 1: 100%|██████████| 32/32 [00:15<00:00,  2.04it/s]


Epoch 1, val/rmse: 1.1081476541216275, val/mae: 0.5569794774055481


Epoch 2: 100%|██████████| 32/32 [00:16<00:00,  2.00it/s]


Epoch 2, val/rmse: 1.037747978541152, val/mae: 0.4860028624534607


Epoch 3: 100%|██████████| 32/32 [00:16<00:00,  1.97it/s]


Epoch 3, val/rmse: 1.0014471669080778, val/mae: 0.43980199098587036


Epoch 4: 100%|██████████| 32/32 [00:16<00:00,  1.92it/s]


Epoch 4, val/rmse: 0.9761763457722158, val/mae: 0.4146272540092468


Epoch 5: 100%|██████████| 32/32 [00:15<00:00,  2.01it/s]


Epoch 5, val/rmse: 0.9475856279819469, val/mae: 0.3821941018104553


Epoch 6: 100%|██████████| 32/32 [00:16<00:00,  1.94it/s]


Epoch 6, val/rmse: 0.9195391269758066, val/mae: 0.35318058729171753


Epoch 7: 100%|██████████| 32/32 [00:16<00:00,  1.98it/s]


Epoch 7, val/rmse: 0.8990149591026663, val/mae: 0.34921517968177795


Epoch 8: 100%|██████████| 32/32 [00:16<00:00,  1.94it/s]


Epoch 8, val/rmse: 0.8856155994807454, val/mae: 0.3115306794643402


Epoch 9: 100%|██████████| 32/32 [00:17<00:00,  1.83it/s]


Epoch 9, val/rmse: 0.868349482470958, val/mae: 0.29591983556747437


Epoch 10: 100%|██████████| 32/32 [00:16<00:00,  1.92it/s]


Epoch 10, val/rmse: 0.8593699368414413, val/mae: 0.28167763352394104


Epoch 11: 100%|██████████| 32/32 [00:16<00:00,  1.95it/s]


Epoch 11, val/rmse: 0.8482982120725764, val/mae: 0.2676037847995758


Epoch 12: 100%|██████████| 32/32 [00:17<00:00,  1.82it/s]


Epoch 12, val/rmse: 0.8389395881290931, val/mae: 0.25787636637687683


Epoch 13: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s]


Epoch 13, val/rmse: 0.8257770361382294, val/mae: 0.24264487624168396


Epoch 14: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s]


Epoch 14, val/rmse: 0.8179161674263496, val/mae: 0.2353849858045578


Epoch 15: 100%|██████████| 32/32 [00:18<00:00,  1.76it/s]


Epoch 15, val/rmse: 0.8118694132553184, val/mae: 0.22822602093219757


Epoch 16: 100%|██████████| 32/32 [00:16<00:00,  1.89it/s]


Epoch 16, val/rmse: 0.8047864538292688, val/mae: 0.22047922015190125


Epoch 17: 100%|██████████| 32/32 [00:17<00:00,  1.81it/s]


Epoch 17, val/rmse: 0.7975381821726879, val/mae: 0.21484145522117615


Epoch 18: 100%|██████████| 32/32 [00:17<00:00,  1.88it/s]


Epoch 18, val/rmse: 0.794705110655562, val/mae: 0.2104140669107437


Epoch 19: 100%|██████████| 32/32 [00:22<00:00,  1.44it/s]


Epoch 19, val/rmse: 0.7882887190586862, val/mae: 0.20334938168525696


Epoch 20: 100%|██████████| 32/32 [00:21<00:00,  1.48it/s]


Epoch 20, val/rmse: 0.784823750554938, val/mae: 0.23194439709186554


Epoch 21: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s]


Epoch 21, val/rmse: 0.7956290134229818, val/mae: 0.2555084824562073


Epoch 22: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s]


Epoch 22, val/rmse: 0.7760413295736583, val/mae: 0.19505389034748077


Epoch 23: 100%|██████████| 32/32 [00:21<00:00,  1.51it/s]


Epoch 23, val/rmse: 0.7769099803239807, val/mae: 0.19751904904842377


Epoch 24: 100%|██████████| 32/32 [00:20<00:00,  1.53it/s]


Epoch 24, val/rmse: 0.7698561084360038, val/mae: 0.1848697066307068


Epoch 25: 100%|██████████| 32/32 [00:19<00:00,  1.66it/s]


Epoch 25, val/rmse: 0.7667474887235877, val/mae: 0.18205635249614716


Epoch 26: 100%|██████████| 32/32 [00:20<00:00,  1.57it/s]


Epoch 26, val/rmse: 0.7634947978818649, val/mae: 0.17923355102539062


Epoch 27: 100%|██████████| 32/32 [00:21<00:00,  1.47it/s]


Epoch 27, val/rmse: 0.7602492413778889, val/mae: 0.17567890882492065


Epoch 28: 100%|██████████| 32/32 [00:21<00:00,  1.52it/s]


Epoch 28, val/rmse: 0.7586980934550509, val/mae: 0.17511647939682007


Epoch 29: 100%|██████████| 32/32 [00:18<00:00,  1.69it/s]


Epoch 29, val/rmse: 0.7549866361887934, val/mae: 0.170461043715477


Epoch 30: 100%|██████████| 32/32 [00:19<00:00,  1.64it/s]


Epoch 30, val/rmse: 0.7532897563233281, val/mae: 0.169327512383461


Epoch 31: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s]


Epoch 31, val/rmse: 0.7505095261077863, val/mae: 0.16687002778053284


Epoch 32: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s]


Epoch 32, val/rmse: 0.7468911348778758, val/mae: 0.16333219408988953


Epoch 33: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s]


Epoch 33, val/rmse: 0.7457854747512689, val/mae: 0.1621895581483841


Epoch 34: 100%|██████████| 32/32 [00:19<00:00,  1.65it/s]


Epoch 34, val/rmse: 0.7429262199222652, val/mae: 0.15959322452545166


Epoch 35: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s]


Epoch 35, val/rmse: 0.7415178590785855, val/mae: 0.1585560441017151


Epoch 36: 100%|██████████| 32/32 [00:22<00:00,  1.45it/s]


Epoch 36, val/rmse: 0.7399706609812355, val/mae: 0.15878956019878387


Epoch 37: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s]


Epoch 37, val/rmse: 0.7364594117354105, val/mae: 0.15650033950805664


Epoch 38: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s]


Epoch 38, val/rmse: 0.7348962516612253, val/mae: 0.15301468968391418


Epoch 39: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s]


Epoch 39, val/rmse: 0.7357229888007085, val/mae: 0.15812242031097412


Epoch 40: 100%|██████████| 32/32 [00:20<00:00,  1.59it/s]


Epoch 40, val/rmse: 0.7920170226242532, val/mae: 0.317247211933136


Epoch 41: 100%|██████████| 32/32 [00:19<00:00,  1.63it/s]


Epoch 41, val/rmse: 0.7327376435200134, val/mae: 0.1587306261062622


Epoch 42: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s]


Epoch 42, val/rmse: 0.7307659095522664, val/mae: 0.1512630581855774


Epoch 43: 100%|██████████| 32/32 [00:20<00:00,  1.56it/s]


Epoch 43, val/rmse: 0.7287065598380134, val/mae: 0.1479777842760086


Epoch 44: 100%|██████████| 32/32 [00:19<00:00,  1.61it/s]


Epoch 44, val/rmse: 0.7268871689464168, val/mae: 0.14604264497756958


Epoch 45: 100%|██████████| 32/32 [00:19<00:00,  1.62it/s]


Epoch 45, val/rmse: 0.7261013495398693, val/mae: 0.14695191383361816


Epoch 46: 100%|██████████| 32/32 [00:18<00:00,  1.70it/s]


Epoch 46, val/rmse: 0.7245590727175599, val/mae: 0.14558178186416626


Epoch 47: 100%|██████████| 32/32 [00:19<00:00,  1.60it/s]


Epoch 47, val/rmse: 0.7225406863626448, val/mae: 0.14256401360034943


Epoch 48: 100%|██████████| 32/32 [00:19<00:00,  1.67it/s]


Epoch 48, val/rmse: 0.720767093290359, val/mae: 0.14812614023685455


Epoch 49: 100%|██████████| 32/32 [00:20<00:00,  1.60it/s]


Epoch 49, val/rmse: 0.7201615860208117, val/mae: 0.14087459444999695


In [10]:
if track_with_wandb:
    best_model = RecurrentGNN(gnn = gnn, node_features = 4).to(device)
    best_model.load_state_dict(torch.load(f'{model_name}.pth', weights_only=True))
    test_rmse, test_mae = eval(epoch, best_model, test_dataset, 'test')


Epoch 49, test/rmse: 3.268968686434329, test/mae: 0.5551874041557312
