# Dataset Construction

In [2]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignalBatch, DynamicGraphTemporalSignalBatch
import torch
from typing import Union
import glob
from natsort import natsorted
import random

random.seed(42)

class SP500CorrelationsDatasetLoader(object):
    def __init__(self, corr_name, corr_scope):
        self._read_csv(corr_name, corr_scope)

    def _read_csv(self, corr_name, corr_scope):
        match corr_scope:
            case 'global':
                self._correlation_matrices = [np.loadtxt(f'{corr_name}/{corr_scope}_corr.csv', delimiter=',')]
            case 'local':
                self._correlation_matrices = []
                corr_files = natsorted(glob.glob(f'{corr_name}/local_corr_*.csv'))
                for corr_file in corr_files:
                    matrix = np.loadtxt(corr_file, delimiter=',')
                    self._correlation_matrices.append(matrix)
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)

        # Round data size to nearest multiple of batch_size
        self.days_in_quarter = 64
        num_quarters = data.size(0) // self.days_in_quarter
        num_days = num_quarters * self.days_in_quarter
        data = data[:num_days]
        
        # z-score normalization with training data following GERU
        train_days = int(0.8 * num_quarters) * self.days_in_quarter
        data = (data - data[:train_days].mean(dim=0)) / data[:train_days].std(dim=0)
        data = data.numpy()
        np.savetxt('s&p500_z_scores.csv', data, delimiter=',')
        self._dataset = data

    def _get_edges(self):
        if len(self._correlation_matrices) == 1:
            self._edges = np.array(np.ones_like(self._correlation_matrices[0]).nonzero())
        else:
            self._edges = []
            for time in range(self._dataset.shape[0] - self.lags):
                corr_index = time // self.days_in_quarter
                self._edges.append(
                    np.array(np.ones_like(self._correlation_matrices[corr_index]).nonzero())
                )

    def _get_edge_weights(self):
        if len(self._correlation_matrices) == 1:
            self._edge_weights = self._correlation_matrices[0].flatten()
            # print(self._edge_weights.shape)
        else:
            self._edge_weights = []
            for time in range(self._dataset.shape[0] - self.lags):
                corr_index = time // self.days_in_quarter
                self._edge_weights.append(
                    np.array(self._correlation_matrices[corr_index]).flatten()
                )

    def _get_targets_and_features(self):
        stacked_target = self._dataset
        self.features = [
            stacked_target[i : i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]
        # predict next-day stock movement
        self.targets = [
            ((stacked_target[i + self.lags, :] > stacked_target[i + self.lags - 1, :]).astype(int)).T
            for i in range(stacked_target.shape[0] - self.lags)
        ]

    def get_dataset(self) -> Union[StaticGraphTemporalSignalBatch, DynamicGraphTemporalSignalBatch]:
        """Returning the data iterator.
        """
        self.lags = self.days_in_quarter
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        self.batches = np.repeat(np.arange((self._dataset.shape[0] - self.lags) // self.days_in_quarter), self.days_in_quarter)
        dataset = (DynamicGraphTemporalSignalBatch if type(self._edges) == list else StaticGraphTemporalSignalBatch)(
            self._edges, self._edge_weights, self.features, self.targets, self.batches
        )
        return dataset

In [3]:
from torch_geometric_temporal.signal import temporal_signal_split

# device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = 'cpu'

corr_name = 'pearsons'
corr_scope = 'local'
loader = SP500CorrelationsDatasetLoader(corr_name=corr_name, corr_scope=corr_scope)

dataset = loader.get_dataset()
lags = loader.lags

train_dataset, test_val_dataset = temporal_signal_split(dataset, train_ratio=0.8)
val_dataset, test_dataset = temporal_signal_split(test_val_dataset, train_ratio=0.5)

# Evaluation

In [4]:
from torcheval.metrics.functional import binary_f1_score, binary_accuracy

def accuracy(y_hats, ys):
    return binary_accuracy(y_hats.flatten(), ys.flatten()).item()

def f1(y_hats, ys):
    return binary_f1_score(y_hats.flatten(), ys.flatten()).item()


# Differential Graph Transformer

In [5]:
import torch
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn.conv import MessagePassing
import torch.nn as nn
import torch.nn.functional as F


class DGAttn(MessagePassing):
    r"""An implementation of the Diffusion Convolution Layer.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        K (int): Filter size :math:`K`.
        bias (bool, optional): If set to :obj:`False`, the layer
            will not learn an additive bias (default :obj:`True`).

    """

    def __init__(self, in_channels, out_channels, num_heads=1):
        super(DGAttn, self).__init__(aggr="add", flow="source_to_target")
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads
        self.Q = nn.Linear(in_channels, out_channels)
        self.K = nn.Linear(in_channels, out_channels)
        self.V = nn.Linear(in_channels, out_channels)
        self.ln = nn.LayerNorm(out_channels)
        self.ffn = nn.Linear(out_channels, out_channels)
        self.head_dim = out_channels // self.num_heads
        self.lambda_ = nn.Parameter(torch.zeros(1, dtype=torch.float32).normal_(mean=0,std=0.1))
        # self.__reset_parameters()

    # def __reset_parameters(self):
        # torch.nn.init.xavier_uniform_(self.V)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        A = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
        # print("A.shape", A.shape)
        a1 = F.softmax(A, dim=-1)
        # print("a.shape", a.shape)
        q = self.Q(X)
        k = self.K(X)
        a2 = F.softmax(torch.matmul(q, k.transpose(-1, -2)) * (self.head_dim ** -0.5), dim=-1)
        a = a1 - self.lambda_ * a2
        v = self.V(X)
        # print("v.shape", v.shape)
        H1 = torch.matmul(a, v)
        # print('H1.shape', H1.shape)
        # print('X.shape', X.shape)
        H = self.ffn(self.ln(H1)) + H1
        # print("H.shape", H.shape)
        return H


class DGRNN(torch.nn.Module):
    r"""An implementation of the Diffusion Convolutional Gated Recurrent Unit.
    For details see: `"Diffusion Convolutional Recurrent Neural Network:
    Data-Driven Traffic Forecasting" <https://arxiv.org/abs/1707.01926>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        num_heads (int): Number of attention heads.
    """

    def __init__(self, in_channels: int, out_channels: int, num_heads: int):
        super(DGRNN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_heads = num_heads

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):
        self.conv_x_z = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
        )

    def _create_reset_gate_parameters_and_layers(self):
        self.conv_x_r = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
        )

    def _create_candidate_state_parameters_and_layers(self):
        self.conv_x_h = DGAttn(
            in_channels=self.in_channels + self.out_channels,
            out_channels=self.out_channels,
            num_heads=self.num_heads,
        )

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([X, H], dim=1)
        Z = self.conv_x_z(Z, edge_index, edge_weight)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([X, H], dim=1)
        R = self.conv_x_r(R, edge_index, edge_weight)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([X, H * R], dim=1)
        H_tilde = self.conv_x_h(H_tilde, edge_index, edge_weight)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        r"""Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** (PyTorch Float Tensor) - Node features.
            * **edge_index** (PyTorch Long Tensor) - Graph edge indices.
            * **edge_weight** (PyTorch Long Tensor, optional) - Edge weight vector.
            * **H** (PyTorch Float Tensor, optional) - Hidden state matrix for all nodes.

        Return types:
            * **H** (PyTorch Float Tensor) - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H


# RGCN

In [6]:
import torch
import torch.nn.functional as F

class RecurrentGNN(torch.nn.Module):
    def __init__(self, gnn, node_features):
        super(RecurrentGNN, self).__init__()
        self.recurrent = gnn(node_features, 64, 1)
        self.linear = torch.nn.Linear(64, 2)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return F.log_softmax(h, dim=-1)

In [6]:
from tqdm import tqdm
import wandb
from torch_geometric_temporal.nn.recurrent import DCRNN

gnn = DGRNN

model = RecurrentGNN(gnn = gnn, node_features = lags).to(device)

lr = 1e-3
num_epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
track_with_wandb = True

if track_with_wandb:
    wandb.init(project="cs224w-stock-market-prediction", config={
        "dataset": "S&P500",
        "corr_name": corr_name,
        "corr_scope": corr_scope,
        "learning_rate": lr,
        "epochs": num_epochs,
        "architecture": gnn.__name__,
    })

best_acc = 0

train_samples = list(train_dataset)

for epoch in tqdm(range(num_epochs)):
    model.train()
    train_loss = 0
    random.shuffle(train_samples)
    for time, snapshot in enumerate(train_samples):
        y_hat = model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device))
        loss = F.nll_loss(y_hat.squeeze(), snapshot.y.to(device))
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    train_loss /= (time+1)

    if track_with_wandb:
        model.eval()
        with torch.no_grad():
            val_y_hats, val_ys = zip(*[(model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device)), snapshot.y.to(device))
                       for time, snapshot in enumerate(val_dataset)])
            val_y_hats, val_ys = torch.stack([val_y_hat.max(dim=-1)[1] for val_y_hat in val_y_hats]).to(device), torch.stack(list(val_ys)).to(device)
            val_acc = accuracy(val_y_hats, val_ys)
            val_f1 = f1(val_y_hats, val_ys)
            wandb.log({"epoch": epoch,
                    "train/loss": train_loss,
                    "val/acc": val_acc,
                    "val/f1": val_f1 })
            print(f'Epoch {epoch}, val/acc: {val_acc}, val/f1: {val_f1}')
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), 'dcrnn_best_model.pth')


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkevinxli[0m. Use [1m`wandb login --relogin`[0m to force relogin


  2%|▏         | 1/50 [00:27<22:48, 27.93s/it]

Epoch 0, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


  4%|▍         | 2/50 [00:51<20:24, 25.51s/it]

Epoch 1, val/acc: 0.48798564076423645, val/f1: 0.0


  6%|▌         | 3/50 [01:16<19:48, 25.28s/it]

Epoch 2, val/acc: 0.48798564076423645, val/f1: 0.0


  8%|▊         | 4/50 [01:39<18:40, 24.36s/it]

Epoch 3, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 10%|█         | 5/50 [02:01<17:39, 23.55s/it]

Epoch 4, val/acc: 0.48798564076423645, val/f1: 0.0


 12%|█▏        | 6/50 [02:23<16:53, 23.04s/it]

Epoch 5, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 14%|█▍        | 7/50 [02:46<16:27, 22.97s/it]

Epoch 6, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 16%|█▌        | 8/50 [03:09<15:58, 22.82s/it]

Epoch 7, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 18%|█▊        | 9/50 [03:31<15:31, 22.71s/it]

Epoch 8, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 20%|██        | 10/50 [03:53<15:01, 22.53s/it]

Epoch 9, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 22%|██▏       | 11/50 [04:15<14:32, 22.38s/it]

Epoch 10, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 24%|██▍       | 12/50 [04:37<14:05, 22.24s/it]

Epoch 11, val/acc: 0.48798564076423645, val/f1: 0.0


 26%|██▌       | 13/50 [04:59<13:40, 22.19s/it]

Epoch 12, val/acc: 0.48732301592826843, val/f1: 0.09155234694480896


 28%|██▊       | 14/50 [05:25<13:57, 23.26s/it]

Epoch 13, val/acc: 0.48836052417755127, val/f1: 0.038850218057632446


 30%|███       | 15/50 [05:49<13:40, 23.43s/it]

Epoch 14, val/acc: 0.48798564076423645, val/f1: 0.0


 32%|███▏      | 16/50 [06:12<13:12, 23.31s/it]

Epoch 15, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 34%|███▍      | 17/50 [06:34<12:38, 22.98s/it]

Epoch 16, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 36%|███▌      | 18/50 [06:58<12:21, 23.19s/it]

Epoch 17, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 38%|███▊      | 19/50 [07:22<12:06, 23.45s/it]

Epoch 18, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 40%|████      | 20/50 [07:46<11:46, 23.57s/it]

Epoch 19, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 42%|████▏     | 21/50 [08:10<11:33, 23.92s/it]

Epoch 20, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 44%|████▍     | 22/50 [08:34<11:06, 23.82s/it]

Epoch 21, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 46%|████▌     | 23/50 [08:57<10:32, 23.44s/it]

Epoch 22, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 48%|████▊     | 24/50 [09:19<10:05, 23.29s/it]

Epoch 23, val/acc: 0.5240113139152527, val/f1: 0.6194108128547668


 50%|█████     | 25/50 [09:45<09:59, 23.96s/it]

Epoch 24, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 52%|█████▏    | 26/50 [10:10<09:43, 24.31s/it]

Epoch 25, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 54%|█████▍    | 27/50 [10:35<09:25, 24.57s/it]

Epoch 26, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 56%|█████▌    | 28/50 [11:00<09:02, 24.68s/it]

Epoch 27, val/acc: 0.5205238461494446, val/f1: 0.6597748398780823


 58%|█████▊    | 29/50 [11:24<08:31, 24.38s/it]

Epoch 28, val/acc: 0.5132873058319092, val/f1: 0.6511741280555725


 60%|██████    | 30/50 [11:47<07:58, 23.92s/it]

Epoch 29, val/acc: 0.5119010210037231, val/f1: 0.6769310235977173


 62%|██████▏   | 31/50 [12:10<07:29, 23.66s/it]

Epoch 30, val/acc: 0.5122584700584412, val/f1: 0.6456407308578491


 64%|██████▍   | 32/50 [12:45<08:05, 26.97s/it]

Epoch 31, val/acc: 0.5217357277870178, val/f1: 0.6668730974197388


 66%|██████▌   | 33/50 [13:20<08:22, 29.56s/it]

Epoch 32, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 68%|██████▊   | 34/50 [13:54<08:13, 30.87s/it]

Epoch 33, val/acc: 0.5189980268478394, val/f1: 0.665871262550354


 70%|███████   | 35/50 [14:30<08:05, 32.36s/it]

Epoch 34, val/acc: 0.5120143890380859, val/f1: 0.6772612929344177


 72%|███████▏  | 36/50 [15:05<07:46, 33.33s/it]

Epoch 35, val/acc: 0.5064780116081238, val/f1: 0.6705851554870605


 74%|███████▍  | 37/50 [15:43<07:27, 34.45s/it]

Epoch 36, val/acc: 0.5175333023071289, val/f1: 0.6663532257080078


 76%|███████▌  | 38/50 [16:20<07:04, 35.38s/it]

Epoch 37, val/acc: 0.5107937455177307, val/f1: 0.6428662538528442


 78%|███████▊  | 39/50 [16:58<06:36, 36.02s/it]

Epoch 38, val/acc: 0.5109855532646179, val/f1: 0.6717734932899475


 80%|████████  | 40/50 [17:35<06:03, 36.38s/it]

Epoch 39, val/acc: 0.5110727548599243, val/f1: 0.6760818362236023


 82%|████████▏ | 41/50 [18:13<05:30, 36.77s/it]

Epoch 40, val/acc: 0.5067046880722046, val/f1: 0.6704142689704895


 84%|████████▍ | 42/50 [18:50<04:54, 36.85s/it]

Epoch 41, val/acc: 0.48809024691581726, val/f1: 0.16287890076637268


 86%|████████▌ | 43/50 [19:28<04:21, 37.34s/it]

Epoch 42, val/acc: 0.5037316083908081, val/f1: 0.6628082394599915


 88%|████████▊ | 44/50 [20:06<03:44, 37.42s/it]

Epoch 43, val/acc: 0.505518913269043, val/f1: 0.6688117384910583


 90%|█████████ | 45/50 [20:43<03:06, 37.29s/it]

Epoch 44, val/acc: 0.48798564076423645, val/f1: 0.0


In [10]:
if track_with_wandb:
    best_model = RecurrentGNN(gnn = gnn, node_features = lags).to(device)
    best_model.load_state_dict(torch.load('dcrnn_best_model.pth', weights_only=True))
    best_model.eval()
    with torch.no_grad():
        y_hats, ys = zip(*[(best_model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device)), snapshot.y.to(device))
                       for time, snapshot in enumerate(test_dataset)])
        y_hats, ys = torch.stack([y_hat.max(dim=-1)[1] for y_hat in y_hats]).to(device), torch.stack(list(ys)).to(device)
        wandb.log({"epoch": epoch,
                "test/acc": accuracy(y_hats, ys),
                "test/f1": f1(y_hats, ys) })

: 