# Dataset Construction

In [4]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignal, DynamicGraphTemporalSignal
import torch
from typing import Union
import glob
from natsort import natsorted
import random

random.seed(42)

class SP500CorrelationsDatasetLoader(object):
    def __init__(self, corr_name, corr_scope):
        self._read_csv(corr_name, corr_scope)

    def _read_csv(self, corr_name, corr_scope):
        match corr_scope:
            case 'global':
                self._correlation_matrices = [np.loadtxt(f'{corr_name}/{corr_scope}_corr.csv', delimiter=',')]
            case 'local':
                self._correlation_matrices = []
                corr_files = natsorted(glob.glob(f'{corr_name}/local_corr_*.csv'))
                for corr_file in corr_files:
                    matrix = np.loadtxt(corr_file, delimiter=',')
                    self._correlation_matrices.append(matrix)
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)

        # Round data size to nearest multiple of batch_size
        self.days_in_quarter = 64
        num_quarters = data.size(0) // self.days_in_quarter
        num_days = num_quarters * self.days_in_quarter
        data = data[:num_days]
        
        # z-score normalization with training data following GERU
        train_days = int(0.8 * num_quarters) * self.days_in_quarter
        data = (data - data[:train_days].mean(dim=0)) / data[:train_days].std(dim=0)
        data = data.numpy()

        data = data[..., np.newaxis]

        # # Add percent change features
        # p_chg = data / np.roll(data, 1, axis=0) - 1
        # p_chg[0] = 0.0
        # p_chg_3 = data / np.roll(data, 3, axis=0) - 1
        # p_chg_3[0:3] = 0.0
        # p_chg_6 = data / np.roll(data, 6, axis=0) - 1
        # p_chg_6[0:6] = 0.0

        # data = np.stack([data, p_chg, p_chg_3, p_chg_6], axis=-1)
        # print('data.shape', data.shape)

        assert(not np.any(np.isnan(data)))
        self._dataset = data

    def _get_edges(self, times, overlap):
        if len(self._correlation_matrices) == 1:
            _edges = np.array(np.ones_like(self._correlation_matrices[0]).nonzero())
        else:
            _edges = []
            for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):
                if not time in times:
                    continue
                corr_index = max(0, time // self.days_in_quarter - 1)
                _edges.append(
                    np.array(np.ones_like(self._correlation_matrices[corr_index]).nonzero())
                )
        return _edges

    def _get_edge_weights(self, times, overlap):
        if len(self._correlation_matrices) == 1:
            _edge_weights = self._correlation_matrices[0].flatten()
        else:
            _edge_weights = []
            for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):
                if not time in times:
                    continue
                corr_index = max(0, time // self.days_in_quarter - 1)
                _edge_weights.append(
                    np.array(self._correlation_matrices[corr_index]).flatten()
                )
        return _edge_weights

    def _get_targets_and_features(self, times, overlap, predict_all):
        features = [
            self._dataset[i : i + self.batch_size, :]
            for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)
            if i in times
        ]
        # predict next-day stock prices
        targets = [
            (self._dataset[i+1 : i + self.batch_size+1, :, 0]).T if predict_all else (self._dataset[i + self.batch_size, :, 0]).T
            for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)
            if i in times
        ]
        return features, targets

    def get_dataset(self, batch_size, split) -> Union[StaticGraphTemporalSignal, DynamicGraphTemporalSignal]:
        """Returning the data iterator.
        """
        self.batch_size = batch_size

        total_times = list(range(0, self._dataset.shape[0] - self.batch_size, self.batch_size))

        if split == 'train':
            times = list(range(total_times[int(len(total_times) * 0.6)], total_times[int(len(total_times) * 0.8)]))
            overlap = self.batch_size
            predict_all = True
        elif split == 'val':
            times = list(range(total_times[int(len(total_times) * 0.8)], total_times[int(len(total_times) * 0.85)]))
            overlap = 1
            predict_all = False
        elif split == 'test':
            times = list(range(total_times[int(len(total_times) * 0.85)], total_times[int(len(total_times) * 0.9)]))
            overlap = 1
            predict_all = False
        else:
            raise ValueError(f'Invalid split name: {split}')

        _edges = self._get_edges(times, overlap)
        _edge_weights = self._get_edge_weights(times, overlap)
        features, targets = self._get_targets_and_features(times, overlap, predict_all)
        dataset = (DynamicGraphTemporalSignal if type(_edges) == list else StaticGraphTemporalSignal)(
            _edges, _edge_weights, features, targets
        )
        return dataset

