In [None]:
import scipy.io as scio
import pandas as pd
import os
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import interpolate
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.utils.data as Data
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import einsum
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold, KFold, LeaveOneGroupOut
import copy
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
#from sklearn import preprocessing
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from tqdm import tqdm, trange

from ConLoss import SupConLoss
import random
import Module as md

In [None]:
def seed_it(seed):
    random.seed(seed) 
    os.environ["PYTHONSEED"] = str(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True 
    torch.manual_seed(seed)
seed = 123
seed_it(seed)

In [None]:
'''
load data
'''
with open('.//data_all.pkl', 'rb') as file:
    data_all = pickle.load(file)
eeg_data = data_all['eeg_data']
emo_label = data_all['emo_label']
task_label = data_all['task_label']
group = data_all['group']

eeg_data.shape, emo_label.shape, task_label.shape, group.shape

In [None]:
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda:0")
kf = KFold(n_splits=10, shuffle=True, random_state=seed)
logo = LeaveOneGroupOut()

In [None]:
def train(dim, epoch_size, learning_rate, save=True):
    test_emo = []
    test_task = []
    
    for k, (train, test) in enumerate(kf.split(eeg_data, emo_label)):

        """ Build Network """
        model = md.model_1(token_dim=dim, out_put='pred').to(device)
        
        """ Optimizer """
        parameters = model.parameters()
        optimizer = torch.optim.Adam(parameters, lr=learning_rate, weight_decay=0.0005)
        learning_rate = learning_rate * 0.99

        """ Load data """
        print('*'*10, '{}-fold'.format(k+1), '*'*10)
        train_set = TensorDataset(eeg_data[train], emo_label[train], task_label[train])
        test_set = TensorDataset(eeg_data[test], emo_label[test], task_label[test])
        train_loader = Data.DataLoader(train_set, batch_size=64)
        test_loader = Data.DataLoader(test_set, batch_size=1)


        for i in range(epoch_size):                                         
            loop = tqdm(enumerate(train_loader), total=len(train_loader))
            model.train()


            train_loss = 0.0
            train_acc_task = 0.0
            for step, (x, y1, y2) in loop:
                x, y1, y2 =  Variable(x).to(device), Variable(y1).to(device),  Variable(y2).to(device)
                optimizer.zero_grad()
                pred_task= model(x)

                loss = criterion(pred_task, y2.long())

  
                train_loss += loss.item()

                pred_task = torch.max(pred_task, 1)[1]
                train_correct_task = (pred_task == y2).sum()

                train_acc_task += train_correct_task.item()

                loss.backward()
                optimizer.step()
                loop.set_description(f'Epoch [{i+1} / {epoch_size}]')
                loop.set_postfix({
                        'loss' : '{:.6f}'.format(train_loss/len(train_set)),
                        'acc_task' : '{:.6f}'.format(train_acc_task*100/len(train_set))
                                                    })

                if i+1 == epoch_size and save == True:   
                    model_path = './model_parameter/model1_dim=%s' % (dim)  
                    os.makedirs(model_path, exist_ok=True)   
                    pkl_name ='KFold=%s.pkl' % (k+1) 
                    state = {'model':model.state_dict()
                            }
                    torch.save(state, os.path.join(model_path, pkl_name))
        test_loss_all = 0.0
        test_loss = 0.0
        test_acc_emo = 0.0
        test_acc_task = 0.0
        task_loss = 0.0
        with torch.no_grad():
            model.eval()
            for x, y1, y2 in test_loader:
                x, y1, y2 = Variable(x).to(device), Variable(y1).to(device), Variable(y2).to(device)

                pred_task = model(x)
                loss = criterion(pred_task, y2.long())
                test_loss_all += loss.item()

                pred_test_task = torch.max(pred_task, 1)[1]
                test_correct_task = (pred_test_task == y2).sum()
                test_acc_task += test_correct_task.item()

            print(
                'Test Loss: {:.6f},  Test Acc: {:.6f}'.format(test_loss_all / (len(test_set)), test_acc_task * 100 / (len(test_set)))
                )


        test_task.append(test_acc_task * 100 / (len(test_set)))
      
        if k+1 == 10 and save == True:  
            np.save(os.path.join(model_path, 'result'), test_task)
    return test_task

In [None]:
result = train(dim=128, epoch_size=50, 
               learning_rate=0.001,
               save=True
              )