In [20]:
import sys
import os

module_path = os.path.abspath(os.path.join("./EEG-to-Text-Project.git"))
if module_path not in sys.path:
    sys.path.append(module_path)
    
sys.path

custome_cache_dir = "/home/linal20/transformers_cache"

In [27]:
# custome_cache_dir = "/home/linal20/transformers_cache"
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pack_padded_sequence 
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm

from transformers import BertTokenizer, BertLMHeadModel, BertConfig
from data import ZuCo_dataset
from model_sentiment import BaselineMLPSentence, BaselineLSTM

# Function to calculate the accuracy
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()  
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [36]:
##### training sentiment baseline models #####
def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints2/eeg_sentiment/best/test.pt', checkpoint_path_last = './checkpoints2/eeg_sentiment/last/test.pt'):
    since = time.time()
      
    best_weights = copy.deepcopy(model.state_dict())
    best_loss = 100000000000
    best_acc = 0.0
    

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'dev']:
            total_accuracy = 0.0
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0

            # Iterate over data.
            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]):
                
                input_word_eeg_features = input_word_eeg_features.to(device).float()
                sent_level_EEG = sent_level_EEG.to(device)
                input_masks = input_masks.to(device)
                sentiment_labels = sentiment_labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                if isinstance(model, BaselineMLPSentence):
                    # forward
                    logits = model(sent_level_EEG) # before softmax
                    # calculate loss
                    loss = criterion(logits, sentiment_labels)
                
                elif isinstance(model, BaselineLSTM):
                    x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False)
                    logits = model(x_packed)
                    # calculate loss
                    loss = criterion(logits, sentiment_labels)
                

                # backward & optimize only if in training phase
                if phase == 'train':
                    # with torch.autograd.detect_anomaly():
                    loss.backward()
                    optimizer.step()

                # calculate accuracy
                preds_cpu = logits.detach().cpu().numpy()
                label_cpu = sentiment_labels.cpu().numpy()

                total_accuracy += flat_accuracy(preds_cpu, label_cpu)

                # statistics
                running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = total_accuracy / len(dataloaders[phase])
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            print('{} Acc: {:.4f}'.format(phase, epoch_acc))
         
            if phase == 'dev' and (epoch_acc > best_acc):
                best_loss = epoch_loss
                best_acc = epoch_acc
                best_weights = copy.deepcopy(model.state_dict())
                '''save checkpoint'''
                torch.save(model.state_dict(), checkpoint_path_best)
                print(f'update best on dev checkpoint: {checkpoint_path_best}')
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))
    print('Best val acc: {:4f}'.format(best_acc))
    torch.save(model.state_dict(), checkpoint_path_last)
    print(f'update last checkpoint: {checkpoint_path_last}')

    # load best model weights
    model.load_state_dict(best_weights)
    return model
# import os
# os.environ["TRANSFORMERS_CACHE"] = "/home/linal20/NLP_final_project"
# python3 train_sentiment_baseline.py --model_name BaselineMLP --num_epoch 20 -lr 0.00005 -b 32 -s ./checkpoints/eeg_sentiment -cuda cuda:0