In [5]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = 'cpu'

corr_name = 'mi'
corr_scope = 'global'
loader = SP500CorrelationsDatasetLoader(corr_name=corr_name, corr_scope=corr_scope)

lag_size = 64
train_dataset = loader.get_dataset(batch_size=lag_size * 2, split='train')
val_dataset = loader.get_dataset(batch_size=lag_size, split='val')
test_dataset = loader.get_dataset(batch_size=lag_size, split='test')

In [6]:
print(len(train_dataset.features))
print(len(train_dataset.targets))
print(len(val_dataset.features))
print(len(val_dataset.targets))
print(len(test_dataset.features))
print(len(test_dataset.targets))

4
4
128
128
128
128


# Differential Graph Transformer

In [7]:
import torch
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn.conv import MessagePassing
import torch.nn as nn
import torch.nn.functional as F


class DGAttn(MessagePassing):
    r"""An implementation of the Diffusion Convolution Layer.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Filter size :math:`K`.
        bias (bool, optional): If set to :obj:`False`, the layer
            will not learn an additive bias (default :obj:`True`).

    """

    def __init__(self, in_channels, out_channels, num_heads=1, diff_attn=True):
        super(DGAttn, self).__init__(aggr="add", flow="source_to_target")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.head_dim = out_channels // self.num_heads
        self.diff_attn = diff_attn
        if self.diff_attn:
            self.Q = nn.Linear(in_channels, out_channels)
            self.K = nn.Linear(in_channels, out_channels)
        self.V = nn.Linear(in_channels, out_channels)
        self.ffn = nn.Linear(out_channels, out_channels)
        self.ln = nn.RMSNorm(self.head_dim, eps=1e-5, elementwise_affine=True)

        if self.diff_attn:
            self.lambda_init = 0.2
            self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

        # self.__reset_parameters()

    # def __reset_parameters(self):
        # torch.nn.init.xavier_uniform_(self.V)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        A = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
        a1 = F.softmax(A, dim=-1)
        if self.diff_attn:
            q = self.Q(X)
            k = self.K(X)
            a2 = F.softmax(torch.matmul(q, k.transpose(-1, -2)) * (self.head_dim ** -0.5), dim=-1)
            lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
            lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
            lambda_full = lambda_1 - lambda_2 + self.lambda_init
            a = a1 - lambda_full * a2
        else:
            a = a1
        v = self.V(X)
        attn = torch.matmul(a, v)
        attn = self.ln(attn)
        if self.diff_attn:
            attn = attn * (1 - self.lambda_init)
        attn = self.ffn(attn)
        return attn


