In [1]:
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from itertools import product

from utils import *
from loader.fi_loader import *
from models.cnn_lstm import CNN_LSTM
from models.cnn import CNN
from models.lstm import LSTM
from models.mlp import MLP
from train import batch_train

# Train CNN-LSTM

In [2]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'CNN_LSTM_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = CNN_LSTM()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

0it [00:00, ?it/s]


        Epoch 1/15,
        Train Loss: 0.9813, Train Acc:  0.5909,
        Validation Loss: 0.9536, Validation Acc:  0.5925,
        Duration: 0:00:18.520207

        Epoch 2/15,
        Train Loss: 0.9179, Train Acc:  0.6322,
        Validation Loss: 0.9557, Validation Acc:  0.5923,
        Duration: 0:00:17.241208

        Epoch 3/15,
        Train Loss: 0.9161, Train Acc:  0.6324,
        Validation Loss: 0.9554, Validation Acc:  0.5919,
        Duration: 0:00:17.764070

        Epoch 4/15,
        Train Loss: 0.9149, Train Acc:  0.6325,
        Validation Loss: 0.9564, Validation Acc:  0.5895,
        Duration: 0:00:17.469690

        Epoch 5/15,
        Train Loss: 0.9138, Train Acc:  0.6332,
        Validation Loss: 0.9565, Validation Acc:  0.5873,
        Duration: 0:00:17.487949

        Epoch 6/15,
        Train Loss: 0.9128, Train Acc:  0.6338,
        Validation Loss: 0.9563, Validation Acc:  0.5861,
        Duration: 0:00:17.653556

        Epoch 7/15,
        Train Loss:

1it [05:16, 316.96s/it]


        Epoch 15/15,
        Train Loss: 0.9046, Train Acc:  0.6415,
        Validation Loss: 0.9681, Validation Acc:  0.5656,
        Duration: 0:00:23.880260
model saved

        Epoch 1/15,
        Train Loss: 1.0645, Train Acc:  0.4414,
        Validation Loss: 1.0831, Validation Acc:  0.4067,
        Duration: 0:00:23.846414

        Epoch 2/15,
        Train Loss: 1.0417, Train Acc:  0.4584,
        Validation Loss: 1.0932, Validation Acc:  0.3991,
        Duration: 0:00:23.469619

        Epoch 3/15,
        Train Loss: 1.0325, Train Acc:  0.4703,
        Validation Loss: 1.0933, Validation Acc:  0.3957,
        Duration: 0:00:23.693020

        Epoch 4/15,
        Train Loss: 1.0235, Train Acc:  0.4825,
        Validation Loss: 1.0874, Validation Acc:  0.4073,
        Duration: 0:00:23.683984

        Epoch 5/15,
        Train Loss: 1.0132, Train Acc:  0.4990,
        Validation Loss: 1.0888, Validation Acc:  0.3991,
        Duration: 0:00:23.900440

        Epoch 6/15,
      

2it [11:21, 344.78s/it]


        Epoch 15/15,
        Train Loss: 0.9271, Train Acc:  0.6095,
        Validation Loss: 1.1245, Validation Acc:  0.3799,
        Duration: 0:00:24.239224
model saved

        Epoch 1/15,
        Train Loss: 1.0613, Train Acc:  0.4419,
        Validation Loss: 1.0517, Validation Acc:  0.4501,
        Duration: 0:00:24.631312

        Epoch 2/15,
        Train Loss: 0.9858, Train Acc:  0.5441,
        Validation Loss: 1.0423, Validation Acc:  0.4700,
        Duration: 0:00:24.304547

        Epoch 3/15,
        Train Loss: 0.9515, Train Acc:  0.5847,
        Validation Loss: 1.0265, Validation Acc:  0.4949,
        Duration: 0:00:24.359346

        Epoch 4/15,
        Train Loss: 0.9246, Train Acc:  0.6132,
        Validation Loss: 1.0240, Validation Acc:  0.4990,
        Duration: 0:00:24.420848

        Epoch 5/15,
        Train Loss: 0.8984, Train Acc:  0.6412,
        Validation Loss: 1.0162, Validation Acc:  0.5109,
        Duration: 0:00:24.300209

        Epoch 6/15,
      

