In [1]:
import pickle
import numpy as np
from config import get_config

In [2]:
from transformers import BertLMHeadModel, BartTokenizer
from data import ZuCo_dataset

In [3]:
task_name = "task1, task2, taskNRv2"

In [4]:
''' set up dataloader '''

whole_dataset_dicts = []
'''
if 'task1' in task_name:
    dataset_path_task1 = r'I:\Science\CIS-YASHMOSH\niallmcguire\ZuCo\task1-SR\pickle\task1-SR-dataset.pickle'
    with open(dataset_path_task1, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))

if 'task2' in task_name:
    dataset_path_task2 = r'I:\Science\CIS-YASHMOSH\niallmcguire\ZuCo\task2-NR\pickle\task2-NR-dataset.pickle'
    with open(dataset_path_task2, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))

if 'task3' in task_name:
    dataset_path_task3 = r'I:\Science\CIS-YASHMOSH\niallmcguire\ZuCo\task3-TSR\pickle\task3-TSR-dataset.pickle'
    with open(dataset_path_task3, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))
'''
if 'taskNRv2' in task_name:
    dataset_path_taskNRv2 = r'I:\Science\CIS-YASHMOSH\niallmcguire\ZuCo\task2-NR-2.0\pickle\task2-NR-2.0-dataset.pickle'
    with open(dataset_path_taskNRv2, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))


In [5]:
print("Loaded in", len(whole_dataset_dicts), "task datasets")

Loaded in 1 task datasets


In [6]:
Task_Dataset_List = whole_dataset_dicts
if not isinstance(whole_dataset_dicts,list):
    Task_Dataset_List = [whole_dataset_dicts]

In [7]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], setting = 'unique_sent', is_add_CLS_token = False)

[INFO]loading 1 task datasets
[INFO]using subjects:  ['YAC', 'YAG', 'YAK', 'YDG', 'YDR', 'YFR', 'YFS', 'YHS', 'YIS', 'YLS', 'YMD', 'YMS', 'YRH', 'YRK', 'YRP', 'YSD', 'YSL', 'YTL']
train divider = 279
dev divider = 313
[INFO]initializing a train set...
['Henry', 'Ford,', 'with', 'son', 'Edsel,', 'founded', 'Ford', 'Foundation', 'in', '1936', 'local', 'philanthropic', 'organization', 'broad', 'charter', 'promote', 'human', 'welfare.']
['this', 'initial', 'success,', 'Ford', 'left', 'Edison', 'Illuminating', 'and,', 'with', 'other', 'investors,', 'formed', 'Detroit', 'Automobile', 'Company.']
['With', 'his', 'interest', 'race', 'cars,', 'formed', 'second', 'company,', 'Henry', 'Ford', 'Company.']
['During', 'this', 'period,', 'personally', 'drove', 'his', 'Quadricycle', 'victory', 'a', 'race', 'against', 'Alexander', 'Winton,', 'well-known', 'driver', 'heavy', 'favorite', 'on', 'October', '10,', '1901.']
['Ford', 'was', 'forced', 'of', 'the', 'company', 'by', 'investors,', 'including', 'H

In [10]:
set = train_set.__getitem__(-1)
set[-1]

['received',
 'his',
 'degree',
 '1965',
 "master's",
 'degree',
 'political',
 'science',
 '1966',
 'both',
 'University',
 'Wyoming.']

In [9]:


def get_eeg_word_embedding(word, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']):
    EEG_frequency_features = []
    EEG_word_level_label = word['content']
    for band in bands:
        EEG_frequency_features.append(word['word_level_EEG'][eeg_type][eeg_type+band])
    word_eeg_embedding = np.concatenate(EEG_frequency_features)
    if len(word_eeg_embedding) != 105*len(bands):
        print(f'expect word eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None')
        word_eeg_embedding = None
    else:
        word_eeg_embedding = word_eeg_embedding.reshape(105, 8)

    return word_eeg_embedding, EEG_word_level_label





In [None]:

#print number of unique words in each task
for Task_Dataset in Task_Dataset_List:
    subjects = list(Task_Dataset.keys())
    print('[INFO]using subjects: ', subjects)
    total_num_sentence = len(Task_Dataset[subjects[0]])
    print(f'[INFO]total number of sentences = {total_num_sentence}')
    unique_words = set()
    for key in subjects:
        for i in range(total_num_sentence):
            if Task_Dataset[key][i] is not None:
                sentence_object = Task_Dataset[key][i]
                for word in sentence_object['word']:
                    unique_words.add(word['content'])
    print(f'[INFO]total number of unique words = {len(unique_words)}')

In [None]:
import torch

#Main loop, looping through each task
for Task_Dataset in Task_Dataset_List:
    subjects = list(Task_Dataset.keys())
    print('[INFO]using subjects: ', subjects)

    total_num_sentence = len(Task_Dataset[subjects[0]])

    train_divider = int(0.8*total_num_sentence)
    dev_divider = train_divider + int(0.1*total_num_sentence)

    print(f'train size = {train_divider}')
    print(f'dev size = {dev_divider}')

    EEG_word_level_embeddings = []
    EEG_word_level_labels = []
    print('[INFO]initializing a train set...')
    for key in subjects:
        print(f'key = {key}')
        for i in range(train_divider):
            if Task_Dataset[key][i] is not None:
                sentence_object = Task_Dataset[key][i]
                for word in sentence_object['word']:
                    word_eeg_embedding, EEG_word_level_label = get_eeg_word_embedding(word)
                    if word_eeg_embedding is not None and torch.isnan(torch.from_numpy(word_eeg_embedding)).any() == False:
                        EEG_word_level_embeddings.append(word_eeg_embedding)
                        EEG_word_level_labels.append(EEG_word_level_label)



In [None]:
len(EEG_word_level_labels)
#count unique items in list
unique, counts = np.unique(EEG_word_level_labels, return_counts=True)
print(len(unique))


In [None]:
'''
train_data = []
for i in range(len(EEG_word_level_embeddings)):
   train_data.append([EEG_word_level_embeddings[i], EEG_word_level_labels[i]])
'''

In [None]:
'''
# Save the lists to a file using pickle
with open('EEG_Text_Pairs.pkl', 'wb') as file:
    pickle.dump(EEG_word_level_embeddings, file)
    pickle.dump(EEG_word_level_labels, file)

# To load the lists from the file:
with open('EEG_Text_Pairs.pkl', 'rb') as file:
    EEG_word_level_embeddings = pickle.load(file)
    EEG_word_level_labels = pickle.load(file)
'''

In [None]:
import torch

trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64)

In [None]:
#sant