In [1]:
import pickle
import numpy as np
from config import get_config

In [2]:
from transformers import BertLMHeadModel, BartTokenizer
from data import ZuCo_dataset

In [3]:
task_name = "task1, task2, taskNRv2"

In [4]:
''' set up dataloader '''
whole_dataset_dicts = []
if 'task1' in task_name:
    dataset_path_task1 = r'C:\Users\gxb18167\PycharmProjects\EEG-To-Text\dataset\ZuCo\task1-SR\pickle\task1-SR-dataset.pickle'
    with open(dataset_path_task1, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))
if 'task2' in task_name:
    dataset_path_task2 = r'C:\Users\gxb18167\PycharmProjects\EEG-To-Text\\dataset\ZuCo\task2-NR\pickle\task2-NR-dataset.pickle'
    with open(dataset_path_task2, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))
if 'task3' in task_name:
    dataset_path_task3 = r'C:\Users\gxb18167\PycharmProjects\EEG-To-Text\\dataset\ZuCo\task3-TSR\pickle\task3-TSR-dataset.pickle'
    with open(dataset_path_task3, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))
if 'taskNRv2' in task_name:
    dataset_path_taskNRv2 = r'C:\Users\gxb18167\PycharmProjects\EEG-To-Text\dataset\ZuCo\task2-NR-2.0\pickle\task2-NR-2.0-dataset.pickle'
    with open(dataset_path_taskNRv2, 'rb') as handle:
        whole_dataset_dicts.append(pickle.load(handle))

In [5]:
print("Loaded in", len(whole_dataset_dicts), "task datasets")

Loaded in 3 task datasets


In [6]:
Task_Dataset_List = whole_dataset_dicts
if not isinstance(whole_dataset_dicts,list):
    Task_Dataset_List = [whole_dataset_dicts]

In [7]:


def get_eeg_word_embedding(word, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']):
    EEG_frequency_features = []
    EEG_word_level_label = word['content']
    for band in bands:
        EEG_frequency_features.append(word['word_level_EEG'][eeg_type][eeg_type+band])
    word_eeg_embedding = np.concatenate(EEG_frequency_features)
    if len(word_eeg_embedding) != 105*len(bands):
        print(f'expect word eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None')
        word_eeg_embedding = None
    else:
        word_eeg_embedding = word_eeg_embedding.reshape(105, 8)

    return word_eeg_embedding, EEG_word_level_label





In [17]:
import torch

#Main loop, looping through each task
for Task_Dataset in Task_Dataset_List:
    subjects = list(Task_Dataset.keys())
    print('[INFO]using subjects: ', subjects)

    total_num_sentence = len(Task_Dataset[subjects[0]])

    train_divider = int(0.8*total_num_sentence)
    dev_divider = train_divider + int(0.1*total_num_sentence)

    print(f'train size = {train_divider}')
    print(f'dev size = {dev_divider}')

    EEG_word_level_embeddings = []
    EEG_word_level_labels = []
    print('[INFO]initializing a train set...')
    for key in subjects:
        print(f'key = {key}')
        for i in range(train_divider):
            if Task_Dataset[key][i] is not None:
                sentence_object = Task_Dataset[key][i]
                for word in sentence_object['word']:
                    word_eeg_embedding, EEG_word_level_label = get_eeg_word_embedding(word)
                    if word_eeg_embedding is not None and torch.isnan(torch.from_numpy(word_eeg_embedding)).any() == False:
                        print("True")
                        EEG_word_level_embeddings.append(word_eeg_embedding)
                        EEG_word_level_labels.append(EEG_word_level_label)



[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZMG', 'ZPH']
train size = 320
dev size = 360
[INFO]initializing a train set...
key = ZAB
key = ZDM
key = ZDN
key = ZGW
key = ZJM
key = ZJN
key = ZJS
key = ZKB
key = ZKH
key = ZMG
key = ZPH
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train size = 240
dev size = 270
[INFO]initializing a train set...
key = ZAB
key = ZDM
key = ZDN
key = ZGW
key = ZJM
key = ZJN
key = ZJS
key = ZKB
key = ZKH
key = ZKW
key = ZMG
key = ZPH
[INFO]using subjects:  ['YAC', 'YAG', 'YAK', 'YDG', 'YDR', 'YFR', 'YFS', 'YHS', 'YIS', 'YLS', 'YMD', 'YMS', 'YRH', 'YRK', 'YRP', 'YSD', 'YSL', 'YTL']
train size = 279
dev size = 313
[INFO]initializing a train set...
key = YAC
expect word eeg embedding dim to be 840, but got 0, return None
expect word eeg embedding dim to be 840, but got 0, return None
expect word eeg embedding dim to be 840, but got 0, return None
expect word

In [18]:
train_data = []
for i in range(len(EEG_word_level_embeddings)):
   train_data.append([EEG_word_level_embeddings[i], EEG_word_level_labels[i]])

In [21]:
train_data[0]

tensor(False)

In [22]:
# Save the lists to a file using pickle
with open('EEG_Text_Pairs.pkl', 'wb') as file:
    pickle.dump(EEG_word_level_embeddings, file)
    pickle.dump(EEG_word_level_labels, file)

# To load the lists from the file:
with open('EEG_Text_Pairs.pkl', 'rb') as file:
    EEG_word_level_embeddings = pickle.load(file)
    EEG_word_level_labels = pickle.load(file)

[array([[0.09515667, 0.45813037, 0.30879774, 0.48925483, 0.4663575 ,
         0.47793435, 0.08888531, 0.28408581],
        [0.34188234, 0.48905038, 0.7512639 , 1.65633442, 0.43332053,
         0.39021828, 0.42390918, 0.42640141],
        [0.21606035, 0.31481223, 0.50272886, 0.24707651, 0.38156306,
         0.41565591, 0.88719495, 0.86575871],
        [1.02048984, 0.58520356, 0.83853885, 1.01260253, 0.95925589,
         0.96508799, 0.18278719, 1.41046171],
        [1.40865462, 1.17285938, 1.03840129, 0.07390626, 0.05391974,
         1.59567461, 1.20218847, 1.16447783],
        [1.59421139, 1.21012559, 1.154855  , 1.1757874 , 1.64789178,
         1.0209493 , 1.75032681, 1.51349833],
        [1.46716506, 1.55225766, 1.53080869, 2.53303736, 1.55151326,
         1.93549963, 1.69098489, 1.9197302 ],
        [1.76540461, 2.14217345, 2.05729949, 2.13237491, 2.07070855,
         3.52763658, 2.27397441, 1.96421766],
        [1.58608133, 1.37231919, 1.02504781, 2.76136217, 3.10373409,
         2.

In [23]:
import torch

trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=64)

In [13]:
#sant