In [1]:
import torch

In [2]:
import scipy.io
import numpy as np

## DE Features (one subject)

https://github.com/ynulonger/DE_CNN

https://www.researchgate.net/publication/328504085_Continuous_Convolutional_Neural_Network_with_3D_Input_for_EEG-Based_Emotion_Recognition

In [3]:
deap_de_path = '../../methods/DE_CNN/1D_dataset/'

In [4]:
s0 = scipy.io.loadmat(deap_de_path + 'DE_s01.mat')
for i, key in enumerate(s0):
    print(key)

__header__
__version__
__globals__
base_data
data
valence_labels
arousal_labels


In [5]:
X_0 = s0['data']
y_0_valence = s0['valence_labels']
y_0_arousal = s0['arousal_labels']

In [6]:
X_0.shape

(2400, 4, 32)

In [7]:
y_0_valence.shape

(1, 2400)

In [8]:
np.transpose(y_0_valence).shape

(2400, 1)

### Merge all subjects' features

Subjects number and indexing

In [9]:
c = 2400

for idx in range(32):
    print(idx+1, idx*c, (idx+1)*c-1)

1 0 2399
2 2400 4799
3 4800 7199
4 7200 9599
5 9600 11999
6 12000 14399
7 14400 16799
8 16800 19199
9 19200 21599
10 21600 23999
11 24000 26399
12 26400 28799
13 28800 31199
14 31200 33599
15 33600 35999
16 36000 38399
17 38400 40799
18 40800 43199
19 43200 45599
20 45600 47999
21 48000 50399
22 50400 52799
23 52800 55199
24 55200 57599
25 57600 59999
26 60000 62399
27 62400 64799
28 64800 67199
29 67200 69599
30 69600 71999
31 72000 74399
32 74400 76799


In [10]:
deap_de_path

'../../methods/DE_CNN/1D_dataset/'

In [11]:
merge_de_cnn_features = False

In [12]:
if merge_de_cnn_features:
    de_cnn_features = np.empty((2400 * 32, 4, 32))
    de_cnn_y_valence = np.empty((2400 * 32, 1))
    de_cnn_y_arousal = np.empty((2400 * 32, 1))
    
    for i in range(1, 33):  # Subjects 1-32 in DEAP
        subj_data = scipy.io.loadmat(deap_de_path + f'DE_s{i:02}.mat')

        Xi_de = subj_data['data']
        yi_valence = np.transpose(subj_data['valence_labels'])
        yi_arousal = np.transpose(subj_data['arousal_labels'])
        
        idx = i-1  # indexing 0-31 for arrays
        c = 2400  # size of each subject's trials*1s_windows

        # efficient assigning, not really needed, could use np.append
        de_cnn_features[idx*c:(idx+1)*c] = Xi_de
        de_cnn_y_valence[idx*c:(idx+1)*c] = yi_valence
        de_cnn_y_arousal[idx*c:(idx+1)*c] = yi_arousal
        
        save_dict = {'data': de_cnn_features, 
                     'valence_labels': de_cnn_y_valence, 
                     'arousal_labels': de_cnn_y_arousal}
         
    np.save(deap_de_path + 'DE_merged.npy', save_dict)  

In [13]:
if not merge_de_cnn_features:
    de_cnn_merged = np.load(deap_de_path + 'DE_merged.npy', allow_pickle=True).item()
    de_cnn_features = de_cnn_merged['data']
    de_cnn_y_valence = de_cnn_merged['valence_labels']
    de_cnn_y_arousal = de_cnn_merged['arousal_labels']
    
    print('Loaded from file.')
    print(de_cnn_features.shape)
    print(de_cnn_y_valence.shape)
    print(de_cnn_y_arousal.shape)

Loaded from file.
(76800, 4, 32)
(76800, 1)
(76800, 1)


## Load DE Features (all subjects)

https://github.com/gzoumpourlis/DEAP_MNE_preprocessing

In [14]:
de_features_path = '../../preprocessing/DEAP_MNE_preprocessing/features_new/de_feats_merged.npy'

In [15]:
de_features = np.load(de_features_path)

In [16]:
de_features.shape

(1280, 32, 5, 232)

In [17]:
deap_path = '../../datasets/DEAP/merged/'

In [18]:
y = np.load(deap_path + 'deap_full_labels.npy')
y.shape

(1280, 3)

Column 0 is Valence, 1 is Arousal, 2 is quadrants notation (HAHV, HALV, LAHV, LALV)

