# Dataset Construction

In [10]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignal, DynamicGraphTemporalSignal
import torch
from typing import Union
import glob
from natsort import natsorted
import random

random.seed(42)

class SP500CorrelationsDatasetLoader(object):
    def __init__(self, corr_name, corr_scope):
        self._read_csv(corr_name, corr_scope)

    def _read_csv(self, corr_name, corr_scope):
        match corr_scope:
            case 'global':
                self._correlation_matrices = [np.loadtxt(f'{corr_name}/{corr_scope}_corr.csv', delimiter=',')]
            case 'local':
                self._correlation_matrices = []
                corr_files = natsorted(glob.glob(f'{corr_name}/local_corr_*.csv'))
                for corr_file in corr_files:
                    matrix = np.loadtxt(corr_file, delimiter=',')
                    self._correlation_matrices.append(matrix)
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)

        # Round data size to nearest multiple of batch_size
        self.days_in_quarter = 64
        num_quarters = data.size(0) // self.days_in_quarter
        num_days = num_quarters * self.days_in_quarter
        data = data[:num_days]
        
        # z-score normalization with training data following GERU
        train_days = int(0.8 * num_quarters) * self.days_in_quarter
        data = (data - data[:train_days].mean(dim=0)) / data[:train_days].std(dim=0)
        data = data.numpy()

        data = data[..., np.newaxis]

        # # Add percent change features
        # p_chg = data / np.roll(data, 1, axis=0) - 1
        # p_chg[0] = 0.0
        # p_chg_3 = data / np.roll(data, 3, axis=0) - 1
        # p_chg_3[0:3] = 0.0
        # p_chg_6 = data / np.roll(data, 6, axis=0) - 1
        # p_chg_6[0:6] = 0.0

        # data = np.stack([data, p_chg, p_chg_3, p_chg_6], axis=-1)
        # print('data.shape', data.shape)

        assert(not np.any(np.isnan(data)))
        self._dataset = data

    def _get_edges(self, times, overlap):
        if len(self._correlation_matrices) == 1:
            _edges = np.array(np.ones_like(self._correlation_matrices[0]).nonzero())
        else:
            _edges = []
            for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):
                if not time in times:
                    continue
                corr_index = max(0, time // self.days_in_quarter - 1)
                _edges.append(
                    np.array(np.ones_like(self._correlation_matrices[corr_index]).nonzero())
                )
        return _edges

    def _get_edge_weights(self, times, overlap):
        if len(self._correlation_matrices) == 1:
            _edge_weights = self._correlation_matrices[0].flatten()
        else:
            _edge_weights = []
            for time in range(0, self._dataset.shape[0] - self.batch_size, overlap):
                if not time in times:
                    continue
                corr_index = max(0, time // self.days_in_quarter - 1)
                _edge_weights.append(
                    np.array(self._correlation_matrices[corr_index]).flatten()
                )
        return _edge_weights

    def _get_targets_and_features(self, times, overlap, predict_all):
        features = [
            self._dataset[i : i + self.batch_size, :]
            for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)
            if i in times
        ]
        # predict next-day stock prices
        targets = [
            (self._dataset[i+1 : i + self.batch_size+1, :, 0]).T if predict_all else (self._dataset[i + self.batch_size, :, 0]).T
            for i in range(0, self._dataset.shape[0] - self.batch_size, overlap)
            if i in times
        ]
        return features, targets

    def get_dataset(self, batch_size, split) -> Union[StaticGraphTemporalSignal, DynamicGraphTemporalSignal]:
        """Returning the data iterator.
        """
        self.batch_size = batch_size

        total_times = list(range(0, self._dataset.shape[0] - self.batch_size, self.batch_size))

        if split == 'train':
            times = list(range(0, total_times[int(len(total_times) * 0.8)]))
            overlap = self.batch_size
            predict_all = True
        elif split == 'val':
            times = list(range(total_times[int(len(total_times) * 0.8)], total_times[int(len(total_times) * 0.9)]))
            overlap = 1
            predict_all = False
        elif split == 'test':
            times = list(range(total_times[int(len(total_times) * 0.9)], total_times[-1] + self.batch_size))
            overlap = 1
            predict_all = False
        else:
            raise ValueError(f'Invalid split name: {split}')

        _edges = self._get_edges(times, overlap)
        _edge_weights = self._get_edge_weights(times, overlap)
        features, targets = self._get_targets_and_features(times, overlap, predict_all)
        dataset = (DynamicGraphTemporalSignal if type(_edges) == list else StaticGraphTemporalSignal)(
            _edges, _edge_weights, features, targets
        )
        return dataset

In [11]:
# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = 'cpu'

corr_name = 'mi'
corr_scope = 'global'
loader = SP500CorrelationsDatasetLoader(corr_name=corr_name, corr_scope=corr_scope)

lag_size = 64
train_dataset = loader.get_dataset(batch_size=lag_size * 2, split='train')
val_dataset = loader.get_dataset(batch_size=lag_size, split='val')
test_dataset = loader.get_dataset(batch_size=lag_size, split='test')

data.shape (2496, 472, 1)


In [12]:
print(len(train_dataset.features))
print(len(train_dataset.targets))
print(len(val_dataset.features))
print(len(val_dataset.targets))
print(len(test_dataset.features))
print(len(test_dataset.targets))

15
15
256
256
256
256


# Differential Graph Transformer

In [13]:
import torch
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn.conv import MessagePassing
import torch.nn as nn
import torch.nn.functional as F


class DGAttn(MessagePassing):
    r"""An implementation of the Diffusion Convolution Layer.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Filter size :math:`K`.
        bias (bool, optional): If set to :obj:`False`, the layer
            will not learn an additive bias (default :obj:`True`).

    """

    def __init__(self, in_channels, out_channels, num_heads=1, diff_attn=True):
        super(DGAttn, self).__init__(aggr="add", flow="source_to_target")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.head_dim = out_channels // self.num_heads
        self.diff_attn = diff_attn
        if self.diff_attn:
            self.Q = nn.Linear(in_channels, out_channels)
            self.K = nn.Linear(in_channels, out_channels)
        self.V = nn.Linear(in_channels, out_channels)
        self.ffn = nn.Linear(out_channels, out_channels)
        self.ln = nn.RMSNorm(self.head_dim, eps=1e-5, elementwise_affine=True)

        if self.diff_attn:
            self.lambda_init = 0.2
            self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
            self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

        # self.__reset_parameters()

    # def __reset_parameters(self):
        # torch.nn.init.xavier_uniform_(self.V)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        A = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
        a1 = F.softmax(A, dim=-1)
        if self.diff_attn:
            q = self.Q(X)
            k = self.K(X)
            a2 = F.softmax(torch.matmul(q, k.transpose(-1, -2)) * (self.head_dim ** -0.5), dim=-1)
            lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
            lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
            lambda_full = lambda_1 - lambda_2 + self.lambda_init
            a = a1 - lambda_full * a2
        else:
            a = a1
        v = self.V(X)
        attn = torch.matmul(a, v)
        attn = self.ln(attn)
        if self.diff_attn:
            attn = attn * (1 - self.lambda_init)
        attn = self.ffn(attn)
        return attn


class DGRNN(torch.nn.Module):
    r"""An implementation of the Diffusion Convolutional Gated Recurrent Unit.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        num_heads (int): Number of attention heads.
    """

    def __init__(self, in_channels: int, out_channels: int, num_heads: int, diff_attn: bool):
        super(DGRNN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.diff_attn = diff_attn,

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):
        self.conv_x_z = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_reset_gate_parameters_and_layers(self):
        self.conv_x_r = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_candidate_state_parameters_and_layers(self):
        self.conv_x_h = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
            diff_attn=self.diff_attn,
        )

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([X, H], dim=1)
        Z = self.conv_x_z(Z, edge_index, edge_weight)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([X, H], dim=1)
        R = self.conv_x_r(R, edge_index, edge_weight)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([X, H * R], dim=1)
        H_tilde = self.conv_x_h(H_tilde, edge_index, edge_weight)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.
            * **H** (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        # print('X.shape', X.shape)
        # print('H.shape', H.shape)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H


# Plain RNN

In [14]:
class GRU(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(GRU, self).__init__()
        self.rnn = nn.GRU(input_size=in_channels, hidden_size=out_channels, num_layers=2, batch_first=True)

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        outputs, _ = self.rnn(X, H)
        return outputs

# RGCN

In [15]:
import torch
import torch.nn.functional as F

class RecurrentGNN(torch.nn.Module):
    def __init__(self, gnn, node_features, hidden_size=32, **kwargs):
        super(RecurrentGNN, self).__init__()
        if gnn.__name__ == "DCRNN":
            self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        elif gnn.__name__ == "DGRNN":
            self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        elif gnn.__name__ == "GRU":
            self.recurrent = gnn(in_channels=node_features, out_channels=hidden_size, **kwargs)
        else:
            raise ValueError(f"Invalid GNN type {gnn.__name__}")
        self.linear = torch.nn.Linear(hidden_size, 1)

    def forward(self, x, edge_index, edge_weight, hidden):
        outputs = self.recurrent(x, edge_index, edge_weight, hidden)
        return self.linear(F.relu(outputs))

In [16]:
import math

def rmse(y_hat, y):
    return math.sqrt(F.mse_loss(y_hat, y).item())

def mae(y_hat, y):
    return F.l1_loss(y_hat, y).item()

In [17]:
import wandb

def eval(epoch, model, eval_dataset, eval_name):
    model.eval()
    with torch.no_grad():
        y_hats = []
        ys = []
        for snapshot in eval_dataset:
            X = snapshot.x.to(device)
            batch_y_hats = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=None)
            y_hats.append(batch_y_hats[:, -1])
            ys.append(snapshot.y.to(device))
        y_hats, ys = torch.cat(y_hats, dim=0).squeeze().to(device), torch.cat(ys, dim=0).to(device)
        eval_rmse = rmse(y_hats, ys)
        eval_mae = mae(y_hats, ys)
        wandb.log({"epoch": epoch,
                f"{eval_name}/rmse": eval_rmse,
                f"{eval_name}/mae": eval_mae })
        print(f'Epoch {epoch}, {eval_name}/rmse: {eval_rmse}, {eval_name}/mae: {eval_mae}')
        return (eval_rmse, eval_mae)

In [18]:
from tqdm import tqdm
import wandb
from torch_geometric_temporal.nn.recurrent import DCRNN

gnn = GRU

node_features = 1

model = RecurrentGNN(gnn = gnn, node_features = node_features).to(device)

if gnn.__name__ == 'GRU':
    model_name = f'{gnn.__name__}'
elif gnn.__name__ == 'DCRNN':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}'
elif gnn.__name__ == 'DGRNN':
    model_name = f'{gnn.__name__}_{corr_name}_{corr_scope}{"_diff_attn" if model.recurrent.diff_attn else ""}'

