In [1]:
import os
import numpy as np
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
import json
import matplotlib.pyplot as plt
from glob import glob
from transformers import BartTokenizer
from tqdm import tqdm
from torch.utils.data import Dataset
from datasets import load_metric
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import sys
sys.path.append('C:/Users/acer/Desktop/IISc/EEG2text/src')
from dataclasses import dataclass
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.data import Dataset, DataLoader
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch
import numpy as np
from torch import optim
import logging
from datetime import datetime
from config import ModelConfig_wordlevel
from models import *
from transformers import BartTokenizer
import pickle
from common.utils.data import *
import warnings


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ZuCo_dataset(Dataset):
    def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], setting = 'unique_sent', is_add_CLS_token = False, max_len : int= None):
        self.inputs = []
        self.tokenizer = tokenizer
        self.max_len = max_len

        if not isinstance(input_dataset_dicts,list):
            input_dataset_dicts = [input_dataset_dicts]
        print(f'[INFO]loading {len(input_dataset_dicts)} task datasets')
        
        for input_dataset_dict in input_dataset_dicts:
            if subject == 'ALL':
                subjects = list(input_dataset_dict.keys())
                print('[INFO]using subjects: ', subjects)
            else:
                subjects = [subject]
            
            total_num_sentence = len(input_dataset_dict[subjects[0]])
            print("total number of sentences is, ", total_num_sentence)
            
            train_divider = int(0.8*total_num_sentence)
            dev_divider = train_divider + int(0.1*total_num_sentence)
            
            print(f'train divider = {train_divider}')
            print(f'dev divider = {dev_divider}')

            if setting == 'unique_sent':
                # take first 80% as trainset, 10% as dev and 10% as test
                if phase == 'train':
                    print('[INFO]initializing a train set...')
                    for key in subjects:
                        for i in range(train_divider):
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
                            if input_sample is not None:
                                self.inputs.append(input_sample)
                elif phase == 'dev':
                    print('[INFO]initializing a dev set...')
                    for key in subjects:
                        for i in range(train_divider,dev_divider):
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
                            if input_sample is not None:
                                self.inputs.append(input_sample)
                elif phase == 'test':
                    print('[INFO]initializing a test set...')
                    for key in subjects:
                        for i in range(dev_divider,total_num_sentence):
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
                            if input_sample is not None:
                                self.inputs.append(input_sample)
            elif setting == 'unique_subj':
                print('WARNING!!! only implemented for SR v1 dataset ')
                # subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train
                # subject ['ZMG'] for dev
                # subject ['ZPH'] for test
                if phase == 'train':
                    print(f'[INFO]initializing a train set using {setting} setting...')
                    for i in range(total_num_sentence):
                        for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']:
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
                            if input_sample is not None:
                                self.inputs.append(input_sample)
                if phase == 'dev':
                    print(f'[INFO]initializing a dev set using {setting} setting...')
                    for i in range(total_num_sentence):
                        for key in ['ZMG']:
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
                            if input_sample is not None:
                                self.inputs.append(input_sample)
                if phase == 'test':
                    print(f'[INFO]initializing a test set using {setting} setting...')
                    for i in range(total_num_sentence):
                        for key in ['ZPH']:
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
                            if input_sample is not None:
                                self.inputs.append(input_sample)
            # print('++ adding task to dataset, now we have:', len(self.inputs))

        # print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size())
        print()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_sample = self.inputs[idx]
        
        raw_eeg = input_sample['raw_eeg']
        if self.max_len is not None:
            # Truncate or pad raw_eeg array
            if raw_eeg.shape[1] > self.max_len:
                raw_eeg = raw_eeg[:, :self.max_len]
            elif raw_eeg.shape[1] < self.max_len:
                pad_width = ((0, 0), (0, self.max_len - raw_eeg.shape[1]))
                raw_eeg = np.pad(raw_eeg, pad_width, mode='constant', constant_values=0)
        return (
            raw_eeg,
            input_sample['target_string'],
            input_sample['target_ids'], 
            input_sample['target_mask'],  
            input_sample['input_embeddings'], # Word level features (32 * 840)
            input_sample['seq_len'], # new
            input_sample['input_attn_mask'], 
            input_sample['input_attn_mask_invert'],
        )

In [3]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
dataset_setting = 'unique_sent'
subject_choice = 'ALL'
eeg_type_choice = 'GD'
bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 

config=ModelConfig_wordlevel()

In [4]:
print('Creating whole_dataset_dicts...')
whole_dataset_dicts = []
dataset_path_task1 = 'NR-1.0' 
with open(dataset_path_task1, 'rb') as handle:
    whole_dataset_dicts.append(pickle.load(handle))
print('whole_dataset_dicts created')

Creating whole_dataset_dicts...
whole_dataset_dicts created


In [5]:
train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, max_len=config.time_len)
dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, max_len=config.time_len)
test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, max_len=config.time_len)


[INFO]loading 1 task datasets
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train divider = 240
dev divider = 270
[INFO]initializing a train set...


  normalized_data = (eeg_data - channel_means) / channel_stds



[INFO]loading 1 task datasets
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train divider = 240
dev divider = 270
[INFO]initializing a dev set...

[INFO]loading 1 task datasets
[INFO]using subjects:  ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH']
train divider = 240
dev divider = 270
[INFO]initializing a test set...



In [27]:
train_loader=DataLoader(
        train_set,drop_last=False,
        batch_size=2,
        shuffle=False
    )

In [33]:
# Each data point has the following
# raw_eeg,
# input_sample['target_string'],
# input_sample['target_ids'], 
# input_sample['target_mask'],  
# input_sample['input_embeddings'], # new
# input_sample['seq_len'], # new
# input_sample['input_attn_mask'], 
# input_sample['input_attn_mask_invert']

for _,item in enumerate(train_loader):
    print(item[5])

tensor([12, 19])
tensor([14, 28])
tensor([19, 22])
tensor([12, 18])
tensor([32, 23])
tensor([ 9, 16])
tensor([24, 27])
tensor([10, 35])
tensor([14,  7])
tensor([18, 18])
tensor([13, 11])
tensor([21, 12])
tensor([12, 14])
tensor([37, 11])
tensor([11, 38])
tensor([13, 21])
tensor([11,  7])
tensor([ 9, 11])
tensor([21,  7])
tensor([11, 14])
tensor([14,  6])
tensor([14, 17])
tensor([16, 18])
tensor([13, 15])
tensor([5, 7])
tensor([11, 13])
tensor([20, 10])
tensor([13, 14])
tensor([18,  6])
tensor([37, 14])
tensor([26, 33])
tensor([19, 10])
tensor([30,  7])
tensor([32, 10])
tensor([15, 40])
tensor([18,  8])
tensor([11, 10])
tensor([16, 13])
tensor([12, 12])
tensor([12,  8])
tensor([11,  9])
tensor([18, 17])
tensor([ 5, 10])
tensor([18,  9])
tensor([6, 8])
tensor([13, 14])
tensor([13,  8])
tensor([12,  6])
tensor([16,  9])
tensor([11, 17])
tensor([14, 17])
tensor([10, 15])
tensor([10,  9])
tensor([14,  3])
tensor([5, 9])
tensor([17,  5])
tensor([11, 31])
tensor([12, 10])
tensor([16, 12])
ten

In [None]:
word 5 
EEG word level feature 5*8*105


In [None]:
BS,56*840