In [1]:
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from itertools import product

from utils import *
from loader.fi_loader import *
from models.cnn_lstm import CNN_LSTM
from models.cnn import CNN
from models.lstm import LSTM
from models.mlp import MLP
from train import batch_train

# Train CNN-LSTM

In [2]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'CNN_LSTM_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = CNN_LSTM()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

0it [00:00, ?it/s]


        Epoch 1/15,
        Train Loss: 1.0111, Train Acc:  0.5060,
        Validation Loss: 1.0089, Validation Acc:  0.5086,
        Duration: 0:01:03.760454

        Epoch 2/15,
        Train Loss: 0.9134, Train Acc:  0.6212,
        Validation Loss: 0.9477, Validation Acc:  0.5837,
        Duration: 0:00:57.142130

        Epoch 3/15,
        Train Loss: 0.8613, Train Acc:  0.6768,
        Validation Loss: 0.8975, Validation Acc:  0.6387,
        Duration: 0:01:05.581307

        Epoch 4/15,
        Train Loss: 0.8307, Train Acc:  0.7101,
        Validation Loss: 0.8733, Validation Acc:  0.6639,
        Duration: 0:01:05.910181

        Epoch 5/15,
        Train Loss: 0.8102, Train Acc:  0.7326,
        Validation Loss: 0.8645, Validation Acc:  0.6744,
        Duration: 0:01:06.518268

        Epoch 6/15,
        Train Loss: 0.7957, Train Acc:  0.7476,
        Validation Loss: 0.8488, Validation Acc:  0.6921,
        Duration: 0:01:07.942270

        Epoch 7/15,
        Train Loss:

9it [22:26, 149.56s/it]


        Epoch 15/15,
        Train Loss: 0.7345, Train Acc:  0.8133,
        Validation Loss: 0.8301, Validation Acc:  0.7127,
        Duration: 0:01:45.041338
model saved

        Epoch 1/15,
        Train Loss: 0.9225, Train Acc:  0.6264,
        Validation Loss: 0.8424, Validation Acc:  0.7045,
        Duration: 0:02:35.699928

        Epoch 2/15,
        Train Loss: 0.9155, Train Acc:  0.6295,
        Validation Loss: 0.8427, Validation Acc:  0.7044,
        Duration: 0:03:07.723048

        Epoch 3/15,
        Train Loss: 0.9144, Train Acc:  0.6317,
        Validation Loss: 0.8418, Validation Acc:  0.7068,
        Duration: 0:03:21.077885

        Epoch 4/15,
        Train Loss: 0.9129, Train Acc:  0.6340,
        Validation Loss: 0.8414, Validation Acc:  0.7079,
        Duration: 0:03:25.190671

        Epoch 5/15,
        Train Loss: 0.9113, Train Acc:  0.6364,
        Validation Loss: 0.8411, Validation Acc:  0.7088,
        Duration: 0:03:37.658936

        Epoch 6/15,
      

10it [1:16:06, 570.24s/it]


        Epoch 15/15,
        Train Loss: 0.7766, Train Acc:  0.7657,
        Validation Loss: 0.7293, Validation Acc:  0.8160,
        Duration: 0:04:04.739293
model saved

        Epoch 1/15,
        Train Loss: 1.0416, Train Acc:  0.4629,
        Validation Loss: 0.9758, Validation Acc:  0.5616,
        Duration: 0:02:46.208619

        Epoch 2/15,
        Train Loss: 0.9689, Train Acc:  0.5449,
        Validation Loss: 0.8677, Validation Acc:  0.6745,
        Duration: 0:03:04.769699

        Epoch 3/15,
        Train Loss: 0.8811, Train Acc:  0.6549,
        Validation Loss: 0.8113, Validation Acc:  0.7312,
        Duration: 0:03:35.507359

        Epoch 4/15,
        Train Loss: 0.8561, Train Acc:  0.6810,
        Validation Loss: 0.8036, Validation Acc:  0.7393,
        Duration: 0:03:45.248810

        Epoch 5/15,
        Train Loss: 0.8469, Train Acc:  0.6914,
        Validation Loss: 0.7967, Validation Acc:  0.7460,
        Duration: 0:03:14.396992

        Epoch 6/15,
      

11it [2:11:57, 1025.35s/it]


        Epoch 15/15,
        Train Loss: 0.8101, Train Acc:  0.7327,
        Validation Loss: 0.7841, Validation Acc:  0.7578,
        Duration: 0:03:19.076094
