# Dataset Construction

In [1]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
import torch

class StaticSP500DatasetLoader(object):
    def __init__(self, correlation_type='pearsons'):
        self._read_csv(correlation_type)

    def _read_csv(self, correlation_type):
        self._correlation_matrix = np.fromfile(f's&p500_{correlation_type}.csv', sep=',')
        N = int(np.sqrt(len(self._correlation_matrix)))
        self._correlation_matrix = self._correlation_matrix.reshape(N, N)
        _correlation_threshold = 0.9
        self._correlation_matrix[self._correlation_matrix < _correlation_threshold] = 0
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)
        ratios = torch.zeros_like(data)
        # calculate daily return ratio
        for d in range(1, data.size(0)):
            ratios[d-1] = (data[d] - data[d-1]) / data[d-1]
        # skip the first day which cannot calculate daily return ratio
        # and round data size to nearest multiple of batch_size
        days_in_quarter = 64
        num_quarters = data.size(0) // days_in_quarter
        num_days = num_quarters * days_in_quarter
        ratios = ratios[:num_days]
        self._dataset = ratios.numpy()

    def _get_edges(self):
        self._edges = np.array(self._correlation_matrix.nonzero())

    def _get_edge_weights(self):
        self._edge_weights = self._correlation_matrix[self._correlation_matrix > 0]

    def _get_targets_and_features(self):
        stacked_target = self._dataset
        self.features = [
            stacked_target[i : i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]
        self.targets = [
            stacked_target[i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]

    def get_dataset(self, lags: int) -> StaticGraphTemporalSignal:
        """Returning the data iterator.

        Args types:
            * **lags** *(int)* - The number of time lags.
        Return types:
            * **dataset** *(StaticGraphTemporalSignal)*
        """
        self.lags = lags
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        dataset = StaticGraphTemporalSignal(
            self._edges, self._edge_weights, self.features, self.targets
        )
        return dataset

In [2]:
from torch_geometric_temporal.signal import temporal_signal_split

device = 'cpu'

loader = StaticSP500DatasetLoader()

lags = 63

dataset = loader.get_dataset(lags)

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

# RGCN

In [3]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [4]:
from tqdm import tqdm

model = RecurrentGCN(node_features = lags).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(100)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device))
        cost = cost + torch.mean((y_hat-snapshot.y.to(device))**2)
    cost = cost / (time+1)
    print(f'Epoch {epoch}, MSE: {cost.item()}')
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0, MSE: 0.010833659209311008


  1%|          | 1/100 [00:09<16:25,  9.95s/it]

Epoch 1, MSE: 0.007697213441133499


  2%|▏         | 2/100 [00:21<18:14, 11.17s/it]

Epoch 2, MSE: 0.0050129820592701435


  3%|▎         | 3/100 [00:34<18:45, 11.60s/it]

Epoch 3, MSE: 0.002805506344884634


  4%|▍         | 4/100 [00:47<19:37, 12.27s/it]

Epoch 4, MSE: 0.0012611948186531663


  5%|▌         | 5/100 [01:00<19:50, 12.54s/it]

Epoch 5, MSE: 0.0005417628563009202


  6%|▌         | 6/100 [01:13<19:56, 12.73s/it]

Epoch 6, MSE: 0.0006845638272352517


  7%|▋         | 7/100 [01:26<19:46, 12.76s/it]

Epoch 7, MSE: 0.0013658736133947968


  8%|▊         | 8/100 [01:38<19:26, 12.68s/it]

Epoch 8, MSE: 0.0019495455780997872


  9%|▉         | 9/100 [01:51<19:18, 12.73s/it]

Epoch 9, MSE: 0.0021083292085677385


 10%|█         | 10/100 [02:06<20:18, 13.54s/it]

Epoch 10, MSE: 0.0018905947217717767


 11%|█         | 11/100 [02:19<19:36, 13.22s/it]

Epoch 11, MSE: 0.0014780224300920963


 12%|█▏        | 12/100 [02:32<19:24, 13.24s/it]