In [19]:
valence = 0
arousal = 1
quadrants = 2

In [20]:
y = y[:, valence]

In [21]:
y.shape

(1280,)

## Define DEAP Dataset

In [22]:
from torch.utils.data import Dataset

In [23]:
class DEAPDataset(Dataset):
    def __init__(self, data, labels):
        self.X = data
        self.y = labels
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

Either use 'de_features' (DEAP_MNE_preprocessing) or 'de_cnn_features' (DE_CNN)

Labels are 'y' or 'de_cnn_y_valence/de_cnn_y_arousal' respectively

In [24]:
use_de_cnn = True

In [25]:
if use_de_cnn:
    deap_dataset = DEAPDataset(de_cnn_features, de_cnn_y_valence)
else:
    deap_dataset = DEAPDataset(de_features, y)

## DataLoader

In [26]:
from torch.utils.data import DataLoader

In [27]:
batch_size=32

In [28]:
train_dataloader = DataLoader(deap_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [29]:
len(train_dataloader) * batch_size

76800

## Define BiHDM Model

In [31]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
device

device(type='cpu')

In [32]:
from Models_DEAP import BiHDM

### BiHDM Initialization Parameters

In [33]:
hidden_size=32
num_layers=2
input_size=4
n_classes=1

# batch_first=False
# bidirectional=False

fc_input=448
fc_hidden=96

In [34]:
model = BiHDM(hidden_size=hidden_size, num_layers=num_layers, input_size=input_size, 
              fc_input=fc_input, fc_hidden=fc_hidden, n_classes=n_classes)

In [35]:
model.to(device).float()

BiHDM(
  (RNN_VL): RNN(4, 32, num_layers=2)
  (RNN_VR): RNN(4, 32, num_layers=2)
  (RNN_V): RNN(32, 32, num_layers=2)
  (RNN_HL): RNN(4, 32, num_layers=2)
  (RNN_HR): RNN(4, 32, num_layers=2)
  (RNN_H): RNN(32, 32, num_layers=2)
  (fc_v): Sequential(
    (0): Linear(in_features=448, out_features=96, bias=True)
    (1): ReLU()
  )
  (fc_h): Sequential(
    (0): Linear(in_features=448, out_features=96, bias=True)
    (1): ReLU()
  )
  (fc_c): Sequential(
    (0): Linear(in_features=96, out_features=1, bias=True)
  )
)

In [36]:
input_data = de_features[:64] # small batch deap_mne_preprocessing
input_data.shape

(64, 32, 5, 232)

In [37]:
input_data = de_cnn_features[:64] # small batch de_cnn
input_data.shape

(64, 4, 32)

## Training BiHDM

In [30]:
criterion = torch.nn.BCEWithLogitsLoss()

In [41]:
lr=0.001
betas=(0.9, 0.999)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

In [42]:
num_epochs = 1

In [47]:
for epoch in range(num_epochs):
    for i, (data, labels) in enumerate(train_dataloader):
        model.train()

        data = data.to(device).float()
        labels = labels.to(device).float()

        outputs = model(data.permute(0, 2, 1))    
        loss = criterion(outputs, labels)

        print(f'batch: {i}, loss: {loss.item()}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 

batch: 0, loss: 0.5709908604621887
batch: 1, loss: 0.5710475444793701
batch: 2, loss: 0.5710583329200745
batch: 3, loss: 0.5710269212722778
batch: 4, loss: 0.5709583759307861
batch: 5, loss: 0.6689745187759399
batch: 6, loss: 0.8326308727264404
batch: 7, loss: 0.7017001509666443
batch: 8, loss: 0.5706980228424072
batch: 9, loss: 0.5706344246864319
batch: 10, loss: 0.570536732673645
batch: 11, loss: 0.5704078078269958
batch: 12, loss: 0.570252001285553
batch: 13, loss: 0.570070743560791
batch: 14, loss: 0.5698675513267517
batch: 15, loss: 0.83408123254776
batch: 16, loss: 0.8342747688293457
batch: 17, loss: 0.8343794941902161
batch: 18, loss: 0.834405779838562
batch: 19, loss: 0.8343607187271118
batch: 20, loss: 0.8342517614364624
batch: 21, loss: 0.8340850472450256
batch: 22, loss: 0.8338665962219238
batch: 23, loss: 0.8336013555526733
batch: 24, loss: 0.8332943320274353
batch: 25, loss: 0.8329495191574097
batch: 26, loss: 0.8325717449188232
batch: 27, loss: 0.8321627378463745
batch: 2

batch: 230, loss: 0.5788710117340088
batch: 231, loss: 0.5789196491241455
batch: 232, loss: 0.7005223035812378
batch: 233, loss: 0.8221120834350586
batch: 234, loss: 0.6701322793960571
batch: 235, loss: 0.5790378451347351
batch: 236, loss: 0.7612394690513611
batch: 237, loss: 0.8219336867332458
batch: 238, loss: 0.6094856858253479
batch: 239, loss: 0.5791931748390198
batch: 240, loss: 0.5791904330253601
batch: 241, loss: 0.6094828248023987
batch: 242, loss: 0.8219242691993713
batch: 243, loss: 0.7612193822860718
batch: 244, loss: 0.5790886878967285
batch: 245, loss: 0.5790659189224243
batch: 246, loss: 0.5790042877197266
batch: 247, loss: 0.7005243897438049
batch: 248, loss: 0.8222451210021973
batch: 249, loss: 0.6701051592826843
batch: 250, loss: 0.5787806510925293
batch: 251, loss: 0.5787171721458435
batch: 252, loss: 0.5786188244819641
batch: 253, loss: 0.792151927947998
batch: 254, loss: 0.8227721452713013
batch: 255, loss: 0.5783973336219788
batch: 256, loss: 0.6089072227478027
ba

batch: 457, loss: 0.5733487606048584
batch: 458, loss: 0.5731762647628784
batch: 459, loss: 0.5729803442955017
batch: 460, loss: 0.5727632641792297
batch: 461, loss: 0.5725269317626953
batch: 462, loss: 0.5722736716270447
batch: 463, loss: 0.5720051527023315
batch: 464, loss: 0.5717227458953857
batch: 465, loss: 0.5714279413223267
batch: 466, loss: 0.5711219310760498
batch: 467, loss: 0.5708063840866089
batch: 468, loss: 0.570481538772583
batch: 469, loss: 0.5701488852500916
batch: 470, loss: 0.5698089003562927
batch: 471, loss: 0.5694628357887268
batch: 472, loss: 0.5691109299659729
batch: 473, loss: 0.5687540173530579
batch: 474, loss: 0.5683923363685608
batch: 475, loss: 0.568026602268219
batch: 476, loss: 0.5676573514938354
batch: 477, loss: 0.5672848224639893
batch: 478, loss: 0.80381178855896
batch: 479, loss: 0.8380440473556519
batch: 480, loss: 0.5663988590240479
batch: 481, loss: 0.5661661028862
batch: 482, loss: 0.5659165382385254
batch: 483, loss: 0.5656516551971436
batch: 4

batch: 686, loss: 0.5701514482498169
batch: 687, loss: 0.5699998140335083
batch: 688, loss: 0.5698595643043518
batch: 689, loss: 0.5696487426757812
batch: 690, loss: 0.5694072246551514
batch: 691, loss: 0.5691702961921692
batch: 692, loss: 0.5689184069633484
batch: 693, loss: 0.5686418414115906
batch: 694, loss: 0.5683610439300537
batch: 695, loss: 0.5680619478225708
batch: 696, loss: 0.5677560567855835
batch: 697, loss: 0.5674417614936829
batch: 698, loss: 0.5671175718307495
batch: 699, loss: 0.56678307056427
batch: 700, loss: 0.566443920135498
batch: 701, loss: 0.5660985112190247
batch: 702, loss: 0.5657480359077454
batch: 703, loss: 0.8053515553474426
batch: 704, loss: 0.8400129675865173
batch: 705, loss: 0.8402709364891052
batch: 706, loss: 0.8059737682342529
batch: 707, loss: 0.5647161602973938
batch: 708, loss: 0.5646128058433533
batch: 709, loss: 0.5644816160202026
batch: 710, loss: 0.5643175840377808
batch: 711, loss: 0.5641318559646606
batch: 712, loss: 0.7027536630630493
batc

batch: 912, loss: 0.8266247510910034
batch: 913, loss: 0.6069028973579407
batch: 914, loss: 0.5756282210350037
batch: 915, loss: 0.5756682753562927
batch: 916, loss: 0.6069921255111694
batch: 917, loss: 0.8263373374938965
batch: 918, loss: 0.7636444568634033
batch: 919, loss: 0.5757011771202087
batch: 920, loss: 0.5757045149803162
batch: 921, loss: 0.5756667256355286
batch: 922, loss: 0.5755918025970459
batch: 923, loss: 0.5754833817481995
batch: 924, loss: 0.575344979763031
batch: 925, loss: 0.5751795172691345
batch: 926, loss: 0.7641175985336304
batch: 927, loss: 0.82734215259552
batch: 928, loss: 0.827438473701477
batch: 929, loss: 0.8274574875831604
batch: 930, loss: 0.574798047542572
batch: 931, loss: 0.5747924447059631
batch: 932, loss: 0.5747466683387756
batch: 933, loss: 0.5746647715568542
batch: 934, loss: 0.574550211429596
batch: 935, loss: 0.5744062066078186
batch: 936, loss: 0.5742359757423401
batch: 937, loss: 0.7012115120887756
batch: 938, loss: 0.8285988569259644
batch: 

batch: 1134, loss: 0.725239634513855
batch: 1135, loss: 0.8046988844871521
batch: 1136, loss: 0.6458005905151367
batch: 1137, loss: 0.5928928852081299
batch: 1138, loss: 0.7781336903572083
batch: 1139, loss: 0.8045600652694702
batch: 1140, loss: 0.8044627904891968
batch: 1141, loss: 0.7779102921485901
batch: 1142, loss: 0.593267023563385
batch: 1143, loss: 0.5933616161346436
batch: 1144, loss: 0.5934051275253296
batch: 1145, loss: 0.5934000611305237
batch: 1146, loss: 0.5933539271354675
batch: 1147, loss: 0.6986939907073975
batch: 1148, loss: 0.8042054176330566
batch: 1149, loss: 0.8042191863059998
batch: 1150, loss: 0.8041667342185974
batch: 1151, loss: 0.804054856300354
batch: 1152, loss: 0.8038896918296814
batch: 1153, loss: 0.803676187992096
batch: 1154, loss: 0.8034195899963379
batch: 1155, loss: 0.8031240105628967
batch: 1156, loss: 0.776737630367279
batch: 1157, loss: 0.5946294665336609
batch: 1158, loss: 0.646674394607544
batch: 1159, loss: 0.8019757866859436
batch: 1160, loss:

batch: 1353, loss: 0.6010656356811523
batch: 1354, loss: 0.6011285781860352
batch: 1355, loss: 0.6011419892311096
batch: 1356, loss: 0.6011108756065369
batch: 1357, loss: 0.601039469242096
batch: 1358, loss: 0.6009321808815002
batch: 1359, loss: 0.6007924675941467
batch: 1360, loss: 0.600623369216919
batch: 1361, loss: 0.6004281640052795
batch: 1362, loss: 0.6002092957496643
batch: 1363, loss: 0.5999693274497986
batch: 1364, loss: 0.5997101664543152
batch: 1365, loss: 0.599433958530426
batch: 1366, loss: 0.5991424918174744
batch: 1367, loss: 0.5988370776176453
batch: 1368, loss: 0.5985192656517029
batch: 1369, loss: 0.5981904864311218
batch: 1370, loss: 0.5978516936302185
batch: 1371, loss: 0.5975040793418884
batch: 1372, loss: 0.5971483588218689
batch: 1373, loss: 0.5967855453491211
batch: 1374, loss: 0.5964164137840271
batch: 1375, loss: 0.5960415005683899
batch: 1376, loss: 0.5956615805625916
batch: 1377, loss: 0.5952770709991455
batch: 1378, loss: 0.7762190699577332
batch: 1379, lo

batch: 1571, loss: 0.6492304801940918
batch: 1572, loss: 0.6009083986282349
batch: 1573, loss: 0.7703667879104614
batch: 1574, loss: 0.7942020893096924
batch: 1575, loss: 0.60130375623703
batch: 1576, loss: 0.6014633178710938
batch: 1577, loss: 0.6015539169311523
batch: 1578, loss: 0.6015347838401794
batch: 1579, loss: 0.601452112197876
batch: 1580, loss: 0.6012139916419983
batch: 1581, loss: 0.6009461283683777
batch: 1582, loss: 0.6006637811660767
batch: 1583, loss: 0.6003378629684448
batch: 1584, loss: 0.7224608659744263
batch: 1585, loss: 0.7962337136268616
batch: 1586, loss: 0.6487724184989929
batch: 1587, loss: 0.599428117275238
batch: 1588, loss: 0.5992893576622009
batch: 1589, loss: 0.5991228818893433
batch: 1590, loss: 0.5989335775375366
batch: 1591, loss: 0.6235576868057251
batch: 1592, loss: 0.797702431678772
batch: 1593, loss: 0.7479991316795349
batch: 1594, loss: 0.5982444882392883
batch: 1595, loss: 0.6731336116790771
batch: 1596, loss: 0.7983359694480896
batch: 1597, loss

batch: 1792, loss: 0.7941337823867798
batch: 1793, loss: 0.7936694622039795
batch: 1794, loss: 0.7931884527206421
batch: 1795, loss: 0.7926924228668213
batch: 1796, loss: 0.7921832799911499
batch: 1797, loss: 0.7916620969772339
batch: 1798, loss: 0.7911301255226135
batch: 1799, loss: 0.790588915348053
batch: 1800, loss: 0.6048195958137512
batch: 1801, loss: 0.6282384395599365
batch: 1802, loss: 0.7892364859580994
batch: 1803, loss: 0.7430894374847412
batch: 1804, loss: 0.6061274409294128
batch: 1805, loss: 0.6745474934577942
batch: 1806, loss: 0.7879223823547363
batch: 1807, loss: 0.697227954864502
batch: 1808, loss: 0.6070470213890076
batch: 1809, loss: 0.7196871042251587
batch: 1810, loss: 0.7869886159896851
batch: 1811, loss: 0.6523555517196655
batch: 1812, loss: 0.607710599899292
batch: 1813, loss: 0.6078064441680908
batch: 1814, loss: 0.6078488826751709
batch: 1815, loss: 0.6078432202339172
batch: 1816, loss: 0.6301289200782776
batch: 1817, loss: 0.7865614295005798
batch: 1818, lo

batch: 2012, loss: 0.8058739900588989
batch: 2013, loss: 0.8061670660972595
batch: 2014, loss: 0.8063659071922302
batch: 2015, loss: 0.7258104681968689
batch: 2016, loss: 0.5912953019142151
batch: 2017, loss: 0.6989406943321228
batch: 2018, loss: 0.8067946434020996
batch: 2019, loss: 0.8068260550498962
batch: 2020, loss: 0.8067893385887146
batch: 2021, loss: 0.8066912293434143
batch: 2022, loss: 0.8065381050109863
batch: 2023, loss: 0.6183350682258606
batch: 2024, loss: 0.5915946364402771
batch: 2025, loss: 0.5916574001312256
batch: 2026, loss: 0.5916717052459717
batch: 2027, loss: 0.5916423797607422
batch: 2028, loss: 0.5915736556053162
batch: 2029, loss: 0.5914695858955383
batch: 2030, loss: 0.5913336277008057
batch: 2031, loss: 0.5911691188812256
batch: 2032, loss: 0.5909788012504578
batch: 2033, loss: 0.5907654166221619
batch: 2034, loss: 0.5905311703681946
batch: 2035, loss: 0.5902782082557678
batch: 2036, loss: 0.590008556842804
batch: 2037, loss: 0.5897237658500671
batch: 2038, 

batch: 2231, loss: 0.8259029984474182
batch: 2232, loss: 0.8259314298629761
batch: 2233, loss: 0.6072171926498413
batch: 2234, loss: 0.5759783983230591
batch: 2235, loss: 0.8259414434432983
batch: 2236, loss: 0.8259211778640747
batch: 2237, loss: 0.8258357644081116
batch: 2238, loss: 0.7633017301559448
batch: 2239, loss: 0.5762625932693481
batch: 2240, loss: 0.6697469353675842
batch: 2241, loss: 0.8253459930419922
batch: 2242, loss: 0.8252061605453491
batch: 2243, loss: 0.8250131011009216
batch: 2244, loss: 0.6698203682899475
batch: 2245, loss: 0.577012300491333
batch: 2246, loss: 0.7626003623008728
batch: 2247, loss: 0.8242682814598083
batch: 2248, loss: 0.6082386374473572
batch: 2249, loss: 0.5775261521339417
batch: 2250, loss: 0.577593207359314
batch: 2251, loss: 0.5776111483573914
batch: 2252, loss: 0.5775871872901917
batch: 2253, loss: 0.5775246620178223
batch: 2254, loss: 0.5774256587028503
batch: 2255, loss: 0.5772965550422668
batch: 2256, loss: 0.5771396160125732
batch: 2257, l