In [1]:
%load_ext jupyternotify

<IPython.core.display.Javascript object>

In [2]:
from torch.utils.tensorboard import SummaryWriter

In [3]:
writer = SummaryWriter('C:/Users/Alex/git/EEG-emotion/methods/Saliency_Emotion_EEG/TensorBoard/deap_bihdm')

In [4]:
import torch

In [5]:
import scipy.io
import numpy as np

## DE Features (one subject)

https://github.com/ynulonger/DE_CNN

https://www.researchgate.net/publication/328504085_Continuous_Convolutional_Neural_Network_with_3D_Input_for_EEG-Based_Emotion_Recognition

In [6]:
deap_de_path = '../../methods/DE_CNN/1D_dataset/'

In [7]:
s0 = scipy.io.loadmat(deap_de_path + 'DE_s01.mat')
for i, key in enumerate(s0):
    print(key)

__header__
__version__
__globals__
base_data
data
valence_labels
arousal_labels


In [8]:
X_0 = s0['data']
y_0_valence = s0['valence_labels']
y_0_arousal = s0['arousal_labels']

In [9]:
X_0.shape

(2400, 4, 32)

In [10]:
y_0_valence.shape

(1, 2400)

In [11]:
np.transpose(y_0_valence).shape

(2400, 1)

### Merge all subjects' features

Subjects number and indexing

In [12]:
c = 2400

for idx in range(32):
    print(idx+1, idx*c, (idx+1)*c-1)

1 0 2399
2 2400 4799
3 4800 7199
4 7200 9599
5 9600 11999
6 12000 14399
7 14400 16799
8 16800 19199
9 19200 21599
10 21600 23999
11 24000 26399
12 26400 28799
13 28800 31199
14 31200 33599
15 33600 35999
16 36000 38399
17 38400 40799
18 40800 43199
19 43200 45599
20 45600 47999
21 48000 50399
22 50400 52799
23 52800 55199
24 55200 57599
25 57600 59999
26 60000 62399
27 62400 64799
28 64800 67199
29 67200 69599
30 69600 71999
31 72000 74399
32 74400 76799


In [13]:
deap_de_path

'../../methods/DE_CNN/1D_dataset/'

In [14]:
merge_de_cnn_features = False

In [15]:
if merge_de_cnn_features:
    de_cnn_features = np.empty((2400 * 32, 4, 32))
    de_cnn_y_valence = np.empty((2400 * 32, 1))
    de_cnn_y_arousal = np.empty((2400 * 32, 1))
    
    for i in range(1, 33):  # Subjects 1-32 in DEAP
        subj_data = scipy.io.loadmat(deap_de_path + f'DE_s{i:02}.mat')

        Xi_de = subj_data['data']
        yi_valence = np.transpose(subj_data['valence_labels'])
        yi_arousal = np.transpose(subj_data['arousal_labels'])
        
        idx = i-1  # indexing 0-31 for arrays
        c = 2400  # size of each subject's trials*1s_windows

        # efficient assigning, not really needed, could use np.append
        de_cnn_features[idx*c:(idx+1)*c] = Xi_de
        de_cnn_y_valence[idx*c:(idx+1)*c] = yi_valence
        de_cnn_y_arousal[idx*c:(idx+1)*c] = yi_arousal
        
        save_dict = {'data': de_cnn_features, 
                     'valence_labels': de_cnn_y_valence, 
                     'arousal_labels': de_cnn_y_arousal}
         
    np.save(deap_de_path + 'DE_merged.npy', save_dict)  

In [16]:
if not merge_de_cnn_features:
    de_cnn_merged = np.load(deap_de_path + 'DE_merged.npy', allow_pickle=True).item()
    de_cnn_features = de_cnn_merged['data']
    de_cnn_y_valence = de_cnn_merged['valence_labels']
    de_cnn_y_arousal = de_cnn_merged['arousal_labels']
    
    print('Loaded from file.')
    print(de_cnn_features.shape)
    print(de_cnn_y_valence.shape)
    print(de_cnn_y_arousal.shape)

Loaded from file.
(76800, 4, 32)
(76800, 1)
(76800, 1)


## Load DE Features (all subjects)

https://github.com/gzoumpourlis/DEAP_MNE_preprocessing

In [17]:
de_features_path = '../../preprocessing/DEAP_MNE_preprocessing/features_new/de_feats_merged.npy'

In [18]:
de_features = np.load(de_features_path)

In [19]:
de_features.shape

(1280, 32, 5, 232)

In [20]:
deap_path = '../../datasets/DEAP/merged/'

In [21]:
y = np.load(deap_path + 'deap_full_labels.npy')
y.shape