model saved

        Epoch 1/15,
        Train Loss: 0.9910, Train Acc:  0.5281,
        Validation Loss: 0.9746, Validation Acc:  0.5648,
        Duration: 0:02:47.377531

        Epoch 2/15,
        Train Loss: 0.8707, Train Acc:  0.6657,
        Validation Loss: 0.8843, Validation Acc:  0.6566,
        Duration: 0:03:16.201966

        Epoch 3/15,
        Train Loss: 0.8230, Train Acc:  0.7170,
        Validation Loss: 0.8273, Validation Acc:  0.7144,
        Duration: 0:03:50.045467

        Epoch 4/15,
        Train Loss: 0.7981, Train Acc:  0.7439,
        Validation Loss: 0.8168, Validation Acc:  0.7242,
        Duration: 0:03:21.266232

        Epoch 5/15,
        Train Loss: 0.7810, Train Acc:  0.7630,
        Validation Loss: 0.8037, Validation Acc:  0.7387,
        Duration: 0:04:13.217396

        Epoch 6/15,
      

12it [3:09:20, 946.74s/it] 


        Epoch 15/15,
        Train Loss: 0.7207, Train Acc:  0.8272,
        Validation Loss: 0.7716, Validation Acc:  0.7728,
        Duration: 0:04:26.183760
model saved





# Train CNN

In [3]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'CNN_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    lob_model = CNN()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

0it [00:00, ?it/s]


        Epoch 1/15,
        Train Loss: 0.9760, Train Acc:  0.6048,
        Validation Loss: 0.9568, Validation Acc:  0.5914,
        Duration: 0:00:12.917815

        Epoch 2/15,
        Train Loss: 0.9173, Train Acc:  0.6334,
        Validation Loss: 0.9561, Validation Acc:  0.5918,
        Duration: 0:00:10.787764

        Epoch 3/15,
        Train Loss: 0.9169, Train Acc:  0.6335,
        Validation Loss: 0.9542, Validation Acc:  0.5927,
        Duration: 0:00:10.852068

        Epoch 4/15,
        Train Loss: 0.9167, Train Acc:  0.6334,
        Validation Loss: 0.9550, Validation Acc:  0.5921,
        Duration: 0:00:10.781471

        Epoch 5/15,
        Train Loss: 0.9166, Train Acc:  0.6334,
        Validation Loss: 0.9541, Validation Acc:  0.5912,
        Duration: 0:00:10.747962

        Epoch 6/15,
        Train Loss: 0.9164, Train Acc:  0.6335,
        Validation Loss: 0.9537, Validation Acc:  0.5918,
        Duration: 0:00:10.729918

        Epoch 7/15,
        Train Loss:

1it [02:58, 178.88s/it]


        Epoch 15/15,
        Train Loss: 0.9161, Train Acc:  0.6335,
        Validation Loss: 0.9534, Validation Acc:  0.5918,
        Duration: 0:00:10.406938
model saved

        Epoch 1/15,
        Train Loss: 1.0778, Train Acc:  0.3929,
        Validation Loss: 1.0849, Validation Acc:  0.4050,
        Duration: 0:00:10.693489

        Epoch 2/15,
        Train Loss: 1.0539, Train Acc:  0.4507,
        Validation Loss: 1.0832, Validation Acc:  0.4033,
        Duration: 0:00:10.122057

        Epoch 3/15,
        Train Loss: 1.0504, Train Acc:  0.4499,
        Validation Loss: 1.0832, Validation Acc:  0.4043,
        Duration: 0:00:10.045844

        Epoch 4/15,
        Train Loss: 1.0492, Train Acc:  0.4500,
        Validation Loss: 1.0827, Validation Acc:  0.4017,
        Duration: 0:00:10.028521

        Epoch 5/15,
        Train Loss: 1.0486, Train Acc:  0.4503,
        Validation Loss: 1.0831, Validation Acc:  0.3938,
        Duration: 0:00:10.189936

        Epoch 6/15,
      

2it [05:45, 171.64s/it]


        Epoch 15/15,
        Train Loss: 1.0453, Train Acc:  0.4509,
        Validation Loss: 1.0813, Validation Acc:  0.4024,
        Duration: 0:00:10.146720
model saved

        Epoch 1/15,
        Train Loss: 1.0750, Train Acc:  0.3889,
        Validation Loss: 1.0544, Validation Acc:  0.4240,
        Duration: 0:00:10.654768

        Epoch 2/15,
        Train Loss: 1.0450, Train Acc:  0.4322,
        Validation Loss: 1.0509, Validation Acc:  0.4037,
        Duration: 0:00:10.050306

        Epoch 3/15,
        Train Loss: 1.0342, Train Acc:  0.4434,
        Validation Loss: 1.0496, Validation Acc:  0.4227,
        Duration: 0:00:10.045143

        Epoch 4/15,
        Train Loss: 1.0295, Train Acc:  0.4474,
        Validation Loss: 1.0494, Validation Acc:  0.4245,
        Duration: 0:00:10.174559

        Epoch 5/15,
        Train Loss: 1.0273, Train Acc:  0.4516,
        Validation Loss: 1.0498, Validation Acc:  0.4257,
        Duration: 0:00:10.336142

        Epoch 6/15,
      

