# Dataset Construction

In [30]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignalBatch, DynamicGraphTemporalSignalBatch
import torch
from typing import Union
import glob
from natsort import natsorted

class SP500CorrelationsDatasetLoader(object):
    def __init__(self, corr_name, corr_scope):
        self._read_csv(corr_name, corr_scope)

    def _read_csv(self, corr_name, corr_scope):
        match corr_scope:
            case 'global':
                self._correlation_matrices = [np.loadtxt(f'{corr_name}/{corr_scope}_corr.csv', delimiter=',')]
            case 'local':
                self._correlation_matrices = []
                corr_files = natsorted(glob.glob(f'{corr_name}/local_corr_*.csv'))
                for corr_file in corr_files:
                    matrix = np.loadtxt(corr_file, delimiter=',')
                    self._correlation_matrices.append(matrix)
        
        # Thesholding
        _correlation_threshold = 0.9
        for matrix in self._correlation_matrices:
            matrix[matrix < _correlation_threshold] = 0
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)

        # Round data size to nearest multiple of batch_size
        self.days_in_quarter = 64
        num_quarters = data.size(0) // self.days_in_quarter
        num_days = num_quarters * self.days_in_quarter
        data = data[:num_days]
        
        # z-score normalization with training data following GERU
        train_days = int(0.8 * num_quarters) * self.days_in_quarter
        data = (data - data[:train_days].mean(dim=0)) / data[:train_days].std(dim=0)
        data = data.numpy()
        np.savetxt('s&p500_z_scores.csv', data, delimiter=',')
        self._dataset = data

    def _get_edges(self):
        if len(self._correlation_matrices) == 1:
            self._edges = np.array(self._correlation_matrix.nonzero())
        else:
            self._edges = []
            for time in range(self._dataset.shape[0] - self.lags):
                corr_index = time // self.days_in_quarter
                self._edges.append(
                    np.array(self._correlation_matrices[corr_index].nonzero())
                )

    def _get_edge_weights(self):
        if len(self._correlation_matrices) == 1:
            self._edge_weights = self._correlation_matrix[self._correlation_matrix > 0]
        else:
            self._edge_weights = []
            for time in range(self._dataset.shape[0] - self.lags):
                corr_index = time // self.days_in_quarter
                self._edge_weights.append(
                    np.array(self._correlation_matrices[corr_index][self._correlation_matrices[corr_index] > 0])
                )

    def _get_targets_and_features(self):
        stacked_target = self._dataset
        self.features = [
            stacked_target[i : i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]
        # predict next-day stock movement
        self.targets = [
            ((stacked_target[i + self.lags, :] > stacked_target[i + self.lags - 1, :]).astype(float)).T
            for i in range(stacked_target.shape[0] - self.lags)
        ]

    def get_dataset(self) -> Union[StaticGraphTemporalSignalBatch, DynamicGraphTemporalSignalBatch]:
        """Returning the data iterator.
        """
        self.lags = self.days_in_quarter
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        self.batches = np.repeat(np.arange((self._dataset.shape[0] - self.lags) // self.days_in_quarter), self.days_in_quarter)
        dataset = (DynamicGraphTemporalSignalBatch if type(self._edges) == list else StaticGraphTemporalSignalBatch)(
            self._edges, self._edge_weights, self.features, self.targets, self.batches
        )
        return dataset

In [31]:
from torch_geometric_temporal.signal import temporal_signal_split

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'

corr_name = 'mi'
corr_scope = 'local'
loader = SP500CorrelationsDatasetLoader(corr_name=corr_name, corr_scope=corr_scope)

dataset = loader.get_dataset()
lags = loader.lags

train_dataset, test_val_dataset = temporal_signal_split(dataset, train_ratio=0.8)
val_dataset, test_dataset = temporal_signal_split(test_val_dataset, train_ratio=0.5)

# Evaluation

In [32]:
from torcheval.metrics.functional import binary_f1_score, binary_accuracy

def accuracy(y_hats, ys):
    return binary_accuracy(y_hats.flatten(), ys.flatten()).item()

def f1(y_hats, ys):
    return binary_f1_score(y_hats.flatten(), ys.flatten()).item()


# RGCN

In [33]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 64, 1)
        self.linear = torch.nn.Linear(64, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        return F.sigmoid(self.linear(h))

In [34]:
from tqdm import tqdm
import wandb

model = RecurrentGCN(node_features = lags).to(device)

lr = 1e-3
num_epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
track_with_wandb = True

if track_with_wandb:
    wandb.init(project="cs224w-stock-market-prediction", config={
        "dataset": "S&P500",
        "corr_name": corr_name,
        "corr_scope": corr_scope,
        "learning_rate": lr,
        "epochs": num_epochs,
        "architecture": "DCRNN",
    })

best_f1 = 0

for epoch in tqdm(range(num_epochs)):
    model.train()
    train_loss = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device))
        loss = F.binary_cross_entropy(y_hat.squeeze(), snapshot.y)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    train_loss /= (time+1)

    if track_with_wandb:
        model.eval()
        with torch.no_grad():
            val_y_hats, val_ys = zip(*[(model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device)), snapshot.y)
                       for time, snapshot in enumerate(val_dataset)])
            val_y_hats, val_ys = torch.stack(list(val_y_hats)).squeeze().to(device), torch.stack(list(val_ys)).to(device)
            val_acc = accuracy(val_y_hats, val_ys)
            val_f1 = f1(val_y_hats, val_ys)
            wandb.log({"epoch": epoch,
                    "train/loss": train_loss,
                    "val/acc": val_acc,
                    "val/f1": val_f1 })
            print(f'Epoch {epoch}, val/acc: {val_acc}, val/f1: {val_f1}')
            if val_f1 > best_f1:
                best_f1 = val_f1
                torch.save(model.state_dict(), 'dcrnn_best_model.pth')


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇███
test/acc,▁
test/f1,▁
train/loss,█▄▃▂▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
val/acc,▁▁▁▁▁███▂▇█▁▁█▂▂▁▂▁▁
val/f1,▁▂▂▃▃███▃██▂▂█▂▂▂▂▂▂

