In [1]:
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from itertools import product

from utils import *
from loader.fi_loader import *
from models.cnn_lstm import CNN_LSTM
from models.cnn import CNN
from models.lstm import LSTM
from models.mlp import MLP
from train import batch_train

# Train CNN-LSTM

In [None]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'CNN_LSTM_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = CNN_LSTM()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

# Train CNN

In [None]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'CNN_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = CNN()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

# Train LSTM

In [2]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'LSTM_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    lob_model = LSTM()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

0it [00:00, ?it/s]


        Epoch 1/15,
        Train Loss: 0.9750, Train Acc:  0.6232,
        Validation Loss: 0.9548, Validation Acc:  0.5920,
        Duration: 0:00:17.789400

        Epoch 2/15,
        Train Loss: 0.9181, Train Acc:  0.6323,
        Validation Loss: 0.9539, Validation Acc:  0.5924,
        Duration: 0:00:13.839703

        Epoch 3/15,
        Train Loss: 0.9175, Train Acc:  0.6322,
        Validation Loss: 0.9538, Validation Acc:  0.5918,
        Duration: 0:00:13.484217

        Epoch 4/15,
        Train Loss: 0.9175, Train Acc:  0.6323,
        Validation Loss: 0.9531, Validation Acc:  0.5920,
        Duration: 0:00:13.464067

        Epoch 5/15,
        Train Loss: 0.9174, Train Acc:  0.6322,
        Validation Loss: 0.9543, Validation Acc:  0.5921,
        Duration: 0:00:13.548719

        Epoch 6/15,
        Train Loss: 0.9173, Train Acc:  0.6322,
        Validation Loss: 0.9545, Validation Acc:  0.5921,
        Duration: 0:00:14.305797

        Epoch 7/15,
        Train Loss:

1it [03:37, 217.87s/it]


        Epoch 15/15,
        Train Loss: 0.9161, Train Acc:  0.6323,
        Validation Loss: 0.9559, Validation Acc:  0.5917,
        Duration: 0:00:13.824507
model saved

        Epoch 1/15,
        Train Loss: 1.0690, Train Acc:  0.4236,
        Validation Loss: 1.0851, Validation Acc:  0.4028,
        Duration: 0:00:13.710783

        Epoch 2/15,
        Train Loss: 1.0477, Train Acc:  0.4481,
        Validation Loss: 1.0840, Validation Acc:  0.3938,
        Duration: 0:00:13.425445

        Epoch 3/15,
        Train Loss: 1.0452, Train Acc:  0.4500,
        Validation Loss: 1.0835, Validation Acc:  0.3959,
        Duration: 0:00:13.143023

        Epoch 4/15,
        Train Loss: 1.0438, Train Acc:  0.4501,
        Validation Loss: 1.0843, Validation Acc:  0.3882,
        Duration: 0:00:16.142641

        Epoch 5/15,
        Train Loss: 1.0422, Train Acc:  0.4549,
        Validation Loss: 1.0843, Validation Acc:  0.3954,
        Duration: 0:00:13.285643

        Epoch 6/15,
      

2it [07:02, 210.35s/it]


        Epoch 15/15,
        Train Loss: 1.0284, Train Acc:  0.4805,
        Validation Loss: 1.0913, Validation Acc:  0.3969,
        Duration: 0:00:13.256938
model saved

        Epoch 1/15,
        Train Loss: 1.0770, Train Acc:  0.3865,
        Validation Loss: 1.0514, Validation Acc:  0.4228,
        Duration: 0:00:13.731919

        Epoch 2/15,
        Train Loss: 1.0318, Train Acc:  0.4495,
        Validation Loss: 1.0495, Validation Acc:  0.4235,
        Duration: 0:00:12.951858

        Epoch 3/15,
        Train Loss: 1.0251, Train Acc:  0.4542,
        Validation Loss: 1.0494, Validation Acc:  0.4232,
        Duration: 0:00:12.812280

        Epoch 4/15,
        Train Loss: 1.0270, Train Acc:  0.4538,
        Validation Loss: 1.0487, Validation Acc:  0.4243,
        Duration: 0:00:13.121694

        Epoch 5/15,
        Train Loss: 1.0251, Train Acc:  0.4529,
        Validation Loss: 1.0489, Validation Acc:  0.4238,
        Duration: 0:00:13.097952

        Epoch 6/15,
      

