In [1]:
import json
import os.path

import pandas as pd
import torch.optim as optim
import torch
from torch.utils.data import DataLoader

from utils import find_last_model, load_data, handle_zigzag, label_data
from dataloader.SwingDatasets import HourlyDataset
from trainer import HourlySwingModelTrainer
from model.swing_model import HourlySwingModel

In [2]:
def prepare_data(data_df, tz):
    handle_zigzag(data_df)
    data_df.drop(["Volume MA", "RelativeIndice", "HighOrLow"], axis=1, inplace=True)
    data_df = data_df[['time', 'open', 'high', 'low', 'close', 'Volume', 'zigzag']]
    data_df.columns = ['time', 'open', 'high', 'low', 'close', 'volume', 'zigzag']
    data_df['time'] = pd.to_datetime(data_df['time'], utc=True)
    data_df['time'] = data_df['time'].dt.tz_convert(tz)
    label_data(data_df)
    data_df = data_df[data_df.value.notna()]
    assert data_df.shape[0] == (data_df.index[-1] + 1)
    assert len(data_df[data_df.isna().any(axis=1)]) == 0

    return data_df

In [3]:
data_base = "data/"
out_base = "output/"
path_dict = {
    "audusd": data_base + "AUDUSD.csv",
    "eurusd": data_base + "EURUSD.csv",
    "sp": data_base + "SP.csv",
    "btcusdt": data_base + "BTCUSDT.csv",
    "dogeusdt": data_base + "DOGEUSDT.csv",
}
tz_dict = {
    "audusd": "Etc/GMT-2",   # UTC+2
    "eurusd": "Etc/GMT-2",   # UTC+2
    "sp": "Etc/GMT-2",       # UTC+2
    "btcusdt": "Etc/GMT-0",  # UTC
    "dogeusdt": "Etc/GMT-0", # UTC
}

In [4]:
with open("config_train.json", "r") as file:
    config_train = json.load(file)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# train cfg
PAIR_PATH = path_dict[config_train["pair"]]
PAIR_TZ = tz_dict[config_train["pair"]]
WINDOW_SIZE = config_train["window_size"]
BATCH_SIZE = config_train["batch_size"]
VALUE_ONLY = config_train["value_only"]
LR = config_train["lr"]
EPOCHS = config_train["epochs"]
EVAL_PERIOD = config_train["eval_period"]
CHECKPOINT_PERIOD = config_train["checkpoint_period"]
MODEL_OUT_PATH = config_train["model_out"]

In [5]:
# model cfg
with open("config_model.json", "r") as file:
    config_model = json.load(file)

inp_dim = config_model["inp_dim"]
metadata_dim = config_model["metadata_dim"]
metadata_bias = config_model["metadata_bias"]
metadata_gate_bias = config_model["metadata_gate_bias"]
fusion_model_dim = config_model["fusion_model_dim"]
fusion_num_heads = config_model["fusion_num_heads"]
fusion_num_layers = config_model["fusion_num_layers"]
fusion_apply_grn = config_model["fusion_apply_grn"]
fusion_dropout = config_model["fusion_dropout"]
lstm_num_layers = config_model["lstm_num_layers"]
lstm_bidirectional = config_model["lstm_bidirectional"]
lstm_dropout = config_model["lstm_dropout"]
loss_punish_cert = config_model["loss_punish_cert"]

In [6]:
# load model
model = HourlySwingModel(inp_dim=inp_dim, metadata_dim=metadata_dim, metadata_bias=metadata_bias,
                         metadata_gate_bias=metadata_gate_bias, fusion_model_dim=fusion_model_dim,
                         fusion_num_heads=fusion_num_heads, fusion_num_layers=fusion_num_layers,
                         fusion_apply_grn=fusion_apply_grn, fusion_dropout=fusion_dropout,
                         lstm_num_layers=lstm_num_layers, lstm_bidirectional=lstm_bidirectional,
                         lstm_dropout=lstm_dropout, loss_punish_cert=loss_punish_cert)
model.to(device)
if config_train["load_model"] is not None:
    model_file = find_last_model(config_train["load_model"])
    print(f"Loading model from: {os.path.join(config_train['load_model'], model_file)}")
    model.load_state_dict(torch.load(os.path.join(config_train["load_model"], model_file), map_location=device))

In [7]:
# load data
pair_data = load_data(PAIR_PATH, add_zigzag_col=True)
pair_data = prepare_data(pair_data, PAIR_TZ)
train_dataset = HourlyDataset(pair_data, WINDOW_SIZE, "train")
dev_dataset = HourlyDataset(pair_data, WINDOW_SIZE, "dev")
test_dataset = HourlyDataset(pair_data, WINDOW_SIZE, "test")

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=HourlyDataset.get_collate_fn())
dev_dataloader = DataLoader(dev_dataset, batch_size=16, shuffle=False, collate_fn=HourlyDataset.get_collate_fn())
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=HourlyDataset.get_collate_fn())

In [9]:
# save the cfg before start
HourlySwingModelTrainer.save_cfg(config_model, config_train, out_path=MODEL_OUT_PATH)

# train for value only
if VALUE_ONLY:
    model.set_value_only(True)
    optimizer_reg = optim.Adam(model.parameters(), lr=LR)
    trainer = HourlySwingModelTrainer(model, train_dataloader, dev_dataloader, test_dataloader, optimizer_reg, device)
    trainer.train(epochs=EPOCHS, eval_period=EVAL_PERIOD,
                  checkpoint_period=CHECKPOINT_PERIOD, out_path=os.path.join(MODEL_OUT_PATH, "reg"))

# train for certitude as well
else:
    model.set_value_only(False)
    optimizer_cer = optim.Adam(model.parameters(), lr=LR)
    trainer = HourlySwingModelTrainer(model, train_dataloader, dev_dataloader, test_dataloader, optimizer_cer)
    trainer.train(epochs=EPOCHS, eval_period=EVAL_PERIOD,
                  checkpoint_period=CHECKPOINT_PERIOD, out_path=os.path.join(MODEL_OUT_PATH, "cer"))

Model configuration saved to out/audusd/checkpoint1/model_config.json
Training configuration saved to out/audusd/checkpoint1/training_config.json


Epoch 1/10:  26%|████▋             | 112/428 [00:25<01:13,  4.32it/s, loss=4.26]


KeyboardInterrupt: 