3it [08:34, 170.24s/it]


        Epoch 15/15,
        Train Loss: 1.0217, Train Acc:  0.4579,
        Validation Loss: 1.0497, Validation Acc:  0.4256,
        Duration: 0:00:10.382614
model saved

        Epoch 1/15,
        Train Loss: 0.9437, Train Acc:  0.6159,
        Validation Loss: 0.9501, Validation Acc:  0.5979,
        Duration: 0:00:19.280926

        Epoch 2/15,
        Train Loss: 0.9270, Train Acc:  0.6209,
        Validation Loss: 0.9491, Validation Acc:  0.5975,
        Duration: 0:00:14.864416

        Epoch 3/15,
        Train Loss: 0.9267, Train Acc:  0.6209,
        Validation Loss: 0.9485, Validation Acc:  0.5978,
        Duration: 0:00:13.478386

        Epoch 4/15,
        Train Loss: 0.9267, Train Acc:  0.6209,
        Validation Loss: 0.9485, Validation Acc:  0.5979,
        Duration: 0:00:13.449622

        Epoch 5/15,
        Train Loss: 0.9266, Train Acc:  0.6208,
        Validation Loss: 0.9482, Validation Acc:  0.5977,
        Duration: 0:00:13.724279

        Epoch 6/15,
      

4it [12:32, 197.17s/it]


        Epoch 15/15,
        Train Loss: 0.9259, Train Acc:  0.6208,
        Validation Loss: 0.9480, Validation Acc:  0.5979,
        Duration: 0:00:13.661738
model saved

        Epoch 1/15,
        Train Loss: 1.0707, Train Acc:  0.4378,
        Validation Loss: 1.0843, Validation Acc:  0.4070,
        Duration: 0:00:16.386982

        Epoch 2/15,
        Train Loss: 1.0633, Train Acc:  0.4410,
        Validation Loss: 1.0826, Validation Acc:  0.4168,
        Duration: 0:00:13.774120

        Epoch 3/15,
        Train Loss: 1.0616, Train Acc:  0.4424,
        Validation Loss: 1.0809, Validation Acc:  0.4163,
        Duration: 0:00:13.617643

        Epoch 4/15,
        Train Loss: 1.0605, Train Acc:  0.4425,
        Validation Loss: 1.0826, Validation Acc:  0.4158,
        Duration: 0:00:13.604419

        Epoch 5/15,
        Train Loss: 1.0596, Train Acc:  0.4426,
        Validation Loss: 1.0803, Validation Acc:  0.4120,
        Duration: 0:00:13.713468

        Epoch 6/15,
      

5it [16:29, 211.60s/it]


        Epoch 15/15,
        Train Loss: 1.0524, Train Acc:  0.4498,
        Validation Loss: 1.0835, Validation Acc:  0.4020,
        Duration: 0:00:13.568772
model saved

        Epoch 1/15,
        Train Loss: 1.0688, Train Acc:  0.3873,
        Validation Loss: 1.0502, Validation Acc:  0.4110,
        Duration: 0:00:16.094236

        Epoch 2/15,
        Train Loss: 1.0563, Train Acc:  0.4125,
        Validation Loss: 1.0492, Validation Acc:  0.4122,
        Duration: 0:00:13.637723

        Epoch 3/15,
        Train Loss: 1.0510, Train Acc:  0.4174,
        Validation Loss: 1.0487, Validation Acc:  0.4128,
        Duration: 0:00:13.552833

        Epoch 4/15,
        Train Loss: 1.0496, Train Acc:  0.4194,
        Validation Loss: 1.0501, Validation Acc:  0.4038,
        Duration: 0:00:13.484406

        Epoch 5/15,
        Train Loss: 1.0487, Train Acc:  0.4241,
        Validation Loss: 1.0490, Validation Acc:  0.4066,
        Duration: 0:00:13.537105

        Epoch 6/15,
      

6it [20:23, 219.07s/it]


        Epoch 15/15,
        Train Loss: 1.0402, Train Acc:  0.4421,
        Validation Loss: 1.0508, Validation Acc:  0.4167,
        Duration: 0:00:13.605832
model saved

        Epoch 1/15,
        Train Loss: 0.9482, Train Acc:  0.6075,
        Validation Loss: 0.9325, Validation Acc:  0.6115,
        Duration: 0:00:26.469335

        Epoch 2/15,
        Train Loss: 0.9366, Train Acc:  0.6088,
        Validation Loss: 0.9309, Validation Acc:  0.6115,
        Duration: 0:00:17.026168

        Epoch 3/15,
        Train Loss: 0.9366, Train Acc:  0.6087,
        Validation Loss: 0.9288, Validation Acc:  0.6148,
        Duration: 0:00:17.299119

        Epoch 4/15,
        Train Loss: 0.9365, Train Acc:  0.6087,
        Validation Loss: 0.9288, Validation Acc:  0.6148,
        Duration: 0:00:17.412208

        Epoch 5/15,
        Train Loss: 0.9363, Train Acc:  0.6087,
        Validation Loss: 0.9285, Validation Acc:  0.6148,
        Duration: 0:00:17.347959

        Epoch 6/15,
      