class DGRNN(torch.nn.Module):
    r"""An implementation of the Diffusion Convolutional Gated Recurrent Unit.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        num_heads (int): Number of attention heads.
    """

    def __init__(self, in_channels: int, out_channels: int, num_heads: int, diff_attn: bool):
        super(DGRNN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.diff_attn = diff_attn,

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):
        self.conv_x_z = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_reset_gate_parameters_and_layers(self):
        self.conv_x_r = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_candidate_state_parameters_and_layers(self):
        self.conv_x_h = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([X, H], dim=1)
        Z = self.conv_x_z(Z, edge_index, edge_weight)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([X, H], dim=1)
        R = self.conv_x_r(R, edge_index, edge_weight)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([X, H * R], dim=1)
        H_tilde = self.conv_x_h(H_tilde, edge_index, edge_weight)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.
            * **H** (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H


# Plain RNN

In [8]:
class GRU(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(GRU, self).__init__()
        self.rnn = nn.GRU(input_size=in_channels, hidden_size=out_channels, num_layers=2, batch_first=True)

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        outputs, _ = self.rnn(X, H)
        return outputs

# RGCN

In [9]:
import torch
import torch.nn.functional as F

class RecurrentGNN(torch.nn.Module):
    def __init__(self, gnn, node_features, hidden_size=32, **kwargs):
        super(RecurrentGNN, self).__init__()
        self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        self.linear = torch.nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, hidden):
        outputs = self.recurrent(x, edge_index, edge_weight, hidden)
        return self.linear(F.relu(outputs)), outputs

In [10]:
import math

def rmse(y_hat, y):
    return math.sqrt(F.mse_loss(y_hat, y).item())

def mae(y_hat, y):
    return F.l1_loss(y_hat, y).item()

In [11]:
import wandb
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

def eval(epoch, model, eval_dataset, eval_name):
    model.eval()
    with torch.no_grad():
        def process(snapshot):
            X = snapshot.x.to(device)
            if model.recurrent.__class__.__name__ == 'GRU':
                batch_y_hats, _ = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=None)
                return batch_y_hats[:, -1]
            else:
                H = None
                for x in X:
                    y_hat, H = model(x, snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=H)
                return y_hat
        with ThreadPoolExecutor() as executor:
            y_hats = list(tqdm(executor.map(process, eval_dataset), total=len(eval_dataset), desc=eval_name))
        ys = [snapshot.y.to(device) for snapshot in eval_dataset]
        y_hats, ys = torch.cat(y_hats, dim=0).squeeze().to(device), torch.cat(ys, dim=0).to(device)
        eval_rmse = rmse(y_hats, ys)
        eval_mae = mae(y_hats, ys)
        wandb.log({"epoch": epoch,
                f"{eval_name}/rmse": eval_rmse,
                f"{eval_name}/mae": eval_mae })
        print(f'Epoch {epoch}, {eval_name}/rmse: {eval_rmse}, {eval_name}/mae: {eval_mae}')
        return (eval_rmse, eval_mae)

In [12]:
from tqdm import tqdm
import wandb
from torch_geometric_temporal import DCRNN

gnn = GRU

node_features = 1

model = RecurrentGNN(gnn = gnn, node_features = node_features).to(device)

if gnn.__name__ == 'GRU':
    model_name = f'{gnn.__name__}'
elif gnn.__name__ == 'DCRNN':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}'
elif gnn.__name__ == 'DGRNN':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}{"_diff_attn" if model.recurrent.diff_attn else ""}'

lr = 1e-2
num_epochs = 100
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
track_with_wandb = True

if track_with_wandb:
    wandb.init(project="cs224w-stock-market-prediction", config={
        "dataset": "S&P500",
        "corr_name": corr_name,
        "corr_scope": corr_scope,
        "learning_rate": lr,
        "epochs": num_epochs,
        "architecture": gnn.__name__,
    })

best_val_rmse = float('inf')

batch_size = 64
eval_per_epoch = 10

train_samples = list(train_dataset)
val_samples = list(val_dataset)
test_samples = list(test_dataset)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for step, snapshot in tqdm(enumerate(train_samples), total=len(train_samples), desc=f'Epoch {epoch}'):
        X = snapshot.x.to(device)
        if model.recurrent.__class__.__name__ == 'GRU':
            y_hats, _ = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=None)
        else:
            H = None
            y_hats = []
            for x in X:
                y_hat, H = model(x, snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=H)
                y_hats.append(y_hat.squeeze())
            y_hats = torch.stack(y_hats, dim=1)
        loss = F.mse_loss(y_hats.squeeze(), snapshot.y.to(device))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        wandb.log({"epoch": epoch,
                   "step": step,
                    "train/loss": loss.item() })
    train_loss /= len(train_samples)

    if track_with_wandb and epoch % eval_per_epoch == 0:
        val_rmse, val_mae = eval(epoch, model, val_samples, 'val')
        if val_rmse < best_val_rmse:
            best_rmse = val_rmse
            torch.save(model.state_dict(), f'{model_name}.pth')


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: 100%|██████████| 4/4 [00:00<00:00,  5.24it/s]
val: 100%|██████████| 128/128 [00:01<00:00, 83.77it/s]


