In [2]:
import os
import random
import torch
import numpy as np
import pickle as pkl
from analysis import *
import argparse
from sys import platform
%load_ext autoreload
%autoreload 2

In [3]:
torch.manual_seed(12)
torch.cuda.manual_seed(12)
np.random.seed(12)
torch.backends.cudnn.deterministics = True
torch.set_num_threads(1)

In [4]:
import sys
import os
import random
import math
import time
import torch; torch.utils.backcompat.broadcast_warning.enabled = True
from torchvision import transforms, datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
import numpy as np

In [5]:
import sys
import os
import random
import math
import time
import torch; torch.utils.backcompat.broadcast_warning.enabled = True
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.backends.cudnn as cudnn; cudnn.benchmark = True
from scipy.fftpack import fft, rfft, fftfreq, irfft, ifft, rfftfreq
from scipy import signal
import numpy as np
import importlib

In [6]:
class Model(nn.Module):

    def __init__(self, input_size=128, lstm_size=128, lstm_layers=1, output_size=128):
        # Call parent
        super().__init__()
        # Define parameters
        self.input_size = input_size
        self.lstm_size = lstm_size
        self.lstm_layers = lstm_layers
        self.output_size = output_size

        # Define internal modules
        self.lstm = nn.LSTM(input_size, lstm_size, num_layers=lstm_layers, batch_first=True)
        self.output = nn.Linear(lstm_size, output_size)
        self.classifier = nn.Linear(output_size,40)
        
    def forward(self, x):
        # Prepare LSTM initiale state
        batch_size = x.size(0)
        lstm_init = (torch.zeros(self.lstm_layers, batch_size, self.lstm_size), torch.zeros(self.lstm_layers, batch_size, self.lstm_size))
        if x.is_cuda: lstm_init = (lstm_init[0].cuda(), lstm_init[0].cuda())
        lstm_init = (Variable(lstm_init[0], volatile=x.volatile), Variable(lstm_init[1], volatile=x.volatile))

        # Forward LSTM and get final state
        x = self.lstm(x, lstm_init)[0][:,-1,:]
        
        # Forward output
        x = F.relu(self.output(x))
        x = self.classifier((x))
        return x

In [7]:
# Define options
import argparse
parser = argparse.ArgumentParser(description="Template")
# Dataset options

#Data - Data needs to be pre-filtered and filtered data is available

### BLOCK DESIGN ###
#Data
#parser.add_argument('-ed', '--eeg-dataset', default=r"data\block\eeg_55_95_std.pth", help="EEG dataset path") #55-95Hz
parser.add_argument('-ed', '--eeg-dataset', default=r"/media/mountHDD1/LanxHuyen/CVPR2017/eeg_5_95_std.pth", help="EEG dataset path") #5-95Hz
#parser.add_argument('-ed', '--eeg-dataset', default=r"data\block\eeg_14_70_std.pth", help="EEG dataset path") #14-70Hz
#Splits
parser.add_argument('-sp', '--splits-path', default=r"/media/mountHDD1/LanxHuyen/CVPR2017/block_splits_by_image_all.pth", help="splits path") #All subjects
#parser.add_argument('-sp', '--splits-path', default=r"data\block\block_splits_by_image_single.pth", help="splits path") #Single subject
### BLOCK DESIGN ###

parser.add_argument('-sn', '--split-num', default=0, type=int, help="split number") #leave this always to zero.

#Subject selecting
parser.add_argument('-sub','--subject', default=0   , type=int, help="choose a subject from 1 to 6, default is 0 (all subjects)")

#Time options: select from 20 to 460 samples from EEG data
parser.add_argument('-tl', '--time_low', default=20, type=float, help="lowest time value")
parser.add_argument('-th', '--time_high', default=460,  type=float, help="highest time value")

