# Utilizing FIFDataLoader for EEG Data Augmentation with DDPMs

This notebook demonstrates how to use the `FIFDataLoader` for loading EEG data from `.fif` files. We will explore setting up the dataloader for both unconditional and conditional (class-based) training of Denoising Diffusion Probabilistic Models (DDPMs).

The primary goal is to simulate new EEG signals, which can be particularly useful for data augmentation, such as balancing datasets with underrepresented classes. For this demonstration, we will consider three classes: Dementia, Mild Cognitive Decline (MCI), and Healthy Controls.

## 1. Setup: Importing Libraries

In [None]:
import os
import shutil
import glob
import numpy as np
import mne
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Assuming 'ntd' is in PYTHONPATH or the notebook is run from the repo root
from ntd.datasets import FIFDataLoader
# from ntd.diffusion_model import DiffusionModel # Placeholder for model, if full training is shown
# from ntd.networks import AdaConv # Placeholder for network, if full training is shown
# import hydra # For config loading (optional, can also manually define configs)
# from omegaconf import OmegaConf # For config loading

# Configure matplotlib for inline plotting
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plt.style.use('seaborn-v0_8-whitegrid')

print("Libraries imported.")

## 2. Data Preparation: Creating Dummy .fif Files

For this demonstration, we will create a set of dummy `.fif` files. In a real-world scenario, you would point the `FIFDataLoader` to your existing directory of `.fif` files.

The dummy data will have the following parameters:
- Sampling frequency (`sfreq`): 200 Hz
- Number of channels (`n_channels`): 19 (standard EEG channels)
- Epoch duration: 5 seconds (resulting in `n_times = sfreq * epoch_duration = 1000` time points per epoch)
- Number of epochs per subject file (`n_epochs_per_subject`): 20

In [None]:
def create_dummy_fif_files(base_path, class_name, subject_ids, n_epochs_per_subject=20, sfreq=200, n_channels=19, epoch_duration=5):
    """
    Creates dummy .fif files for a given class and subject IDs directly in base_path.
    Each file will contain random EEG-like data.
    """
    # Files go into base_path, class information is in the filename
    os.makedirs(base_path, exist_ok=True) # Ensure base_path exists
    n_times = int(sfreq * epoch_duration)
    
    created_files = []
    for subj_id in subject_ids:
        # Create random data: (n_epochs, n_channels, n_times)
        # Simulating EEG voltage levels (e.g., microvolts)
        data = np.random.randn(n_epochs_per_subject, n_channels, n_times) * 10e-6 
        
        # Create MNE info object
        ch_names = [f'EEG {i+1:02}' for i in range(n_channels)]
        ch_types = ['eeg'] * n_channels
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
        info.set_montage('standard_1020', on_missing='ignore') # Add a standard montage
        
        # Create MNE EpochsArray
        # Events are simplified for this dummy data
        events = np.array([[i, 0, 1] for i in range(n_epochs_per_subject)])
        event_id = {'dummy_event': 1}
        epochs = mne.EpochsArray(data, info, events=events, event_id=event_id, tmin=0, verbose=False)
        
        # Filename convention: CLASSNAME_SUBJECTID_eeg.fif (e.g., DEMENTIA_NBS001_eeg.fif)
        file_name = f"{class_name.upper()}_{subj_id}_eeg.fif" 
        file_path = os.path.join(base_path, file_name) # Save directly in base_path
        epochs.save(file_path, overwrite=True, verbose=False)
        created_files.append(file_path)
        print(f"Created: {file_path}")
    return base_path, created_files # Return base_path as it's the location of files

We will create a small number of dummy subjects for each class:
- Dementia: 2 subjects
- Mild Cognitive Decline (MCI): 2 subjects
- Healthy Controls (CONTROL): 2 subjects

In a real scenario, you would have many more subjects per class (e.g., Dementia: 216, MCI: 318, CONTROL: 433, as per the original dataset description). The `FIFDataLoader` will use the first part of the filename (before the first `_`) as the class label if `condition_on_class_label=True`. The subject ID is typically the second part. For example, `DEMENTIA_NBS001_eeg.fif` (as created by our function).