Epoch 0, val/rmse: 0.8863358195926226, val/mae: 0.6591887474060059


Epoch 1: 100%|██████████| 4/4 [00:00<00:00,  6.64it/s]
Epoch 2: 100%|██████████| 4/4 [00:00<00:00,  6.62it/s]
Epoch 3: 100%|██████████| 4/4 [00:00<00:00,  5.51it/s]
Epoch 4: 100%|██████████| 4/4 [00:00<00:00,  6.47it/s]
Epoch 5: 100%|██████████| 4/4 [00:00<00:00,  6.31it/s]
Epoch 6: 100%|██████████| 4/4 [00:00<00:00,  6.84it/s]
Epoch 7: 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
Epoch 8: 100%|██████████| 4/4 [00:00<00:00,  6.93it/s]
Epoch 9: 100%|██████████| 4/4 [00:00<00:00,  6.66it/s]
Epoch 10: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]
val: 100%|██████████| 128/128 [00:00<00:00, 131.57it/s]


Epoch 10, val/rmse: 0.3345378728618125, val/mae: 0.15500696003437042


Epoch 11: 100%|██████████| 4/4 [00:00<00:00,  6.32it/s]
Epoch 12: 100%|██████████| 4/4 [00:00<00:00,  5.64it/s]
Epoch 13: 100%|██████████| 4/4 [00:00<00:00,  5.81it/s]
Epoch 14: 100%|██████████| 4/4 [00:00<00:00,  6.35it/s]
Epoch 15: 100%|██████████| 4/4 [00:00<00:00,  6.52it/s]
Epoch 16: 100%|██████████| 4/4 [00:00<00:00,  6.71it/s]
Epoch 17: 100%|██████████| 4/4 [00:00<00:00,  6.88it/s]
Epoch 18: 100%|██████████| 4/4 [00:00<00:00,  6.56it/s]
Epoch 19: 100%|██████████| 4/4 [00:00<00:00,  6.20it/s]
Epoch 20: 100%|██████████| 4/4 [00:00<00:00,  6.10it/s]
val: 100%|██████████| 128/128 [00:00<00:00, 136.17it/s]


Epoch 20, val/rmse: 0.2824798424247039, val/mae: 0.1358330249786377


Epoch 21: 100%|██████████| 4/4 [00:00<00:00,  6.61it/s]
Epoch 22: 100%|██████████| 4/4 [00:00<00:00,  6.61it/s]
Epoch 23: 100%|██████████| 4/4 [00:00<00:00,  6.23it/s]
Epoch 24: 100%|██████████| 4/4 [00:00<00:00,  6.50it/s]
Epoch 25: 100%|██████████| 4/4 [00:00<00:00,  6.81it/s]
Epoch 26: 100%|██████████| 4/4 [00:00<00:00,  6.03it/s]
Epoch 27: 100%|██████████| 4/4 [00:00<00:00,  6.54it/s]
Epoch 28: 100%|██████████| 4/4 [00:00<00:00,  6.19it/s]
Epoch 29: 100%|██████████| 4/4 [00:00<00:00,  6.51it/s]
Epoch 30: 100%|██████████| 4/4 [00:00<00:00,  6.09it/s]
val: 100%|██████████| 128/128 [00:01<00:00, 114.89it/s]


Epoch 30, val/rmse: 0.25519357203916737, val/mae: 0.1097416803240776


