<a href="https://colab.research.google.com/github/Deepak-Mewada/NeuralDecoder/blob/main/BrainDecoder/Basics/BenchMarking_eager_and_lazy_loading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install mne



In [None]:
!pip install torch



In [None]:
!pip install braindecode



In [None]:
from itertools import product
import time

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

import mne
import numpy as np
import pandas as pd
import seaborn as sns

from braindecode.datasets import TUHAbnormal
from braindecode.preprocessing import create_fixed_length_windows
from braindecode.models import ShallowFBCSPNet, Deep4Net


mne.set_log_level('WARNING')

In [None]:
N_JOBS = 8
torch.backends.cudnn.benchmark = True  # Enables automatic algorithm optimizations
torch.set_num_threads(N_JOBS)

In [None]:
TUH_PATH = ('/storage/store/data/tuh_eeg/www.isip.piconepress.com/projects/'
            'tuh_eeg/downloads/tuh_eeg_abnormal/v2.0.0')

In [None]:
def load_example_data(preload, window_len_s, n_recordings=10, add_physician_reports=False, n_jobs=1):

    recording_ids = list(range(n_recordings))

    ds = TUHAbnormal(
        TUH_PATH, recording_ids=recording_ids,
        target_name='pathological',
        preload=preload)

    fs = ds.datasets[0].raw.info['sfreq']
    window_len_samples = int(fs * window_len_s)
    window_stride_samples = int(fs * 4)
    windows_ds = create_fixed_length_windows(
        ds, start_offset_samples=0, stop_offset_samples=None,
        window_size_samples=window_len_samples,
        window_stride_samples=window_stride_samples, drop_last_window=True,
        preload=preload)


    for ds in windows_ds.datasets:
        ds.windows.drop_bad()
        assert ds.windows.preload == preload

    return windows_ds


def create_example_model(n_channels, n_classes, window_len_samples,
                         kind='shallow', cuda=False):

    if kind == 'shallow':
        model = ShallowFBCSPNet(
            n_channels, n_classes, input_window_samples=window_len_samples,
            n_filters_time=40, filter_time_length=25, n_filters_spat=40,
            pool_time_length=75, pool_time_stride=15, final_conv_length='auto',
            split_first_layer=True, batch_norm=True, batch_norm_alpha=0.1,
            drop_prob=0.5)
    elif kind == 'deep':
        model = Deep4Net(
            n_channels, n_classes, input_window_samples=window_len_samples,
            final_conv_length='auto', n_filters_time=25, n_filters_spat=25,
            filter_time_length=10, pool_time_length=3, pool_time_stride=3,
            n_filters_2=50, filter_length_2=10, n_filters_3=100,
            filter_length_3=10, n_filters_4=200, filter_length_4=10,
            first_pool_mode="max", later_pool_mode="max", drop_prob=0.5,
            double_time_convs=False, split_first_layer=True, batch_norm=True,
            batch_norm_alpha=0.1, stride_before_pool=False)
    else:
        raise ValueError

    if cuda:
        model.cuda()

    optimizer = optim.Adam(model.parameters())
    loss = nn.NLLLoss()

    return model, loss, optimizer


def run_training(model, dataloader, loss, optimizer, n_epochs=1, cuda=False):
    for i in range(n_epochs):
        loss_vals = list()
        for X, y, _ in dataloader:
            model.train()
            model.zero_grad()

            y = y.long()
            if cuda:
                X, y = X.cuda(), y.cuda()

            loss_val = loss(model(X), y)
            loss_vals.append(loss_val.item())

            loss_val.backward()
            optimizer.step()

        print(f'Epoch {i + 1} - mean training loss: {np.mean(loss_vals)}')

    return model
PRELOAD = [True, False]
N_RECORDINGS = [10]
WINDOW_LEN_S = [2, 4, 15]
N_EPOCHS = [2]
BATCH_SIZE = [64, 256]
MODEL = ['shallow', 'deep']

NUM_WORKERS = [8, 0]
PIN_MEMORY = [False]
CUDA = [True, False] if torch.cuda.is_available() else [False]

N_REPETITIONS = 3

In [None]:
all_results = list()
for (i, preload, n_recordings, win_len_s, n_epochs, batch_size, model_kind,
        num_workers, pin_memory, cuda) in product(
            range(N_REPETITIONS), PRELOAD, N_RECORDINGS, WINDOW_LEN_S, N_EPOCHS,
            BATCH_SIZE, MODEL, NUM_WORKERS, PIN_MEMORY, CUDA):

    results = {
        'repetition': i,
        'preload': preload,
        'n_recordings': n_recordings,
        'win_len_s': win_len_s,
        'n_epochs': n_epochs,
        'batch_size': batch_size,
        'model_kind': model_kind,
        'num_workers': num_workers,
        'pin_memory': pin_memory,
        'cuda': cuda
    }
    print(f'\nRepetition {i + 1}/{N_REPETITIONS}:\n{results}')

    # Load the dataset
    data_loading_start = time.time()
    dataset = load_example_data(preload, win_len_s, n_recordings=n_recordings)
    data_loading_end = time.time()

    # Create the data loader
    training_setup_start = time.time()
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory,
        num_workers=num_workers, worker_init_fn=None)

    # Instantiate model and optimizer
    n_channels = len(dataset.datasets[0].windows.ch_names)
    n_times = len(dataset.datasets[0].windows.times)
    n_classes = 2
    model, loss, optimizer = create_example_model(
        n_channels, n_classes, n_times, kind=model_kind, cuda=cuda)
    training_setup_end = time.time()

    # Start training loop
    model_training_start = time.time()
    trained_model = run_training(
        model, dataloader, loss, optimizer, n_epochs=n_epochs, cuda=cuda)
    model_training_end = time.time()

    del dataset, model, loss, optimizer, trained_model

    # Record timing results
    results['data_preparation'] = data_loading_end - data_loading_start
    results['training_setup'] = training_setup_end - training_setup_start
    results['model_training'] = model_training_end - model_training_start
    all_results.append(results)



Repetition 1/3:
{'repetition': 0, 'preload': True, 'n_recordings': 10, 'win_len_s': 2, 'n_epochs': 2, 'batch_size': 64, 'model_kind': 'shallow', 'num_workers': 8, 'pin_memory': False, 'cuda': False}


KeyError: 'year'