In [1]:
import pickle

from torch.utils.data import DataLoader

X = pickle.load(open('data/X.dat', 'rb'))
Y = pickle.load(open('data/Y.dat', 'rb'))

(X.shape, Y.shape)

((32, 40, 6, 3, 28, 28), (32, 40, 2))

In [2]:
SEQ_LEN, IMG_CHANNELS, RESOLUTION = X.shape[2:5]

In [3]:
from src.training import cv_subject, train_subject, finetune_evaluate
import numpy as np

In [4]:
from src.model import CNN_LSTM

results = []
accuracies = []
models = []
for subject in range(32):
    model = CNN_LSTM(2)
    cv_loss, cv_accuracies = cv_subject(subject, X, Y, model)

    mean_mse, mean_accuracies = np.mean(cv_loss, axis=0), np.mean(cv_accuracies, axis=0)

    model, _ = train_subject(subject, X, Y,  CNN_LSTM(2))
    models.append(model)

    print(f'SUBJECT: {subject + 1}, CV MSE {mean_mse}, CV ACC {mean_accuracies}')

    accuracies.append(mean_accuracies)
    results.append(mean_mse)

SUBJECT: 1, CV MSE [0.40406528 0.30275238], CV ACC [0.35 0.6 ]
SUBJECT: 2, CV MSE [0.54679906 0.5829817 ], CV ACC [0.55 0.6 ]
SUBJECT: 3, CV MSE [0.13510291 0.15464857], CV ACC [0.55 0.8 ]
SUBJECT: 4, CV MSE [0.34724063 0.23586233], CV ACC [0.6 0.6]
SUBJECT: 5, CV MSE [0.3420049  0.24230602], CV ACC [0.6   0.475]
SUBJECT: 6, CV MSE [0.1423387  0.13334192], CV ACC [0.75  0.575]
SUBJECT: 7, CV MSE [0.24642369 0.23320922], CV ACC [0.7   0.625]
SUBJECT: 8, CV MSE [0.27139825 0.14004673], CV ACC [0.55  0.575]
SUBJECT: 9, CV MSE [0.13613693 0.07439212], CV ACC [0.375 0.6  ]
SUBJECT: 10, CV MSE [0.28270546 0.14337233], CV ACC [0.475 0.4  ]
SUBJECT: 11, CV MSE [0.2916038 0.366479 ], CV ACC [0.55  0.625]
SUBJECT: 12, CV MSE [0.32480314 0.18965743], CV ACC [0.425 0.825]
SUBJECT: 13, CV MSE [0.35641676 0.25988418], CV ACC [0.55 0.85]
SUBJECT: 14, CV MSE [0.36594063 0.18812475], CV ACC [0.225 0.675]
SUBJECT: 15, CV MSE [0.3627438  0.11817557], CV ACC [0.5   0.475]
SUBJECT: 16, CV MSE [0.21790862 0

In [5]:
print(np.mean(results, axis=0))
print(np.std(results, axis=0))

[0.28494513 0.23748952]
[0.10818671 0.12326779]


In [6]:
print(np.mean(accuracies, axis=0))
print(np.std(accuracies, axis=0))

[0.54921875 0.6       ]
[0.10977327 0.16286018]


In [10]:
from src.training import MyDataset, train_epoch, evaluate, predict

from torch import nn

from torch import optim


def finetune_evaluate(idx, X, y, model, trial_num):
    tx, ty = X[:, :trial_num], y[:, :trial_num]
    vx, vy = X[:, 20:], y[:, 20:]

    tx, ty = tx[idx].reshape((-1, SEQ_LEN, IMG_CHANNELS, RESOLUTION, RESOLUTION)), ty[idx].reshape((-1, y.shape[-1]))
    vx, vy = vx[idx].reshape((-1, SEQ_LEN, IMG_CHANNELS, RESOLUTION, RESOLUTION)), vy[idx].reshape((-1, y.shape[-1]))

    val_loss = []
    val_accuracies = []
    model.train()

    iter_no_change = 0
    best_model_state = model.state_dict()
    best_loss = 0
    best_metrics = []

    train_dl = DataLoader(MyDataset(tx, ty), batch_size=4, shuffle=True)
    val_dl = DataLoader(MyDataset(vx, vy), batch_size=4, shuffle=False)
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.01)
    model.to('cuda')

    for epoch in range(10):
        model.train()
        epoch_train_loss = train_epoch(model, train_dl, criterion, optimizer)

        epoch_val_loss, val_acc, ars_acc = evaluate(*predict(model, val_dl))

        if np.mean(epoch_val_loss) < best_loss or epoch == 0:
            best_loss = np.mean(epoch_val_loss)
            best_metrics = epoch_val_loss, val_acc, ars_acc
            iter_no_change = 0
            best_model_state = model.state_dict()
        else:
            iter_no_change += 1

        if iter_no_change > 5:
            model.load_state_dict(best_model_state)
            break

    return best_metrics