3it [18:21, 379.14s/it]


        Epoch 15/15,
        Train Loss: 0.7849, Train Acc:  0.7635,
        Validation Loss: 1.0113, Validation Acc:  0.5240,
        Duration: 0:00:24.211795
model saved

        Epoch 1/15,
        Train Loss: 0.9489, Train Acc:  0.6112,
        Validation Loss: 0.9473, Validation Acc:  0.5988,
        Duration: 0:01:21.369612

        Epoch 2/15,
        Train Loss: 0.9269, Train Acc:  0.6204,
        Validation Loss: 0.9468, Validation Acc:  0.5981,
        Duration: 0:00:45.152205

        Epoch 3/15,
        Train Loss: 0.9262, Train Acc:  0.6204,
        Validation Loss: 0.9481, Validation Acc:  0.5988,
        Duration: 0:01:26.337819

        Epoch 4/15,
        Train Loss: 0.9255, Train Acc:  0.6203,
        Validation Loss: 0.9475, Validation Acc:  0.5989,
        Duration: 0:01:24.573271

        Epoch 5/15,
        Train Loss: 0.9246, Train Acc:  0.6209,
        Validation Loss: 0.9483, Validation Acc:  0.5991,
        Duration: 0:01:09.528402

        Epoch 6/15,
      

4it [36:37, 662.10s/it]


        Epoch 15/15,
        Train Loss: 0.9164, Train Acc:  0.6317,
        Validation Loss: 0.9494, Validation Acc:  0.5966,
        Duration: 0:01:13.597240
model saved

        Epoch 1/15,
        Train Loss: 1.0665, Train Acc:  0.4310,
        Validation Loss: 1.0912, Validation Acc:  0.4196,
        Duration: 0:00:49.245369

        Epoch 2/15,
        Train Loss: 1.0491, Train Acc:  0.4499,
        Validation Loss: 1.0862, Validation Acc:  0.4263,
        Duration: 0:01:25.660665

        Epoch 3/15,
        Train Loss: 1.0386, Train Acc:  0.4631,
        Validation Loss: 1.0842, Validation Acc:  0.4243,
        Duration: 0:00:45.269320

        Epoch 4/15,
        Train Loss: 1.0257, Train Acc:  0.4810,
        Validation Loss: 1.0818, Validation Acc:  0.4285,
        Duration: 0:01:26.717555

        Epoch 5/15,
        Train Loss: 1.0123, Train Acc:  0.5012,
        Validation Loss: 1.0583, Validation Acc:  0.4498,
        Duration: 0:01:26.572453

        Epoch 6/15,
      

5it [54:20, 806.96s/it]


        Epoch 15/15,
        Train Loss: 0.8752, Train Acc:  0.6624,
        Validation Loss: 0.9284, Validation Acc:  0.6057,
        Duration: 0:01:28.368335
model saved

        Epoch 1/15,
        Train Loss: 1.0390, Train Acc:  0.4716,
        Validation Loss: 1.0427, Validation Acc:  0.4612,
        Duration: 0:00:52.175940

        Epoch 2/15,
        Train Loss: 0.9682, Train Acc:  0.5639,
        Validation Loss: 1.0130, Validation Acc:  0.5074,
        Duration: 0:00:58.719153

        Epoch 3/15,
        Train Loss: 0.9229, Train Acc:  0.6128,
        Validation Loss: 0.9858, Validation Acc:  0.5408,
        Duration: 0:01:28.503997

        Epoch 4/15,
        Train Loss: 0.8855, Train Acc:  0.6532,
        Validation Loss: 0.9446, Validation Acc:  0.5856,
        Duration: 0:00:56.621489

        Epoch 5/15,
        Train Loss: 0.8565, Train Acc:  0.6849,
        Validation Loss: 0.9204, Validation Acc:  0.6126,
        Duration: 0:01:38.074278

        Epoch 6/15,
      