# Model type/options
parser.add_argument('-mt','--model_type', default='lstm', help='specify which generator should be used: lstm|EEGChannelNet')
# It is possible to test out multiple deep classifiers:
# - lstm is the model described in the paper "Deep Learning Human Mind for Automated Visual Classification”, in CVPR 2017
# - model10 is the model described in the paper "Decoding brain representations by multimodal learning of neural activity and visual features", TPAMI 2020
parser.add_argument('-mp','--model_params', default='', nargs='*', help='list of key=value pairs of model options')
parser.add_argument('--pretrained_net', default='', help="path to pre-trained net (to continue training)")

# Training options
parser.add_argument("-b", "--batch_size", default=16, type=int, help="batch size")
parser.add_argument('-o', '--optim', default="Adam", help="optimizer")
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, help="learning rate")
parser.add_argument('-lrdb', '--learning-rate-decay-by', default=0.5, type=float, help="learning rate decay factor")
parser.add_argument('-lrde', '--learning-rate-decay-every', default=10, type=int, help="learning rate decay period")
parser.add_argument('-dw', '--data-workers', default=4, type=int, help="data loading workers")
parser.add_argument('-e', '--epochs', default=200, type=int, help="training epochs")

# Save options
parser.add_argument('-sc', '--saveCheck', default=100, type=int, help="learning rate")

# Backend options
parser.add_argument('--no-cuda', default=False, help="disable CUDA", action="store_true")

# Parse arguments
opt, unknown = parser.parse_known_args()
print(opt)


Namespace(eeg_dataset='/media/mountHDD1/LanxHuyen/CVPR2017/eeg_5_95_std.pth', splits_path='/media/mountHDD1/LanxHuyen/CVPR2017/block_splits_by_image_all.pth', split_num=0, subject=0, time_low=20, time_high=460, model_type='lstm', model_params='', pretrained_net='', batch_size=16, optim='Adam', learning_rate=0.001, learning_rate_decay_by=0.5, learning_rate_decay_every=10, data_workers=4, epochs=200, saveCheck=100, no_cuda=False)


In [8]:
# Dataset class
class EEGDataset:
    
    # Constructor
    def __init__(self, eeg_signals_path):
        # Load EEG signals
        loaded = torch.load(eeg_signals_path)
        if opt.subject!=0:
            self.data = [loaded['dataset'][i] for i in range(len(loaded['dataset']) ) if loaded['dataset'][i]['subject']==opt.subject]
        else:
            self.data=loaded['dataset']        
        self.labels = loaded["labels"]
        self.images = loaded["images"]
        
        # Compute size
        self.size = len(self.data)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Process EEG
        eeg = self.data[i]["eeg"].float().t()
        eeg = eeg[opt.time_low:opt.time_high,:]

        if opt.model_type == "model10":
            eeg = eeg.t()
            eeg = eeg.view(1,128,opt.time_high-opt.time_low)
        # Get label
        label = self.data[i]["label"]
        # Return
        return eeg, label

# Splitter class
class Splitter:

    def __init__(self, dataset, split_path, split_num=0, split_name="train"):
        # Set EEG dataset
        self.dataset = dataset
        # Load split
        loaded = torch.load(split_path)
        self.split_idx = loaded["splits"][split_num][split_name]
        # Filter data
        self.split_idx = [i for i in self.split_idx if 450 <= self.dataset.data[i]["eeg"].size(1) <= 600]
        # Compute size
        self.size = len(self.split_idx)

    # Get size
    def __len__(self):
        return self.size

    # Get item
    def __getitem__(self, i):
        # Get sample from dataset
        eeg, label = self.dataset[self.split_idx[i]]
        # Return
        return eeg, label


In [9]:
# Load dataset
dataset = EEGDataset(opt.eeg_dataset)
# Create loaders
loaders = {split: DataLoader(Splitter(dataset, split_path = opt.splits_path, split_num = opt.split_num, split_name = split), batch_size = opt.batch_size, drop_last = True, shuffle = True) for split in ["train", "val", "test"]}

In [12]:
# Load model

model_options = {key: int(value) if value.isdigit() else (float(value) if value[0].isdigit() else value) for (key, value) in [x.split("=") for x in opt.model_params]}
# Create discriminator model/optimizer
# module = importlib.import_module(Model() + opt.model_type)
model = Model(**model_options)
optimizer = getattr(torch.optim, opt.optim)(model.parameters(), lr = opt.learning_rate)