7it [25:42, 251.76s/it]


        Epoch 15/15,
        Train Loss: 0.9351, Train Acc:  0.6093,
        Validation Loss: 0.9310, Validation Acc:  0.6117,
        Duration: 0:00:17.491920
model saved

        Epoch 1/15,
        Train Loss: 1.0757, Train Acc:  0.4220,
        Validation Loss: 1.0733, Validation Acc:  0.4502,
        Duration: 0:00:37.894523

        Epoch 2/15,
        Train Loss: 1.0698, Train Acc:  0.4308,
        Validation Loss: 1.0741, Validation Acc:  0.4457,
        Duration: 0:00:20.201073

        Epoch 3/15,
        Train Loss: 1.0683, Train Acc:  0.4315,
        Validation Loss: 1.0677, Validation Acc:  0.4456,
        Duration: 0:00:20.686731

        Epoch 4/15,
        Train Loss: 1.0672, Train Acc:  0.4320,
        Validation Loss: 1.0744, Validation Acc:  0.4406,
        Duration: 0:00:20.341913

        Epoch 5/15,
        Train Loss: 1.0663, Train Acc:  0.4331,
        Validation Loss: 1.0710, Validation Acc:  0.4435,
        Duration: 0:00:20.728500

        Epoch 6/15,
      

8it [32:00, 292.06s/it]


        Epoch 15/15,
        Train Loss: 1.0607, Train Acc:  0.4389,
        Validation Loss: 1.0746, Validation Acc:  0.4419,
        Duration: 0:00:21.006784
model saved

        Epoch 1/15,
        Train Loss: 1.0575, Train Acc:  0.4051,
        Validation Loss: 1.0672, Validation Acc:  0.4028,
        Duration: 0:00:32.113174

        Epoch 2/15,
        Train Loss: 1.0494, Train Acc:  0.4169,
        Validation Loss: 1.0690, Validation Acc:  0.3816,
        Duration: 0:00:20.144831

        Epoch 3/15,
        Train Loss: 1.0458, Train Acc:  0.4270,
        Validation Loss: 1.0669, Validation Acc:  0.3940,
        Duration: 0:00:21.050243

        Epoch 4/15,
        Train Loss: 1.0448, Train Acc:  0.4309,
        Validation Loss: 1.0702, Validation Acc:  0.3862,
        Duration: 0:00:21.120660

        Epoch 5/15,
        Train Loss: 1.0438, Train Acc:  0.4312,
        Validation Loss: 1.0679, Validation Acc:  0.3965,
        Duration: 0:00:21.126573

        Epoch 6/15,
      

9it [38:15, 317.98s/it]


        Epoch 15/15,
        Train Loss: 1.0394, Train Acc:  0.4473,
        Validation Loss: 1.0684, Validation Acc:  0.4044,
        Duration: 0:00:21.524360
model saved

        Epoch 1/15,
        Train Loss: 0.9240, Train Acc:  0.6283,
        Validation Loss: 0.8420, Validation Acc:  0.7045,
        Duration: 0:01:18.211090

        Epoch 2/15,
        Train Loss: 0.9165, Train Acc:  0.6283,
        Validation Loss: 0.8420, Validation Acc:  0.7045,
        Duration: 0:00:37.611925

        Epoch 3/15,
        Train Loss: 0.9164, Train Acc:  0.6283,
        Validation Loss: 0.8419, Validation Acc:  0.7045,
        Duration: 0:00:38.755981

        Epoch 4/15,
        Train Loss: 0.9163, Train Acc:  0.6283,
        Validation Loss: 0.8420, Validation Acc:  0.7044,
        Duration: 0:00:38.938534

        Epoch 5/15,
        Train Loss: 0.9163, Train Acc:  0.6283,
        Validation Loss: 0.8420, Validation Acc:  0.7045,
        Duration: 0:00:38.586458

        Epoch 6/15,
      

10it [50:33, 447.59s/it]


        Epoch 15/15,
        Train Loss: 0.9157, Train Acc:  0.6295,
        Validation Loss: 0.8427, Validation Acc:  0.7036,
        Duration: 0:00:41.536802