lr = 1e-2
num_epochs = 50
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
track_with_wandb = True

if track_with_wandb:
    wandb.init(project="cs224w-stock-market-prediction", config={
        "dataset": "S&P500",
        "corr_name": corr_name,
        "corr_scope": corr_scope,
        "learning_rate": lr,
        "epochs": num_epochs,
        "architecture": gnn.__name__,
    })

best_val_rmse = float('inf')

batch_size = 64

def split_into_batches(lst, batch_size):
    return [lst[i:i + batch_size] for i in range(0, len(lst), batch_size)]

train_samples = list(train_dataset)
train_batches = split_into_batches(train_samples, batch_size)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    random.shuffle(train_batches)
    for step, batch in tqdm(enumerate(train_batches), total=len(train_batches), desc=f'Epoch {epoch}'):
        y_hats = []
        ys = []
        for snapshot in batch:
            X = snapshot.x.to(device)
            batched_y_hats = model(X.transpose(0, 1), snapshot.edge_index.to(device), snapshot.edge_attr.to(device), hidden=None)
            y_hats.append(batched_y_hats)
            ys.append(snapshot.y.to(device))
        y_hats, ys = torch.cat(y_hats, dim=0).squeeze().to(device), torch.cat(ys, dim=0).to(device)
        loss = F.mse_loss(y_hats, ys)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        wandb.log({"epoch": epoch,
                   "step": epoch * len(train_samples) + step,
                    "train/loss": loss.item() })
    train_loss /= len(train_samples)

    if track_with_wandb:
        val_rmse, val_mae = eval(epoch, model, val_dataset, 'val')
        if val_rmse < best_val_rmse:
            best_rmse = val_rmse
            torch.save(model.state_dict(), f'{model_name}.pth')


