In [None]:
import itertools
import os
import warnings
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
from IPython.display import display
from sklearn.decomposition import PCA
from umap import UMAP

warnings.filterwarnings("ignore")

BASE_RESULTS_DIR = "../PSD_ANALYSIS_RESULTS"
PSD_DATA_DIR = os.path.join(BASE_RESULTS_DIR, "PSD_DATA")

CONDITIONS = ["pre", "post", "follow"]  # pre, MI-SES, MI-IES, post, follow

TARGET_SUBJECT_ID = "sub-01"
TARGET_FREQ_BAND = "ALL"  # "ALL", "THETA", "ALPHA", "BETA"

FREQ_BANDS = {
    "ALL": (0, 40),
    "THETA": (4, 7),
    "ALPHA": (9, 13),
    "BETA": (14, 35),
}

# Define the Parameter Grid

umap_param_grid = {
    #
    "n_neighbors": [20],
    "n_components": [100],
    #
    "metric": ["euclidean"],
    "min_dist": [0.1],
    "spread": [1.0],
    #
    "learning_rate": [1.0],
    "local_connectivity": [1],
    "repulsion_strength": [1.0],
    "negative_sample_rate": [5],
}

In [8]:
def load_and_filter_data(
    subject_id: str, data_dir: str, conditions: List[str], freq_band_key: str
) -> Tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None, str]:
    """Loads PSD data for a subject and filters by conditions and frequency band.

    Args:
        subject_id (str): The ID of the current subject.
        data_dir (str): Path to the directory containing the PSD data files.
        conditions (List[str]): List of conditions to include.
        freq_band_key (str): Type of the frequency band to use.

    Returns:
        Tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None, str]: A tuple containing:
            1. X_filtered: np.ndarray | None
                Filtered PSD features (epochs x features).
            2. labels_filtered: np.ndarray | None
                Corresponding condition labels.
            3. run_labels_filtered: np.ndarray | None
                Corresponding run labels
            4. current_band_info: np.ndarray | None
                A string describing the applied frequency filter.
    """
    file_path = os.path.join(data_dir, f"{subject_id}_epoch_psd_data.npz")

    if not os.path.exists(file_path):
        print(f"❌ Error: Data file not found for {subject_id} at {file_path}")
        return None, None, None, ""

    data = np.load(file_path, allow_pickle=True)
    X = data["data_for_dr"]
    labels = data["labels"]
    run_labels = data["run_labels"]
    freqs = data["freqs"]

    # Filter by conditions
    mask = np.isin(labels, conditions)
    X_filtered = X[mask]
    labels_filtered = labels[mask]
    run_labels_filtered = run_labels[mask]

    # Filter by frequency band
    current_band_info = f"Band: {freq_band_key}"
    if freq_band_key != "ALL":
        if freq_band_key not in FREQ_BANDS:
            print(
                f"❌ Error: Frequency band '{freq_band_key}' not defined in FREQ_BANDS."
            )
            return None, None, None, ""

        f_min, f_max = FREQ_BANDS[freq_band_key]
        freq_indices = np.where((freqs >= f_min) & (freqs <= f_max))[0]

        if freq_indices.size == 0:
            print(
                f"⚠️ Warning: No frequencies found in range {f_min}-{f_max} Hz. Skipping."
            )
            return None, None, None, ""

        n_freqs_all = len(freqs)
        n_channels = X_filtered.shape[1] // n_freqs_all

        # Reshape to (N_epochs, N_channels, N_freqs)
        X_reshaped = X_filtered.reshape(X_filtered.shape[0], n_channels, n_freqs_all)

        # Select band and flatten back to (N_epochs, N_features_band_only)
        X_band_selected = X_reshaped[:, :, freq_indices]
        X_filtered = X_band_selected.reshape(X_band_selected.shape[0], -1)

        current_band_info = f"Band: {freq_band_key} ({f_min}-{f_max} Hz)"

    if X_filtered.shape[0] == 0:
        print("⚠️ Warning: No data points remaining after filtering. Skipping.")
        return None, None, None, ""

    print(f"Data loaded. Filtered shape for DR: {X_filtered.shape}")
    return X_filtered, labels_filtered, run_labels_filtered, current_band_info