model saved

        Epoch 1/15,
        Train Loss: 1.0582, Train Acc:  0.4594,
        Validation Loss: 0.9917, Validation Acc:  0.5669,
        Duration: 0:01:34.027840

        Epoch 2/15,
        Train Loss: 1.0522, Train Acc:  0.4632,
        Validation Loss: 0.9890, Validation Acc:  0.5669,
        Duration: 0:00:39.782822

        Epoch 3/15,
        Train Loss: 1.0505, Train Acc:  0.4637,
        Validation Loss: 0.9870, Validation Acc:  0.5667,
        Duration: 0:00:43.017035

        Epoch 4/15,
        Train Loss: 1.0486, Train Acc:  0.4648,
        Validation Loss: 0.9926, Validation Acc:  0.5678,
        Duration: 0:00:41.786491

        Epoch 5/15,
        Train Loss: 1.0469, Train Acc:  0.4655,
        Validation Loss: 0.9950, Validation Acc:  0.5582,
        Duration: 0:00:41.582062

        Epoch 6/15,
      

11it [1:03:37, 550.58s/it]


        Epoch 15/15,
        Train Loss: 1.0386, Train Acc:  0.4710,
        Validation Loss: 0.9889, Validation Acc:  0.5539,
        Duration: 0:00:41.730804
model saved

        Epoch 1/15,
        Train Loss: 1.0661, Train Acc:  0.4009,
        Validation Loss: 1.1144, Validation Acc:  0.3521,
        Duration: 0:01:54.744638

        Epoch 2/15,
        Train Loss: 1.0608, Train Acc:  0.4133,
        Validation Loss: 1.1168, Validation Acc:  0.3451,
        Duration: 0:00:39.927968

        Epoch 3/15,
        Train Loss: 1.0599, Train Acc:  0.4153,
        Validation Loss: 1.1151, Validation Acc:  0.3506,
        Duration: 0:00:40.636175

        Epoch 4/15,
        Train Loss: 1.0592, Train Acc:  0.4163,
        Validation Loss: 1.1120, Validation Acc:  0.3689,
        Duration: 0:00:43.105114

        Epoch 5/15,
        Train Loss: 1.0585, Train Acc:  0.4186,
        Validation Loss: 1.1143, Validation Acc:  0.3588,
        Duration: 0:00:42.289722

        Epoch 6/15,
      

12it [1:17:16, 386.40s/it]


        Epoch 15/15,
        Train Loss: 1.0381, Train Acc:  0.4619,
        Validation Loss: 1.0969, Validation Acc:  0.3904,
        Duration: 0:00:41.370140
model saved





# Train LSTM

In [4]:
for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
    model_name = f'LSTM_{method}_CF{cf}_pred_{k}'
    if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
        continue
    train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
    test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
    lob_model = LSTM()
    lob_model.to(lob_model.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)

    batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)

0it [00:00, ?it/s]


        Epoch 1/15,
        Train Loss: 0.9801, Train Acc:  0.5915,
        Validation Loss: 0.9530, Validation Acc:  0.5918,
        Duration: 0:00:12.525883

        Epoch 2/15,
        Train Loss: 0.9167, Train Acc:  0.6335,
        Validation Loss: 0.9535, Validation Acc:  0.5914,
        Duration: 0:00:11.539791

        Epoch 3/15,
        Train Loss: 0.9162, Train Acc:  0.6335,
        Validation Loss: 0.9514, Validation Acc:  0.5916,
        Duration: 0:00:11.368249

        Epoch 4/15,
        Train Loss: 0.9161, Train Acc:  0.6334,
        Validation Loss: 0.9526, Validation Acc:  0.5918,
        Duration: 0:00:11.717339

        Epoch 5/15,
        Train Loss: 0.9159, Train Acc:  0.6335,
        Validation Loss: 0.9543, Validation Acc:  0.5916,
        Duration: 0:00:11.620203

        Epoch 6/15,
        Train Loss: 0.9158, Train Acc:  0.6335,
        Validation Loss: 0.9539, Validation Acc:  0.5916,
        Duration: 0:00:11.702326

        Epoch 7/15,
        Train Loss:

1it [03:14, 194.97s/it]


        Epoch 15/15,
        Train Loss: 0.9151, Train Acc:  0.6333,
        Validation Loss: 0.9559, Validation Acc:  0.5912,
        Duration: 0:00:12.093943
model saved

        Epoch 1/15,
        Train Loss: 1.0714, Train Acc:  0.4350,
        Validation Loss: 1.0837, Validation Acc:  0.4055,
        Duration: 0:00:12.373872

        Epoch 2/15,
        Train Loss: 1.0471, Train Acc:  0.4503,
        Validation Loss: 1.0840, Validation Acc:  0.3935,
        Duration: 0:00:11.733405

        Epoch 3/15,
        Train Loss: 1.0439, Train Acc:  0.4518,
        Validation Loss: 1.0844, Validation Acc:  0.3971,
        Duration: 0:00:11.795772

        Epoch 4/15,
        Train Loss: 1.0423, Train Acc:  0.4530,
        Validation Loss: 1.0834, Validation Acc:  0.3960,
        Duration: 0:00:11.713578

        Epoch 5/15,
        Train Loss: 1.0409, Train Acc:  0.4536,
        Validation Loss: 1.0857, Validation Acc:  0.3854,
        Duration: 0:00:11.938461

        Epoch 6/15,
      