3it [10:25, 206.69s/it]


        Epoch 15/15,
        Train Loss: 1.0170, Train Acc:  0.4636,
        Validation Loss: 1.0486, Validation Acc:  0.4181,
        Duration: 0:00:13.193229
model saved

        Epoch 1/15,
        Train Loss: 0.9562, Train Acc:  0.5753,
        Validation Loss: 0.9488, Validation Acc:  0.5985,
        Duration: 0:00:21.076203

        Epoch 2/15,
        Train Loss: 0.9271, Train Acc:  0.6205,
        Validation Loss: 0.9471, Validation Acc:  0.5987,
        Duration: 0:00:19.040177

        Epoch 3/15,
        Train Loss: 0.9269, Train Acc:  0.6205,
        Validation Loss: 0.9470, Validation Acc:  0.5985,
        Duration: 0:00:19.008679

        Epoch 4/15,
        Train Loss: 0.9268, Train Acc:  0.6205,
        Validation Loss: 0.9472, Validation Acc:  0.5987,
        Duration: 0:00:19.466249

        Epoch 5/15,
        Train Loss: 0.9267, Train Acc:  0.6204,
        Validation Loss: 0.9471, Validation Acc:  0.5988,
        Duration: 0:00:19.060293

        Epoch 6/15,
      

4it [15:25, 243.76s/it]


        Epoch 15/15,
        Train Loss: 0.9263, Train Acc:  0.6205,
        Validation Loss: 0.9477, Validation Acc:  0.5987,
        Duration: 0:00:19.736430
model saved

        Epoch 1/15,
        Train Loss: 1.0690, Train Acc:  0.4373,
        Validation Loss: 1.0836, Validation Acc:  0.4081,
        Duration: 0:00:20.207208

        Epoch 2/15,
        Train Loss: 1.0592, Train Acc:  0.4395,
        Validation Loss: 1.0843, Validation Acc:  0.4175,
        Duration: 0:00:18.716006

        Epoch 3/15,
        Train Loss: 1.0558, Train Acc:  0.4423,
        Validation Loss: 1.0821, Validation Acc:  0.4059,
        Duration: 0:00:19.177648

        Epoch 4/15,
        Train Loss: 1.0532, Train Acc:  0.4446,
        Validation Loss: 1.0816, Validation Acc:  0.3989,
        Duration: 0:00:19.120959

        Epoch 5/15,
        Train Loss: 1.0509, Train Acc:  0.4472,
        Validation Loss: 1.0854, Validation Acc:  0.3951,
        Duration: 0:00:19.196534

        Epoch 6/15,
      

5it [20:34, 267.27s/it]


        Epoch 15/15,
        Train Loss: 1.0364, Train Acc:  0.4692,
        Validation Loss: 1.0968, Validation Acc:  0.3846,
        Duration: 0:00:19.859117
model saved

        Epoch 1/15,
        Train Loss: 1.0600, Train Acc:  0.4129,
        Validation Loss: 1.0482, Validation Acc:  0.4129,
        Duration: 0:00:22.154246

        Epoch 2/15,
        Train Loss: 1.0446, Train Acc:  0.4223,
        Validation Loss: 1.0476, Validation Acc:  0.4053,
        Duration: 0:00:20.494394

        Epoch 3/15,
        Train Loss: 1.0424, Train Acc:  0.4244,
        Validation Loss: 1.0494, Validation Acc:  0.4015,
        Duration: 0:00:24.864270

        Epoch 4/15,
        Train Loss: 1.0411, Train Acc:  0.4278,
        Validation Loss: 1.0466, Validation Acc:  0.4029,
        Duration: 0:00:20.616594

        Epoch 5/15,
        Train Loss: 1.0402, Train Acc:  0.4279,
        Validation Loss: 1.0488, Validation Acc:  0.4022,
        Duration: 0:00:19.646427

        Epoch 6/15,
      

