In [1]:
import torch
from pathlib import Path
import itertools
import numpy as np
import time
import random
from random import randrange
    
###Test Streaming DataLoader with PyTorch####
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, filePaths, frameLength, hopSize):
            super(MyIterableDataset).__init__()
            self.filePaths = self.group_recordings(filePaths)
            self.frameLength = frameLength
            self.hopSize = hopSize
            self.filePage = len(self.filePaths)
            self.filePool = list(range(self.filePage))
            random.shuffle(self.filePool)

            self.currentFileIndx = 0
            self.CurrentEEG = []
            self.CurrentAudio = []
            self.samplePosistions = []

            self.currentSampleIndex = 0
            self.loadDataToBuffer(self.currentFileIndx)
            
            self.samplePosMap = []
            self.generateSamplePostion()
            
    def group_recordings(self, files):
        #Group recordings and corresponding stimuli.
        new_files = []
        grouped = itertools.groupby(sorted(files), lambda x: "_-_".join(x.stem.split("_-_")[:3]))
        for recording_name, feature_paths in grouped:
            new_files += [sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x)]
        return new_files    
    import random
    
    def loadDataToBuffer(self,fileIndex):
        self.CurrentEEG = np.load(self.filePaths[self.filePool[self.currentFileIndx]][0]).astype(np.float32)
        self.CurrentAudio = np.load(self.filePaths[self.filePool[self.currentFileIndx]][1]).astype(np.float32)

    
    def generateSamplePostion(self):
        count = 0
        for i in range(self.filePage):
            tempAudio = np.load(self.filePaths[i][1]).astype(np.float32)
            totalLength,_ = tempAudio.shape
            startPos = [*range(self.frameLength, totalLength+1, self.hopSize)]
            self.samplePosMap.append(startPos)
            noData = (totalLength-self.frameLength)//self.hopSize + 1
            assert len(startPos)==noData
            count += noData
        return count
    def sample_random_data_number_in_one_batch(self,n, total):
    #Return a randomly chosen list of n nonnegative integers summing to total.
    #n: the number of total files    total: batch size
        return [x - 1 for x in self.constrained_sum_sample_pos(n, total + n)]
    
    def constrained_sum_sample_pos(self,n, total):
    #Return a randomly chosen list of n positive integers summing to total.Each such list is equally likely to occur."""
        dividers = sorted(random.sample(range(1, total), n - 1))
        return [a - b for a, b in zip(dividers + [total], [0] + dividers)]
            
    def __iter__(self):
       
        return self
    
    def __next__(self):
        if self.currentSampleIndex < len(self.samplePosMap[self.filePool[self.currentFileIndx]]): # still in the same file
            thisEnd = self.samplePosMap[self.filePool[self.currentFileIndx]][self.currentSampleIndex]
            self.currentSampleIndex += 1
            return self.CurrentEEG[thisEnd-self.frameLength:thisEnd,:], self.CurrentAudio[thisEnd-self.frameLength:thisEnd,:]
        else: # move to the next file
            #print("next file")
            #### need to shuffle samples from the last file
            random.shuffle(self.samplePosMap[self.filePool[self.currentFileIndx]])
            self.currentFileIndx +=1
            self.currentSampleIndex = 0
            if self.currentFileIndx < self.filePage: # still in the same iteration
                self.loadDataToBuffer(self.currentFileIndx)
                thisEnd = self.samplePosMap[self.filePool[self.currentFileIndx]][self.currentSampleIndex]
                self.currentSampleIndex += 1
                return self.CurrentEEG[thisEnd-self.frameLength:thisEnd,:], self.CurrentAudio[thisEnd-self.frameLength:thisEnd,:]
            else:
                #print("here 2")
                random.shuffle(self.filePool)
                self.currentFileIndx = 0
                self.loadDataToBuffer(self.currentFileIndx)
                raise StopIteration
                print("iteration done, restart")

In [2]:
import glob
import json
from pathlib import Path

 # Get the path to the config file
experiments_folder = "C:/Users/YLY/Documents/eegAudChallenge/auditory-eeg-challenge-2023-code/task2_regression"
task_folder = Path(experiments_folder)
config_path = task_folder/ "util/config.json"
    
with open(config_path) as fp:
    config = json.load(fp)

data_folder = Path(config["dataset_folder"])/ config["split_folder"]
stimulus_features = ["envelope"]
features = ["eeg"] + stimulus_features

train_files = [path for path in Path(data_folder).resolve().glob("train_-_*") if path.stem.split("_-_")[-1].split(".")[0] in features]

train_dataset = MyIterableDataset(train_files,64*10,64)

dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=False)

for ite in range(3):
    start_time = time.time()
    count = 0
    for batch in dataloader:
        count+=1
    print("total: ", (count-1)*64,batch[0][1,1,1],batch[1][1][1])
    print("--- %s seconds ---" % (time.time() - start_time))

total:  340544 tensor(1.3223) tensor([-0.4877])
--- 28.135335445404053 seconds ---
total:  340544 tensor(0.8434) tensor([-0.9459])
--- 34.69516396522522 seconds ---
total:  340544 tensor(-1.2050) tensor([-0.1638])
--- 37.483428716659546 seconds ---


In [3]:
print(batch[0].shape)

torch.Size([41, 640, 64])
