In [1]:
# train_eval_sentiment_textbased.ipynb: this includes code for training and evaluating zero-shot classification model.

##### training sentiment classification model #####
customed_cache_dir = "/home/linal20/transformers_cache"

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, random_split
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
from data import ZuCo_dataset, SST_tenary_dataset
from model_sentiment import FineTunePretrainedTwoStep


# Function to calculate the accuracy of our predictions vs labels
def accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()  
    labels_flat = labels.flatten()


    return np.mean(pred_flat == labels_flat)

def train_SST(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'):
    
    since = time.time()
      
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 100000000000
    best_acc = 0.0
    

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'dev']:
            sum_accuracy = 0.0
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0

            # Iterate over data.
            for input_ids,input_masks,sentiment_labels in tqdm(dataloaders[phase]):
                
                input_ids = input_ids.to(device)
                input_masks = input_masks.to(device)
                sentiment_labels = sentiment_labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                output = model(input_ids = input_ids, attention_mask = input_masks, return_dict = True, labels = sentiment_labels)
                logits = output.logits
                loss = output.loss
  
                if phase == 'train':
                    # with torch.autograd.detect_anomaly():
                    loss.backward()
                    optimizer.step()

                # calculate accuracy
                preds_cpu = logits.detach().cpu().numpy()
                label_cpu = sentiment_labels.cpu().numpy()

                sum_accuracy += accuracy(preds_cpu, label_cpu)

                # statistics
                running_loss += loss.item() * input_ids.size()[0] # batch loss
                          

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = sum_accuracy / len(dataloaders[phase])
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            print('{} Acc: {:.4f}'.format(phase, epoch_acc))

            # deep copy the model
            if phase == 'dev' and (epoch_acc > best_acc):
                best_loss = epoch_loss
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                '''save checkpoint'''
                torch.save(model.state_dict(), checkpoint_path_best)
                print(f'update best on dev checkpoint: {checkpoint_path_best}')
        print()

    print('Best val loss: {:4f}'.format(best_loss))
    print('Best val acc: {:4f}'.format(best_acc))
    torch.save(model.state_dict(), checkpoint_path_last)
    print(f'update last checkpoint: {checkpoint_path_last}')
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

customed_cache_dir = "/home/linal20/transformers_cache"

# set parameter

epoch_n = 20
# lr = 1e-3 # Bert, RoBerta
lr = 1e-4 # Bart
# lr = 1e-3

# lr = 0.0001
dataset_name = 'SST'
dataset_setting = 'unique_sent'
batch_size = 32

# model_name = 'pretrain_Bert'
# model_name = 'pretrain_RoBerta'
model_name = 'pretrain_Bart'
print(f'model: {model_name}')

save_path = './checkpoints/text_sentiment_classifier'
 
dataset_name == 'SST'
save_name = f'Textbased_StanfordSentitmentTreeband_{model_name}_b{batch_size}_{epoch_n}_{lr}'

output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' 
output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' 


# set seed
set_seed = 312
np.random.seed(set_seed)
torch.manual_seed(set_seed)
torch.cuda.manual_seed_all(set_seed)

# set up device
# use cuda
if torch.cuda.is_available():  
     dev = "cuda:0"
else:  
    dev = "cpu"
device = torch.device(dev)
print(f'device: {dev}')


# tokenizer
if model_name == 'pretrain_Bert':
    print('tokenizer: bert-base-cased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased',cache_dir = customed_cache_dir)
elif model_name == 'pretrain_RoBerta':
    print('tokenizer: roberta-base')
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base',cache_dir = customed_cache_dir)
elif model_name == 'pretrain_Bart':
    print('tokenizer: bart-large')
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large',cache_dir = customed_cache_dir)

dataset_name == 'SST'

# creaet SST_dataset
SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))
SST_dataset = SST_tenary_dataset(SENTIMENT_LABELS, tokenizer)  


# create train and test set
from torch.utils.data import random_split

train_size = int(0.9 * len(SST_dataset))
train_set, dev_set = random_split(SST_dataset, [train_size, len(SST_dataset) - train_size])

dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}
print('train size: ', len(train_set))
print('dev size: ', len(dev_set))

# lr = 0.0001 
# train dataloader
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)
# dev dataloader
val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)
# dataloaders
dataloaders = {'train':train_dataloader, 'dev':val_dataloader}


if model_name == 'pretrain_Bert':
    model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3,cache_dir = customed_cache_dir)
elif model_name == 'pretrain_RoBerta':
    model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3,cache_dir = customed_cache_dir)
elif model_name == 'pretrain_Bart':
    model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels = 3,cache_dir = customed_cache_dir)

model.to(device)

optimizer_step1 = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=10, gamma=0.1)

criterion = nn.CrossEntropyLoss()