Epoch 31: 100%|██████████| 4/4 [00:00<00:00,  6.22it/s]
Epoch 32: 100%|██████████| 4/4 [00:00<00:00,  6.14it/s]
Epoch 33: 100%|██████████| 4/4 [00:00<00:00,  6.47it/s]
Epoch 34: 100%|██████████| 4/4 [00:00<00:00,  6.71it/s]
Epoch 35: 100%|██████████| 4/4 [00:00<00:00,  6.48it/s]
Epoch 36: 100%|██████████| 4/4 [00:00<00:00,  6.35it/s]
Epoch 37: 100%|██████████| 4/4 [00:00<00:00,  6.56it/s]
Epoch 38: 100%|██████████| 4/4 [00:00<00:00,  7.22it/s]
Epoch 39: 100%|██████████| 4/4 [00:00<00:00,  6.60it/s]
Epoch 40: 100%|██████████| 4/4 [00:00<00:00,  6.59it/s]
val: 100%|██████████| 128/128 [00:01<00:00, 116.35it/s]


Epoch 40, val/rmse: 0.23163425604055768, val/mae: 0.10077028721570969


Epoch 41: 100%|██████████| 4/4 [00:00<00:00,  6.59it/s]
Epoch 42: 100%|██████████| 4/4 [00:00<00:00,  6.65it/s]
Epoch 43: 100%|██████████| 4/4 [00:00<00:00,  6.55it/s]
Epoch 44: 100%|██████████| 4/4 [00:00<00:00,  6.76it/s]
Epoch 45: 100%|██████████| 4/4 [00:00<00:00,  5.53it/s]
Epoch 46: 100%|██████████| 4/4 [00:00<00:00,  5.32it/s]
Epoch 47: 100%|██████████| 4/4 [00:00<00:00,  6.72it/s]
Epoch 48: 100%|██████████| 4/4 [00:00<00:00,  6.46it/s]
Epoch 49: 100%|██████████| 4/4 [00:00<00:00,  6.75it/s]
Epoch 50: 100%|██████████| 4/4 [00:00<00:00,  6.54it/s]
val: 100%|██████████| 128/128 [00:00<00:00, 140.93it/s]


Epoch 50, val/rmse: 0.21789922630919797, val/mae: 0.09768307954072952


Epoch 51: 100%|██████████| 4/4 [00:00<00:00,  6.77it/s]
Epoch 52: 100%|██████████| 4/4 [00:00<00:00,  6.91it/s]
Epoch 53: 100%|██████████| 4/4 [00:00<00:00,  6.65it/s]
Epoch 54: 100%|██████████| 4/4 [00:00<00:00,  6.49it/s]
Epoch 55: 100%|██████████| 4/4 [00:00<00:00,  6.63it/s]
Epoch 56: 100%|██████████| 4/4 [00:00<00:00,  6.66it/s]
Epoch 57: 100%|██████████| 4/4 [00:00<00:00,  6.80it/s]
Epoch 58: 100%|██████████| 4/4 [00:00<00:00,  6.43it/s]
Epoch 59: 100%|██████████| 4/4 [00:00<00:00,  7.14it/s]
Epoch 60: 100%|██████████| 4/4 [00:00<00:00,  6.79it/s]
val: 100%|██████████| 128/128 [00:01<00:00, 108.94it/s]


Epoch 60, val/rmse: 0.23716223722925972, val/mae: 0.11480558663606644


Epoch 61: 100%|██████████| 4/4 [00:00<00:00,  6.51it/s]
Epoch 62: 100%|██████████| 4/4 [00:00<00:00,  6.91it/s]
Epoch 63: 100%|██████████| 4/4 [00:00<00:00,  7.05it/s]
Epoch 64: 100%|██████████| 4/4 [00:00<00:00,  6.96it/s]
Epoch 65: 100%|██████████| 4/4 [00:00<00:00,  7.03it/s]
Epoch 66: 100%|██████████| 4/4 [00:00<00:00,  6.91it/s]
Epoch 67: 100%|██████████| 4/4 [00:00<00:00,  6.79it/s]
Epoch 68: 100%|██████████| 4/4 [00:00<00:00,  5.51it/s]
Epoch 69: 100%|██████████| 4/4 [00:00<00:00,  6.69it/s]
Epoch 70: 100%|██████████| 4/4 [00:00<00:00,  7.00it/s]
val: 100%|██████████| 128/128 [00:01<00:00, 126.94it/s]