(1280, 3)

Column 0 is Valence, 1 is Arousal, 2 is quadrants notation (HAHV, HALV, LAHV, LALV)

In [22]:
valence = 0
arousal = 1
quadrants = 2

In [23]:
y = y[:, valence]

In [24]:
y.shape

(1280,)

## Define DEAP Dataset

In [25]:
from torch.utils.data import Dataset

In [26]:
class DEAPDataset(Dataset):
    def __init__(self, data, labels):
        self.X = data
        self.y = labels
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

Either use 'de_features' (DEAP_MNE_preprocessing) or 'de_cnn_features' (DE_CNN)

Labels are 'y' or 'de_cnn_y_valence/de_cnn_y_arousal' respectively

In [27]:
use_de_cnn = True

In [28]:
if use_de_cnn:
    deap_dataset = DEAPDataset(de_cnn_features, de_cnn_y_valence)
else:
    deap_dataset = DEAPDataset(de_features, y)

In [29]:
deap_dataset[0][0].shape

(4, 32)

In [30]:
input_data = de_features[:64] # small batch deap_mne_preprocessing
input_data.shape

(64, 32, 5, 232)

In [31]:
input_data = de_cnn_features[:64] # small batch de_cnn
input_data.shape

(64, 4, 32)

## DataLoader

In [32]:
from torch.utils.data import DataLoader

In [33]:
batch_size=32

In [34]:
train_split = 0.75
test_split = 0.25

train_n_elems = int(train_split * len(deap_dataset))
test_n_elems = int(test_split * len(deap_dataset))

print(train_split, test_split)
print(train_n_elems, test_n_elems)

0.75 0.25
57600 19200


In [35]:
from torch.utils.data import random_split

In [36]:
deap_train, deap_test = random_split(deap_dataset, [train_n_elems, test_n_elems])

In [37]:
# deap_train = deap_dataset[:train_n_elems]
# deap_test = deap_dataset[train_n_elems:]

In [38]:
len(deap_train)

57600

In [39]:
train_dataloader = DataLoader(deap_train, batch_size=batch_size, shuffle=True, num_workers=0)
test_dataloader = DataLoader(deap_test, batch_size=batch_size, shuffle=True, num_workers=0)

## Define BiHDM Model

In [40]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')
device

device(type='cpu')

In [41]:
from Models_DEAP import BiHDM

### BiHDM Initialization Parameters

In [42]:
hidden_size = 32
num_layers = 2
input_size = 4
n_classes = 1

# batch_first=False
# bidirectional=False

fc_input=448
fc_hidden=96

In [43]:
model = BiHDM(hidden_size=hidden_size, num_layers=num_layers, input_size=input_size, 
              fc_input=fc_input, fc_hidden=fc_hidden, n_classes=n_classes)

In [44]:
model.to(device).float()

BiHDM(
  (RNN_VL): RNN(4, 32, num_layers=2)
  (RNN_VR): RNN(4, 32, num_layers=2)
  (RNN_V): RNN(32, 32, num_layers=2)
  (RNN_HL): RNN(4, 32, num_layers=2)
  (RNN_HR): RNN(4, 32, num_layers=2)
  (RNN_H): RNN(32, 32, num_layers=2)
  (fc_v): Sequential(
    (0): Linear(in_features=448, out_features=96, bias=True)
    (1): ReLU()
  )
  (fc_h): Sequential(
    (0): Linear(in_features=448, out_features=96, bias=True)
    (1): ReLU()
  )
  (fc_c): Sequential(
    (0): Linear(in_features=96, out_features=1, bias=True)
  )
)

In [45]:
example_data = next(iter(train_dataloader))
example_data[0].shape

torch.Size([32, 4, 32])

In [46]:
writer.add_graph(model, example_data[0].float().permute(0, 2, 1).to(device))
#writer.close()
#sys.exit()

In [47]:
#writer.add_scalar('test', 1, 3)

## Training BiHDM

In [48]:
criterion = torch.nn.BCEWithLogitsLoss()

In [49]:
lr=0.001
betas=(0.9, 0.999)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

In [50]:
num_epochs = 100

In [51]:
n_total_steps = len(train_dataloader)

In [52]:
model_path = './saved_models/BiHDM-Model.pth'

In [53]:
skip_training = False
resume_training = False

In [54]:
if not skip_training:
    if resume_training:
        #checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 
        min_loss = checkpoint['min_loss']
    else:
        min_loss = 1000
        
print(min_loss)    

1000


In [55]:
running_loss = 0.0
running_correct = 0
acc_n_steps = 300

