In [None]:
!pip install -q pybaselines

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft, ifft
import pandas as pd
from pybaselines import Baseline, utils
import os
import torch
from torch.utils.data import Dataset
import pandas as pd


In [None]:

class SignalDataset(Dataset):
    def __init__(self, training_data_paths, baseline_correction=None, filter_function=None):
        self.training_data_paths = training_data_paths
        self.baseline_correction = baseline_correction
        self.filter_function = filter_function
        
        data_path = get_data(self.training_data_paths)
        self.signals = []
        self.labels = []
        
        for path in data_path:
            signal_path, label_path = path
            signal = np.load(signal_path)
            label = np.load(label_path)
        
            segments_and_labels = self.get_segments_and_labels(
                signal, label, self.baseline_correction, self.filter_function
            )
            for segments in segments_and_labels:
                segment, label = segments
                self.signals.append(segment)
                self.labels.append(label)
        
            
    
    def get_data(self, path: str):
        files = os.listdir(path)
        subject_id_set = set()

        for file in files:
            result = re.search("(s.*p[0-9]+)_([0-9]+)", file)
            subject_id_set.add(result.group(1))

        pairs_data = []
        for subject in subject_id_set:
            for i in range(8):
                signal_path = "_".join([subject, str(i).zfill(3), "data", "time", "series"]) + ".npy"
                label_path = "_".join([subject, str(i).zfill(3), "label", "time", "series"]) + ".npy"

                signal_path = os.path.join(path, signal_path)
                label_path = os.path.join(path, label_path)
                if not os.path.exists(signal_path) or not os.path.exists(label_path):
                    continue
                pairs_data.append((signal_path, label_path))

        return pairs_data

    def get_segments_and_labels(self, signals, labels, baseline=None, signal_filter=None):
        all_signals = []

        for channel in signals.T:
            if baseline != None:
                channel -= baseline(channel)
            if signal_filter != None:
                channel = signal_filter(channel)

            all_signals.append(channel)

        all_signals = np.array(all_signals).T
        slices = []
        for slice in range(30):
            slices.append(all_signals[slice * 1750: (slice + 1) * 1750])

        return tuple(zip(slices, labels))
    
    def __len__(self):
        return len(self.signals)

    def __getitem__(self, idx):
        return self.signals[idx], self.labels[idx]
        
def baseline_snip(self, signal):
    range_sample = range(len(signal))
    baseline_fitter = Baseline(x_data=range_sample)
    baseline, _ = baseline_fitter.snip(
        signal, max_half_window=40, decreasing=True, smooth_half_window=3
    )
    return baseline

In [None]:
signal_dataset = SignalDataset("/kaggle/input/brain-motor-imagery-classification/train/train")

for segment, label in signal_dataset:
    print(segment.shape, label)
    break