In [None]:
# Base path for dummy data
DUMMY_DATA_PARENT_DIR = "dummy_eeg_data_notebook" # Changed to avoid conflict with test data

# Clean up old dummy data if it exists
if os.path.exists(DUMMY_DATA_PARENT_DIR):
    shutil.rmtree(DUMMY_DATA_PARENT_DIR)
os.makedirs(DUMMY_DATA_PARENT_DIR)

# Define subjects for each class
class_subjects_map = {
    "DEMENTIA": ["NBS001", "NBS002"], # Using "NBS" prefix for "Notebook Subject"
    "MCI": ["NBS003", "NBS004"],
    "CONTROL": ["NBS005", "NBS006"]
}

all_created_fif_files = []
for class_label, subject_list in class_subjects_map.items():
    # The function now saves files directly into DUMMY_DATA_PARENT_DIR
    _returned_path, created_files_for_class = create_dummy_fif_files(DUMMY_DATA_PARENT_DIR, class_label, subject_list)
    all_created_fif_files.extend(created_files_for_class)

print(f"\nAll dummy .fif files created in parent directory: {os.path.abspath(DUMMY_DATA_PARENT_DIR)}")
print(f"Total files created: {len(all_created_fif_files)}")

# For loading, we can point FIFDataLoader to DUMMY_DATA_PARENT_DIR.
# FIFDataLoader will find all *.fif files directly in this directory.

## 3. Loading Data: Unconditional Case
First, let's see how to load the data without any conditioning on subject ID or class label. We will point the dataloader to the parent directory containing all class subdirectories. `FIFDataLoader` will recursively find all `.fif` files.

In [None]:
# Configuration for unconditional loading
unconditional_config = {
    'file_path': DUMMY_DATA_PARENT_DIR, # Point to the parent directory of all classes
    'n_epochs': 20, # Number of epochs to load from each file
    'condition_on_subject_id': False,
    'condition_on_class_label': False
}

print(f"Loading unconditional data from: {unconditional_config['file_path']}")
unconditional_dataset = FIFDataLoader(**unconditional_config)
unconditional_dataloader = DataLoader(unconditional_dataset, batch_size=4, shuffle=True)

print(f"Total epochs loaded (unconditional): {len(unconditional_dataset)}")
print(f"Number of batches: {len(unconditional_dataloader)}")

# Fetch and inspect a batch
try:
    batch_unconditional = next(iter(unconditional_dataloader))
    print(f"Batch signal shape: {batch_unconditional['signal'].shape}") # Expected: (batch_size, n_channels, n_times)
    print(f"Keys in batch: {batch_unconditional.keys()}")
    
    # Plot a sample from the batch
    sample_signal_unconditional = batch_unconditional['signal'][0].numpy()
    plt.figure(figsize=(10, 4))
    plt.plot(sample_signal_unconditional[0, :]) # Plot first channel
    plt.title("Sample EEG Signal (Unconditional - First Channel)")
    plt.xlabel("Time Points")
    plt.ylabel("Amplitude (Simulated)")
    plt.show()
except StopIteration:
    print("Dataloader is empty. Ensure dummy files were created and DUMMY_DATA_PARENT_DIR is correct.")

## 4. Loading Data: Conditional on Class Label
Now, let's configure the dataloader to be aware of class labels. This is essential for training conditional DDPMs. The `FIFDataLoader` extracts class labels from the first part of the filename (e.g., `DEMENTIA` from `DEMENTIA_NBS001_eeg.fif`).

In [None]:
# Configuration for class-conditional loading
class_conditional_config = {
    'file_path': DUMMY_DATA_PARENT_DIR,
    'n_epochs': 20,
    'condition_on_subject_id': False, # Can also be True
    'condition_on_class_label': True
}

print(f"Loading class-conditional data from: {class_conditional_config['file_path']}")
class_conditional_dataset = FIFDataLoader(**class_conditional_config)
class_conditional_dataloader = DataLoader(class_conditional_dataset, batch_size=4, shuffle=True)