In [61]:
%%notify -m "Training Completed!"

if not skip_training:
    for epoch in range(num_epochs):
        model.train()
        for i, (data, labels) in enumerate(train_dataloader):
            data = data.to(device).float().permute(0, 2, 1)
            labels = labels.to(device).float()

            outputs = model(data)    
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step() 
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            running_correct += (predicted == labels.reshape(-1,)).sum().item()

            if (i+1) % acc_n_steps == 0:
                writer.add_scalar('training loss', running_loss / acc_n_steps, epoch * n_total_steps + i)
                running_accuracy = running_correct / acc_n_steps / predicted.size(0) * 100
                writer.add_scalar('training accuracy', running_accuracy, epoch * n_total_steps + i)                        
                running_correct = 0
                running_loss = 0.0
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.6f}, Accuracy: {running_accuracy:.2f}')

            if epoch > 5 and loss.item() < min_loss:
                    min_loss = loss.item()
                    torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'min_loss': min_loss
                    }, model_path)
                    print(f'Saved checkpoint - loss: {min_loss:.6f}')

    print('Finished Training')

Epoch [1/100], Step [300/1800], Loss: 0.645213, Accuracy: 44.90
Epoch [1/100], Step [600/1800], Loss: 0.684014, Accuracy: 45.14
Epoch [1/100], Step [900/1800], Loss: 0.656388, Accuracy: 43.94
Epoch [1/100], Step [1200/1800], Loss: 0.652716, Accuracy: 44.69
Epoch [1/100], Step [1500/1800], Loss: 0.560026, Accuracy: 44.72
Epoch [1/100], Step [1800/1800], Loss: 0.589553, Accuracy: 44.38
Epoch [2/100], Step [300/1800], Loss: 0.563429, Accuracy: 44.44
Epoch [2/100], Step [600/1800], Loss: 0.576751, Accuracy: 45.45
Epoch [2/100], Step [900/1800], Loss: 0.553657, Accuracy: 44.42
Epoch [2/100], Step [1200/1800], Loss: 0.568537, Accuracy: 44.59
Epoch [2/100], Step [1500/1800], Loss: 0.616477, Accuracy: 44.36
Epoch [2/100], Step [1800/1800], Loss: 0.674468, Accuracy: 44.49
Epoch [3/100], Step [300/1800], Loss: 0.447895, Accuracy: 44.89
Epoch [3/100], Step [600/1800], Loss: 0.597282, Accuracy: 43.66
Epoch [3/100], Step [900/1800], Loss: 0.655064, Accuracy: 44.83
Epoch [3/100], Step [1200/1800], L

Epoch [20/100], Step [1500/1800], Loss: 0.456962, Accuracy: 44.59
Epoch [20/100], Step [1800/1800], Loss: 0.570227, Accuracy: 44.74
Epoch [21/100], Step [300/1800], Loss: 0.517955, Accuracy: 45.30
Epoch [21/100], Step [600/1800], Loss: 0.408738, Accuracy: 44.36
Epoch [21/100], Step [900/1800], Loss: 0.482150, Accuracy: 44.86
Epoch [21/100], Step [1200/1800], Loss: 0.534109, Accuracy: 43.99
Epoch [21/100], Step [1500/1800], Loss: 0.475088, Accuracy: 44.74
Epoch [21/100], Step [1800/1800], Loss: 0.348393, Accuracy: 44.49
Epoch [22/100], Step [300/1800], Loss: 0.484027, Accuracy: 43.88
Epoch [22/100], Step [600/1800], Loss: 0.579137, Accuracy: 44.44
Epoch [22/100], Step [900/1800], Loss: 0.381793, Accuracy: 44.23
Epoch [22/100], Step [1200/1800], Loss: 0.426502, Accuracy: 44.85
Epoch [22/100], Step [1500/1800], Loss: 0.534128, Accuracy: 45.29
Epoch [22/100], Step [1800/1800], Loss: 0.498062, Accuracy: 45.06
Epoch [23/100], Step [300/1800], Loss: 0.468666, Accuracy: 43.90
Epoch [23/100], S

