# Dataset Construction

In [9]:
import numpy as np
import pandas as pd
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
import torch

class StaticSP500DatasetLoader(object):
    def __init__(self, correlation_type='pearsons'):
        self._read_csv(correlation_type)

    def _read_csv(self, correlation_type):
        self._correlation_matrix = np.fromfile(f's&p500_{correlation_type}.csv', sep=',')
        N = int(np.sqrt(len(self._correlation_matrix)))
        self._correlation_matrix = self._correlation_matrix.reshape(N, N)
        _correlation_threshold = 0.9
        self._correlation_matrix[self._correlation_matrix < _correlation_threshold] = 0
        
        df = pd.read_csv('s&p500.csv')
        df = df.set_index('Date')
        data = torch.from_numpy(df.to_numpy()).to(torch.float32)
        ratios = torch.zeros_like(data)
        # calculate daily return ratio
        for d in range(1, data.size(0)):
            ratios[d-1] = (data[d] - data[d-1]) / data[d-1]
        # skip the first day which cannot calculate daily return ratio
        # and round data size to nearest multiple of batch_size
        days_in_quarter = 64
        num_quarters = data.size(0) // days_in_quarter
        num_days = num_quarters * days_in_quarter
        ratios = ratios[:num_days]
        self._dataset = ratios.numpy()

    def _get_edges(self):
        self._edges = np.array(self._correlation_matrix.nonzero())

    def _get_edge_weights(self):
        self._edge_weights = self._correlation_matrix[self._correlation_matrix > 0]

    def _get_targets_and_features(self):
        stacked_target = self._dataset
        self.features = [
            stacked_target[i : i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]
        self.targets = [
            stacked_target[i + self.lags, :].T
            for i in range(stacked_target.shape[0] - self.lags)
        ]

    def get_dataset(self, lags: int) -> StaticGraphTemporalSignal:
        """Returning the data iterator.

        Args types:
            * **lags** *(int)* - The number of time lags.
        Return types:
            * **dataset** *(StaticGraphTemporalSignal)*
        """
        self.lags = lags
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        dataset = StaticGraphTemporalSignal(
            self._edges, self._edge_weights, self.features, self.targets
        )
        return dataset

In [10]:
from torch_geometric_temporal.signal import temporal_signal_split

device = 'cpu'

loader = StaticSP500DatasetLoader()

lags = 64

dataset = loader.get_dataset(lags)

train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

# RGCN

In [11]:
import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [12]:
from tqdm import tqdm

model = RecurrentGCN(node_features = lags).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for epoch in tqdm(range(100)):
    cost = 0
    for time, snapshot in enumerate(train_dataset):
        y_hat = model(snapshot.x.to(device), snapshot.edge_index.to(device), snapshot.edge_attr.to(device))
        cost = cost + torch.mean((y_hat-snapshot.y.to(device))**2)
    cost = cost / (time+1)
    print(f'Epoch {epoch}, MSE: {cost.item()}')
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0, MSE: 0.022084178403019905


  1%|          | 1/100 [00:12<20:15, 12.28s/it]

Epoch 1, MSE: 0.01684962399303913


  2%|▏         | 2/100 [00:35<30:44, 18.82s/it]

Epoch 2, MSE: 0.01205836609005928


  3%|▎         | 3/100 [00:53<29:23, 18.19s/it]

Epoch 3, MSE: 0.007785529363900423


  4%|▍         | 4/100 [01:06<26:23, 16.49s/it]

Epoch 4, MSE: 0.0043006958439946175


  5%|▌         | 5/100 [01:21<24:49, 15.68s/it]

Epoch 5, MSE: 0.0018705528927966952


  6%|▌         | 6/100 [01:36<24:08, 15.41s/it]

Epoch 6, MSE: 0.0007183755515143275


  7%|▋         | 7/100 [01:53<24:46, 15.98s/it]

Epoch 7, MSE: 0.0008846055716276169


  8%|▊         | 8/100 [02:07<23:48, 15.53s/it]

Epoch 8, MSE: 0.0019559580832719803


  9%|▉         | 9/100 [02:22<23:08, 15.26s/it]

Epoch 9, MSE: 0.003061587456613779


 10%|█         | 10/100 [02:38<23:25, 15.62s/it]

Epoch 10, MSE: 0.003575196024030447


 11%|█         | 11/100 [02:53<22:32, 15.20s/it]

Epoch 11, MSE: 0.0034212584141641855


 12%|█▏        | 12/100 [03:08<22:19, 15.22s/it]

Epoch 12, MSE: 0.0028189164586365223


 13%|█▎        | 13/100 [03:23<22:12, 15.32s/it]

Epoch 13, MSE: 0.0020476889330893755


 14%|█▍        | 14/100 [03:39<22:01, 15.37s/it]

Epoch 14, MSE: 0.0013331820955500007


 15%|█▌        | 15/100 [03:58<23:30, 16.59s/it]

Epoch 15, MSE: 0.0008082785061560571


 16%|█▌        | 16/100 [04:13<22:17, 15.92s/it]

Epoch 16, MSE: 0.0005158912390470505


 17%|█▋        | 17/100 [04:29<22:04, 15.95s/it]

Epoch 17, MSE: 0.0004327164788264781


 18%|█▊        | 18/100 [04:44<21:28, 15.72s/it]

Epoch 18, MSE: 0.0004995304043404758


 19%|█▉        | 19/100 [04:58<20:39, 15.30s/it]

Epoch 19, MSE: 0.0006475562695413828


 20%|██        | 20/100 [05:13<20:06, 15.08s/it]

Epoch 20, MSE: 0.0008159090648405254


 21%|██        | 21/100 [05:27<19:27, 14.78s/it]

Epoch 21, MSE: 0.000959953002166003


 22%|██▏       | 22/100 [05:41<18:44, 14.42s/it]

Epoch 22, MSE: 0.0010530605213716626


 23%|██▎       | 23/100 [05:55<18:22, 14.32s/it]

Epoch 23, MSE: 0.0010847091907635331


 24%|██▍       | 24/100 [06:10<18:23, 14.52s/it]

Epoch 24, MSE: 0.0010570604354143143


 25%|██▌       | 25/100 [06:24<18:15, 14.61s/it]

Epoch 25, MSE: 0.0009812420466914773


 26%|██▌       | 26/100 [06:39<17:50, 14.47s/it]

Epoch 26, MSE: 0.0008738181786611676


 27%|██▋       | 27/100 [06:54<17:55, 14.73s/it]

Epoch 27, MSE: 0.0007535990444011986


 28%|██▊       | 28/100 [07:08<17:21, 14.46s/it]

Epoch 28, MSE: 0.0006387936882674694


 29%|██▉       | 29/100 [07:22<16:53, 14.27s/it]

Epoch 29, MSE: 0.0005445484421215951


 30%|███       | 30/100 [07:35<16:25, 14.08s/it]

Epoch 30, MSE: 0.0004809717647731304


 31%|███       | 31/100 [07:48<15:53, 13.82s/it]

Epoch 31, MSE: 0.00045189596130512655


 32%|███▏      | 32/100 [08:01<15:12, 13.42s/it]

Epoch 32, MSE: 0.0004546280833892524


 33%|███▎      | 33/100 [08:14<14:45, 13.21s/it]

Epoch 33, MSE: 0.00048087415052577853


 34%|███▍      | 34/100 [08:26<14:24, 13.10s/it]

Epoch 34, MSE: 0.0005188047070987523


 35%|███▌      | 35/100 [08:39<14:06, 13.02s/it]

Epoch 35, MSE: 0.0005558606353588402


 36%|███▌      | 36/100 [08:52<13:50, 12.98s/it]

Epoch 36, MSE: 0.000581616535782814


 37%|███▋      | 37/100 [09:05<13:38, 12.99s/it]

Epoch 37, MSE: 0.0005899238749407232


 38%|███▊      | 38/100 [09:18<13:21, 12.92s/it]

Epoch 38, MSE: 0.0005797534249722958


 39%|███▉      | 39/100 [09:31<13:14, 13.02s/it]

Epoch 39, MSE: 0.0005546510801650584


 40%|████      | 40/100 [09:44<12:52, 12.88s/it]

Epoch 40, MSE: 0.000521110778208822


 41%|████      | 41/100 [09:57<12:38, 12.85s/it]

Epoch 41, MSE: 0.0004865115915890783


 42%|████▏     | 42/100 [10:09<12:24, 12.84s/it]

Epoch 42, MSE: 0.0004572265315800905


 43%|████▎     | 43/100 [10:22<12:08, 12.79s/it]

Epoch 43, MSE: 0.0004373649135231972


 44%|████▍     | 44/100 [10:35<12:00, 12.87s/it]

Epoch 44, MSE: 0.0004282970039639622


 45%|████▌     | 45/100 [10:50<12:28, 13.61s/it]

Epoch 45, MSE: 0.0004289110074751079


 46%|████▌     | 46/100 [11:05<12:27, 13.84s/it]

Epoch 46, MSE: 0.00043634528992697597


 47%|████▋     | 47/100 [11:19<12:21, 13.98s/it]

Epoch 47, MSE: 0.00044694149983115494


 48%|████▊     | 48/100 [11:33<12:08, 14.01s/it]

Epoch 48, MSE: 0.0004571536846924573


 49%|████▉     | 49/100 [11:47<11:52, 13.98s/it]

Epoch 49, MSE: 0.000464248179923743


 50%|█████     | 50/100 [12:02<11:51, 14.22s/it]

Epoch 50, MSE: 0.000466705794679001


 51%|█████     | 51/100 [12:16<11:35, 14.19s/it]

Epoch 51, MSE: 0.0004643096763174981


 52%|█████▏    | 52/100 [12:33<11:35, 14.50s/it]


KeyboardInterrupt: 

In [14]:
model.eval()
cost = 0
for time, snapshot in enumerate(test_dataset):
    y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    cost = cost + torch.mean((y_hat-snapshot.y)**2)
cost = cost / (time+1)
cost = cost.item()
print("MSE: {}".format(cost))

MSE: 0.0003744852147065103
