In [50]:
import scipy.io as scio
import pandas as pd
import os
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import interpolate
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.utils.data as Data
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import einsum
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold, KFold, LeaveOneGroupOut
import copy
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
#from sklearn import preprocessing
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from tqdm import tqdm, trange

from ConLoss import SupConLoss
import random
import Module as md
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, accuracy_score



In [51]:
def seed_it(seed):
    random.seed(seed) 
    os.environ["PYTHONSEED"] = str(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = True 
    torch.backends.cudnn.enabled = True 
    torch.manual_seed(seed)
seed = 123
seed_it(seed)

In [52]:
'''
load data
'''
with open('.//data_all.pkl', 'rb') as file:
    data_all = pickle.load(file)
eeg_data = data_all['eeg_data']
emo_label = data_all['emo_label']
task_label = data_all['task_label']
group = data_all['group']

eeg_data.shape, emo_label.shape, task_label.shape, group.shape

(torch.Size([1363, 1, 62, 425]),
 torch.Size([1363]),
 torch.Size([1363]),
 torch.Size([1363]))

In [53]:
temp=1
criterion1 = nn.CrossEntropyLoss()
device = torch.device("cuda:0")
criterion2 = SupConLoss(temperature=temp)
kf = KFold(n_splits=10, shuffle=True, random_state=seed)
logo = LeaveOneGroupOut()
device = torch.device("cuda:0")

In [54]:
load_path = './model_parameter'
model_list = os.listdir(load_path)

In [61]:
for model_name in model_list:  
    print(model_name)
    acc_all = []
    f1_all = []
    recall_all = []
    precision_all = []
    model_path = os.path.join(load_path, model_name)
    for k, (train, test) in enumerate(kf.split(eeg_data, emo_label)):
        pkl_name ='KFold=%s.pkl' % (k+1)  
        state_dict = torch.load(os.path.join(model_path, pkl_name))
        if model_name[5] == '1':
            model = md.model_1(token_dim=128, out_put='pred').to(device)
        if model_name[5] == '2':
            model = md.model_2(token_dim=128, out_put='pred').to(device)       
        if model_name[5] == '3' and model_name[-3:] != 'dim':
 
            model = md.model_3(token_dim=128, out_put='pred', device='cuda:0', GRL=False).to(device)        
        
        if model_name[5] == '3'  and model_name[-3:] == 'dim' and model_name[41:43] == '64':
            model = md.model_3(token_dim=64, out_put='pred', device='cuda:0', GRL=False).to(device)
            
        if model_name[5] == '3'  and model_name[-3:] == 'dim' and model_name[41:43] != '64':
            model = md.model_3(token_dim=int(model_name[41:44]), out_put='pred', device='cuda:0', GRL=False).to(device)
            
        model.load_state_dict(state_dict['model'])
        test_set = TensorDataset(eeg_data[test], emo_label[test], task_label[test])
        prob_all = []
        label_all = []
        test_loader = Data.DataLoader(test_set, batch_size=32)
        with torch.no_grad():
            model.eval()  
            for x, y1, y2 in test_loader:
                x, y1, y2 = Variable(x).to(device), Variable(y1).to(device), Variable(y2).to(device)

                pred_task = model(x)
                prob = pred_task.cpu().numpy()
                prob_all.extend(np.argmax(prob,axis=1)) #求每一行的最大值索引
                label_all.extend(y2.cpu().numpy())
            acc_now = accuracy_score(y_true=prob_all, y_pred=label_all)
            f1_now = f1_score(y_true=prob_all, y_pred=label_all, average="macro")
            recall_now = recall_score(y_true=prob_all, y_pred=label_all, average="macro")
            precision_now = precision_score(y_true=prob_all, y_pred=label_all, average="macro")
        acc_all.append(acc_now)
        f1_all.append(f1_now)
        recall_all.append(recall_now)
        precision_all.append(precision_now)
    
    print("-"*10, model_name, '-'*10)
    print('acc:', np.around(np.mean(acc_all)*100, 2), np.std(acc_all))

    print('precision:', np.around(np.mean(precision_all)*100, 2), np.std(precision_all))

    print('recall:', np.around(np.mean(recall_all)*100, 2), np.std(recall_all))
    print('f1:', np.around(np.mean(f1_all)*100, 2), np.std(f1_all))


model1_dim=128
---------- model1_dim=128 ----------
acc: 76.3 0.028886019747642207
precision: 73.88 0.0355498035095437
recall: 75.96 0.028419327268883222
f1: 74.08 0.034540620506842444
model2_cls=1_emo=0.05_task=0.05_dim=128_temp=1_contrast=all
---------- model2_cls=1_emo=0.05_task=0.05_dim=128_temp=1_contrast=all ----------
acc: 80.39 0.07448221980072922
precision: 76.8 0.09202802728173429
recall: 81.72 0.06673434799287611
f1: 77.29 0.09305800243363845
model2_cls=1_emo=0.5_task=0.5_dim=128_temp=1_contrast=all
---------- model2_cls=1_emo=0.5_task=0.5_dim=128_temp=1_contrast=all ----------
acc: 90.75 0.039254144317588016
precision: 89.12 0.04845801893079232
recall: 91.6 0.03474904163358188
f1: 89.87 0.044730928392923856
model2_cls=1_emo=0.5_task=0_dim=128_temp=1_contrast=emo
---------- model2_cls=1_emo=0.5_task=0_dim=128_temp=1_contrast=emo ----------
acc: 79.82 0.026656489868539513
precision: 76.7 0.03288338396104354
recall: 81.11 0.02255189991989765
f1: 77.36 0.03304142480193799
model

  _warn_prf(average, modifier, msg_start, len(result))


---------- model3_cls=1_emo=0.5_task=0.5_mi=0.1_dim=448_temp=1_GRL=False_dim ----------
acc: 68.24 0.08132567311057293
precision: 60.2 0.1040123554243482
recall: 68.17 0.1632122287471637
f1: 55.5 0.1465102008980704
model3_cls=1_emo=0.5_task=0.5_mi=0.1_dim=512_temp=1_GRL=False_dim


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


---------- model3_cls=1_emo=0.5_task=0.5_mi=0.1_dim=512_temp=1_GRL=False_dim ----------
acc: 65.15 0.058292664824009126
precision: 56.04 0.06465739365156797
recall: 57.04 0.2243810075299224
f1: 49.12 0.11993623068494046
model3_cls=1_emo=0.5_task=0.5_mi=0.1_dim=64_temp=1_GRL=False_dim
---------- model3_cls=1_emo=0.5_task=0.5_mi=0.1_dim=64_temp=1_GRL=False_dim ----------
acc: 91.05 0.03888188265104451
precision: 89.07 0.04913916013939472
recall: 92.7 0.02874311481583811
f1: 90.12 0.04391686059624269
model3_cls=1_emo=0.5_task=0.5_mi=10_dim=128_temp=1_GRL=False
---------- model3_cls=1_emo=0.5_task=0.5_mi=10_dim=128_temp=1_GRL=False ----------
acc: 84.59 0.04706303996228171
precision: 81.65 0.056440831569555316
recall: 86.98 0.031910365055666405
f1: 82.55 0.05581550686829044
model3_cls=1_emo=0.5_task=0.5_mi=1_dim=128_temp=1_GRL=False
---------- model3_cls=1_emo=0.5_task=0.5_mi=1_dim=128_temp=1_GRL=False ----------
acc: 91.05 0.03290574478823793
precision: 89.13 0.03747825726388325
recall: 9