Epoch [41/100], Step [600/1800], Loss: 0.306022, Accuracy: 45.14
Epoch [41/100], Step [900/1800], Loss: 0.473352, Accuracy: 44.11
Epoch [41/100], Step [1200/1800], Loss: 0.326478, Accuracy: 44.95
Epoch [41/100], Step [1500/1800], Loss: 0.338830, Accuracy: 44.49
Epoch [41/100], Step [1800/1800], Loss: 0.312074, Accuracy: 44.62
Epoch [42/100], Step [300/1800], Loss: 0.245071, Accuracy: 44.97
Epoch [42/100], Step [600/1800], Loss: 0.262424, Accuracy: 44.74
Epoch [42/100], Step [900/1800], Loss: 0.364092, Accuracy: 44.11
Epoch [42/100], Step [1200/1800], Loss: 0.398532, Accuracy: 44.79
Epoch [42/100], Step [1500/1800], Loss: 0.345310, Accuracy: 45.01
Epoch [42/100], Step [1800/1800], Loss: 0.279924, Accuracy: 44.12
Epoch [43/100], Step [300/1800], Loss: 0.272869, Accuracy: 44.62
Epoch [43/100], Step [600/1800], Loss: 0.298556, Accuracy: 44.38
Epoch [43/100], Step [900/1800], Loss: 0.506817, Accuracy: 44.50
Epoch [43/100], Step [1200/1800], Loss: 0.347194, Accuracy: 44.68
Epoch [43/100], St

Epoch [61/100], Step [1800/1800], Loss: 0.366447, Accuracy: 45.05
Epoch [62/100], Step [300/1800], Loss: 0.241882, Accuracy: 44.27
Epoch [62/100], Step [600/1800], Loss: 0.350180, Accuracy: 45.09
Epoch [62/100], Step [900/1800], Loss: 0.347775, Accuracy: 44.64
Epoch [62/100], Step [1200/1800], Loss: 0.269960, Accuracy: 45.03
Epoch [62/100], Step [1500/1800], Loss: 0.316176, Accuracy: 44.36
Epoch [62/100], Step [1800/1800], Loss: 0.233170, Accuracy: 44.35
Epoch [63/100], Step [300/1800], Loss: 0.201728, Accuracy: 45.23
Epoch [63/100], Step [600/1800], Loss: 0.459712, Accuracy: 44.62
Epoch [63/100], Step [900/1800], Loss: 0.327582, Accuracy: 44.20
Epoch [63/100], Step [1200/1800], Loss: 0.457190, Accuracy: 44.33
Epoch [63/100], Step [1500/1800], Loss: 0.234290, Accuracy: 44.71
Epoch [63/100], Step [1800/1800], Loss: 0.382316, Accuracy: 44.66
Epoch [64/100], Step [300/1800], Loss: 0.214918, Accuracy: 44.70
Epoch [64/100], Step [600/1800], Loss: 0.255089, Accuracy: 45.29
Epoch [64/100], St

Epoch [82/100], Step [1500/1800], Loss: 0.281611, Accuracy: 44.70
Epoch [82/100], Step [1800/1800], Loss: 0.235606, Accuracy: 44.69
Epoch [83/100], Step [300/1800], Loss: 0.150297, Accuracy: 45.19
Epoch [83/100], Step [600/1800], Loss: 0.334760, Accuracy: 44.85
Epoch [83/100], Step [900/1800], Loss: 0.232791, Accuracy: 44.03
Epoch [83/100], Step [1200/1800], Loss: 0.254706, Accuracy: 43.86
Epoch [83/100], Step [1500/1800], Loss: 0.060684, Accuracy: 44.55
Epoch [83/100], Step [1800/1800], Loss: 0.399012, Accuracy: 45.26
Epoch [84/100], Step [300/1800], Loss: 0.131903, Accuracy: 44.58
Epoch [84/100], Step [600/1800], Loss: 0.126247, Accuracy: 43.86
Epoch [84/100], Step [900/1800], Loss: 0.200978, Accuracy: 45.34
Epoch [84/100], Step [1200/1800], Loss: 0.262835, Accuracy: 44.53
Epoch [84/100], Step [1500/1800], Loss: 0.472915, Accuracy: 44.42
Epoch [84/100], Step [1800/1800], Loss: 0.219541, Accuracy: 45.01
Epoch [85/100], Step [300/1800], Loss: 0.217073, Accuracy: 45.23
Epoch [85/100], S

<IPython.core.display.Javascript object>

In [62]:
writer.close()

## Test Accuracy

In [63]:
load_model = True

In [64]:
chosen_n_classes = 1

In [65]:
with torch.no_grad():
    model.eval()  
    if load_model:
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        
    n_correct = 0
    n_samples = 0

    for data, labels in test_dataloader:
        data = data.to(device).float().permute(0, 2, 1)
        labels = labels.to(device).float().reshape(-1,)

        outputs = model(data)

        # torch.max returns (value, index)
        _, predicted = torch.max(outputs, 1)

        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

acc = n_correct / n_samples * 100.0 
print(f'Accuracy of the network: {acc:.2f}%')

Accuracy of the network: 44.88%
