In [5]:
from ast import Tuple
from mads_datasets import datatools
from pathlib import Path
from scipy.io import arff
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from typing import Protocol, List, Any, Mapping, Iterator, Optional, Callable, Sequence
from abc import ABC, abstractmethod
from pydantic import BaseModel
import numpy as np
from numpy import ndarray

In [9]:
class DatasetProtocol(Protocol):
    def __len__(self) -> int:
        ...

    def __getitem__(self, index: int) -> Any:
        ...

class ProcessingDatasetProtocol(DatasetProtocol):
    def process_data(self) -> None:
        ...

class DataStreamerProtocol(Protocol):
    def stream(self) -> Iterator:
        ...

class PreprocessorProtocol(Protocol):
    def __call__(self, batch: List[tuple]) -> tuple[torch.Tensor, torch.Tensor]:
        ...

class AbstractDataset(ProcessingDatasetProtocol):
    def __init__(self, data: Tuple) -> None:
        self.dataset: List = []
        self.process_data(data)
    
    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> Tuple:
        return self.dataset[index]
    
    @abstractmethod
    def process_data(self, data: List) -> None:
        raise NotImplementedError

class EegDataset(AbstractDataset):
    def process_data(self, data: ndarray) -> None:

        for set in data:
            set = set.tolist()
            self.dataset.append((torch.tensor(set[:-1]), torch.tensor(int(set[-1]))))


In [8]:
class BasePreprocessor(PreprocessorProtocol):
    def __call__(self, batch: list[tuple]) -> tuple[torch.Tensor, torch.Tensor]:
        X, y = zip(*batch)
        return torch.stack(X), torch.stack(y)


class PaddedPreprocessor(PreprocessorProtocol):
    def __call__(self, batch: list[tuple]) -> tuple[torch.Tensor, torch.Tensor]:
        X, y = zip(*batch)
        X_ = pad_sequence(X, batch_first=True, padding_value=0)  # noqa N806
        return X_, torch.tensor(y)

class WindowingPreprocessor(PreprocessorProtocol):
    def __init__(self, window_size: int):
        self.window_size = window_size

    def __call__(self, batch: list[tuple]) -> tuple[torch.Tensor, torch.Tensor]:
        X, y = zip(*batch)

        X_windowed = [self.window_sequence(seq) for seq in X]
        X_padded = pad_sequence(X_windowed, batch_first=True, padding_value=0)

        return X_padded, torch.stack(y)

    def window_sequence(self, sequence):
        windows = [sequence[i:i + self.window_size] for i in range(0, len(sequence), self.window_size)]
        return torch.stack(windows)

class BaseDatastreamer(DataStreamerProtocol):
    def __init__(
        self, 
        dataset: DatasetProtocol, 
        batch_size: int, 
        preprocessor: Optional[Callable] = None
    ) -> None:
        self.dataset = dataset
        self.batch_size = batch_size

        if preprocessor == None:
            self.preprocessor = lambda x: zip(*x)
        else:
            self.preprocessor = preprocessor
        
        self.size = len(self.dataset)
        self.reset_index()
    
    def __len__(self) -> int:
        return int(len(self.dataset) / self.batch_size)
    
    def reset_index(self) -> None:
        self.index_list = np.random.permutation(self.size)
        self.index = 0
    
    def batchloop(self) -> Sequence[Tuple]:
        batch = []
        for _ in range(self.batch_size):
            x, y = self.dataset[self.index_list[self.index]]
            batch.append((x, y))
            self.index += 1
        return batch
    
    def stream(self) -> Iterator:
        if self.index > (self.size - self.batch_size):
            self.reset_index()
        batch = self.batchloop()
        X, Y = self.preprocessor(batch)
        yield X, Y




In [12]:
class DatasetSettings(BaseModel):
    dataset_url: str
    data_dir: Path
    filename: Path
    name: str
    unzip: bool
    