2it [06:30, 195.22s/it]


        Epoch 15/15,
        Train Loss: 1.0313, Train Acc:  0.4695,
        Validation Loss: 1.0866, Validation Acc:  0.4085,
        Duration: 0:00:12.179188
model saved

        Epoch 1/15,
        Train Loss: 1.0703, Train Acc:  0.4174,
        Validation Loss: 1.0515, Validation Acc:  0.4220,
        Duration: 0:00:12.385247

        Epoch 2/15,
        Train Loss: 1.0297, Train Acc:  0.4494,
        Validation Loss: 1.0494, Validation Acc:  0.4245,
        Duration: 0:00:11.941239

        Epoch 3/15,
        Train Loss: 1.0203, Train Acc:  0.4584,
        Validation Loss: 1.0493, Validation Acc:  0.4251,
        Duration: 0:00:11.974410

        Epoch 4/15,
        Train Loss: 1.0172, Train Acc:  0.4590,
        Validation Loss: 1.0496, Validation Acc:  0.4254,
        Duration: 0:00:12.046539

        Epoch 5/15,
        Train Loss: 1.0201, Train Acc:  0.4585,
        Validation Loss: 1.0495, Validation Acc:  0.4253,
        Duration: 0:00:12.035896

        Epoch 6/15,
      

3it [09:46, 195.67s/it]


        Epoch 15/15,
        Train Loss: 1.0083, Train Acc:  0.4763,
        Validation Loss: 1.0485, Validation Acc:  0.4341,
        Duration: 0:00:12.201574
model saved

        Epoch 1/15,
        Train Loss: 0.9494, Train Acc:  0.6159,
        Validation Loss: 0.9479, Validation Acc:  0.5976,
        Duration: 0:00:17.604069

        Epoch 2/15,
        Train Loss: 0.9266, Train Acc:  0.6209,
        Validation Loss: 0.9483, Validation Acc:  0.5975,
        Duration: 0:00:16.492374

        Epoch 3/15,
        Train Loss: 0.9264, Train Acc:  0.6208,
        Validation Loss: 0.9482, Validation Acc:  0.5980,
        Duration: 0:00:17.365681

        Epoch 4/15,
        Train Loss: 0.9262, Train Acc:  0.6209,
        Validation Loss: 0.9487, Validation Acc:  0.5980,
        Duration: 0:00:18.096907

        Epoch 5/15,
        Train Loss: 0.9262, Train Acc:  0.6208,
        Validation Loss: 0.9474, Validation Acc:  0.5980,
        Duration: 0:00:17.531529

        Epoch 6/15,
      

4it [14:40, 234.33s/it]


        Epoch 15/15,
        Train Loss: 0.9254, Train Acc:  0.6209,
        Validation Loss: 0.9481, Validation Acc:  0.5975,
        Duration: 0:00:18.274647
model saved

        Epoch 1/15,
        Train Loss: 1.0712, Train Acc:  0.4240,
        Validation Loss: 1.0835, Validation Acc:  0.4233,
        Duration: 0:00:18.681397

        Epoch 2/15,
        Train Loss: 1.0602, Train Acc:  0.4405,
        Validation Loss: 1.0799, Validation Acc:  0.4236,
        Duration: 0:00:16.996678

        Epoch 3/15,
        Train Loss: 1.0575, Train Acc:  0.4411,
        Validation Loss: 1.0809, Validation Acc:  0.4103,
        Duration: 0:00:17.674999

        Epoch 4/15,
        Train Loss: 1.0552, Train Acc:  0.4435,
        Validation Loss: 1.0802, Validation Acc:  0.4088,
        Duration: 0:00:18.460995

        Epoch 5/15,
        Train Loss: 1.0529, Train Acc:  0.4453,
        Validation Loss: 1.0844, Validation Acc:  0.4104,
        Duration: 0:00:17.952689

        Epoch 6/15,
      

5it [19:44, 259.48s/it]


        Epoch 15/15,
        Train Loss: 1.0348, Train Acc:  0.4729,
        Validation Loss: 1.0842, Validation Acc:  0.4076,
        Duration: 0:00:18.408581