0,1
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▅▅▅▅▆▆▆▆▇▇▇▇█
step,▁▂▂▃▃▄▅▅▆▆▇▇█
train/loss,█▇▇▆▆▅▅▄▃▃▂▂▁
val/mae,█▇▇▆▆▅▅▄▃▃▂▁
val/rmse,█▇▇▆▆▅▅▄▃▂▂▁

0,1
epoch,12.0
step,180.0
train/loss,0.74188
val/mae,1.33227
val/rmse,1.72539


Epoch 0: 100%|██████████| 1/1 [00:02<00:00,  2.66s/it]


Epoch 0, val/rmse: 1.6795894150595054, val/mae: 1.280681848526001


Epoch 1: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]


Epoch 1, val/rmse: 1.5167575812426184, val/mae: 1.110804796218872


Epoch 2: 100%|██████████| 1/1 [00:02<00:00,  2.52s/it]


Epoch 2, val/rmse: 1.3325668108510638, val/mae: 0.8893771171569824


Epoch 3: 100%|██████████| 1/1 [00:02<00:00,  2.74s/it]


Epoch 3, val/rmse: 1.1632873207072225, val/mae: 0.6498034596443176


Epoch 4: 100%|██████████| 1/1 [00:02<00:00,  2.36s/it]


Epoch 4, val/rmse: 1.0708684660693413, val/mae: 0.5682592391967773