Epoch 12, MSE: 0.0010442928178235888


 13%|█▎        | 13/100 [02:46<19:13, 13.25s/it]

Epoch 13, MSE: 0.0007028182735666633


 14%|█▍        | 14/100 [03:00<19:39, 13.72s/it]

Epoch 14, MSE: 0.0004998959484510124


 15%|█▌        | 15/100 [03:16<20:09, 14.23s/it]

Epoch 15, MSE: 0.0004297230625525117


 16%|█▌        | 16/100 [03:29<19:27, 13.89s/it]

Epoch 16, MSE: 0.0004570379969663918


 17%|█▋        | 17/100 [03:41<18:34, 13.43s/it]

Epoch 17, MSE: 0.0005376611952669919


 18%|█▊        | 18/100 [03:54<17:58, 13.16s/it]

Epoch 18, MSE: 0.0006320815300568938


 19%|█▉        | 19/100 [04:08<18:00, 13.34s/it]

Epoch 19, MSE: 0.0007118212524801493


 20%|██        | 20/100 [04:21<17:56, 13.45s/it]

Epoch 20, MSE: 0.000760625465773046


 21%|██        | 21/100 [04:34<17:14, 13.10s/it]

Epoch 21, MSE: 0.0007729039643891156


 22%|██▏       | 22/100 [04:48<17:29, 13.45s/it]

Epoch 22, MSE: 0.0007511576404795051


 23%|██▎       | 23/100 [05:01<17:04, 13.31s/it]

Epoch 23, MSE: 0.0007032848079688847


 24%|██▍       | 24/100 [05:14<16:44, 13.22s/it]

Epoch 24, MSE: 0.0006401057471521199


 25%|██▌       | 25/100 [05:26<16:13, 12.98s/it]

Epoch 25, MSE: 0.0005732049467042089


 26%|██▌       | 26/100 [05:39<15:56, 12.93s/it]

Epoch 26, MSE: 0.0005130909848958254


 27%|██▋       | 27/100 [05:52<15:52, 13.05s/it]

Epoch 27, MSE: 0.00046772274072282016


 28%|██▊       | 28/100 [06:05<15:37, 13.01s/it]

Epoch 28, MSE: 0.00044148764573037624


 29%|██▉       | 29/100 [06:18<15:09, 12.82s/it]

Epoch 29, MSE: 0.00043477045255713165


 30%|███       | 30/100 [06:30<14:47, 12.68s/it]

Epoch 30, MSE: 0.00044421630445867777


 31%|███       | 31/100 [06:43<14:40, 12.76s/it]

Epoch 31, MSE: 0.0004637207603082061


 32%|███▏      | 32/100 [06:55<14:15, 12.58s/it]

Epoch 32, MSE: 0.00048596897977404296


 33%|███▎      | 33/100 [07:07<13:50, 12.40s/it]

Epoch 33, MSE: 0.000504193885717541


 34%|███▍      | 34/100 [07:20<13:39, 12.42s/it]

Epoch 34, MSE: 0.0005136636318638921


 35%|███▌      | 35/100 [07:32<13:27, 12.42s/it]

Epoch 35, MSE: 0.000512521481141448


 36%|███▌      | 36/100 [07:45<13:17, 12.46s/it]

Epoch 36, MSE: 0.0005017827497795224


 37%|███▋      | 37/100 [07:57<13:06, 12.48s/it]

Epoch 37, MSE: 0.0004846023803111166


 38%|███▊      | 38/100 [08:11<13:12, 12.78s/it]

Epoch 38, MSE: 0.0004651318013202399


 39%|███▉      | 39/100 [08:23<12:56, 12.73s/it]

Epoch 39, MSE: 0.0004473435692489147


 40%|████      | 40/100 [08:36<12:37, 12.62s/it]

Epoch 40, MSE: 0.00043413706589490175


 41%|████      | 41/100 [08:48<12:16, 12.48s/it]

Epoch 41, MSE: 0.0004268856719136238


 42%|████▏     | 42/100 [09:01<12:10, 12.59s/it]

