In [None]:
# imports
import torch
from stock_dataloader import create_stock_dataloader
from lstm import StockLSTM
from transformer import StockTransformer
from model_trainer import train_model, save_model, load_model
from evaluation import evaluate_model

In [None]:
# Create dataloader

# Hyperparameters
SEQ_LEN = 100           # default 100; window for set of time-series data points
BATCH_SIZE = 32         # default 32; increase if GPU mem allows
STOCKS_PER_BUCKET = 5   # default 13; number of stocks per category bucket
TRAIN_PER_BUCKET = 3    # default 10; number of training stocks per category bucket

stock_csv = 'selected_stocks_data.csv'
metadata_csv = 'selected_stocks_quality.csv'
stock_dataloader = create_stock_dataloader(stock_csv, metadata_csv, seq_len=SEQ_LEN, batch_size=BATCH_SIZE,
                                           stocks_per_bucket=STOCKS_PER_BUCKET, train_per_bucket=TRAIN_PER_BUCKET)
train_loader = stock_dataloader['train_loader']
eval_loader = stock_dataloader['eval_loader']

In [None]:
# Create LSTM Model

# Hyperparameters
INPUT_SIZE = 1          # default 1; based on data
HIDDEN_SIZE = 64        # default 64; analogous to D_MODEL; increase to 128 if underfitting
NUM_LAYERS = 2          # default 2, re-evaluate if underfitting
DROP_OUT = 0.2          # default 0.2; re-evaluate if overfitting

lstm_model = StockLSTM(input_size=INPUT_SIZE,
                  hidden_size=HIDDEN_SIZE,
                  num_layers=NUM_LAYERS,
                  dropout=DROP_OUT)

In [None]:
# Create Transformer Model

# Hyperparameters
INP_DIM = 1             # default 1; based on data
D_MODEL = 64            # default 64; analogous to HIDDEN_SIZE; re-evaluate if underfitting
N_HEADS = 4             # default 4; 64/4 = 16 - standard ratio
N_LAYERS = 3            # default 3; re-evaluate if underfitting
DIM_FEEDFORWARD = 256   # default 256; 4x D_MODEL is standard
DROPOUT = 0.1           # default 0.1; re-evaluate if overfitting
OUTPUT_DIM = 1          # default 1; based on data - next-day closing price
MAX_LEN = 500           # default 500; should be > SEQ_LEN

transformer_model = StockTransformer(inp_dim=INP_DIM,
                         d_model=D_MODEL,
                         n_heads=N_HEADS,
                         n_layers=N_LAYERS,
                         dim_feedforward=DIM_FEEDFORWARD,
                         dropout=DROPOUT,
                         output_dim=OUTPUT_DIM,
                         max_len=MAX_LEN)

In [None]:
# Train model

# Hyperparameters
NUM_EPOCHS = 50         # default 50; increase if underfitting
LEARNING_RATE = 0.001   # default 0.001; drop to 3e-4 if unstable
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model_choice = 'Transformer'   # Select 'LSTM' or 'Transformer'
model_save_name = f'Stock{model_choice}_ModelMini'
model_save = True
model_load = True

print(f"Training {model_choice} model on device: {DEVICE}")
if model_choice == 'LSTM':
    if model_load:
        try:
            trained_model = load_model(f'models/{model_save_name}')
        except FileNotFoundError as e:
            print(e)
            trained_model = train_model(lstm_model, train_loader, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, device=DEVICE)
    else:
        trained_model = train_model(lstm_model, train_loader, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, device=DEVICE)
    if model_save:
        save_model(trained_model, save_name=model_save_name)
elif model_choice == 'Transformer':
    if model_load:
        try:
            trained_model = load_model(f'models/{model_save_name}')
        except FileNotFoundError as e:
            print(e)
            trained_model = train_model(transformer_model, train_loader, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, device=DEVICE)
    else:
        trained_model = train_model(transformer_model, train_loader, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, device=DEVICE)
    if model_save:
        save_model(trained_model, save_name=model_save_name)


In [None]:
# Evaluate model
evaluate_model(trained_model, eval_loader, device=DEVICE)