Epoch 5: 100%|██████████| 1/1 [00:02<00:00,  2.31s/it]


Epoch 5, val/rmse: 1.0522616144084467, val/mae: 0.5928544998168945


Epoch 6: 100%|██████████| 1/1 [00:02<00:00,  2.40s/it]


Epoch 6, val/rmse: 1.0368840606827032, val/mae: 0.5831014513969421


Epoch 7: 100%|██████████| 1/1 [00:02<00:00,  2.46s/it]


Epoch 7, val/rmse: 1.008340682378633, val/mae: 0.5220925807952881


Epoch 8: 100%|██████████| 1/1 [00:02<00:00,  2.33s/it]


Epoch 8, val/rmse: 1.0041229370430311, val/mae: 0.4656895697116852


Epoch 9: 100%|██████████| 1/1 [00:02<00:00,  2.56s/it]


Epoch 9, val/rmse: 1.057571818751484, val/mae: 0.5597987771034241


Epoch 10: 100%|██████████| 1/1 [00:02<00:00,  2.60s/it]


Epoch 10, val/rmse: 1.0699336920005909, val/mae: 0.5941886305809021


Epoch 11: 100%|██████████| 1/1 [00:02<00:00,  2.80s/it]


Epoch 11, val/rmse: 1.037263388252418, val/mae: 0.5478986501693726


Epoch 12: 100%|██████████| 1/1 [00:02<00:00,  2.41s/it]


Epoch 12, val/rmse: 0.9907424195312691, val/mae: 0.4631049633026123


Epoch 13: 100%|██████████| 1/1 [00:02<00:00,  2.36s/it]


Epoch 13, val/rmse: 0.9523008566906509, val/mae: 0.3920777440071106


Epoch 14: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]


Epoch 14, val/rmse: 0.928930663741676, val/mae: 0.38553139567375183


Epoch 15: 100%|██████████| 1/1 [00:02<00:00,  2.38s/it]


Epoch 15, val/rmse: 0.9180417944883049, val/mae: 0.4042925536632538


Epoch 16: 100%|██████████| 1/1 [00:02<00:00,  2.36s/it]


Epoch 16, val/rmse: 0.9118568675360794, val/mae: 0.41678929328918457


Epoch 17: 100%|██████████| 1/1 [00:02<00:00,  2.80s/it]


Epoch 17, val/rmse: 0.9033035854131193, val/mae: 0.40836870670318604


Epoch 18: 100%|██████████| 1/1 [00:02<00:00,  2.47s/it]


Epoch 18, val/rmse: 0.8915361127615478, val/mae: 0.37841200828552246


Epoch 19: 100%|██████████| 1/1 [00:02<00:00,  2.33s/it]


Epoch 19, val/rmse: 0.8819530320604667, val/mae: 0.33890652656555176


Epoch 20: 100%|██████████| 1/1 [00:02<00:00,  2.46s/it]


Epoch 20, val/rmse: 0.8798338001076795, val/mae: 0.31494179368019104


Epoch 21: 100%|██████████| 1/1 [00:02<00:00,  2.43s/it]


Epoch 21, val/rmse: 0.8828511525022311, val/mae: 0.33097848296165466


Epoch 22: 100%|██████████| 1/1 [00:02<00:00,  2.49s/it]


Epoch 22, val/rmse: 0.8827222592303682, val/mae: 0.3464357554912567


Epoch 23: 100%|██████████| 1/1 [00:02<00:00,  2.43s/it]


Epoch 23, val/rmse: 0.8745979679755754, val/mae: 0.3363317847251892


Epoch 24: 100%|██████████| 1/1 [00:02<00:00,  2.47s/it]


Epoch 24, val/rmse: 0.8601787969560145, val/mae: 0.306037575006485


Epoch 25: 100%|██████████| 1/1 [00:02<00:00,  2.53s/it]


Epoch 25, val/rmse: 0.8446017310897789, val/mae: 0.2750967741012573