6it [1:11:37, 884.95s/it]


        Epoch 15/15,
        Train Loss: 0.7510, Train Acc:  0.7968,
        Validation Loss: 0.8776, Validation Acc:  0.6611,
        Duration: 0:01:04.979754
model saved

        Epoch 1/15,
        Train Loss: 0.9499, Train Acc:  0.6029,
        Validation Loss: 0.9303, Validation Acc:  0.6140,
        Duration: 0:01:51.489863

        Epoch 2/15,
        Train Loss: 0.9357, Train Acc:  0.6096,
        Validation Loss: 0.9301, Validation Acc:  0.6154,
        Duration: 0:01:41.470304

        Epoch 3/15,
        Train Loss: 0.9342, Train Acc:  0.6107,
        Validation Loss: 0.9305, Validation Acc:  0.6149,
        Duration: 0:02:21.992705

        Epoch 4/15,
        Train Loss: 0.9329, Train Acc:  0.6116,
        Validation Loss: 0.9297, Validation Acc:  0.6159,
        Duration: 0:02:10.703740

        Epoch 5/15,
        Train Loss: 0.9316, Train Acc:  0.6126,
        Validation Loss: 0.9315, Validation Acc:  0.6163,
        Duration: 0:01:35.912692

        Epoch 6/15,
      

7it [1:41:53, 1189.41s/it]


        Epoch 15/15,
        Train Loss: 0.9214, Train Acc:  0.6246,
        Validation Loss: 0.9336, Validation Acc:  0.6126,
        Duration: 0:02:13.645365
model saved

        Epoch 1/15,
        Train Loss: 1.0675, Train Acc:  0.4300,
        Validation Loss: 1.0637, Validation Acc:  0.4540,
        Duration: 0:01:49.302676

        Epoch 2/15,
        Train Loss: 1.0495, Train Acc:  0.4442,
        Validation Loss: 1.0564, Validation Acc:  0.4538,
        Duration: 0:01:37.768066

        Epoch 3/15,
        Train Loss: 1.0119, Train Acc:  0.4965,
        Validation Loss: 0.9807, Validation Acc:  0.5487,
        Duration: 0:01:55.336887

        Epoch 4/15,
        Train Loss: 0.9463, Train Acc:  0.5811,
        Validation Loss: 0.9348, Validation Acc:  0.6017,
        Duration: 0:01:50.047059

        Epoch 5/15,
        Train Loss: 0.9051, Train Acc:  0.6281,
        Validation Loss: 0.9007, Validation Acc:  0.6364,
        Duration: 0:02:29.372088

        Epoch 6/15,
      

8it [2:10:25, 1355.91s/it]


        Epoch 15/15,
        Train Loss: 0.8358, Train Acc:  0.7056,
        Validation Loss: 0.8530, Validation Acc:  0.6852,
        Duration: 0:01:53.787126
model saved

        Epoch 1/15,
        Train Loss: 1.0117, Train Acc:  0.5066,
        Validation Loss: 0.9932, Validation Acc:  0.5322,
        Duration: 0:01:54.433472

        Epoch 2/15,
        Train Loss: 0.9056, Train Acc:  0.6303,
        Validation Loss: 0.9291, Validation Acc:  0.5994,
        Duration: 0:01:50.671061

        Epoch 3/15,
        Train Loss: 0.8514, Train Acc:  0.6881,
        Validation Loss: 0.8786, Validation Acc:  0.6584,
        Duration: 0:02:31.158336

        Epoch 4/15,
        Train Loss: 0.8204, Train Acc:  0.7221,
        Validation Loss: 0.8604, Validation Acc:  0.6788,
        Duration: 0:01:47.811616

        Epoch 5/15,
        Train Loss: 0.7997, Train Acc:  0.7443,
        Validation Loss: 0.8501, Validation Acc:  0.6888,
        Duration: 0:01:54.717806

        Epoch 6/15,
      

9it [2:42:02, 1524.85s/it]


        Epoch 15/15,
        Train Loss: 0.7250, Train Acc:  0.8232,
        Validation Loss: 0.8171, Validation Acc:  0.7262,
        Duration: 0:02:16.179728
