In [1]:
import numpy as np
import pandas as pd 
from pathlib import Path
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F
from transformers import ViTMAEForPreTraining
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [None]:
class PreTrainDataset(Dataset):

    def __init__(
        self, 
        df: pd.DataFrame,
        all_specs: Dict[str, np.ndarray],
        all_eegs: Dict[str, np.ndarray],
    ): 
        self.df = df
        self.spectrograms = all_specs
        self.eeg_spectrograms = all_eegs
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, index):
        X, y = self.__data_generation(index)
        X, y = self.__transform(X, y)
        return X, y
    
    def __data_generation(self, index): # --> [(C=8) x (H=128) x (W=256)]
        
        row = self.df.iloc[index]
        r = int((row['min'] + row['max']) // 4)
        
        img_list = []
        for region in range(4):
            img = np.zeros((128, 256), dtype='float32')

            spectrogram = self.spectrograms[row['spectrogram_id']][r:r+300, region*100:(region+1)*100].T
            spectrogram = transform_spectrogram(spectrogram)
            
            img[14:-14, :] = spectrogram[:, 22:-22] / 2.0
            img_list.append(img)

        img = self.eeg_spectrograms[row['eeg_id']]
        img_list += [img[:, :, i] for i in range(4)]
      
        X = np.array(img_list, dtype='float32')
        X = torch.tensor(X, dtype=torch.float32)
                
        if (self.mode == 'train') or (self.mode == 'valid'):
            y = row[self.label_cols].values.astype(np.float32)
        elif self.mode == 'test':
            y = np.zeros(len(self.label_cols), dtype=np.float32)
        else:
            raise ValueError(f"Invalid mode {self.mode}!")
        
        y = torch.tensor(y, dtype=torch.float32)
        
        return X, y

    def __transform(self, x, y):

        return x, y