model saved

        Epoch 1/15,
        Train Loss: 1.0632, Train Acc:  0.4066,
        Validation Loss: 1.0490, Validation Acc:  0.4116,
        Duration: 0:00:18.955508

        Epoch 2/15,
        Train Loss: 1.0474, Train Acc:  0.4183,
        Validation Loss: 1.0493, Validation Acc:  0.4108,
        Duration: 0:00:17.047549

        Epoch 3/15,
        Train Loss: 1.0422, Train Acc:  0.4264,
        Validation Loss: 1.0477, Validation Acc:  0.4040,
        Duration: 0:00:17.869797

        Epoch 4/15,
        Train Loss: 1.0394, Train Acc:  0.4345,
        Validation Loss: 1.0471, Validation Acc:  0.4187,
        Duration: 0:00:19.541118

        Epoch 5/15,
        Train Loss: 1.0391, Train Acc:  0.4388,
        Validation Loss: 1.0466, Validation Acc:  0.4258,
        Duration: 0:00:18.334088

        Epoch 6/15,
      

6it [24:49, 274.93s/it]


        Epoch 15/15,
        Train Loss: 0.9883, Train Acc:  0.5209,
        Validation Loss: 1.0163, Validation Acc:  0.4919,
        Duration: 0:00:18.325888
model saved

        Epoch 1/15,
        Train Loss: 0.9482, Train Acc:  0.5982,
        Validation Loss: 0.9318, Validation Acc:  0.6115,
        Duration: 0:00:31.225634

        Epoch 2/15,
        Train Loss: 0.9366, Train Acc:  0.6087,
        Validation Loss: 0.9288, Validation Acc:  0.6148,
        Duration: 0:00:26.058223

        Epoch 3/15,
        Train Loss: 0.9363, Train Acc:  0.6086,
        Validation Loss: 0.9287, Validation Acc:  0.6148,
        Duration: 0:00:27.664640

        Epoch 4/15,
        Train Loss: 0.9360, Train Acc:  0.6087,
        Validation Loss: 0.9324, Validation Acc:  0.6115,
        Duration: 0:00:27.803554

        Epoch 5/15,
        Train Loss: 0.9360, Train Acc:  0.6086,
        Validation Loss: 0.9289, Validation Acc:  0.6148,
        Duration: 0:00:27.479893

        Epoch 6/15,
      

7it [32:39, 338.77s/it]


        Epoch 15/15,
        Train Loss: 0.9335, Train Acc:  0.6125,
        Validation Loss: 0.9297, Validation Acc:  0.6156,
        Duration: 0:00:28.202237
model saved

        Epoch 1/15,
        Train Loss: 1.0748, Train Acc:  0.4195,
        Validation Loss: 1.0774, Validation Acc:  0.4450,
        Duration: 0:00:36.701183

        Epoch 2/15,
        Train Loss: 1.0662, Train Acc:  0.4318,
        Validation Loss: 1.0694, Validation Acc:  0.4504,
        Duration: 0:00:25.946649

        Epoch 3/15,
        Train Loss: 1.0630, Train Acc:  0.4344,
        Validation Loss: 1.0649, Validation Acc:  0.4481,
        Duration: 0:00:26.838125

        Epoch 4/15,
        Train Loss: 1.0604, Train Acc:  0.4359,
        Validation Loss: 1.0674, Validation Acc:  0.4471,
        Duration: 0:00:28.043115

        Epoch 5/15,
        Train Loss: 1.0576, Train Acc:  0.4368,
        Validation Loss: 1.0714, Validation Acc:  0.4444,
        Duration: 0:00:27.767352

        Epoch 6/15,
      

8it [40:39, 383.67s/it]


        Epoch 15/15,
        Train Loss: 1.0389, Train Acc:  0.4652,
        Validation Loss: 1.0680, Validation Acc:  0.4462,
        Duration: 0:00:27.794655
model saved

        Epoch 1/15,
        Train Loss: 1.0559, Train Acc:  0.4088,
        Validation Loss: 1.0668, Validation Acc:  0.3839,
        Duration: 0:00:37.808673

        Epoch 2/15,
        Train Loss: 1.0469, Train Acc:  0.4155,
        Validation Loss: 1.0658, Validation Acc:  0.4028,
        Duration: 0:00:26.349210

        Epoch 3/15,
        Train Loss: 1.0444, Train Acc:  0.4178,
        Validation Loss: 1.0649, Validation Acc:  0.3852,
        Duration: 0:00:27.704957

        Epoch 4/15,
        Train Loss: 1.0417, Train Acc:  0.4241,
        Validation Loss: 1.0653, Validation Acc:  0.4017,
        Duration: 0:00:27.588716

        Epoch 5/15,
        Train Loss: 1.0393, Train Acc:  0.4303,
        Validation Loss: 1.0678, Validation Acc:  0.3904,
        Duration: 0:00:27.906634

        Epoch 6/15,
      

9it [48:38, 413.45s/it]


        Epoch 15/15,
        Train Loss: 0.9750, Train Acc:  0.5357,
        Validation Loss: 1.0257, Validation Acc:  0.4759,
        Duration: 0:00:28.261389