if __name__ == '__main__':

    
    #set param
    num_epochs = 20
    step_lr = 0.00005
    dataset_setting = 'unique_sent'

    subject_choice = 'ALL'
    print(f'![Debug]using {subject_choice}')
    eeg_type_choice = 'GD'
    print(f'[INFO]eeg type {eeg_type_choice}')
    bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 
    print(f'[INFO]using bands {bands_choice}')
    
    # model name
    model_name = 'BaselineMLP'
    # model_name = 'BaselineLSTM'


    batch_size = 32
    save_path = './checkpoints/eeg_sentiment'
    save_name = f'{model_name}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}'

    if model_name == 'BaselineLSTM':
        num_layers = 4
        save_name = f'{model_name}_numLayers-{num_layers}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}'

    output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' 
    output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' 


    
    # random seeds 
    seed_val = 312
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)


    # set up device
    # use cuda
    if torch.cuda.is_available():  
        # dev = args['cuda']
        dev = 'cuda:0'
    else:  
        dev = "cpu"
    # CUDA_VISIBLE_DEVICES=0,1,2,3  
    device = torch.device(dev)
    print(f'[INFO]using device {dev}')


    # load pickle
    whole_dataset_dict = []
    # dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' 
    dataset_path_task1 ='./dataset/processed/Task1_SR_processed.pickle'
    with open(dataset_path_task1, 'rb') as handle:
        whole_dataset_dict.append(pickle.load(handle))
    
    # tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased',cache_dir=custome_cache_dir )

    # set up dataloader 
    # train dataset
    train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
    # dev dataset
    dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)

    dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}
    print('[INFO]train_set size: ', len(train_set))
    print('[INFO]dev_set size: ', len(dev_set))
    
    # train dataloader
    train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)
    # dev dataloader
    val_dataloader = DataLoader(dev_set, batch_size = batch_size, shuffle=False, num_workers=4)
    # dataloaders
    dataloaders = {'train':train_dataloader, 'dev':val_dataloader}

    ''' set up model '''
    if model_name == 'BaselineMLP':
        print('[INFO]Model: BaselineMLP')
        model = BaselineMLPSentence(input_dim = 105*len(bands_choice), hidden_dim = 128, output_dim = 3)
    elif model_name == 'BaselineLSTM':
        print('[INFO]Model: BaselineLSTM')
        model = BaselineLSTM(input_dim = 105*len(bands_choice), hidden_dim = 256, output_dim = 3, num_layers = num_layers)

    model.to(device)
    


    
    ### training loop ### 

    # set up optimizer and scheduler
    optimizer_step1 = optim.SGD(model.parameters(), lr=step_lr, momentum=0.9)
    exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.5)

    # loss function 
    criterion = nn.CrossEntropyLoss()

    print('=== start training  ===')
    # return best loss model from step1 training
    model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)

![Debug]using ALL
[INFO]eeg type GD
[INFO]using bands ['_t1', '_t2', '_a1', '_a2', '_b1', '_b2', '_g1', '_g2']
[INFO]using device cuda:0
[INFO]loading 1 task datasets
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train divider = 320
dev divider = 360
[INFO]initializing a train set...
discard length zero instance:  Weiss and Speck never make a convincing case for the relevance of these two 20th-century footnotes.
discard length zero instance:  Reassuring, retro uplifter.
discard length zero instance:  Flaccid drama and exasperatingly slow journey.
++ adding task to dataset, now we have: 3609
[INFO]input tensor size: torch.Size([56, 840])

[INFO]loading 1 task datasets
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train divider = 320
dev divider = 360
[INFO]initializing a dev set...
discard length zero instance:  Gollum's `performance' is incredible!
++ adding task 

100%|██████████| 113/113 [00:01<00:00, 69.28it/s]


train Loss: 1.1012
train Acc: 0.3509


100%|██████████| 15/15 [00:00<00:00, 25.03it/s]


dev Loss: 1.1024
dev Acc: 0.2940
update best on dev checkpoint: ./checkpoints/eeg_sentiment/best/BaselineMLP_5e-05_b32_unique_sent_GD.pt

Epoch 1/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.45it/s]


train Loss: 1.1007
train Acc: 0.3529


100%|██████████| 15/15 [00:00<00:00, 24.56it/s]


dev Loss: 1.1006
dev Acc: 0.2940

Epoch 2/19
----------


100%|██████████| 113/113 [00:01<00:00, 68.54it/s]


train Loss: 1.0998
train Acc: 0.3529


100%|██████████| 15/15 [00:00<00:00, 24.32it/s]


dev Loss: 1.0989
dev Acc: 0.2898

Epoch 3/19
----------


100%|██████████| 113/113 [00:01<00:00, 68.36it/s]


train Loss: 1.1001
train Acc: 0.3550


100%|██████████| 15/15 [00:00<00:00, 25.14it/s]


dev Loss: 1.0983
dev Acc: 0.2856

Epoch 4/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.76it/s]


train Loss: 1.1008
train Acc: 0.3399


100%|██████████| 15/15 [00:00<00:00, 24.50it/s]


dev Loss: 1.0979
dev Acc: 0.2836

Epoch 5/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.48it/s]


train Loss: 1.0987
train Acc: 0.3469


100%|██████████| 15/15 [00:00<00:00, 24.35it/s]


dev Loss: 1.0976
dev Acc: 0.2856

Epoch 6/19
----------


100%|██████████| 113/113 [00:01<00:00, 66.87it/s]


train Loss: 1.0995
train Acc: 0.3374


100%|██████████| 15/15 [00:00<00:00, 24.78it/s]