Epoch 70, val/rmse: 0.20837200322481989, val/mae: 0.0897238701581955


Epoch 71: 100%|██████████| 4/4 [00:00<00:00,  6.30it/s]
Epoch 72: 100%|██████████| 4/4 [00:00<00:00,  6.55it/s]
Epoch 73: 100%|██████████| 4/4 [00:00<00:00,  6.47it/s]
Epoch 74: 100%|██████████| 4/4 [00:00<00:00,  6.52it/s]
Epoch 75: 100%|██████████| 4/4 [00:00<00:00,  5.49it/s]
Epoch 76: 100%|██████████| 4/4 [00:00<00:00,  6.62it/s]
Epoch 77: 100%|██████████| 4/4 [00:00<00:00,  6.21it/s]
Epoch 78: 100%|██████████| 4/4 [00:00<00:00,  6.09it/s]
Epoch 79: 100%|██████████| 4/4 [00:00<00:00,  6.40it/s]
Epoch 80: 100%|██████████| 4/4 [00:00<00:00,  6.26it/s]
val: 100%|██████████| 128/128 [00:01<00:00, 121.94it/s]


Epoch 80, val/rmse: 0.2000919354255372, val/mae: 0.08869493007659912


Epoch 81: 100%|██████████| 4/4 [00:00<00:00,  6.13it/s]
Epoch 82: 100%|██████████| 4/4 [00:00<00:00,  6.10it/s]
Epoch 83: 100%|██████████| 4/4 [00:00<00:00,  6.26it/s]
Epoch 84: 100%|██████████| 4/4 [00:00<00:00,  6.00it/s]
Epoch 85: 100%|██████████| 4/4 [00:00<00:00,  5.94it/s]
Epoch 86: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]
Epoch 87: 100%|██████████| 4/4 [00:00<00:00,  5.21it/s]
Epoch 88: 100%|██████████| 4/4 [00:00<00:00,  5.99it/s]
Epoch 89: 100%|██████████| 4/4 [00:00<00:00,  6.40it/s]
Epoch 90: 100%|██████████| 4/4 [00:00<00:00,  5.95it/s]
val: 100%|██████████| 128/128 [00:01<00:00, 116.97it/s]


Epoch 90, val/rmse: 0.19340025344455292, val/mae: 0.08611840754747391


Epoch 91: 100%|██████████| 4/4 [00:00<00:00,  6.45it/s]
Epoch 92: 100%|██████████| 4/4 [00:00<00:00,  6.43it/s]
Epoch 93: 100%|██████████| 4/4 [00:00<00:00,  5.77it/s]
Epoch 94: 100%|██████████| 4/4 [00:00<00:00,  6.46it/s]
Epoch 95: 100%|██████████| 4/4 [00:00<00:00,  6.31it/s]
Epoch 96: 100%|██████████| 4/4 [00:00<00:00,  6.30it/s]
Epoch 97: 100%|██████████| 4/4 [00:00<00:00,  6.10it/s]
Epoch 98: 100%|██████████| 4/4 [00:00<00:00,  6.34it/s]
Epoch 99: 100%|██████████| 4/4 [00:00<00:00,  6.51it/s]


In [13]:
if track_with_wandb:
    best_model = RecurrentGNN(gnn = gnn, node_features = node_features).to(device)
    best_model.load_state_dict(torch.load(f'{model_name}.pth', weights_only=True))
    test_rmse, test_mae = eval(epoch, best_model, test_samples, 'test')


test: 100%|██████████| 128/128 [00:01<00:00, 120.55it/s]


Epoch 99, test/rmse: 0.8909698956238586, test/mae: 0.13745185732841492