Epoch 26: 100%|██████████| 1/1 [00:02<00:00,  2.84s/it]


Epoch 26, val/rmse: 0.8326544778138907, val/mae: 0.2696090042591095


Epoch 27: 100%|██████████| 1/1 [00:02<00:00,  2.63s/it]


Epoch 27, val/rmse: 0.8256928337929169, val/mae: 0.28244391083717346


Epoch 28: 100%|██████████| 1/1 [00:02<00:00,  2.42s/it]


Epoch 28, val/rmse: 0.8207451179143069, val/mae: 0.2888505458831787


Epoch 29: 100%|██████████| 1/1 [00:02<00:00,  2.38s/it]


Epoch 29, val/rmse: 0.8144058469959654, val/mae: 0.2750096619129181


Epoch 30: 100%|██████████| 1/1 [00:02<00:00,  2.31s/it]


Epoch 30, val/rmse: 0.8083946240840458, val/mae: 0.24567601084709167


Epoch 31: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]


Epoch 31, val/rmse: 0.8078614387961733, val/mae: 0.23908236622810364


Epoch 32: 100%|██████████| 1/1 [00:02<00:00,  2.38s/it]


Epoch 32, val/rmse: 0.8117799869822254, val/mae: 0.2669588625431061


Epoch 33: 100%|██████████| 1/1 [00:02<00:00,  2.34s/it]


Epoch 33, val/rmse: 0.8114559728533421, val/mae: 0.27984708547592163


Epoch 34: 100%|██████████| 1/1 [00:02<00:00,  2.22s/it]


Epoch 34, val/rmse: 0.8025035294190749, val/mae: 0.2612822651863098


Epoch 35: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]


Epoch 35, val/rmse: 0.7898565772398811, val/mae: 0.22599904239177704


Epoch 36: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]


Epoch 36, val/rmse: 0.7805379292755611, val/mae: 0.20857276022434235


Epoch 37: 100%|██████████| 1/1 [00:02<00:00,  2.23s/it]


Epoch 37, val/rmse: 0.7760859909883217, val/mae: 0.21711918711662292


Epoch 38: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]


Epoch 38, val/rmse: 0.7731994108319197, val/mae: 0.22195829451084137


Epoch 39: 100%|██████████| 1/1 [00:02<00:00,  2.46s/it]


Epoch 39, val/rmse: 0.7696108317010462, val/mae: 0.2113344818353653


Epoch 40: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]


Epoch 40, val/rmse: 0.7668387465235835, val/mae: 0.1962495595216751


Epoch 41: 100%|██████████| 1/1 [00:02<00:00,  2.67s/it]


Epoch 41, val/rmse: 0.7669222215758508, val/mae: 0.20001812279224396


Epoch 42: 100%|██████████| 1/1 [00:02<00:00,  2.39s/it]


Epoch 42, val/rmse: 0.7677518536375445, val/mae: 0.21381458640098572


Epoch 43: 100%|██████████| 1/1 [00:02<00:00,  2.59s/it]


Epoch 43, val/rmse: 0.7652028535403832, val/mae: 0.21137748658657074


Epoch 44: 100%|██████████| 1/1 [00:02<00:00,  2.88s/it]


Epoch 44, val/rmse: 0.7596069405577626, val/mae: 0.1939026266336441


Epoch 45: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]


Epoch 45, val/rmse: 0.7550498313753569, val/mae: 0.18592965602874756


Epoch 46: 100%|██████████| 1/1 [00:02<00:00,  2.52s/it]


Epoch 46, val/rmse: 0.7529273125874509, val/mae: 0.19222520291805267


Epoch 47: 100%|██████████| 1/1 [00:02<00:00,  2.30s/it]


Epoch 47, val/rmse: 0.7512469257666323, val/mae: 0.19338050484657288


Epoch 48: 100%|██████████| 1/1 [00:02<00:00,  2.23s/it]


Epoch 48, val/rmse: 0.7492608958493568, val/mae: 0.18483421206474304


Epoch 49: 100%|██████████| 1/1 [00:02<00:00,  2.36s/it]


Epoch 49, val/rmse: 0.7483655683956294, val/mae: 0.17858277261257172


In [19]:
if track_with_wandb:
    best_model = RecurrentGNN(gnn = gnn, node_features = node_features).to(device)
    best_model.load_state_dict(torch.load(f'{model_name}.pth', weights_only=True))
    test_rmse, test_mae = eval(epoch, best_model, test_dataset, 'test')


Epoch 49, test/rmse: 3.3100578374505885, test/mae: 0.6717974543571472
