# Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
os.chdir(Path(os.path.abspath("")).parent)
from mros_data.datamodule import SleepEventDataModule

ModuleNotFoundError: No module named 'librosa'

# Datamodule

The `SleepEventDataModule` class contains logic to iterate over event data by wrapping a `SleepEventDataset` class.
The datamodule is also responsible for splitting the data into train/eval partitions using the `setup()` method, and the user can then get a PyTorch `DataLoader` for each partition from the respective `*_dataloader()` methods.

## Instantiate class

We pass a dictionary of parameters to the datamodule class in order to instantiate it.
The only event-specific parameters of note are `events`, `default_event_window_duration`, `fs`, and `picks`, corresponding to the event code/event name, duration of default events, sampling frequency, and the specific channels to include.

Any transformations of the input data, such as short-time Fourier or continuous wavelet transforms can be included by the `transform` parameter.

In [None]:
from mros_data.datamodule.transforms import STFTTransform

params = dict(
    data_dir="data/processed/mros/ar",
    batch_size=16,
    n_eval=2,
    n_test=2,
    num_workers=0,
    seed=1337,
    events={"ar": "Arousal"},
    window_duration=600,  # seconds
    cache_data=True,
    default_event_window_duration=[15],
    event_buffer_duration=3,
    factor_overlap=2,
    fs=128,
    matching_overlap=0.5,
    n_jobs=-1,
    n_records=10,
    picks=["c3", "c4", "eogl", 'eogr', 'chin'],
    # transform=MultitaperTransform(128, 0.5, 35.0, tw=8.0, normalize=True),
    transform=STFTTransform(
        fs=128, segment_size=int(4.0 * 128), step_size=int(0.125 * 128), nfft=1024, normalize=True
    ),
    scaling="robust",
)
dm = SleepEventDataModule(**params)
print(dm)

## Split dataset into train/eval partitions

In [None]:
# The datamodule will split the dataset into train/eval partitions by calling the setup() method.
dm.setup('fit')
train_dl, eval_dl = dm.train_dataloader(), dm.val_dataloader()

# The dataloaders are generators, ie. we can iterate over them using a for-loop.
for i, (data, events, records, *_) in enumerate(train_dl):
    if i < 1:
        print(f'Batch size: {data.shape[0]} | No. channels: {data.shape[1]} | No. timepoints {data.shape[2]} | No. events: {sum([ev.shape[0] for ev in events])} | Data sample size: {list(data.shape[1:])} ')
    break

## Access the underlying datasets

The underlying data windows can be accessed by indexing into the dataset. This will call the `__getitem__()` method and yield the signals, and associated events. 
The events' start times and durations are normalized to the window, ie. an event with elements (0.1, 0.025) in a 10 min window will start at 10 min x 60 s / min x 0.1 = 60 s , and will last 10 min x 60 s / min x 0.025 = 15 s.

In [None]:
train_ds = dm.train
for idx, batch in enumerate(train_ds):
    record = batch['record']
    data = batch['signal']
    events = batch['events']
    if len(events) > 5:
        break
print(batch.keys())
print(f'Record: {record} | No. channels: {data.shape[0]} | No. timepoints: {data.shape[1]} | No. events: {len(events)}')

## Plotting signals

We can plot signals in the underlying dataset by using the `plot_signals()` method in the `SleepEventDataset`. Simply provide an index in the range `[0, len(dataset)]` and optionally a list of the applied channels:

In [None]:
train_ds.plot_signals(idx, channel_names=['C3-A2', 'C4-A1', 'EOGL-A2', 'EOGR-A2', 'EMG'])#['Leg L', "Leg R"])

## Transforming data on the fly

By using the `transform` argument in the `SleepEventDataModule`, we can get spectrograms of the data as well.

In [None]:
train_ds.plot_spect(idx, channel_idx=0, window_size=int(4.0 * train_ds.fs), step_size=int(0.125 * train_ds.fs), nfft=1024)

We can also combine the plots by using the `plot()` method:

In [None]:
%matplotlib widget
train_ds.plot(idx, channel_names=['C3-A2', 'C4-A1', 'EOGL-A2', 'EOGR-A2', 'EMG'], channel_idx=0, window_size=int(4.0 * train_ds.fs), step_size=int(0.125 * train_ds.fs), nfft=1024)