Epoch 42, MSE: 0.00042543330346234143


 43%|████▎     | 43/100 [09:13<11:52, 12.50s/it]

Epoch 43, MSE: 0.00042843198752962053


 44%|████▍     | 44/100 [09:25<11:32, 12.37s/it]

Epoch 44, MSE: 0.0004338687576819211


 45%|████▌     | 45/100 [09:38<11:36, 12.66s/it]

Epoch 45, MSE: 0.00043962447671219707


 46%|████▌     | 46/100 [09:51<11:18, 12.56s/it]

Epoch 46, MSE: 0.000443935306975618


 47%|████▋     | 47/100 [10:03<11:11, 12.67s/it]

Epoch 47, MSE: 0.00044570371392183006


 48%|████▊     | 48/100 [10:16<10:58, 12.66s/it]

Epoch 48, MSE: 0.0004446028033271432


 49%|████▉     | 49/100 [10:29<10:55, 12.86s/it]

Epoch 49, MSE: 0.00044101770617999136


 50%|█████     | 50/100 [10:44<11:08, 13.37s/it]

Epoch 50, MSE: 0.0004358432488515973


 51%|█████     | 51/100 [10:56<10:41, 13.08s/it]

Epoch 51, MSE: 0.00043020903831347823


 52%|█████▏    | 52/100 [11:10<10:29, 13.12s/it]

Epoch 52, MSE: 0.00042519738781265914


 53%|█████▎    | 53/100 [11:24<10:32, 13.46s/it]

Epoch 53, MSE: 0.00042160460725426674


 54%|█████▍    | 54/100 [11:38<10:24, 13.59s/it]

Epoch 54, MSE: 0.0004197918460704386


 55%|█████▌    | 55/100 [11:51<10:12, 13.60s/it]

Epoch 55, MSE: 0.00041966239223256707


 56%|█████▌    | 56/100 [12:05<10:02, 13.69s/it]

Epoch 56, MSE: 0.000420731637859717


 57%|█████▋    | 57/100 [12:19<09:48, 13.68s/it]

Epoch 57, MSE: 0.0004222900024615228


 58%|█████▊    | 58/100 [12:32<09:30, 13.59s/it]

Epoch 58, MSE: 0.00042360491352155805


 59%|█████▉    | 59/100 [12:45<09:08, 13.37s/it]

Epoch 59, MSE: 0.00042411487083882093


 60%|██████    | 60/100 [12:58<08:51, 13.30s/it]

Epoch 60, MSE: 0.0004235676024109125


 61%|██████    | 61/100 [13:12<08:41, 13.38s/it]

Epoch 61, MSE: 0.0004220517002977431


 62%|██████▏   | 62/100 [13:27<08:42, 13.75s/it]

Epoch 62, MSE: 0.00041992616024799645


 63%|██████▎   | 63/100 [13:41<08:38, 14.02s/it]

Epoch 63, MSE: 0.000417673378251493


 64%|██████▍   | 64/100 [13:55<08:27, 14.10s/it]

Epoch 64, MSE: 0.00041573456837795675


 65%|██████▌   | 65/100 [14:09<08:02, 13.80s/it]

Epoch 65, MSE: 0.00041438304469920695


 66%|██████▌   | 66/100 [14:22<07:46, 13.72s/it]

Epoch 66, MSE: 0.0004136827774345875


 67%|██████▋   | 67/100 [14:35<07:28, 13.58s/it]

Epoch 67, MSE: 0.0004135049821343273


 68%|██████▊   | 68/100 [14:48<07:05, 13.31s/it]

Epoch 68, MSE: 0.00041360006434842944


 69%|██████▉   | 69/100 [15:01<06:48, 13.17s/it]

Epoch 69, MSE: 0.0004136963689234108


 70%|███████   | 70/100 [15:13<06:27, 12.91s/it]

Epoch 70, MSE: 0.00041358298039995134


 71%|███████   | 71/100 [15:26<06:13, 12.87s/it]