dev Loss: 1.0969
dev Acc: 0.2794

Epoch 7/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.77it/s]


train Loss: 1.1008
train Acc: 0.3311


100%|██████████| 15/15 [00:00<00:00, 24.76it/s]


dev Loss: 1.0961
dev Acc: 0.2815

Epoch 8/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.59it/s]


train Loss: 1.1000
train Acc: 0.3429


100%|██████████| 15/15 [00:00<00:00, 24.47it/s]


dev Loss: 1.0967
dev Acc: 0.2877

Epoch 9/19
----------


100%|██████████| 113/113 [00:01<00:00, 66.26it/s]


train Loss: 1.0995
train Acc: 0.3407


100%|██████████| 15/15 [00:00<00:00, 24.69it/s]


dev Loss: 1.0960
dev Acc: 0.2877

Epoch 10/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.77it/s]


train Loss: 1.0991
train Acc: 0.3394


100%|██████████| 15/15 [00:00<00:00, 24.00it/s]


dev Loss: 1.0959
dev Acc: 0.2856

Epoch 11/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.72it/s]


train Loss: 1.0973
train Acc: 0.3511


100%|██████████| 15/15 [00:00<00:00, 24.27it/s]


dev Loss: 1.0959
dev Acc: 0.2856

Epoch 12/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.88it/s]


train Loss: 1.0987
train Acc: 0.3522


100%|██████████| 15/15 [00:00<00:00, 23.73it/s]


dev Loss: 1.0960
dev Acc: 0.2919

Epoch 13/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.33it/s]


train Loss: 1.0982
train Acc: 0.3434


100%|██████████| 15/15 [00:00<00:00, 24.92it/s]


dev Loss: 1.0958
dev Acc: 0.2940

Epoch 14/19
----------


100%|██████████| 113/113 [00:01<00:00, 68.04it/s]


train Loss: 1.0997
train Acc: 0.3358


100%|██████████| 15/15 [00:00<00:00, 24.01it/s]


dev Loss: 1.0959
dev Acc: 0.2981
update best on dev checkpoint: ./checkpoints/eeg_sentiment/best/BaselineMLP_5e-05_b32_unique_sent_GD.pt

Epoch 15/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.24it/s]


train Loss: 1.0976
train Acc: 0.3546


100%|██████████| 15/15 [00:00<00:00, 24.36it/s]


dev Loss: 1.0957
dev Acc: 0.3002
update best on dev checkpoint: ./checkpoints/eeg_sentiment/best/BaselineMLP_5e-05_b32_unique_sent_GD.pt

Epoch 16/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.09it/s]


train Loss: 1.0984
train Acc: 0.3546


100%|██████████| 15/15 [00:00<00:00, 24.63it/s]


dev Loss: 1.0955
dev Acc: 0.3009
update best on dev checkpoint: ./checkpoints/eeg_sentiment/best/BaselineMLP_5e-05_b32_unique_sent_GD.pt

Epoch 17/19
----------


100%|██████████| 113/113 [00:01<00:00, 67.31it/s]


train Loss: 1.0988
train Acc: 0.3397


100%|██████████| 15/15 [00:00<00:00, 24.64it/s]


dev Loss: 1.0955
dev Acc: 0.3030
update best on dev checkpoint: ./checkpoints/eeg_sentiment/best/BaselineMLP_5e-05_b32_unique_sent_GD.pt

Epoch 18/19
----------


100%|██████████| 113/113 [00:01<00:00, 66.66it/s]


train Loss: 1.0979
train Acc: 0.3480


100%|██████████| 15/15 [00:00<00:00, 24.20it/s]


dev Loss: 1.0952
dev Acc: 0.3030

Epoch 19/19
----------


100%|██████████| 113/113 [00:01<00:00, 66.62it/s]


train Loss: 1.0988
train Acc: 0.3391


100%|██████████| 15/15 [00:00<00:00, 24.53it/s]

dev Loss: 1.0950
dev Acc: 0.3050
update best on dev checkpoint: ./checkpoints/eeg_sentiment/best/BaselineMLP_5e-05_b32_unique_sent_GD.pt

Training complete in 0m 46s
Best val loss: 1.094992
Best val acc: 0.305044
update last checkpoint: ./checkpoints/eeg_sentiment/last/BaselineMLP_5e-05_b32_unique_sent_GD.pt