eegDatasetSettings = DatasetSettings(
    dataset_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00264/EEG%20Eye%20State.arff",
    data_dir = Path("../data/eeg").resolve(),
    filename = Path("eeg.arff"),
    name = "EEG",
    unzip = False
)

class AbstractDatasetFactory(ABC):
    def __init__(self, settings: DatasetSettings) -> None:
        self.settings = settings
    
    def download_data(self) -> None:
        data_dir = self.settings.data_dir
        filename = self.settings.filename
        data_path = data_dir / filename

        if not data_path.exists():
            data_dir.mkdir(parents=True)
            datatools.get_file(
                data_dir=data_dir, 
                filename=filename, 
                url=self.settings.dataset_url, 
                unzip=False
            )
    
    @abstractmethod
    def create_dataset(self) -> Mapping[str, DatasetProtocol]:
        raise NotImplementedError
    
    def create_datastreamer(
            self, batch_size: int, **kwargs
        ) -> Mapping[str, DataStreamerProtocol]:
        datasets = self.create_dataset()
        train_dataset = datasets["train"]
        valid_dataset = datasets["valid"]
        preprocessor: Optional[Callable] = kwargs.pop("preprocessor", None)

        train_streamer = BaseDatastreamer(
            train_dataset, batch_size=batch_size, preprocessor=preprocessor
        )
        valid_streamer = BaseDatastreamer(
            valid_dataset, batch_size=batch_size, preprocessor=preprocessor
        )

        return {
            "train": train_streamer,
            "valid": valid_streamer
        }

class EegDatasetFactory(AbstractDatasetFactory):
    def __init__(self, settings: DatasetSettings = eegDatasetSettings) -> None:
        super().__init__(settings)
        self.datasets = Mapping[str, DatasetProtocol]
    
    def create_dataset(self) -> Mapping[str, DatasetProtocol]:
        self.download_data()

        data_path = self.settings.data_dir / self.settings.filename
        dataset = arff.loadarff(data_path)[0]

        split = int(0.8 * len(dataset))
        train_dataset = EegDataset(dataset[:split])
        valid_dataset = EegDataset(dataset[split:])

        datasets = {
            "train": train_dataset,
            "valid": valid_dataset
        }
        self.datasets = datasets

        return datasets

In [19]:
eegDatasetFactory = EegDatasetFactory()
eegDataset = eegDatasetFactory.create_dataset()
eegdatastreamer = eegDatasetFactory.create_datastreamer(5)

In [26]:
train_stream = eegdatastreamer["train"].stream()
next(iter(train_stream))

((tensor([4289.2300, 4011.7900, 4248.2100, 4134.3599, 4338.9702, 4622.5601,
          4079.4900, 4603.5898, 4192.8198, 4233.3301, 4195.3799, 4274.8701,
          4587.6899, 4358.4600]),
  tensor([4307.6899, 4007.1799, 4245.6401, 4120.5098, 4340.0000, 4632.3101,
          4084.6201, 4603.0801, 4188.7202, 4237.4399, 4197.9502, 4282.5601,
          4607.1802, 4380.5098]),
  tensor([4306.6699, 4002.5601, 4259.4902, 4116.4102, 4328.7202, 4605.6401,
          4056.4099, 4614.3599, 4187.1802, 4223.5898, 4187.1802, 4265.6401,
          4607.6899, 4362.0498]),
  tensor([4269.2300, 3985.6399, 4245.1299, 4109.2300, 4320.0000, 4625.6401,
          4058.9700, 4610.7700, 4196.4102, 4230.2598, 4195.8999, 4265.1299,
          4606.1499, 4339.4902]),
  tensor([4275.3799, 3968.2100, 4249.2300, 4092.8201, 4321.5400, 4598.9702,
          4036.9199, 4590.7700, 4176.4102, 4221.5400, 4184.1001, 4249.7402,
          4582.5601, 4335.3799])),
 (tensor(0), tensor(0), tensor(1), tensor(0), tensor(0)))