print(f"Total epochs loaded (class-conditional): {len(class_conditional_dataset)}")
print(f"Number of batches: {len(class_conditional_dataloader)}")
    
# Inspect label mapping
if hasattr(class_conditional_dataset, 'label_to_int_id'):
    print(f"Class label to integer ID mapping: {class_conditional_dataset.label_to_int_id}")

# Fetch and inspect a batch
try:
    batch_class_conditional = next(iter(class_conditional_dataloader))
    print(f"Batch signal shape: {batch_class_conditional['signal'].shape}")
    print(f"Keys in batch: {batch_class_conditional.keys()}")
    print(f"Batch class labels: {batch_class_conditional['class_label']}")
    print(f"Batch class labels shape: {batch_class_conditional['class_label'].shape}")

    # Plot a sample and show its class
    sample_idx = 0
    sample_signal_class = batch_class_conditional['signal'][sample_idx].numpy()
    sample_label_int = batch_class_conditional['class_label'][sample_idx].item()
    
    # Reverse map integer label to string label for display
    if hasattr(class_conditional_dataset, 'label_to_int_id'):
        int_to_label_id = {v: k for k, v in class_conditional_dataset.label_to_int_id.items()}
        sample_label_str = int_to_label_id.get(sample_label_int, 'Unknown')
    else:
        sample_label_str = 'Unknown (mapping not found)'
    
    plt.figure(figsize=(10, 4))
    plt.plot(sample_signal_class[0, :]) # Plot first channel
    plt.title(f"Sample EEG Signal (Class: {sample_label_str} [{sample_label_int}] - First Channel)")
    plt.xlabel("Time Points")
    plt.ylabel("Amplitude (Simulated)")
    plt.show()
except StopIteration:
    print("Dataloader is empty. Ensure dummy files were created and DUMMY_DATA_PARENT_DIR is correct.")

## 5. Loading Data: Conditional on Subject ID
Similarly, we can condition on subject IDs. The `FIFDataLoader` extracts subject IDs from the second part of the filename (e.g., `NBS001` from `DEMENTIA_NBS001_eeg.fif`).

In [None]:
# Configuration for subject-conditional loading
subject_conditional_config = {
    'file_path': DUMMY_DATA_PARENT_DIR,
    'n_epochs': 20,
    'condition_on_subject_id': True,
    'condition_on_class_label': False # Can also be True
}

print(f"Loading subject-conditional data from: {subject_conditional_config['file_path']}")
subject_conditional_dataset = FIFDataLoader(**subject_conditional_config)
subject_conditional_dataloader = DataLoader(subject_conditional_dataset, batch_size=4, shuffle=True)

print(f"Total epochs loaded (subject-conditional): {len(subject_conditional_dataset)}")
print(f"Number of batches: {len(subject_conditional_dataloader)}")
    
# Inspect subject ID mapping
if hasattr(subject_conditional_dataset, 'subject_str_to_int_id'):
    print(f"Subject string to integer ID mapping: {subject_conditional_dataset.subject_str_to_int_id}")

# Fetch and inspect a batch
try:
    batch_subject_conditional = next(iter(subject_conditional_dataloader))
    print(f"Batch signal shape: {batch_subject_conditional['signal'].shape}")
    print(f"Keys in batch: {batch_subject_conditional.keys()}")
    print(f"Batch subject IDs: {batch_subject_conditional['subject_id']}")
    print(f"Batch subject IDs shape: {batch_subject_conditional['subject_id'].shape}")

    # Plot a sample and show its subject ID
    sample_idx = 0
    sample_signal_subj = batch_subject_conditional['signal'][sample_idx].numpy()
    sample_subj_id_int = batch_subject_conditional['subject_id'][sample_idx].item()
    
    # Reverse map integer ID to string ID for display
    if hasattr(subject_conditional_dataset, 'subject_str_to_int_id'):
        int_to_subj_id_str = {v: k for k, v in subject_conditional_dataset.subject_str_to_int_id.items()}
        sample_subj_id_str = int_to_subj_id_str.get(sample_subj_id_int, 'Unknown')
    else:
        sample_subj_id_str = 'Unknown (mapping not found)'
        
    plt.figure(figsize=(10, 4))
    plt.plot(sample_signal_subj[0, :]) # Plot first channel
    plt.title(f"Sample EEG Signal (Subject ID: {sample_subj_id_str} [{sample_subj_id_int}] - First Channel)")
    plt.xlabel("Time Points")
    plt.ylabel("Amplitude (Simulated)")
    plt.show()