6it [25:50, 283.56s/it]


        Epoch 15/15,
        Train Loss: 1.0223, Train Acc:  0.4763,
        Validation Loss: 1.0427, Validation Acc:  0.4531,
        Duration: 0:00:20.319836
model saved

        Epoch 1/15,
        Train Loss: 0.9475, Train Acc:  0.6085,
        Validation Loss: 0.9298, Validation Acc:  0.6133,
        Duration: 0:00:50.953176

        Epoch 2/15,
        Train Loss: 0.9366, Train Acc:  0.6085,
        Validation Loss: 0.9304, Validation Acc:  0.6133,
        Duration: 0:00:27.855674

        Epoch 3/15,
        Train Loss: 0.9365, Train Acc:  0.6085,
        Validation Loss: 0.9302, Validation Acc:  0.6135,
        Duration: 0:00:29.276476

        Epoch 4/15,
        Train Loss: 0.9362, Train Acc:  0.6085,
        Validation Loss: 0.9302, Validation Acc:  0.6135,
        Duration: 0:00:29.781458

        Epoch 5/15,
        Train Loss: 0.9359, Train Acc:  0.6085,
        Validation Loss: 0.9302, Validation Acc:  0.6135,
        Duration: 0:00:29.432144

        Epoch 6/15,
      

7it [33:54, 349.09s/it]


        Epoch 15/15,
        Train Loss: 0.9339, Train Acc:  0.6118,
        Validation Loss: 0.9287, Validation Acc:  0.6176,
        Duration: 0:00:30.113740
model saved

        Epoch 1/15,
        Train Loss: 1.0744, Train Acc:  0.4221,
        Validation Loss: 1.0696, Validation Acc:  0.4464,
        Duration: 0:00:54.934301

        Epoch 2/15,
        Train Loss: 1.0664, Train Acc:  0.4320,
        Validation Loss: 1.0670, Validation Acc:  0.4477,
        Duration: 0:00:29.958644

        Epoch 3/15,
        Train Loss: 1.0633, Train Acc:  0.4357,
        Validation Loss: 1.0706, Validation Acc:  0.4482,
        Duration: 0:00:30.042418

        Epoch 4/15,
        Train Loss: 1.0622, Train Acc:  0.4365,
        Validation Loss: 1.0665, Validation Acc:  0.4472,
        Duration: 0:00:28.819793

        Epoch 5/15,
        Train Loss: 1.0601, Train Acc:  0.4388,
        Validation Loss: 1.0679, Validation Acc:  0.4458,
        Duration: 0:00:29.621591

        Epoch 6/15,
      

8it [41:56, 391.50s/it]


        Epoch 15/15,
        Train Loss: 1.0361, Train Acc:  0.4731,
        Validation Loss: 1.0619, Validation Acc:  0.4552,
        Duration: 0:00:30.038147
model saved

        Epoch 1/15,
        Train Loss: 1.0536, Train Acc:  0.4127,
        Validation Loss: 1.0684, Validation Acc:  0.3831,
        Duration: 0:00:50.315019

        Epoch 2/15,
        Train Loss: 1.0441, Train Acc:  0.4198,
        Validation Loss: 1.0669, Validation Acc:  0.3915,
        Duration: 0:00:28.468782

        Epoch 3/15,
        Train Loss: 1.0403, Train Acc:  0.4345,
        Validation Loss: 1.0613, Validation Acc:  0.4186,
        Duration: 0:00:29.726120

        Epoch 4/15,
        Train Loss: 1.0282, Train Acc:  0.4658,
        Validation Loss: 1.0512, Validation Acc:  0.4352,
        Duration: 0:00:29.726125

        Epoch 5/15,
        Train Loss: 1.0190, Train Acc:  0.4805,
        Validation Loss: 1.0464, Validation Acc:  0.4396,
        Duration: 0:00:29.712916

        Epoch 6/15,
      

9it [49:53, 418.27s/it]


        Epoch 15/15,
        Train Loss: 0.9474, Train Acc:  0.5834,
        Validation Loss: 1.0188, Validation Acc:  0.5035,
        Duration: 0:00:29.781033