def plot_umap_pca(
    X: np.ndarray,
    labels: np.ndarray,
    run_labels: np.ndarray,
    subject_id: str,
    band_info: str,
    umap_params: Dict[str, Any],
) -> None:
    """Applies UMAP and PCA, then generates and displays a Plotly 3D scatter plot

    Args:
        X (np.ndarray): Input feature array (epochs x features).
        labels (np.ndarray): Array of condition labels for each epoch.
        run_labels (np.ndarray): Array of run labels for each epoch.
        subject_id (str): The ID of the current subject.
        band_info (str): String describing the frequency band to use.
        umap_params (Dict[str, Any]): Dictionary containing UMAP hyperparameter values.
    """

    try:
        reducer = UMAP(
            n_neighbors=umap_params["n_neighbors"],
            n_components=umap_params["n_components"],
            metric=umap_params["metric"],
            min_dist=umap_params["min_dist"],
            spread=umap_params["spread"],
            learning_rate=umap_params["learning_rate"],
            local_connectivity=umap_params["local_connectivity"],
            repulsion_strength=umap_params["repulsion_strength"],
            negative_sample_rate=umap_params["negative_sample_rate"],
            random_state=42,
            verbose=False,
        )
        X_umap = reducer.fit_transform(X)
    except Exception as e:
        print(f"❌ UMAP failed with parameters: {umap_params}. Error: {e}")
        return

    pca = PCA(n_components=3)
    X_pca_3d = pca.fit_transform(X_umap)

    df = pd.DataFrame(X_pca_3d, columns=["PC 1", "PC 2", "PC 3"])
    df["Condition"] = labels
    df["Run"] = run_labels
    df["Condition_Run"] = df["Condition"].astype(str) + "_" + df["Run"].astype(str)

    umap_title = (
        f"N_N={umap_params['n_neighbors']}, N_C={umap_params['n_components']}, "
        # f"Met='{umap_params['metric']}', MinD={umap_params['min_dist']}, "
        # f"LR={umap_params['learning_rate']}<br>"
        # f"Spr={umap_params['spread']}, Conn={umap_params['local_connectivity']}, "
        # f"RepStr={umap_params['repulsion_strength']}, NegSR={umap_params['negative_sample_rate']}"
    )

    plot_title = (
        f"[{subject_id}] UMAP hyperparameter search<br>{band_info}<br>"
        f"UMAP: {umap_title}<br>"
    )

    fig = px.scatter_3d(
        df,
        x="PC 1",
        y="PC 2",
        z="PC 3",
        color="Condition",
        symbol="Condition_Run",
        hover_data=["Condition", "Run"],
        title=plot_title,
        opacity=0.7,
        height=700,
    )

    fig.update_traces(marker=dict(size=4))

    display(fig)


def run_hyperparameter_search(
    subject_id: str,
    data_dir: str,
    conditions: List[str],
    freq_band_key: str,
    param_grid: Dict[str, List[Any]],
) -> None:
    """Main function to run the grid search and plotting

    Args:
        subject_id (str): The ID of the current subject.
        data_dir (str): Path to the directory containing the PSD data files.
        conditions (List[str]): List of conditions to include.
        freq_band_key (str): Type of the frequency band to use.
        param_grid (Dict[str, List[Any]]): Dictionary where keys are UMAP hyperparameter
                                            names and values are lists of values to test.
    """

    print(f"--- Loading and filtering data for {subject_id} ({freq_band_key}) ---")

    X_filtered, labels_filtered, run_labels_filtered, band_info = load_and_filter_data(
        subject_id, data_dir, conditions, freq_band_key
    )

    if X_filtered is None:
        return

    # Generate all combinations of hyperparameters
    keys = list(param_grid.keys())
    values = list(param_grid.values())

    combinations = list(itertools.product(*values))

    print(f"--- Starting hyperparameter search ({len(combinations)} combinations) ---")

    for i, combo in enumerate(combinations):
        umap_params = dict(zip(keys, combo))
        print(
            f"\n[Iteration {i + 1}/{len(combinations)}]: Running UMAP with {umap_params}"
        )

        # Check for minimum neighbors requirement: n_neighbors must be less than n_samples
        if umap_params["n_neighbors"] >= X_filtered.shape[0]:
            print(
                f"⚠️ Skipping: n_neighbors ({umap_params['n_neighbors']}) must be less than number of samples ({X_filtered.shape[0]})."
            )
            continue

        plot_umap_pca(
            X_filtered,
            labels_filtered,
            run_labels_filtered,
            subject_id,
            band_info,
            umap_params,
        )

    print("\n--- Hyperparameter search complete. ---")

In [None]:
run_hyperparameter_search(
    TARGET_SUBJECT_ID,
    PSD_DATA_DIR,
    CONDITIONS,
    TARGET_FREQ_BAND,
    umap_param_grid,
)

--- Loading and filtering data for sub-01 (ALL) ---
Data loaded. Filtered shape for DR: (114, 12800)
--- Starting hyperparameter search (10 combinations) ---

[Iteration 1/10]: Running UMAP with {'n_neighbors': 5, 'n_components': 50, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 2/10]: Running UMAP with {'n_neighbors': 5, 'n_components': 100, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 3/10]: Running UMAP with {'n_neighbors': 10, 'n_components': 50, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 4/10]: Running UMAP with {'n_neighbors': 10, 'n_components': 100, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 5/10]: Running UMAP with {'n_neighbors': 20, 'n_components': 50, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 6/10]: Running UMAP with {'n_neighbors': 20, 'n_components': 100, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 7/10]: Running UMAP with {'n_neighbors': 50, 'n_components': 50, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 8/10]: Running UMAP with {'n_neighbors': 50, 'n_components': 100, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 9/10]: Running UMAP with {'n_neighbors': 100, 'n_components': 50, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



[Iteration 10/10]: Running UMAP with {'n_neighbors': 100, 'n_components': 100, 'metric': 'euclidean', 'min_dist': 0.1, 'spread': 1.0, 'learning_rate': 1.0, 'local_connectivity': 1, 'repulsion_strength': 1.0, 'negative_sample_rate': 5}



--- Hyperparameter search complete. ---
