In [1]:
import sys
sys.path.append("/ocean/projects/cis240129p/soederha/silent_speech")

import pandas as pd
import torch
from torch.utils.data import Subset, ConcatDataset
from pathlib import Path
import numpy as np
from lib import BrennanDataset

base_dir = Path("/ocean/projects/cis240129p/shared/data/eeg_alice")
subjects_used = ["S04", "S13", "S19"]  # exclude 'S05' - less channels

# ds = BrennanDataset(
#     root_dir=base_dir,
#     phoneme_dir=base_dir / "phonemes",
#     idx="S01",
#     phoneme_dict_path=base_dir / "phoneme_dict.txt",
# )

In [2]:
def create_datasets(subjects, base_dir):
    train_datasets = []
    test_datasets = []
    for subject in subjects:
        dataset = BrennanDataset(
            root_dir=base_dir,
            phoneme_dir=base_dir / "phonemes",
            idx=subject,
            phoneme_dict_path=base_dir / "phoneme_dict.txt",
        )
        num_data_points = len(dataset)

        # Split indices into train and test sets
        split_index = int(num_data_points * 0.8)
        train_indices = list(range(split_index))
        test_indices = list(range(split_index, num_data_points))

        # Create Subset datasets using indices
        train_dataset = Subset(dataset, train_indices)
        test_dataset = Subset(dataset, test_indices)

        train_datasets.append(train_dataset)
        test_datasets.append(test_dataset)
    return train_datasets, test_datasets


train_ds, test_ds = create_datasets(subjects_used, base_dir)
train_dataset = ConcatDataset(train_ds)
test_dataset = ConcatDataset(test_ds)
print(
    f"Train dataset length: {len(train_dataset)}, Test dataset length: {len(test_dataset)}"
)

Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S04.vhdr...
Setting channel info structure...
Reading 0 ... 368449  =      0.000 ...   736.898 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S13.vhdr...
Setting channel info structure...
Reading 0 ... 368274  =      0.000 ...   736.548 secs...
Extracting parameters from /ocean/projects/cis240129p/shared/data/eeg_alice/S19.vhdr...
Setting channel info structure...
Reading 0 ... 373374  =      0.000 ...   746.748 secs...
Train dataset length: 5109, Test dataset length: 1278


In [3]:
def collate_fn(batch):
    """
    A custom collate function that handles different types of data in a batch.
    It dynamically creates batches by converting arrays or lists to tensors and
    applies padding to variable-length sequences.
    """
    batch_dict = {}
    for key in batch[0].keys():
        batch_items = [item[key] for item in batch]
        if isinstance(batch_items[0], np.ndarray) or isinstance(
            batch_items[0], torch.Tensor
        ):
            if isinstance(batch_items[0], np.ndarray):
                batch_items = [torch.tensor(b) for b in batch_items]
            if len(batch_items[0].shape) > 0:
                batch_dict[key] = torch.nn.utils.rnn.pad_sequence(
                    batch_items, batch_first=True  # pad with zeros
                )
            else:
                batch_dict[key] = torch.stack(batch_items)
        else:
            batch_dict[key] = batch_items

    return batch_dict


train_dataloder = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=True,
    collate_fn=collate_fn,
)

test_dataloder = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,
    num_workers=1,
    shuffle=False,
    collate_fn=collate_fn,
)

In [23]:
item = train_dataset[0]
#print(item['eeg_feats'].shape)
'''for k, v in item.items():
    print(k)
    try:
        print(k, v.shape, type(v))
    except:
        print(k, type(v))'''
        
i=0
for batch in train_dataloder:
    print(batch['label'])
    if i > 4:
        break
    i+=1

['feel', 'wind']
['fallen', 'a']
['the', 'time']
['to', 'it']
['bat', 'getting']
['center', 'It']


In [16]:
# test dataloader
i = 0
for batch in train_dataloder:
    print(i)
    for k, v in batch.items():
        try:
            print(k)
            print(k, v.shape, type(v))
            print(v)
        except:
            print(k, type(v))
    i += 1
    if i > 4:
        break

0
label
label <class 'list'>
audio_feats
audio_feats torch.Size([2, 130, 128]) <class 'torch.Tensor'>
tensor([[[ -6.7778,  -6.8072,  -6.8244,  ..., -10.9782, -11.0637, -11.4997],
         [ -7.1028,  -7.7666, -11.2400,  ..., -10.1732, -10.6435, -11.4568],
         [ -9.6935,  -6.3227,  -5.6189,  ...,  -9.2405, -10.0268, -11.3830],
         ...,
         [ -7.1799,  -7.4217,  -7.7385,  ..., -11.3257, -11.4271, -11.5099],
         [ -8.7858,  -7.6728,  -7.1338,  ..., -10.9809, -11.4416, -11.5081],
         [ -6.7073,  -6.2596,  -5.9299,  ..., -11.2956, -11.3942, -11.4918]],

        [[ -6.3884,  -7.0069,  -9.1198,  ..., -11.4579, -11.4999, -11.5121],
         [ -6.1568,  -6.6722,  -7.8415,  ..., -11.4823, -11.4937, -11.5111],
         [ -6.6638,  -6.6713,  -6.6651,  ..., -11.4457, -11.4955, -11.5119],
         ...,
         [ -9.8596,  -9.6112,  -9.3950,  ..., -11.4810, -11.5087, -11.5129],
         [-10.8165, -10.8182, -10.8129,  ..., -11.5027, -11.5126, -11.5129],
         [ -7.8479,  