In [None]:
###### evaluate sentiment baseline models #####
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pack_padded_sequence 
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
from data import ZuCo_dataset
from model_sentiment import BaselineMLPSentence, BaselineLSTM, FineTunePretrainedTwoStep, ZeroShotSentimentDiscovery, JointBrainTranslatorSentimentClassifier
from model_decoding import BrainTranslator, BrainTranslatorNaive
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
# from config import get_config

customed_cache_dir = "/home/linal20/transformers_cache"

# Function to calculate the accuracy 
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()  
    
    labels_flat = labels.flatten()
    
    return np.sum(pred_flat == labels_flat) / len(labels_flat)



def eval_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, tokenizer = BartTokenizer.from_pretrained('facebook/bart-large',cache_dir= customed_cache_dir)):

    def logits2PredString(logits, tokenizer):
        probs = logits[0].softmax(dim = 1)
        values, predictions = probs.topk(1)
        predictions = torch.squeeze(predictions)
        predict_string = tokenizer.decode(predictions)
        return predict_string
    
    since = time.time()
      
    best_weights = copy.deepcopy(model.state_dict())
    best_loss = 100000000000
    best_acc = 0.0
    
    total_pred_labels = np.array([])
    total_true_labels = np.array([])
    
    for epoch in range(1):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
    
        # Each epoch has a training and validation phase
        for phase in ['test']:
            total_accuracy = 0.0
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
    
            running_loss = 0.0
    
            # Iterate over data.
            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in dataloaders[phase]:
                
                input_word_eeg_features = input_word_eeg_features.to(device).float()
                input_masks = input_masks.to(device)
                input_mask_invert = input_mask_invert.to(device)
    
                sent_level_EEG = sent_level_EEG.to(device)
                sentiment_labels = sentiment_labels.to(device)
    
                target_ids = target_ids.to(device)
                target_mask = target_mask.to(device)
    
                ## forward ###################
                if isinstance(model, BaselineMLPSentence):
                    logits = model(sent_level_EEG) # before softmax
                    # calculate loss
                    loss = criterion(logits, sentiment_labels)
    
                elif isinstance(model, BaselineLSTM):
                    x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False)
                    logits = model(x_packed)
                    # calculate loss
                    loss = criterion(logits, sentiment_labels)
    
          
    
                # backward and optimize only if in training phase
                if phase == 'train':
                    # with torch.autograd.detect_anomaly():
                    loss.backward()
                    optimizer.step()
    
                # calculate accuracy
                preds_cpu = logits.detach().cpu().numpy()
                label_cpu = sentiment_labels.cpu().numpy()
    
                total_accuracy += flat_accuracy(preds_cpu, label_cpu)
                
                # add to total pred and label array, for cal F1, precision, recall
                pred_flat = np.argmax(preds_cpu, axis=1).flatten()
                labels_flat = label_cpu.flatten()
    
                total_pred_labels = np.concatenate((total_pred_labels,pred_flat))
                total_true_labels = np.concatenate((total_true_labels,labels_flat))
                
    
                # statistics
                running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss               
    
            if phase == 'train':
                scheduler.step()
    
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = total_accuracy / len(dataloaders[phase])
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            print('{} Acc: {:.4f}'.format(phase, epoch_acc))
    
            # deep copy the model
            if phase == 'test' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_acc = epoch_acc
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best test loss: {:4f}'.format(best_loss))
    print('Best test acc: {:4f}'.format(best_acc))
    print()
    print('test sample num:', len(total_pred_labels))
    print('total preds:',total_pred_labels)
    print('total truth:',total_true_labels)
    print('sklearn macro: precision, recall, F1:')
    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='macro'))
    print()
    print('sklearn micro: precision, recall, F1:')
    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='micro'))
    print()
    print('sklearn accuracy:')
    print(accuracy_score(total_true_labels,total_pred_labels))
    print()

    