Epoch 71, MSE: 0.00041316283750347793


 72%|███████▏  | 72/100 [15:39<06:02, 12.94s/it]

Epoch 72, MSE: 0.00041246626642532647


 73%|███████▎  | 73/100 [15:52<05:53, 13.09s/it]

Epoch 73, MSE: 0.0004116120981052518


 74%|███████▍  | 74/100 [16:05<05:36, 12.94s/it]

Epoch 74, MSE: 0.0004107549611944705


 75%|███████▌  | 75/100 [16:18<05:22, 12.91s/it]

Epoch 75, MSE: 0.0004100161022506654


 76%|███████▌  | 76/100 [16:32<05:16, 13.21s/it]

Epoch 76, MSE: 0.00040944464853964746


 77%|███████▋  | 77/100 [16:45<05:04, 13.24s/it]

Epoch 77, MSE: 0.00040902523323893547


 78%|███████▊  | 78/100 [16:58<04:49, 13.18s/it]

Epoch 78, MSE: 0.0004087116976734251


 79%|███████▉  | 79/100 [17:12<04:41, 13.40s/it]

Epoch 79, MSE: 0.000408459163736552


 80%|████████  | 80/100 [17:26<04:30, 13.53s/it]

Epoch 80, MSE: 0.0004082299128640443


 81%|████████  | 81/100 [17:38<04:11, 13.24s/it]

Epoch 81, MSE: 0.0004079968493897468


 82%|████████▏ | 82/100 [17:51<03:56, 13.15s/it]

Epoch 82, MSE: 0.000407740066293627


 83%|████████▎ | 83/100 [18:04<03:43, 13.13s/it]

Epoch 83, MSE: 0.00040744198486208916


 84%|████████▍ | 84/100 [18:16<03:24, 12.76s/it]

Epoch 84, MSE: 0.00040710344910621643


 85%|████████▌ | 85/100 [18:29<03:10, 12.73s/it]

Epoch 85, MSE: 0.00040673959301784635


 86%|████████▌ | 86/100 [18:42<02:57, 12.68s/it]

Epoch 86, MSE: 0.00040637943311594427


 87%|████████▋ | 87/100 [18:55<02:46, 12.83s/it]

Epoch 87, MSE: 0.0004060568171553314


 88%|████████▊ | 88/100 [19:09<02:38, 13.18s/it]

Epoch 88, MSE: 0.000405793049139902


 89%|████████▉ | 89/100 [19:22<02:23, 13.04s/it]

Epoch 89, MSE: 0.0004055896133650094


 90%|█████████ | 90/100 [19:34<02:07, 12.79s/it]

Epoch 90, MSE: 0.000405429134843871


 91%|█████████ | 91/100 [19:47<01:55, 12.79s/it]

Epoch 91, MSE: 0.0004052751755807549


 92%|█████████▏| 92/100 [19:59<01:42, 12.76s/it]

Epoch 92, MSE: 0.0004050922580063343


 93%|█████████▎| 93/100 [20:12<01:29, 12.72s/it]

Epoch 93, MSE: 0.00040486891521140933


 94%|█████████▍| 94/100 [20:27<01:20, 13.48s/it]

Epoch 94, MSE: 0.00040461879689246416


 95%|█████████▌| 95/100 [20:40<01:06, 13.25s/it]

Epoch 95, MSE: 0.0004043695516884327


 96%|█████████▌| 96/100 [20:54<00:53, 13.49s/it]

Epoch 96, MSE: 0.000404141319449991


 97%|█████████▋| 97/100 [21:08<00:41, 13.79s/it]

Epoch 97, MSE: 0.000403938494855538


 98%|█████████▊| 98/100 [21:22<00:27, 13.82s/it]

Epoch 98, MSE: 0.0004037547914776951


 99%|█████████▉| 99/100 [21:35<00:13, 13.53s/it]

Epoch 99, MSE: 0.0004035794408991933


100%|██████████| 100/100 [21:47<00:00, 13.08s/it]


In [5]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {}".format(cost))

MSE: 0.0003441054723225534