# return best loss model 
print(f'=== start training ===')
model = train_SST(dataloaders, device, model, criterion, optimizer_step1, lr_scheduler_step1, num_epochs=epoch_n, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)    


Matplotlib created a temporary cache directory at /tmp/matplotlib-8ogova30 because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
There was a problem when trying to write in your cache folder (/home/jovyan/.cache/huggingface/hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.


model: pretrain_Bart
device: cuda:0
tokenizer: bart-large
Original distribution:
	Very positive: 1701
	Neutral: 2101
	Very negative: 1383
balance class to 1383 each...
train size:  3734
dev size:  415


Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['classification_head.out_proj.weight', 'classification_head.dense.bias', 'classification_head.out_proj.bias', 'classification_head.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


=== start training ===
Epoch 0/19
----------


 48%|████▊     | 56/117 [00:17<00:17,  3.48it/s]

update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 2/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.40it/s]


train Loss: 0.5698
train Acc: 0.7579


100%|██████████| 415/415 [00:11<00:00, 34.99it/s]


dev Loss: 0.5017
dev Acc: 0.7735
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 3/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.5199
train Acc: 0.7784


100%|██████████| 415/415 [00:12<00:00, 34.41it/s]


dev Loss: 0.4972
dev Acc: 0.7783
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 4/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.4724
train Acc: 0.8044


100%|██████████| 415/415 [00:11<00:00, 35.32it/s]


dev Loss: 0.4368
dev Acc: 0.8072
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 5/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.4296
train Acc: 0.8264


100%|██████████| 415/415 [00:11<00:00, 35.43it/s]


dev Loss: 0.4503
dev Acc: 0.8193
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 6/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.4176
train Acc: 0.8318


100%|██████████| 415/415 [00:11<00:00, 35.21it/s]


dev Loss: 0.4544
dev Acc: 0.8193

Epoch 7/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.4071
train Acc: 0.8317


100%|██████████| 415/415 [00:12<00:00, 34.12it/s]


dev Loss: 0.4889
dev Acc: 0.8072

Epoch 8/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.43it/s]


train Loss: 0.3719
train Acc: 0.8477


100%|██████████| 415/415 [00:11<00:00, 35.45it/s]


dev Loss: 0.4638
dev Acc: 0.8193

Epoch 9/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.3641
train Acc: 0.8517


100%|██████████| 415/415 [00:11<00:00, 35.57it/s]


dev Loss: 0.4739
dev Acc: 0.8289
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 10/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


train Loss: 0.3317
train Acc: 0.8725


100%|██████████| 415/415 [00:12<00:00, 34.14it/s]


dev Loss: 0.4713
dev Acc: 0.8265

Epoch 11/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.3434
train Acc: 0.8616


100%|██████████| 415/415 [00:11<00:00, 34.98it/s]


dev Loss: 0.4675
dev Acc: 0.8313
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 12/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


train Loss: 0.3379
train Acc: 0.8624


100%|██████████| 415/415 [00:12<00:00, 32.83it/s]


dev Loss: 0.4730
dev Acc: 0.8241

Epoch 13/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


train Loss: 0.3380
train Acc: 0.8652


100%|██████████| 415/415 [00:11<00:00, 35.37it/s]


dev Loss: 0.4709
dev Acc: 0.8289

Epoch 14/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


train Loss: 0.3208
train Acc: 0.8741


100%|██████████| 415/415 [00:11<00:00, 35.28it/s]


dev Loss: 0.4770
dev Acc: 0.8313

Epoch 15/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


train Loss: 0.3310
train Acc: 0.8648


100%|██████████| 415/415 [00:12<00:00, 34.29it/s]


dev Loss: 0.4695
dev Acc: 0.8313

Epoch 16/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.43it/s]


train Loss: 0.3314
train Acc: 0.8693


100%|██████████| 415/415 [00:11<00:00, 35.16it/s]


dev Loss: 0.4813
dev Acc: 0.8361
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 17/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.42it/s]


train Loss: 0.3428
train Acc: 0.8640


100%|██████████| 415/415 [00:11<00:00, 35.03it/s]


dev Loss: 0.4605
dev Acc: 0.8386
update best on dev checkpoint: ./checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt

Epoch 18/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


train Loss: 0.3321
train Acc: 0.8638


100%|██████████| 415/415 [00:11<00:00, 35.22it/s]


dev Loss: 0.4649
dev Acc: 0.8361

Epoch 19/19
----------


100%|██████████| 117/117 [00:34<00:00,  3.39it/s]


train Loss: 0.3317
train Acc: 0.8693


100%|██████████| 415/415 [00:12<00:00, 33.70it/s]


dev Loss: 0.4763
dev Acc: 0.8313

Best val loss: 0.460492
Best val acc: 0.838554
update last checkpoint: ./checkpoints/text_sentiment_classifier/last/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt


In [6]:
##### evaluating zero-shot classification model #####

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pack_padded_sequence 
import pickle
import json
import matplotlib.pyplot as plt
from glob import glob
import time
import copy
from tqdm import tqdm

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
from data import ZuCo_dataset
from model_sentiment import BaselineMLPSentence, BaselineLSTM, FineTunePretrainedTwoStep, ZeroShotSentimentDiscovery, JointBrainTranslatorSentimentClassifier
from model_decoding import BrainTranslator, BrainTranslatorNaive
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score


customed_cache_dir = "/home/linal20/transformers_cache"

# Function to calculate the accuracy of our predictions vs labels
def accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()  
    
    labels_flat = labels.flatten()
    
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def eval_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, tokenizer = BartTokenizer.from_pretrained('facebook/bart-large',cache_dir= customed_cache_dir)):

    def logits2PredString(logits, tokenizer):
        probs = logits[0].softmax(dim = 1)
        values, predictions = probs.topk(1)
        predictions = torch.squeeze(predictions)
        predict_string = tokenizer.decode(predictions)
        return predict_string
    

    since = time.time()
      
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 100000000000
    best_acc = 0.0
    
    total_pred_labels = np.array([])
    total_true_labels = np.array([])
    
    for epoch in range(1):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
    
        # set phase for a training and validation
        for phase in ['test']:
            sum_accuracy = 0.0
            model.eval()   
            running_loss = 0.0
    
            # Iterate over data.
            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in dataloaders[phase]:
                
                input_word_eeg_features = input_word_eeg_features.to(device).float()
                input_masks = input_masks.to(device)
                input_mask_invert = input_mask_invert.to(device)
    
                sent_level_EEG = sent_level_EEG.to(device)
                sentiment_labels = sentiment_labels.to(device)
    
                target_ids = target_ids.to(device)
                target_mask = target_mask.to(device)
    
  
                print()
                print('target string:',tokenizer.decode(target_ids[0]).replace('<pad>','').split('</s>')[0]) 


                target_ids[target_ids == tokenizer.pad_token_id] = -100 

                output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels)
                logits = output.logits
                loss = output.loss
                
    
                # calculate accuracy
                preds_cpu = logits.detach().cpu().numpy()
                label_cpu = sentiment_labels.cpu().numpy()
    
                sum_accuracy += accuracy(preds_cpu, label_cpu)
                
                # add to total pred and label array, for cal F1, precision, recall
                pred_flat = np.argmax(preds_cpu, axis=1).flatten()
                labels_flat = label_cpu.flatten()
    
                total_pred_labels = np.concatenate((total_pred_labels,pred_flat))
                total_true_labels = np.concatenate((total_true_labels,labels_flat))
                
    
                # batch loss
                running_loss += loss.item() * sent_level_EEG.size()[0] 

    
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = sum_accuracy / len(dataloaders[phase])
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
            print('{} Acc: {:.4f}'.format(phase, epoch_acc))
    
            # deep copy the model
            if phase == 'test' and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_acc = epoch_acc
        print()
    
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best test loss: {:4f}'.format(best_loss))
    print('Best test acc: {:4f}'.format(best_acc))
    print()
    print('test sample num:', len(total_pred_labels))
    print('total preds:',total_pred_labels)
    print('total truth:',total_true_labels)
    print('sklearn macro: precision, recall, F1:')
    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='macro'))
    print()
    print('sklearn micro: precision, recall, F1:')
    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='micro'))
    print()
    print('sklearn accuracy:')
    print(accuracy_score(total_true_labels,total_pred_labels))
    print()
    
    
    