except StopIteration:
    print("Dataloader is empty. Ensure dummy files were created and DUMMY_DATA_PARENT_DIR is correct.")

## 6. Next Steps: Training a DDPM and Simulating New Signals

With the `FIFDataLoader` set up and capable of providing EEG data in the required format (and optionally class/subject labels), the next stage involves training a Denoising Diffusion Probabilistic Model (DDPM) and then using it to simulate new EEG signals. This is particularly powerful for data augmentation, such as balancing datasets with underrepresented classes.

### Conceptual DDPM Training

The core idea is to train a model that learns to reverse a noise process. The `signal` tensor from our `FIFDataLoader` (e.g., `batch['signal']`) serves as the clean data input to this process.

#### A. Unconditional Model Training

If the goal is to generate diverse EEG signals representative of the entire dataset distribution (without targeting specific classes or subjects):

1.  **Data:** Use `FIFDataLoader` with `condition_on_class_label=False` and `condition_on_subject_id=False`.
2.  **Network Architecture:**
    *   A good starting point could be adapting a configuration like `conf/network/ada_conv_ner.yaml`.
    *   **Key parameters to adjust:**
        *   `signal_channel`: Set this to **19** for our EEG data.
        *   `in_kernel_size`, `slconv_kernel_size`, `num_scales`: These may need tuning based on your data's specific characteristics (e.g., complexity, temporal dependencies). Refer to the Q&A email for guidance on these.
3.  **Model:** The DDPM would typically be an instance of a class like `ntd.diffusion_model.DiffusionModel`, which wraps the chosen network (e.g., `ntd.networks.AdaConv`).
4.  **Training Loop:** This involves:
    *   An optimizer (e.g., AdamW).
    *   Iteratively feeding batches of signals to the `DiffusionModel`.
    *   Calculating and minimizing the diffusion loss, which trains the network to predict and remove the added noise at each step of the diffusion process.

#### B. Class-Conditional Model Training (for Data Augmentation)

To generate data for *specific* classes (e.g., to augment the 'DEMENTIA' class):

1.  **Data:** Use `FIFDataLoader` with `condition_on_class_label=True`. This provides both `signal` and `class_label` tensors in each batch.
2.  **Network Architecture:**
    *   A starting point could be `conf/network/ada_conv_tycho.yaml`.
    *   **Key parameters to adjust:**
        *   `signal_channel`: Set to **19**.
        *   `cond_dim`: This is crucial. Set it to the **number of unique classes** in your dataset (e.g., 3 for 'DEMENTIA', 'MCI', 'CONTROL'). The `FIFDataLoader` makes the integer-mapped class labels available (e.g., in `class_conditional_dataset.label_to_int_id`).
        *   Other parameters like `in_kernel_size`, `slconv_kernel_size` as needed.
3.  **Model:** The `DiffusionModel` (wrapping a network like `AdaConv`) would be configured to accept these class labels as conditional input. The network internally uses this conditional information (often via adaptive layer normalization or similar mechanisms) to guide the denoising process.
4.  **Training Loop:** Similar to unconditional training, but the `class_label` tensor is passed to the model along with the `signal` tensor during each training step.

#### C. Subject-Conditional Model Training (Brief Mention)

For generating data specific to individual subjects:

1.  **Data:** Use `FIFDataLoader` with `condition_on_subject_id=True`.
2.  **Network & Model:** Similar to class-conditional, but `cond_dim` in the network configuration would correspond to the number of unique subjects. The `subject_id` tensor would be used as the condition. This can be useful for generating more data for subjects with limited recordings.

### Simulating New EEG Signals

Once your DDPM is trained:

