In [2]:
import torch

torch.manual_seed(42)

<torch._C.Generator at 0x12136d330>

# Preprocessing

In [3]:
import pandas

df = pandas.read_csv('s&p500.csv')
df = df.set_index('Date')
df

Unnamed: 0_level_0,A,AAL,AAPL,ABBV,ABT,ACGL,ACN,ADBE,ADI,ADM,...,WTW,WY,WYNN,XEL,XOM,XYL,YUM,ZBH,ZBRA,ZTS
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2014-10-31 00:00:00+00:00,36.393887,39.179268,23.973495,41.956585,36.026073,18.773333,68.768227,70.120003,40.179710,35.525333,...,91.401855,22.913128,162.714340,24.432884,62.614067,32.120022,42.933430,99.750992,73.750000,34.574112
2014-11-03 00:00:00+00:00,37.588364,39.795147,24.284258,41.817753,35.778130,18.799999,68.641075,69.910004,40.260685,35.714283,...,90.792976,23.197344,162.286133,24.637280,61.675262,31.837324,42.772045,99.006737,73.879997,35.163273
2014-11-04 00:00:00+00:00,36.980907,40.458393,24.106678,41.427670,36.042606,18.910000,68.793671,71.070000,40.042057,37.445198,...,90.725319,23.319159,158.646683,24.432884,61.196133,31.748985,43.070908,97.930672,72.779999,36.468380
2014-11-05 00:00:00+00:00,36.934883,40.032017,24.164398,41.348339,36.083923,18.906668,69.921150,71.370003,40.511711,38.027210,...,92.010750,23.265015,153.919647,24.885473,61.552242,32.146519,43.298035,99.526817,73.070000,36.608208
2014-11-06 00:00:00+00:00,38.076160,40.875298,24.233505,41.460724,36.108704,19.026667,70.565430,72.099998,40.665558,37.770222,...,93.950172,23.170267,153.037628,24.126284,62.322708,32.782562,43.692513,98.809441,73.879997,37.018387
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2024-10-25 00:00:00+00:00,130.190002,13.150000,231.410004,187.850006,114.220001,105.300003,360.799988,483.720001,230.169998,56.560001,...,290.450012,31.799999,97.989998,64.459999,119.489998,130.419998,133.039993,102.349998,360.089996,180.009995
2024-10-28 00:00:00+00:00,131.539993,13.600000,233.399994,189.679993,114.070000,105.660004,361.320007,481.040009,230.139999,57.240002,...,292.119995,31.950001,98.620003,64.480003,118.900002,130.789993,134.860001,103.599998,363.579987,182.759995
2024-10-29 00:00:00+00:00,131.229996,13.820000,233.669998,189.449997,113.400002,104.800003,363.040009,485.390015,235.229996,56.320000,...,292.179993,30.879999,99.000000,63.340000,117.279999,129.889999,133.970001,103.879997,384.679993,181.270004
2024-10-30 00:00:00+00:00,131.490005,13.940000,230.100006,201.500000,114.449997,105.139999,346.570007,486.679993,230.119995,55.529999,...,293.540009,31.719999,98.489998,63.049999,116.690002,130.220001,133.389999,109.809998,383.890015,182.740005


In [24]:
import torch

def preprocess(df):
    data = torch.from_numpy(df.to_numpy()).to(torch.float32)
    ratios = torch.zeros_like(data)
    # calculate daily return ratio
    for d in range(1, data.size(0)):
        ratios[d-1] = (data[d] - data[d-1]) / data[d-1]
    # skip the first day which cannot calculate daily return ratio
    # and round data size to nearest multiple of batch_size
    days_in_quarter = 64
    num_quarters = data.size(0) // days_in_quarter
    ratios = ratios[:num_quarters * days_in_quarter]
    # split into batches
    ratios = ratios.view(num_quarters, days_in_quarter, 472)
    train_data = ratios[:int(num_quarters*0.8)]
    val_data = ratios[int(num_quarters*0.8):int(num_quarters*0.9)]
    test_data = ratios[int(num_quarters*0.9):]
    # shuffle train batches
    train_data = train_data[torch.randperm(train_data.size(0))]
    return train_data, val_data, test_data