model saved

        Epoch 1/15,
        Train Loss: 0.9231, Train Acc:  0.6254,
        Validation Loss: 0.8434, Validation Acc:  0.7037,
        Duration: 0:04:04.209507

        Epoch 2/15,
        Train Loss: 0.9158, Train Acc:  0.6292,
        Validation Loss: 0.8433, Validation Acc:  0.7050,
        Duration: 0:03:39.929121

        Epoch 3/15,
        Train Loss: 0.9149, Train Acc:  0.6307,
        Validation Loss: 0.8429, Validation Acc:  0.7054,
        Duration: 0:03:44.214855

        Epoch 4/15,
        Train Loss: 0.9140, Train Acc:  0.6319,
        Validation Loss: 0.8427, Validation Acc:  0.7058,
        Duration: 0:03:48.456152

        Epoch 5/15,
        Train Loss: 0.9132, Train Acc:  0.6333,
        Validation Loss: 0.8420, Validation Acc:  0.7066,
        Duration: 0:03:47.418125

        Epoch 6/15,
      

10it [3:44:21, 2208.41s/it]


        Epoch 15/15,
        Train Loss: 0.8095, Train Acc:  0.7358,
        Validation Loss: 0.7557, Validation Acc:  0.7911,
        Duration: 0:03:45.274735
model saved

        Epoch 1/15,
        Train Loss: 1.0403, Train Acc:  0.4624,
        Validation Loss: 0.9716, Validation Acc:  0.5685,
        Duration: 0:03:52.930405

        Epoch 2/15,
        Train Loss: 0.9701, Train Acc:  0.5436,
        Validation Loss: 0.8894, Validation Acc:  0.6479,
        Duration: 0:03:24.565529

        Epoch 3/15,
        Train Loss: 0.9022, Train Acc:  0.6289,
        Validation Loss: 0.8340, Validation Acc:  0.7108,
        Duration: 0:03:41.250856

        Epoch 4/15,
        Train Loss: 0.8669, Train Acc:  0.6697,
        Validation Loss: 0.8210, Validation Acc:  0.7227,
        Duration: 0:04:41.736610

        Epoch 5/15,
        Train Loss: 0.8544, Train Acc:  0.6829,
        Validation Loss: 0.8046, Validation Acc:  0.7376,
        Duration: 0:04:50.505724

        Epoch 6/15,
      

11it [4:49:15, 2724.22s/it]


        Epoch 15/15,
        Train Loss: 0.8175, Train Acc:  0.7241,
        Validation Loss: 0.7853, Validation Acc:  0.7584,
        Duration: 0:05:01.800021
model saved

        Epoch 1/15,
        Train Loss: 0.9483, Train Acc:  0.5757,
        Validation Loss: 0.9092, Validation Acc:  0.6272,
        Duration: 0:03:49.298201

        Epoch 2/15,
        Train Loss: 0.8514, Train Acc:  0.6868,
        Validation Loss: 0.8419, Validation Acc:  0.6964,
        Duration: 0:03:33.676479

        Epoch 3/15,
        Train Loss: 0.8192, Train Acc:  0.7211,
        Validation Loss: 0.8234, Validation Acc:  0.7155,
        Duration: 0:04:39.228462

        Epoch 4/15,
        Train Loss: 0.7978, Train Acc:  0.7445,
        Validation Loss: 0.8016, Validation Acc:  0.7409,
        Duration: 0:03:35.252645

        Epoch 5/15,
        Train Loss: 0.7824, Train Acc:  0.7613,
        Validation Loss: 0.7920, Validation Acc:  0.7516,
        Duration: 0:04:51.301220

        Epoch 6/15,
      

12it [5:47:20, 1736.74s/it]


        Epoch 15/15,
        Train Loss: 0.7212, Train Acc:  0.8267,
        Validation Loss: 0.7745, Validation Acc:  0.7701,
        Duration: 0:03:41.783727
model saved





# Train CNN

In [None]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'CNN_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = CNN()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

# Train LSTM

In [None]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'LSTM_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    lob_model = LSTM()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

# Train MLP

In [None]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'MLP_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = MLP()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)