In [1]:
import numpy as np 
import torch
import os 
import pickle
from src.model import StockMixer
from src.train import get_batch, validate, train

In [2]:
np.random.seed(123456789)
torch.random.manual_seed(12345678)
device = torch.device("cuda") if torch.cuda.is_available() else 'cpu'

data_path = '../dataset'
market_name = 'NASDAQ'
relation_name = 'wikidata'
stock_num = 1026
lookback_length = 16
epochs = 1
valid_index = 756
test_index = 1008
fea_num = 5
market_num = 20
steps = 1
learning_rate = 0.001
alpha = 0.1
scale_factor = 3
activation = 'GELU'

In [3]:
dataset_path = 'dataset/' + market_name
with open(os.path.join(dataset_path, "eod_data.pkl"), "rb") as f:
    eod_data = pickle.load(f)
with open(os.path.join(dataset_path, "mask_data.pkl"), "rb") as f:
    mask_data = pickle.load(f)
with open(os.path.join(dataset_path, "gt_data.pkl"), "rb") as f:
    gt_data = pickle.load(f)
with open(os.path.join(dataset_path, "price_data.pkl"), "rb") as f:
    price_data = pickle.load(f)

In [4]:
trade_dates = mask_data.shape[1]
model = StockMixer(
    stocks=stock_num,
    time_steps=lookback_length,
    channels=fea_num,
    market=market_num,
    scale=scale_factor
).to(device)

In [5]:
import pandas as pd
dates = pd.read_csv('dataset/NASDAQ/dates.csv', header=None)[0].tolist()
tickers = pd.read_csv('dataset/NASDAQ/tickers.csv', header=None)[0].tolist()

In [6]:
train(model)

epoch1##########################################################


100%|████████████████████████████████████████████████████████████████████████████████| 740/740 [00:24<00:00, 29.61it/s]


Train : loss:6.71e-02  =  6.70e-02 + alpha*1.38e-03
Valid : loss:1.28e-02  =  1.27e-02 + alpha*1.03e-03
Test: loss:6.32e-02  =  6.31e-02 + alpha*1.46e-03
Valid performance:
 mse:1.28e-02, IC:2.72e-02, RIC:2.22e-01, prec@10:5.18e-01, SR:2.07e+00
Test performance:
 mse:6.33e-02, IC:7.99e-03, RIC:1.09e-01, prec@10:5.04e-01, SR:7.96e-01 




(        prediction  ground_truth   id
 0         0.503555     -0.055923  740
 1         0.584760     -0.046081  740
 2         0.622293      0.000854  740
 3         0.624257     -0.027092  740
 4         0.759180     -0.029208  740
 ...            ...           ...  ...
 258547    0.697878     -0.002095  991
 258548    0.705162     -0.037346  991
 258549    0.750173      0.005843  991
 258550    0.569513     -0.010418  991
 258551    0.521542     -0.002284  991
 
 [258552 rows x 3 columns],
         prediction  ground_truth    id
 0         0.487662      0.005948   992
 1         0.735221      0.009077   992
 2         0.613335      0.002849   992
 3         0.716638     -0.011506   992
 4         0.679961      0.009832   992
 ...            ...           ...   ...
 243157    0.639666      0.002228  1228
 243158    0.627578      0.013216  1228
 243159    0.705741      0.009115  1228
 243160    0.835815      0.012404  1228
 243161    0.584964      0.002363  1228
 
 [243162 rows x 3 co