model saved

        Epoch 1/15,
        Train Loss: 0.9250, Train Acc:  0.6263,
        Validation Loss: 0.8429, Validation Acc:  0.7036,
        Duration: 0:01:54.901309

        Epoch 2/15,
        Train Loss: 0.9164, Train Acc:  0.6281,
        Validation Loss: 0.8429, Validation Acc:  0.7036,
        Duration: 0:00:52.651767

        Epoch 3/15,
        Train Loss: 0.9163, Train Acc:  0.6281,
        Validation Loss: 0.8434, Validation Acc:  0.7036,
        Duration: 0:00:53.776103

        Epoch 4/15,
        Train Loss: 0.9161, Train Acc:  0.6283,
        Validation Loss: 0.8435, Validation Acc:  0.7041,
        Duration: 0:00:54.849657

        Epoch 5/15,
        Train Loss: 0.9159, Train Acc:  0.6289,
        Validation Loss: 0.8431, Validation Acc:  0.7056,
        Duration: 0:00:57.068556

        Epoch 6/15,
      

10it [1:05:58, 587.19s/it]


        Epoch 15/15,
        Train Loss: 0.9133, Train Acc:  0.6339,
        Validation Loss: 0.8433, Validation Acc:  0.7056,
        Duration: 0:01:04.420790
model saved

        Epoch 1/15,
        Train Loss: 1.0539, Train Acc:  0.4610,
        Validation Loss: 0.9943, Validation Acc:  0.5643,
        Duration: 0:02:20.603710

        Epoch 2/15,
        Train Loss: 1.0449, Train Acc:  0.4650,
        Validation Loss: 0.9852, Validation Acc:  0.5603,
        Duration: 0:00:55.887040

        Epoch 3/15,
        Train Loss: 1.0408, Train Acc:  0.4676,
        Validation Loss: 0.9976, Validation Acc:  0.5427,
        Duration: 0:00:57.011010

        Epoch 4/15,
        Train Loss: 1.0360, Train Acc:  0.4726,
        Validation Loss: 0.9836, Validation Acc:  0.5477,
        Duration: 0:01:02.284585

        Epoch 5/15,
        Train Loss: 1.0325, Train Acc:  0.4771,
        Validation Loss: 0.9729, Validation Acc:  0.5583,
        Duration: 0:00:57.701918

        Epoch 6/15,
      

11it [1:23:50, 735.35s/it]


        Epoch 15/15,
        Train Loss: 0.9585, Train Acc:  0.5682,
        Validation Loss: 0.9130, Validation Acc:  0.6332,
        Duration: 0:01:09.055062
model saved

        Epoch 1/15,
        Train Loss: 1.0607, Train Acc:  0.4031,
        Validation Loss: 1.0949, Validation Acc:  0.3625,
        Duration: 0:02:23.497768

        Epoch 2/15,
        Train Loss: 1.0462, Train Acc:  0.4325,
        Validation Loss: 1.0867, Validation Acc:  0.3978,
        Duration: 0:00:56.284616

        Epoch 3/15,
        Train Loss: 1.0386, Train Acc:  0.4520,
        Validation Loss: 1.0678, Validation Acc:  0.4266,
        Duration: 0:01:01.470432

        Epoch 4/15,
        Train Loss: 1.0103, Train Acc:  0.5007,
        Validation Loss: 1.0567, Validation Acc:  0.4571,
        Duration: 0:00:56.772105

        Epoch 5/15,
        Train Loss: 0.9845, Train Acc:  0.5377,
        Validation Loss: 1.0313, Validation Acc:  0.4784,
        Duration: 0:01:02.937944

        Epoch 6/15,
      

12it [1:41:19, 506.60s/it]


        Epoch 15/15,
        Train Loss: 0.8933, Train Acc:  0.6430,
        Validation Loss: 0.9744, Validation Acc:  0.5529,
        Duration: 0:01:10.418527
model saved





# Train MLP

In [None]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'MLP_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = MLP()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)