1.  **Unconditional Model:**
    *   Call the model's `sample(num_samples)` method (or a similar generation function).
    *   This will generate `num_samples` new EEG epochs, each reflecting the general characteristics of the training data.
2.  **Class-Conditional Model:**
    *   Call `sample(num_samples, condition_labels)` (actual signature might vary).
    *   `condition_labels` would be a tensor of integer class IDs for which you want to generate data. For example, to generate 100 new 'DEMENTIA' epochs, you'd pass the integer ID corresponding to 'DEMENTIA' repeated 100 times.
    *   This is the key to targeted data augmentation for balancing classes.

### Configuration Files

*   Remember to set up your experiment using Hydra configuration files.
*   Start with `conf/dataset/fif_data_example.yaml` for the dataset component, ensuring `file_path` points to your data and conditional flags are set correctly.
*   Create or adapt network configuration files (e.g., from `conf/network/ada_conv_ner.yaml` or `ada_conv_tycho.yaml`) adjusting `signal_channel`, `cond_dim` (if conditional), and other hyperparameters like `in_kernel_size`, `slconv_kernel_size`, `num_scales` as discussed in project documentation or Q&A.
*   Set up the main experiment configuration file that ties together the dataset, network, diffusion, and training parameters.

In [None]:
# --- Conceptual Placeholder Code Snippets for Model Training ---
# Note: These are highly simplified and illustrative. 
# Actual implementation depends on your project's training scripts and Hydra setup.

# --- Unconditional Model Example ---
print("\n--- Conceptual Unconditional Model Setup ---")
# 1. Configuration (Illustrative - assuming Hydra is used elsewhere to load these)
#    Typically, you'd load a main config file that composes dataset, network, diffusion configs.
#    For example, if you have 'conf/experiments/my_unconditional_exp.yaml':
#    # import hydra
#    # from omegaconf import OmegaConf
#    # @hydra.main(config_path="../../conf", config_name="experiments/my_unconditional_exp.yaml") # Adjust path as needed
#    # def my_app(cfg):
#    #    dataset_cfg = cfg.dataset 
#    #    network_cfg = cfg.network 
#    #    diffusion_cfg = cfg.diffusion

# 2. Dataloader (using settings from this notebook for unconditional case)
#    (Make sure DUMMY_DATA_PARENT_DIR is still populated if running this cell standalone)
if os.path.exists(DUMMY_DATA_PARENT_DIR):
    unconditional_ds_placeholder = FIFDataLoader(
        file_path=DUMMY_DATA_PARENT_DIR, 
        n_epochs=20, 
        condition_on_class_label=False, 
        condition_on_subject_id=False
    )
    unconditional_loader_placeholder = DataLoader(unconditional_ds_placeholder, batch_size=4, shuffle=True)
    print(f"Unconditional DataLoader ready: {len(unconditional_ds_placeholder)} samples.")

    # 3. Network & Diffusion Model (Conceptual - actual classes from ntd.*)
    #    from ntd.networks import AdaConv # Example network
    #    from ntd.diffusion_model import DiffusionModel # Example model
    #    # Assuming network_cfg and diffusion_cfg are OmegaConf dicts from Hydra or manually defined
    #    # network_cfg_dict = {'signal_channel': 19, 'in_kernel_size': 17, ...} # Simplified example
    #    # diffusion_cfg_dict = {'num_diff_steps': 1000, ...} # Simplified example
    #    # network = AdaConv(**network_cfg_dict)
    #    # model = DiffusionModel(network=network, **diffusion_cfg_dict) 
    #    print(f"Conceptual Unconditional DiffusionModel instantiated (placeholder). Actual instantiation requires full configs.")
    #    # Training loop would follow:
    #    # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    #    # for epoch in range(num_epochs):
    #    #     for batch in unconditional_loader_placeholder:
    #    #         signals = batch['signal']
    #    #         loss = model.compute_loss(signals) # Or model(signals) depending on API
    #    #         optimizer.zero_grad()
    #    #         loss.backward()
    #    #         optimizer.step()
