In [1]:
import torch
import sys
sys.path.append("/home/onyxia/work/Advanced-ML")
from data import S3ParquetReader
from config import USER
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import polars as pl
from sklearn.preprocessing import StandardScaler

## Data Loading

In [2]:
# load data from s3 storage
BUCKET = f"/{USER}/jane_street_data"
reader = S3ParquetReader(bucket=BUCKET)
data = reader.read_parquet("preprocessed.parquet/data_clean_symb_1.parquet")

# Split and standardization

In [3]:
# we pick responder_6 as the target (same target as the data challenge)
target = "responder_6"
features = [col for col in data.columns if "feature" in col]
X, y = data[features], data[target]

n = X.height
n_train = int(0.8 * n)
X_train = X.slice(0, n_train)
y_train = y.slice(0, n_train)
X_val = X.slice(n_train)
y_val = y.slice(n_train)

# # Scaling
# scaler_x = StandardScaler()
# scaler_y = StandardScaler()

# # We use .to_numpy() for the scalers
# x_train_np = scaler_x.fit_transform(X_train_raw.to_numpy())
# # Important: y must be 2D for the scaler
# y_train_np = scaler_y.fit_transform(y_train_raw.to_numpy().reshape(-1, 1))

# x_val_np = scaler_x.transform(X_val_raw.to_numpy())
# y_val_np = scaler_y.transform(y_val_raw.to_numpy().reshape(-1, 1))

# Train

In [5]:
from models.transformers import TimeSeriesTransformer
from models.transformers_utils import train_model

n_features = X_train.shape[1]
lr = 0.5e-2
criterion = nn.MSELoss()
n_epochs = 10
batch_size = 2048
device = "cuda"
seq_len = 50
d_model = 16
num_heads = 4
num_layers = 3
d_ff = 16

model = TimeSeriesTransformer(
    n_features=n_features,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    d_ff=d_ff
)

optimizer = optim.Adam(model.parameters(), lr=lr)

r2_train_hist, r2_val_hist = train_model(
    model, optimizer, criterion, 
    X_train, y_train, X_val, y_val,
    epochs=n_epochs, batch_size=batch_size, seq_len = seq_len, device=device
)

 10%|█         | 1/10 [01:20<12:04, 80.48s/it]

Epoch 001 | R² Train: 0.7992 | R² Val: 0.8444


 20%|██        | 2/10 [02:38<10:34, 79.32s/it]

Epoch 002 | R² Train: 0.8402 | R² Val: 0.8522


 30%|███       | 3/10 [03:59<09:19, 79.87s/it]

Epoch 003 | R² Train: 0.8428 | R² Val: 0.8501


 40%|████      | 4/10 [05:17<07:55, 79.32s/it]

Epoch 004 | R² Train: 0.8453 | R² Val: 0.8563


 50%|█████     | 5/10 [06:37<06:37, 79.49s/it]

Epoch 005 | R² Train: 0.8467 | R² Val: 0.8343


 60%|██████    | 6/10 [07:56<05:17, 79.29s/it]

Epoch 006 | R² Train: 0.8461 | R² Val: 0.8291


 70%|███████   | 7/10 [09:18<04:00, 80.18s/it]

Epoch 007 | R² Train: 0.8478 | R² Val: 0.8134


 80%|████████  | 8/10 [10:38<02:40, 80.02s/it]

Epoch 008 | R² Train: 0.8484 | R² Val: 0.8408


 90%|█████████ | 9/10 [11:58<01:19, 79.93s/it]

Epoch 009 | R² Train: 0.8485 | R² Val: 0.7691


100%|██████████| 10/10 [13:16<00:00, 79.66s/it]

Epoch 010 | R² Train: 0.8471 | R² Val: 0.7955