model saved

        Epoch 1/15,
        Train Loss: 0.9241, Train Acc:  0.6283,
        Validation Loss: 0.8419, Validation Acc:  0.7045,
        Duration: 0:01:31.190987

        Epoch 2/15,
        Train Loss: 0.9164, Train Acc:  0.6283,
        Validation Loss: 0.8419, Validation Acc:  0.7045,
        Duration: 0:00:47.537171

        Epoch 3/15,
        Train Loss: 0.9162, Train Acc:  0.6283,
        Validation Loss: 0.8420, Validation Acc:  0.7045,
        Duration: 0:00:47.586996

        Epoch 4/15,
        Train Loss: 0.9159, Train Acc:  0.6283,
        Validation Loss: 0.8424, Validation Acc:  0.7055,
        Duration: 0:00:48.723604

        Epoch 5/15,
        Train Loss: 0.9157, Train Acc:  0.6291,
        Validation Loss: 0.8423, Validation Acc:  0.7059,
        Duration: 0:00:49.111469

        Epoch 6/15,
      

10it [1:03:46, 566.15s/it]


        Epoch 15/15,
        Train Loss: 0.9142, Train Acc:  0.6316,
        Validation Loss: 0.8424, Validation Acc:  0.7056,
        Duration: 0:00:53.298213
model saved

        Epoch 1/15,
        Train Loss: 1.0537, Train Acc:  0.4585,
        Validation Loss: 0.9867, Validation Acc:  0.5669,
        Duration: 0:01:46.994972

        Epoch 2/15,
        Train Loss: 1.0461, Train Acc:  0.4630,
        Validation Loss: 0.9874, Validation Acc:  0.5664,
        Duration: 0:00:52.237491

        Epoch 3/15,
        Train Loss: 1.0412, Train Acc:  0.4667,
        Validation Loss: 0.9889, Validation Acc:  0.5599,
        Duration: 0:00:56.631802

        Epoch 4/15,
        Train Loss: 1.0362, Train Acc:  0.4700,
        Validation Loss: 0.9838, Validation Acc:  0.5608,
        Duration: 0:00:57.358888

        Epoch 5/15,
        Train Loss: 1.0323, Train Acc:  0.4742,
        Validation Loss: 0.9800, Validation Acc:  0.5593,
        Duration: 0:00:52.755737

        Epoch 6/15,
      

11it [1:19:37, 684.12s/it]


        Epoch 15/15,
        Train Loss: 0.9639, Train Acc:  0.5648,
        Validation Loss: 0.9142, Validation Acc:  0.6345,
        Duration: 0:00:51.001181
model saved

        Epoch 1/15,
        Train Loss: 1.0567, Train Acc:  0.4139,
        Validation Loss: 1.0786, Validation Acc:  0.3973,
        Duration: 0:01:46.981866

        Epoch 2/15,
        Train Loss: 1.0392, Train Acc:  0.4463,
        Validation Loss: 1.0765, Validation Acc:  0.4099,
        Duration: 0:00:54.018915

        Epoch 3/15,
        Train Loss: 1.0160, Train Acc:  0.4872,
        Validation Loss: 1.0568, Validation Acc:  0.4508,
        Duration: 0:00:54.231766

        Epoch 4/15,
        Train Loss: 0.9898, Train Acc:  0.5224,
        Validation Loss: 1.0383, Validation Acc:  0.4705,
        Duration: 0:00:52.213375

        Epoch 5/15,
        Train Loss: 0.9723, Train Acc:  0.5472,
        Validation Loss: 1.0287, Validation Acc:  0.4796,
        Duration: 0:00:52.586241

        Epoch 6/15,
      

12it [1:36:19, 481.65s/it]


        Epoch 15/15,
        Train Loss: 0.9012, Train Acc:  0.6331,
        Validation Loss: 0.9963, Validation Acc:  0.5290,
        Duration: 0:00:54.809479
model saved





# Train MLP

In [5]:
# for cf, method, k in tqdm(product([1, 3, 5, 8], ['Zscore'], [0, 2, 4])):
#     model_name = f'MLP_{method}_CF{cf}_pred_{k}'
#     if os.path.exists(os.path.join('.', 'trained_models', f'{model_name}.pth')):
#         continue
#     train_data = FIDataset(DATA_DIR, method, cf, k=k, train=True)
#     test_data = FIDataset(DATA_DIR, method, cf, k=k, train=False)
#     train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
#     test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)
#
#     lob_model = MLP()
#     lob_model.to(lob_model.device)
#
#     criterion = nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(lob_model.parameters(), lr=LR)
#
#     batch_train(model_name, lob_model, criterion, optimizer, train_loader, test_loader, EPOCHS)