train_data, val_data, test_data = preprocess(df)
print('Train data size:', train_data.size())
print('Validation data size:', val_data.size())
print('Test data size:', test_data.size())

Train data size: torch.Size([31, 64, 472])
Validation data size: torch.Size([4, 64, 472])
Test data size: torch.Size([4, 64, 472])


# Model

In [25]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, hidden_size, expand_ratio, dropout):
        super(FeedForward, self).__init__()
        self.linear = nn.Linear(hidden_size, hidden_size * expand_ratio)
        self.linear2 = nn.Linear(hidden_size * expand_ratio, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

class Attention(nn.Module):
    def __init__(self, d_model, num_heads, expand_ratio, dropout):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(hidden_size=d_model, expand_ratio=expand_ratio, dropout=dropout)

    def forward(self, x, attn_mask=None):
        x1, _ = self.mha(x, x, x, attn_mask=attn_mask, need_weights=False)
        x2 = self.ln1(x + x1)
        return self.ln2(self.ffn(x2) + x2)

class SpatialTemporalAttention(nn.Module):
    def __init__(self, d_model=32, num_heads=2, expand_ratio=1, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.input_proj = nn.Linear(1, d_model)
        # self.time_embedding = nn.Embedding(T, d_model)
        # self.stock_embedding = nn.Embedding(N, d_model)
        self.spatial_attn = Attention(d_model, num_heads, expand_ratio, dropout)
        self.temporal_attn = Attention(d_model, num_heads, expand_ratio, dropout)
        self.output_proj = nn.Linear(d_model, 1)
    
    def forward(self, x):
        T, N = x.size()
        # x = self.input_proj(x.view(T, N, 1)) + self.stock_embedding(torch.arange(N).unsqueeze(0).expand(T, N))
        x = self.input_proj(x.view(T, N, 1))
        # IDEA: Each spatial head takes in a different type of correlation matrix.
        # Like one takes in positive pearson's coefficnet and the other takes in negative
        x = self.spatial_attn(x).view(N, T, self.d_model)
        # x = self.time_embedding(torch.arange(T).unsqueeze(0).expand(N, T).to(x.device)) + x
        temporal_causal_mask = torch.triu(torch.ones((T, T), dtype=torch.bool), diagonal=1).expand(N * self.num_heads, T, T).to(x.device)
        x = self.temporal_attn(x, attn_mask=temporal_causal_mask)
        return self.output_proj(x).view(T, N)


# Training

In [53]:
import math

def eval(model, data):
    outs = []
    for batch in data:
        out = model(batch[:-1, :])
        outs.append(out)
    rmse = math.sqrt(nn.functional.mse_loss(torch.stack(outs), data[:, 1:, :]))
    return rmse

In [54]:
import wandb

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
model = SpatialTemporalAttention().to(device)
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

model.train()

num_epochs = 10

wandb.init(project="cs224w-stock-market-prediction", name='STAttn')

for epoch in range(num_epochs):
    optimizer.zero_grad()
    for batch_idx, batch in enumerate(train_data):
        out = model(batch[:-1, :])
        loss = loss_fn(out, batch[1:, :])
        wandb.log({"epoch": epoch, "train/rmse": math.sqrt(loss.item())})
        loss.backward()
        optimizer.step()
        # Evaluate on validation
        val_rmse = eval(model, val_data)
        wandb.log({"epoch": epoch, "val/rmse": val_rmse})

# Evaluate on test at the end
test_rmse = eval(model, test_data)
wandb.log({"epoch": epoch, "test/rmse": test_rmse})


wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇█████
test/rmse,▁
train/rmse,█▅▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/rmse,█▇▇▅▅▂▁▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,9.0
test/rmse,0.05994
train/rmse,0.05862
val/rmse,0.06031


In [55]:
torch.save(model.state_dict(), 'model.pth')