0,1
epoch,19.0
test/acc,0.53708
test/f1,0.69185
train/loss,0.69197
val/acc,0.49125
val/f1,0.27127


  5%|▌         | 1/20 [00:10<03:25, 10.80s/it]

Epoch 0, val/acc: 0.5114389061927795, val/f1: 0.6593183279037476


 10%|█         | 2/20 [00:23<03:34, 11.94s/it]

Epoch 1, val/acc: 0.4902612268924713, val/f1: 0.23769477009773254


 15%|█▌        | 3/20 [00:36<03:30, 12.37s/it]

Epoch 2, val/acc: 0.5117702484130859, val/f1: 0.6603999733924866


 20%|██        | 4/20 [00:48<03:18, 12.41s/it]

Epoch 3, val/acc: 0.4902088940143585, val/f1: 0.2013576626777649


 25%|██▌       | 5/20 [01:01<03:07, 12.48s/it]

Epoch 4, val/acc: 0.5118748545646667, val/f1: 0.6542878746986389


 30%|███       | 6/20 [01:13<02:54, 12.45s/it]

Epoch 5, val/acc: 0.5114215016365051, val/f1: 0.655341625213623


 35%|███▌      | 7/20 [01:25<02:39, 12.28s/it]

Epoch 6, val/acc: 0.5115784406661987, val/f1: 0.655494749546051


 40%|████      | 8/20 [01:38<02:28, 12.38s/it]

Epoch 7, val/acc: 0.5117179155349731, val/f1: 0.6622521281242371


 45%|████▌     | 9/20 [01:51<02:17, 12.46s/it]

Epoch 8, val/acc: 0.5113255977630615, val/f1: 0.6542853116989136


 50%|█████     | 10/20 [02:02<02:02, 12.21s/it]

Epoch 9, val/acc: 0.5117179155349731, val/f1: 0.6616277098655701


 55%|█████▌    | 11/20 [02:14<01:48, 12.07s/it]

Epoch 10, val/acc: 0.49080178141593933, val/f1: 0.2851267158985138


 60%|██████    | 12/20 [02:25<01:35, 11.91s/it]

Epoch 11, val/acc: 0.5110989212989807, val/f1: 0.6507166028022766


 65%|██████▌   | 13/20 [02:38<01:23, 11.98s/it]

Epoch 12, val/acc: 0.5112035274505615, val/f1: 0.6653294563293457


 70%|███████   | 14/20 [02:50<01:12, 12.04s/it]

Epoch 13, val/acc: 0.49496060609817505, val/f1: 0.39971813559532166


 75%|███████▌  | 15/20 [03:03<01:01, 12.25s/it]

Epoch 14, val/acc: 0.5111250877380371, val/f1: 0.6674337387084961


 80%|████████  | 16/20 [03:14<00:48, 12.12s/it]

Epoch 15, val/acc: 0.5113255977630615, val/f1: 0.6674261689186096


 85%|████████▌ | 17/20 [03:26<00:35, 11.98s/it]

Epoch 16, val/acc: 0.5110030174255371, val/f1: 0.6607674360275269


 90%|█████████ | 18/20 [03:38<00:24, 12.08s/it]

Epoch 17, val/acc: 0.5108983516693115, val/f1: 0.6506843566894531


 95%|█████████▌| 19/20 [03:51<00:12, 12.18s/it]

Epoch 18, val/acc: 0.5118051171302795, val/f1: 0.6670867800712585


100%|██████████| 20/20 [04:03<00:00, 12.16s/it]

Epoch 19, val/acc: 0.5117789506912231, val/f1: 0.6646665334701538





In [35]:
if track_with_wandb:
    best_model = RecurrentGCN(node_features = lags).to(device)
    best_model.load_state_dict(torch.load('dcrnn_best_model.pth', weights_only=True))
    best_model.eval()
    with torch.no_grad():
        y_hats, ys = zip(*[(best_model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device)), snapshot.y)
                       for time, snapshot in enumerate(test_dataset)])
        y_hats, ys = torch.stack(list(y_hats)).squeeze().to(device), torch.stack(list(ys)).to(device)
        wandb.log({"epoch": epoch,
                "test/acc": accuracy(y_hats, ys),
                "test/f1": f1(y_hats, ys) })