else:
    print(f"Dummy data directory {DUMMY_DATA_PARENT_DIR} not found. Skipping unconditional conceptual setup.")


# --- Class-Conditional Model Example ---
print("\n--- Conceptual Class-Conditional Model Setup ---")
# 1. Configuration (Illustrative)
#    # @hydra.main(config_path="../../conf", config_name="experiments/my_conditional_exp.yaml") 
#    # def my_conditional_app(cfg_cond):
#    #    dataset_cfg_cond = cfg_cond.dataset
#    #    network_cfg_cond = cfg_cond.network # Ensure cond_dim is set (e.g., to 3 for our classes)
#    #    diffusion_cfg_cond = cfg_cond.diffusion

# 2. Dataloader (using settings from this notebook for class-conditional case)
if os.path.exists(DUMMY_DATA_PARENT_DIR):
    class_conditional_ds_placeholder = FIFDataLoader(
        file_path=DUMMY_DATA_PARENT_DIR, 
        n_epochs=20, 
        condition_on_class_label=True,
        condition_on_subject_id=False
    )
    class_conditional_loader_placeholder = DataLoader(class_conditional_ds_placeholder, batch_size=4, shuffle=True)
    print(f"Class-Conditional DataLoader ready: {len(class_conditional_ds_placeholder)} samples.")
    if hasattr(class_conditional_ds_placeholder, 'label_to_int_id'):
        print(f"Class mapping: {class_conditional_ds_placeholder.label_to_int_id}")
    
    # Fetch a sample batch to show structure
    try:
        first_batch_cond = next(iter(class_conditional_loader_placeholder))
        signals_cond = first_batch_cond['signal']
        labels_cond = first_batch_cond['class_label']
        print(f"Sample conditional batch - signals shape: {signals_cond.shape}, labels: {labels_cond}")
    except StopIteration:
        print("Conditional loader is empty (perhaps DUMMY_DATA_PARENT_DIR is empty).")

    # 3. Network & Diffusion Model (Conceptual)
    #    # network_cfg_cond_dict = {'signal_channel': 19, 'cond_dim': 3, ...} # Example for 3 classes
    #    # network_cond = AdaConv(**network_cfg_cond_dict)
    #    # model_cond = DiffusionModel(network=network_cond, **diffusion_cfg_dict) # diffusion_cfg_dict from above
    #    print(f"Conceptual Class-Conditional DiffusionModel instantiated (placeholder). Actual instantiation requires full configs.")
    #    # Training loop would pass labels:
    #    # optimizer_cond = torch.optim.AdamW(model_cond.parameters(), lr=1e-4)
    #    # for epoch in range(num_epochs):
    #    #     for batch in class_conditional_loader_placeholder:
    #    #         signals = batch['signal']
    #    #         class_labels = batch['class_label']
    #    #         loss = model_cond.compute_loss(signals, cond=class_labels) # Or model(signals, cond=class_labels)
    #    #         optimizer_cond.zero_grad()
    #    #         loss.backward()
    #    #         optimizer_cond.step()
    #
    #    # Simulating data for a specific class (e.g., class_id 0 - DEMENTIA)
    #    # num_new_samples = 10
    #    # class_id_to_generate = 0 
    #    # target_class_id_tensor = torch.tensor([class_id_to_generate] * num_new_samples, dtype=torch.long)
    #    # synthetic_signals = model_cond.sample(num_samples=num_new_samples, condition_labels=target_class_id_tensor)
    #    # print(f"Conceptual: Generated {synthetic_signals.shape[0]} samples for class ID {class_id_to_generate}.")
else:
    print(f"Dummy data directory {DUMMY_DATA_PARENT_DIR} not found. Skipping class-conditional conceptual setup.")

## 7. Cleanup
Remove the dummy data directory created for this notebook.

In [None]:
if os.path.exists(DUMMY_DATA_PARENT_DIR):
    shutil.rmtree(DUMMY_DATA_PARENT_DIR)
    print(f"Cleaned up dummy data directory: {DUMMY_DATA_PARENT_DIR}")
else:
    print("Dummy data directory not found, no cleanup needed.")