if __name__ == '__main__':

 
    # set param
    num_epochs = 1
    
    dataset_setting = 'unique_sent'
    
    # model name
    
    model_name = 'ZeroShotSentimentDiscovery'
    print(f'[INFO] eval {model_name}')
    
    #decoder name
    decoder_name = 'BrainTranslator'
    decoder_checkpoint='./checkpoints/task1_all_steps/best/final.pt'
    print(f'[INFO] using decoder: {decoder_name}')
    
    # classifier
    # pretrain_Bert, pretrain_RoBerta, pretrain_Bart
    # classifier_name = config_classifier['model_name']
    # classifier_name = 'pretrain_Bert'
    # classifier_name = 'pretrain_RoBerta'
    classifier_name = 'pretrain_Bart'
    # classifier_checkpoint = args['classifier_checkpoint_path']
    # classifier_checkpoint = './checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_RoBerta_b32_20_0.0001.pt'
    # classifier_checkpoint = './checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bert_b32_20_0.0001.pt'
    classifier_checkpoint = './checkpoints/text_sentiment_classifier/best/Textbased_StanfordSentitmentTreeband_pretrain_Bart_b32_20_0.0001.pt'
    print(f'[INFO] using classifier: {classifier_name}')
    
    batch_size = 32
    
    subject_choice = 'ALL'
    print(f'![Debug]using {subject_choice}')
    eeg_type_choice = 'GD'
    print(f'[INFO]eeg type {eeg_type_choice}')
    bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 
    print(f'[INFO]using bands {bands_choice}')
    
    
    # random seeds 
    seed_val = 312
    np.random.seed(seed_val)
    torch.manual_seed(seed_val)
    torch.cuda.manual_seed_all(seed_val)
    
    
    # set device
    # use cuda
    if torch.cuda.is_available():  
        # dev = args['cuda']
         dev = 'cuda:1'
    else:  
        dev = "cpu"
    # CUDA_VISIBLE_DEVICES=0,1,2,3  
    device = torch.device(dev)
    print(f'[INFO]using device {dev}')
    
    
    # Task1_SR_processed.pickle
    whole_dataset_dict = []
    dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/Task1_SR_processed.pickle' 
    with open(dataset_path_task1, 'rb') as handle:
        whole_dataset_dict.append(pickle.load(handle))
    
    
    # set tokenizer
    decoder_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large',cache_dir = customed_cache_dir ) # Bart
    tokenizer = decoder_tokenizer
    
    if classifier_name == 'pretrain_Bert':
      sentiment_tokenizer = BertTokenizer.from_pretrained('bert-base-cased',cache_dir=customed_cache_dir) # Bert
    elif classifier_name == 'pretrain_Bart':
        sentiment_tokenizer = decoder_tokenizer
    elif classifier_name == 'pretrain_RoBerta':
        sentiment_tokenizer = RobertaTokenizer.from_pretrained('roberta-base',cache_dir=customed_cache_dir)
    
    print(f'[INFO]Model: ZeroShotSentimentDiscovery, using classifer:{classifier_name}, using generator: {decoder_name}')
    
    pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large',cache_dir = customed_cache_dir)
    decoder = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
    
    decoder.load_state_dict(torch.load(decoder_checkpoint), strict=False)
        
    if classifier_name == 'pretrain_Bert':
      classifier = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3,cache_dir=customed_cache_dir)
    elif classifier_name == 'pretrain_Bart':
        classifier = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3,cache_dir=customed_cache_dir)
    elif classifier_name == 'pretrain_RoBerta':
        classifier = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3,cache_dir=customed_cache_dir)  
    
    classifier.load_state_dict(torch.load(classifier_checkpoint))
    
    model = ZeroShotSentimentDiscovery(decoder, classifier, decoder_tokenizer, sentiment_tokenizer, device = device)
    model.to(device)
    
    #set up dataloader 
    # test dataset
    test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = 'unique_sent')
    
    dataset_sizes = {'test': len(test_set)}
    print('[INFO]test_set size: ', len(test_set))
    
    # feed into dataloaders
    test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4)
    dataloaders = {'test':test_dataloader}
    
    # optimizer and scheduler
    optimizer_step1 = None
    lr_scheduler_step1 = None
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    
    print('=== start evaluating ... ===')
    
    # evaluate model
    model = eval_model(dataloaders, device, model, criterion, optimizer_step1, lr_scheduler_step1, num_epochs=num_epochs, tokenizer = tokenizer)


[INFO] eval ZeroShotSentimentDiscovery
[INFO] using decoder: BrainTranslator
[INFO] using classifier: pretrain_Bart
![Debug]using ALL
[INFO]eeg type GD
[INFO]using bands ['_t1', '_t2', '_a1', '_a2', '_b1', '_b2', '_g1', '_g2']
[INFO]using device cuda:1
[INFO]Model: ZeroShotSentimentDiscovery, using classifer:pretrain_Bart, using generator: BrainTranslator


Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['classification_head.out_proj.weight', 'classification_head.dense.bias', 'classification_head.out_proj.bias', 'classification_head.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[INFO]loading 1 task datasets
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train divider = 320
dev divider = 360
[INFO]initializing a test set...
++ adding task to dataset, now we have: 456
[INFO]input tensor size: torch.Size([56, 840])

[INFO]test_set size:  456
=== start evaluating ... ===
Epoch 0/0
----------

target string: <s>Everything its title implies, a standard-issue crime drama spat out from the Tinseltown assembly line.
predict string:  the a.- the. the. the the's's's. the's the the the's. the.'s's's.. the's the's's the's the's's.'s. the the's the the

target string: <s>This odd, poetic road movie, spiked by jolts of pop music, pretty much takes place in Morton's ever-watchful gaze -- and it's a tribute to the actress, and to her inventive director, that the journey is such a mesmerizing one.
predict string: .. the. the..

target string: <s>Co-writer/director Jonathan Parker's attempts to fashion a Brazil-like, 

  _warn_prf(average, modifier, msg_start, len(result))