In [11]:
from src.model import EnsembleModel
from tqdm import tqdm

trial_data = {}
for trial_num in tqdm(range(1, 21)):
    loss_arr = []
    metrics_arr = []
    for subject in range(24, 32):
        ensemble = EnsembleModel(models[:24])
        metrics = finetune_evaluate(subject, X, Y, ensemble, trial_num)
        loss = np.mean(metrics[0])
        metrics_arr.append(metrics)

        if subject == 24:
            print(f'SUBJECT: {subject + 1}, LOSS {loss}')

    trial_data[trial_num] = [loss_arr, metrics_arr]

  0%|          | 0/20 [00:00<?, ?it/s]

SUBJECT: 25, LOSS 0.5256332159042358


  5%|▌         | 1/20 [00:22<07:09, 22.59s/it]

SUBJECT: 25, LOSS 0.43553030490875244


 10%|█         | 2/20 [00:50<07:42, 25.71s/it]

SUBJECT: 25, LOSS 0.3842684030532837


 15%|█▌        | 3/20 [01:14<07:04, 24.98s/it]

SUBJECT: 25, LOSS 0.3124127984046936


 20%|██        | 4/20 [01:38<06:32, 24.56s/it]

SUBJECT: 25, LOSS 0.5073813199996948


 25%|██▌       | 5/20 [02:06<06:27, 25.83s/it]

SUBJECT: 25, LOSS 0.421988844871521


 30%|███       | 6/20 [02:34<06:10, 26.48s/it]

SUBJECT: 25, LOSS 0.4894446134567261


 35%|███▌      | 7/20 [03:04<05:59, 27.62s/it]

SUBJECT: 25, LOSS 0.46718376874923706


 40%|████      | 8/20 [03:29<05:21, 26.81s/it]

SUBJECT: 25, LOSS 0.5210740566253662


 45%|████▌     | 9/20 [03:55<04:53, 26.70s/it]

SUBJECT: 25, LOSS 0.4714527428150177


 50%|█████     | 10/20 [04:33<05:00, 30.04s/it]

SUBJECT: 25, LOSS 0.5172622799873352


 55%|█████▌    | 11/20 [05:11<04:52, 32.50s/it]

SUBJECT: 25, LOSS 0.308277428150177


 60%|██████    | 12/20 [05:40<04:11, 31.40s/it]

SUBJECT: 25, LOSS 0.31748291850090027


 65%|██████▌   | 13/20 [06:16<03:48, 32.71s/it]

SUBJECT: 25, LOSS 0.3005702495574951


 70%|███████   | 14/20 [06:48<03:15, 32.50s/it]

SUBJECT: 25, LOSS 0.39683881402015686


 75%|███████▌  | 15/20 [07:18<02:39, 31.92s/it]

SUBJECT: 25, LOSS 0.4392407536506653


 80%|████████  | 16/20 [07:41<01:56, 29.21s/it]

SUBJECT: 25, LOSS 0.2810152769088745


 85%|████████▌ | 17/20 [08:07<01:24, 28.33s/it]

SUBJECT: 25, LOSS 0.3163915276527405


 90%|█████████ | 18/20 [08:34<00:55, 27.71s/it]

SUBJECT: 25, LOSS 0.3618026375770569


 95%|█████████▌| 19/20 [09:00<00:27, 27.18s/it]

SUBJECT: 25, LOSS 0.4254027009010315


100%|██████████| 20/20 [09:23<00:00, 28.19s/it]


In [12]:
X = pickle.load(open('eeg-experiment/X.dat', 'rb'))
Y = pickle.load(open('eeg-experiment/Y.dat', 'rb'))

(X.shape, Y.shape)

((1, 32, 6, 3, 28, 28), (1, 32, 2))

In [13]:
tx, vx = X[:, :16], X[:, 16:]
ty, vy = Y[:, :16], Y[:, 16:]

In [14]:
subject = 0

model, loss = train_subject(subject, tx, ty,  CNN_LSTM(Y.shape[-1]))
loss

0.20119953155517578