if __name__ == '__main__':
    # set param
    num_epochs = 25
    
    dataset_setting = 'unique_sent'
    
    # model name
    model_name = 'BaselineMLP'
    # model_name = 'BaselineLSTM'
    
    print(f'[INFO] eval {model_name}')
    if model_name == 'ZeroShotSentimentDiscovery':
     
        # generator
        decoder_name = 'BrainTranslator'
        decoder_checkpoint= './checkpoints/decoding/best/task1_finetune_BrainTranslator_skipstep1_b32_20_30_5e-05_5e-07_unique_sent.pt' 
        print(f'[INFO] using decoder: {decoder_name}')
    
       # classifier
        # pretrain_Bert, pretrain_RoBerta, pretrain_Bart
        classifier_name = 'pretrain_Bert'
        classifier_checkpoint = './checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bert_b32_20_0.0001.pt'
        # classifier_checkpoint  = f'./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_{model_name}_b32_20_0.0001.pt'
        print(f'[INFO] using classifier: {classifier_name}')
    else:
        # checkpoint_path = args['checkpoint_path']
        # checkpoint_path = './checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt'
        if model_name == 'BaselineLSTM':
            num_layers = 4
            checkpoint_path = f'./checkpoints/eeg_sentiment/best/{model_name}_numLayers-{num_layers}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}.pt'
        else: 
            checkpoint_path = f'./checkpoints/eeg_sentiment/best/{model_name}_{step_lr}_b{batch_size}_{dataset_setting}_{eeg_type_choice}.pt'
    
        
        print('[INFO] loading baseline:', checkpoint_path)
    
    
    batch_size = 32
    
    
    subject_choice = 'ALL'
    print(f'![Debug]using {subject_choice}')
    eeg_type_choice = 'GD'
    print(f'[INFO]eeg type {eeg_type_choice}')
    bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 
    print(f'[INFO]using bands {bands_choice}')
    
    
    
    # random seeds
    seed_val = 312
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    
    
    # device
    # use cuda
    if torch.cuda.is_available():  
        # dev = args['cuda']
         dev = 'cuda:1'
    else:  
        dev = "cpu"
    # CUDA_VISIBLE_DEVICES=0,1,2,3  
    device = torch.device(dev)
    print(f'[INFO]using device {dev}')
    
    
    #  load Task1_SR_processed.pickle
    whole_dataset_dict = []
    # dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/Task1_SR_processed.pickle'
    dataset_path_task1 ='./dataset/processed/Task1_SR_processed.pickle'
    with open(dataset_path_task1, 'rb') as handle:
        whole_dataset_dict.append(pickle.load(handle))
    
    # tokenizer
    if model_name in ['BaselineMLP','BaselineLSTM', 'NaiveFinetuneBert', 'FinetunedBertOnText']:
        print('[INFO]using Bert tokenizer')
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased',cache_dir="/home/linal20/NLP_final_project")
    
    # model 
    if model_name == 'BaselineMLP':
        print('[INFO]Model: BaselineMLP')
        model = BaselineMLPSentence(input_dim = 840, hidden_dim = 128, output_dim = 3)
    elif model_name == 'BaselineLSTM':
        print('[INFO]Model: BaselineLSTM')
        model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 4)
    
    
    if model_name != 'ZeroShotSentimentDiscovery':
        # load model and send to device
        model.load_state_dict(torch.load(checkpoint_path))
        model.to(device)
    
    #  dataloader '''
    # test dataset
    test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = 'unique_sent')
    
    dataset_sizes = {'test': len(test_set)}
    # print('[INFO]train_set size: ', len(train_set))
    print('[INFO]test_set size: ', len(test_set))
    
    test_dataloader = DataLoader(test_set, batch_size = batch_size, shuffle=False, num_workers=4)
    # dataloaders
    dataloaders = {'test':test_dataloader}
    
    # optimizer and scheduler
    optimizer_step1 = None
    exp_lr_scheduler_step1 = None
    
    # loss function 
    criterion = nn.CrossEntropyLoss()
    
    print('=== start evaluating ===')
    # return best loss model from step1 training
    model = eval_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=25, tokenizer = tokenizer)

[INFO] eval BaselineMLP
[INFO] loading baseline: ./checkpoints/eeg_sentiment/best/BaselineMLP_5e-05_b32_unique_sent_GD.pt
![Debug]using ALL
[INFO]eeg type GD
[INFO]using bands ['_t1', '_t2', '_a1', '_a2', '_b1', '_b2', '_g1', '_g2']
[INFO]using device cuda:1
[INFO]using Bert tokenizer
[INFO]Model: BaselineMLP
[INFO]loading 1 task datasets
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train divider = 320
dev divider = 360
[INFO]initializing a test set...
++ adding task to dataset, now we have: 456
[INFO]input tensor size: torch.Size([56, 840])

[INFO]test_set size:  456
=== start evaluating ===
Epoch 0/24
----------