In [13]:
# Setup CUDA
if not opt.no_cuda:
    model.cuda()
    print("Copied to CUDA")

if opt.pretrained_net != '':
        model = torch.load(opt.pretrained_net)
        print(model)

Copied to CUDA


In [14]:
#initialize training,validation, test losses and accuracy list
losses_per_epoch={"train":[], "val":[],"test":[]}
accuracies_per_epoch={"train":[],"val":[],"test":[]}

best_accuracy = 0
best_accuracy_val = 0
best_epoch = 0

In [15]:
for epoch in range(1, opt.epochs+1):
    # Initialize loss/accuracy variables
    losses = {"train": 0, "val": 0, "test": 0}
    accuracies = {"train": 0, "val": 0, "test": 0}
    counts = {"train": 0, "val": 0, "test": 0}
    # Adjust learning rate for SGD
    if opt.optim == "SGD":
        lr = opt.learning_rate * (opt.learning_rate_decay_by ** (epoch // opt.learning_rate_decay_every))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    # Process each split
    for split in ("train", "val", "test"):
        # Set network mode
        if split == "train":
            model.train()
            torch.set_grad_enabled(True)
        else:
            model.eval()
            torch.set_grad_enabled(False)
        # Process all split batches
        for i, (input, target) in enumerate(loaders[split]):
            # Check CUDA
            if not opt.no_cuda:
                input = input.to("cuda") 
                target = target.to("cuda") 
            # Forward
            output = model(input)

            # Compute loss
            loss = F.cross_entropy(output, target)
            losses[split] += loss.item()
            # Compute accuracy
            _,pred = output.data.max(1)
            correct = pred.eq(target.data).sum().item()
            accuracy = correct/input.data.size(0)   
            accuracies[split] += accuracy
            counts[split] += 1
            # Backward and optimize
            if split == "train":
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    # Print info at the end of the epoch
    if accuracies["val"]/counts["val"] >= best_accuracy_val:
        best_accuracy_val = accuracies["val"]/counts["val"]
        best_accuracy = accuracies["test"]/counts["test"]
        best_epoch = epoch
    
    TrL,TrA,VL,VA,TeL,TeA=  losses["train"]/counts["train"],accuracies["train"]/counts["train"],losses["val"]/counts["val"],accuracies["val"]/counts["val"],losses["test"]/counts["test"],accuracies["test"]/counts["test"]
    print("Model: {11} - Subject {12} - Time interval: [{9}-{10}]  [{9}-{10} Hz] - Epoch {0}: TrL={1:.4f}, TrA={2:.4f}, VL={3:.4f}, VA={4:.4f}, TeL={5:.4f}, TeA={6:.4f}, TeA at max VA = {7:.4f} at epoch {8:d}".format(epoch,
                                                                                                         losses["train"]/counts["train"],
                                                                                                         accuracies["train"]/counts["train"],
                                                                                                         losses["val"]/counts["val"],
                                                                                                         accuracies["val"]/counts["val"],
                                                                                                         losses["test"]/counts["test"],
                                                                                                         accuracies["test"]/counts["test"],
                                                                                                         best_accuracy, best_epoch, opt.time_low,opt.time_high, opt.model_type,opt.subject))

    losses_per_epoch['train'].append(TrL)
    losses_per_epoch['val'].append(VL)
    losses_per_epoch['test'].append(TeL)
    accuracies_per_epoch['train'].append(TrA)
    accuracies_per_epoch['val'].append(VA)
    accuracies_per_epoch['test'].append(TeA)

    if epoch%opt.saveCheck == 0:
                torch.save(model, '%s__subject%d_epoch_%d.pth' % (opt.model_type, opt.subject,epoch))

  lstm_init = (Variable(lstm_init[0], volatile=x.volatile), Variable(lstm_init[1], volatile=x.volatile))


Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 1: TrL=3.6423, TrA=0.0389, VL=3.5429, VA=0.0393, TeL=3.5627, TeA=0.0413, TeA at max VA = 0.0413 at epoch 1
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 2: TrL=3.3629, TrA=0.0638, VL=3.2920, VA=0.0635, TeL=3.3130, TeA=0.0685, TeA at max VA = 0.0685 at epoch 2
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 3: TrL=3.2217, TrA=0.0869, VL=3.1831, VA=0.0781, TeL=3.1850, TeA=0.0811, TeA at max VA = 0.0811 at epoch 3
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 4: TrL=3.0164, TrA=0.1132, VL=3.7909, VA=0.0610, TeL=3.7777, TeA=0.0640, TeA at max VA = 0.0811 at epoch 3
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 5: TrL=2.9806, TrA=0.1246, VL=2.9468, VA=0.1114, TeL=2.9754, TeA=0.1023, TeA at max VA = 0.1023 at epoch 5
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 6: TrL=2.7351, TrA=0.1514, VL=3.0547, V

Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 47: TrL=0.2630, TrA=0.9175, VL=6.6996, VA=0.1457, TeL=6.8857, TeA=0.1452, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 48: TrL=0.1993, TrA=0.9425, VL=6.9295, VA=0.1416, TeL=7.0168, TeA=0.1316, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 49: TrL=0.1535, TrA=0.9584, VL=6.9140, VA=0.1436, TeL=7.0561, TeA=0.1436, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 50: TrL=0.1520, TrA=0.9604, VL=7.1537, VA=0.1447, TeL=7.2634, TeA=0.1431, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 51: TrL=0.2618, TrA=0.9186, VL=7.1430, VA=0.1386, TeL=7.2548, TeA=0.1391, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 52: TrL=0.2389, TrA=0.9247, V

Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 93: TrL=0.0559, TrA=0.9848, VL=8.6082, VA=0.1265, TeL=8.8111, TeA=0.1487, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 94: TrL=0.0284, TrA=0.9947, VL=8.6509, VA=0.1351, TeL=8.7879, TeA=0.1452, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 95: TrL=0.0114, TrA=0.9987, VL=8.7003, VA=0.1310, TeL=8.8821, TeA=0.1502, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 96: TrL=0.0085, TrA=0.9989, VL=8.7720, VA=0.1295, TeL=8.8988, TeA=0.1487, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 97: TrL=0.4235, TrA=0.8916, VL=8.8979, VA=0.1295, TeL=9.0343, TeA=0.1290, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 98: TrL=0.2866, TrA=0.9111, V

Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 139: TrL=0.0430, TrA=0.9883, VL=9.2689, VA=0.1305, TeL=9.3169, TeA=0.1442, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 140: TrL=0.0698, TrA=0.9808, VL=9.4360, VA=0.1316, TeL=9.5427, TeA=0.1426, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 141: TrL=0.1438, TrA=0.9588, VL=9.3301, VA=0.1391, TeL=9.4813, TeA=0.1452, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 142: TrL=0.1664, TrA=0.9481, VL=9.5081, VA=0.1265, TeL=9.6935, TeA=0.1310, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 143: TrL=0.1258, TrA=0.9613, VL=9.4002, VA=0.1310, TeL=9.4627, TeA=0.1361, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 144: TrL=0.0774, TrA=0.9

Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 185: TrL=0.1172, TrA=0.9660, VL=10.4261, VA=0.1265, TeL=10.5640, TeA=0.1255, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 186: TrL=0.1866, TrA=0.9428, VL=10.2426, VA=0.1200, TeL=10.3379, TeA=0.1361, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 187: TrL=0.0673, TrA=0.9794, VL=10.2413, VA=0.1285, TeL=10.2222, TeA=0.1452, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 188: TrL=0.0650, TrA=0.9795, VL=10.3894, VA=0.1235, TeL=10.3763, TeA=0.1416, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 189: TrL=0.0486, TrA=0.9855, VL=10.3924, VA=0.1169, TeL=10.3910, TeA=0.1371, TeA at max VA = 0.1452 at epoch 19
Model: lstm - Subject 0 - Time interval: [20-460]  [20-460 Hz] - Epoch 190: TrL=0.044