In [1]:
import torch

torch.manual_seed(42)

<torch._C.Generator at 0x7f64d13e9cb0>

# Preprocessing

In [2]:
import torch
import numpy as np

def preprocess():
    data = torch.from_numpy(np.loadtxt('s&p500_z_scores.csv', delimiter=',')).to(torch.float32)
    days_in_quarter = 64
    num_quarters = data.size(0) // days_in_quarter
    data = data.view(num_quarters, days_in_quarter, -1)

    train_data = data[:int(num_quarters*0.8)]
    val_data = data[int(num_quarters*0.8):int(num_quarters*0.9)]
    test_data = data[int(num_quarters*0.9):]
    # shuffle train batches
    train_data = train_data[torch.randperm(train_data.size(0))]
    return train_data, val_data, test_data

train_data, val_data, test_data = preprocess()
print('Train data size:', train_data.size())
print('Validation data size:', val_data.size())
print('Test data size:', test_data.size())

Train data size: torch.Size([31, 64, 472])
Validation data size: torch.Size([4, 64, 472])
Test data size: torch.Size([4, 64, 472])


# Model

In [32]:
import torch
import torch.nn as nn
from multihead_diffattn import MultiheadDiffAttn

class FeedForward(nn.Module):
    def __init__(self, hidden_size, expand_ratio, dropout):
        super(FeedForward, self).__init__()
        self.linear = nn.Linear(hidden_size, hidden_size * expand_ratio)
        self.linear2 = nn.Linear(hidden_size * expand_ratio, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

class Attention(nn.Module):
    def __init__(self, d_model, num_heads, expand_ratio, dropout, attn_variant='standard'):
        super().__init__()
        self.attn_variant = attn_variant
        if attn_variant == 'standard':
            self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
            self.ln1 = nn.LayerNorm(d_model)
            self.ln2 = nn.LayerNorm(d_model)
            self.ffn = FeedForward(hidden_size=d_model, expand_ratio=expand_ratio, dropout=dropout)
        elif attn_variant == 'diff':
            self.mha = MultiheadDiffAttn(embed_dim=d_model, num_heads=num_heads, depth=0)

    def forward(self, x, attn_mask=None, need_weights=False):
        if self.attn_variant == 'standard':
            x1, attn_weights = self.mha(x, x, x, attn_mask=attn_mask, need_weights=need_weights, average_attn_weights=False)
        elif self.attn_variant == 'diff':
            x1, attn_weights = self.mha(x, attn_mask=attn_mask)
        x2 = self.ln1(x + x1)
        x = self.ln2(self.ffn(x2) + x2)
        if need_weights:
            return (x, attn_weights)
        else:
            return x

class GraphTransformer(nn.Module):
    def __init__(self, attn_variant='standard', d_model=64, num_heads=2, expand_ratio=1, dropout=0.1, T = 64, N=train_data.size(-1)):
        super().__init__()
        self.T = T
        self.N = N
        self.d_model = d_model
        self.num_heads = num_heads
        self.input_proj = nn.Linear(1, d_model)
        self.time_embedding = nn.Embedding(T, d_model)
        self.stock_embedding = nn.Embedding(N, d_model)
        self.spatial_attn = Attention(d_model, num_heads, expand_ratio, dropout, attn_variant=attn_variant)
        self.temporal_attn = Attention(d_model, num_heads, expand_ratio, dropout)
        self.output_proj = nn.Linear(d_model, 1)
    
    def forward(self, x, need_weights=False):
        T, N = x.size()
        assert(T <= self.T and N == self.N)
        # print(x.size())
        # print(x.view(T, N, 1).size())
        x = self.input_proj(x.view(T, N, 1))
        # print('after input_proj', x.size())
        stock_embs = self.stock_embedding(torch.arange(N).unsqueeze(0).expand(T, N).to(x.device))
        # print('stock_embs.size', stock_embs.size())
        x += stock_embs
        time_embs = self.time_embedding(torch.arange(T).unsqueeze(0).expand(N, T).to(x.device))
        # print('time_embs.size', time_embs.size())
        x += time_embs.view(T, N, self.d_model)

        # x = self.input_proj(x.view(T, N, 1))
        # IDEA: Each spatial head takes in a different type of correlation matrix.
        # Like one takes in positive pearson's coefficnet and the other takes in negative

        x = x.view(N, T, self.d_model)
        temporal_causal_mask = torch.triu(torch.ones((T, T), dtype=torch.bool), diagonal=1).expand(N * self.num_heads, T, T).to(x.device)
        x = self.temporal_attn(x, attn_mask=temporal_causal_mask, need_weights=need_weights)
        if need_weights:
            x, temporal_attn_weights = x
        x = x.view(T, N, self.d_model)
        
        x = self.spatial_attn(x, need_weights=need_weights)
        if need_weights:
            x, spatial_attn_weights = x

        out = F.sigmoid(self.output_proj(x))
        if need_weights:
            return (out, spatial_attn_weights, temporal_attn_weights)
        else:
            return out

# Training

In [33]:
import seaborn as sns
import matplotlib.pylab as plt

def visualize_attn_weights(weights, title):
    fig, ax = plt.subplots(figsize=(20,15)) 
    sns.heatmap(weights, ax=ax)
    plt.title(title)
    plt.show()

In [7]:
import math
from torcheval.metrics.functional import binary_f1_score, binary_accuracy

def accuracy(y_hats, ys):
    return binary_accuracy(y_hats.flatten(), ys.flatten()).item()

def f1(y_hats, ys):
    return binary_f1_score(y_hats.flatten(), ys.flatten()).item()

In [8]:
def get_movements(data):
    return (data[..., :-1, :] < data[..., 1:, :]).float()

In [34]:
import math

def eval(model, data, epoch, need_weights=False):
    model.eval()
    outs = []
    with torch.no_grad():
        for batch_idx, batch in enumerate(data):
            out = model(batch[:-1, :], need_weights=need_weights)
            if need_weights:
                out, spatial_attn_weights, temporal_attn_weights = out
                if batch_idx == 0:
                    for h in [0, 1]:
                        visualize_attn_weights(spatial_attn_weights[0][h].cpu().numpy(), title=f'Spatial Attention Weights (epoch {epoch}, head {h})')
                        visualize_attn_weights(temporal_attn_weights[0][h].cpu().numpy(), title=f'Temporal Attention Weights (epoch {epoch}, head {h})')
            outs.append(out)
    outs = torch.stack(outs)
    y = get_movements(data)
    eval_acc = accuracy(outs, y)
    eval_f1 = f1(outs, y)
    return eval_acc, eval_f1

In [37]:
import wandb
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

model = GraphTransformer().to(device)
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Visualize attention weights
need_weights = False
track_with_wandb = True

num_epochs = 100

if track_with_wandb:
    wandb.init(project="cs224w-stock-market-prediction", config={
        "dataset": "S&P500",
        "learning_rate": lr,
        "epochs": num_epochs,
        "architecture": "DGT",
    })

for epoch in range(num_epochs):
    optimizer.zero_grad()
    epoch_loss = 0
    for batch_idx, batch in enumerate(train_data):
        model.train()
        out = model(batch[:-1, :])
        loss = F.binary_cross_entropy(out.squeeze(), get_movements(batch))
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    if track_with_wandb:
        train_acc, train_f1 = eval(model, train_data, epoch, need_weights=need_weights)
        wandb.log({"epoch": epoch, "train/loss": epoch_loss / len(train_data), 'train/acc': train_acc, 'train_f1': train_f1})
    # Evaluate on validation
    val_acc, val_f1 = eval(model, val_data, epoch, need_weights=need_weights)
    if track_with_wandb:
        wandb.log({"epoch": epoch, "val/acc": val_acc, "val/f1": val_f1})

# Evaluate on test at the end
test_acc, test_f1 = eval(model, test_data, epoch, need_weights=need_weights)
if track_with_wandb:
    wandb.log({"epoch": epoch, "test/acc": test_acc, "test/f1": test_f1})

if track_with_wandb:
    wandb.finish()

KeyboardInterrupt: 

In [9]:
torch.save(model.state_dict(), 'model.pth')