In [1]:
import itertools
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import plotly.express as px
from IPython.display import display
from sklearn.decomposition import PCA
from umap import UMAP

# =========================================================================
# ⚠️ НАСТРОЙКА КОНФИГУРАЦИИ (Configuration Setup)
# Замените эти значения на актуальные пути и константы из вашего config.py
# =========================================================================

# Предполагается, что данные хранятся в PSD_ANALYSIS_RESULTS/PSD_DATA
# Установите базовый путь к вашему проекту (например, ~/diplom)
BASE_DIR = os.path.expanduser("~/diplom")
PSD_DATA_DIR = os.path.join(BASE_DIR, "PSD_ANALYSIS_RESULTS", "PSD_DATA")

# Условия, по которым вы хотите фильтровать данные
CONDITIONS: List[str] = ["pre", "post", "follow"]

# Фиксированное число компонент PCA для 3D-графика
PCA_FINAL_COMPONENTS: int = 3

# Отключение предупреждений, чтобы не засорять вывод в ноутбуке
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


def load_subject_data(
    subject_id: str, data_dir: str, conditions: List[str]
) -> Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
    """Loads and filters PSD data for a single subject."""
    file_path = os.path.join(data_dir, f"{subject_id}_epoch_psd_data.npz")

    if not os.path.exists(file_path):
        print(f"❌ Data file not found for {subject_id}: {file_path}")
        return None

    try:
        data: Any = np.load(file_path, allow_pickle=True)
        X: np.ndarray = data[
            "data_for_dr"
        ]  # Flattened epoch data (N_epochs, N_features)
        labels: np.ndarray = data["labels"]
        run_labels: np.ndarray = data["run_labels"]

        # Filter data by conditions
        mask: np.ndarray = np.isin(labels, conditions)
        X_filtered: np.ndarray = X[mask]
        labels_filtered: np.ndarray = labels[mask]
        run_labels_filtered: np.ndarray = run_labels[mask]

        if X_filtered.shape[0] == 0:
            print(f"❌ No epochs found after filtering for conditions: {conditions}")
            return None

        print(
            f"✅ Data loaded for {subject_id}. Total epochs: {X_filtered.shape[0]}. Feature size: {X_filtered.shape[1]}"
        )
        return X_filtered, labels_filtered, run_labels_filtered

    except Exception as e:
        print(f"❌ Error loading data for {subject_id}: {e}")
        return None



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# =========================================================================
# 1. ЦЕЛЕВОЙ СУБЪЕКТ (Target Subject)
# =========================================================================
TARGET_SUBJECT_ID: str = "sub-01"

# =========================================================================
# 2. СЕТКА ПАРАМЕТРОВ UMAP (UMAP Parameter Grid)
# Задавайте списки значений, которые вы хотите протестировать
# =========================================================================
umap_param_grid: Dict[str, List[Any]] = {
    # Ключевые параметры для настройки
    "n_neighbors": [20],  # Number of nearest neighbors
    "n_components": [100],
    # "n_components": [
    #     10,
    #     50,
    #     100,
    # ],  # Intermediate dimensionality (before final PCA to 3D)
    # "min_dist": [0.0, 0.1, 0.5],  # Minimum distance apart for embedded points
    "min_dist": [0.1],  # Minimum distance apart for embedded points
    # Параметры, влияющие на форму кластеров
    # "metric": ["euclidean", "cosine"],  # Distance metric
    "metric": ["euclidean"],  # Distance metric
    "spread": [1.0],  # Effective scale of embedded points
    # Дополнительные параметры (можно оставить по умолчанию, если не нужны)
    "learning_rate": [1.0],
    "local_connectivity": [1],
    "repulsion_strength": [1.0],
    "negative_sample_rate": [5],
}

# =========================================================================
# 3. СЕТКА ПАРАМЕТРОВ PCA (PCA Parameter Grid)
# =========================================================================
pca_param_grid: Dict[str, List[Any]] = {
    # PCA parameter
    # "whiten": [False, True],  # Normalize principal components to unit variance
    "whiten": [False],  # Normalize principal components to unit variance
}

In [3]:
def run_hyperparameter_search(
    subject_id: str,
    data_root: str,
    conditions: List[str],
    umap_params: Dict[str, List[Any]],
    pca_params: Dict[str, List[Any]],
) -> None:
    """Iterates through hyperparameter combinations, runs DR, and displays plots."""

    # Load the data once
    data_result = load_subject_data(subject_id, data_root, conditions)
    if data_result is None:
        return

    X, labels, run_labels = data_result

    # Generate all UMAP combinations
    umap_keys = list(umap_params.keys())
    umap_combinations = [
        dict(zip(umap_keys, combo))
        for combo in itertools.product(*umap_params.values())
    ]

    # Generate all PCA combinations
    pca_keys = list(pca_params.keys())
    pca_combinations = [
        dict(zip(pca_keys, combo)) for combo in itertools.product(*pca_params.values())
    ]

    total_iterations = len(umap_combinations) * len(pca_combinations)
    print(f"Total hyperparameter combinations to test: {total_iterations}")

    count = 1
    for umap_config in umap_combinations:
        # Check if n_neighbors is compatible with data size
        n_neigh = umap_config.get("n_neighbors", 15)  # Default is 15
        if X.shape[0] < n_neigh:
            print(
                f"Skipping (Iter {count}): n_neighbors={n_neigh} is too large for {X.shape[0]} epochs. Remaining to skip: {total_iterations - count}"
            )
            count += len(pca_combinations)
            continue

        try:
            # Step 1: Run UMAP
            # Use random_state=42 for reproducibility across different hyperparameter combinations
            reducer = UMAP(
                **umap_config,
                random_state=42,
                verbose=False,
            )
            X_umap = reducer.fit_transform(X)
        except Exception as e:
            print(
                f"⚠️ UMAP failed for config {umap_config}: {e}. Skipping related PCA runs."
            )
            count += len(pca_combinations)
            continue

        for pca_config in pca_combinations:
            try:
                # Step 2: Run PCA (on the UMAP result)
                pca = PCA(n_components=PCA_FINAL_COMPONENTS, **pca_config)
                X_pca_3d = pca.fit_transform(X_umap)

                # Step 3: Prepare data for Plotly
                df = pd.DataFrame(X_pca_3d, columns=["PC 1", "PC 2", "PC 3"])
                df["Condition"] = labels
                df["Run"] = run_labels
                df["Condition_Run"] = (
                    df["Condition"].astype(str) + "_" + df["Run"].astype(str)
                )

                # Create Plot Title with parameters
                title_text = f"[{subject_id}] UMAP/PCA Grid Search (Iter {count}/{total_iterations})<br>"
                title_text += f"UMAP: N_N={umap_config['n_neighbors']}, N_C={umap_config['n_components']}, Met={umap_config['metric']}, MinD={umap_config['min_dist']}<br>"
                title_text += f"PCA: Whiten={pca_config['whiten']} (Final D={PCA_FINAL_COMPONENTS})"

                # Step 4: Interactive 3D visualization
                fig = px.scatter_3d(
                    df,
                    x="PC 1",
                    y="PC 2",
                    z="PC 3",
                    color="Condition",
                    symbol="Condition_Run",
                    hover_data=["Condition", "Run"],
                    title=title_text,
                    opacity=0.7,
                    height=700,
                )

                fig.update_traces(marker=dict(size=4))

                # Display the figure in the notebook
                display(fig)

            except Exception as e:
                print(f"❌ Error in Plotting/PCA (Iter {count}): {e}")

            count += 1

    print("\n--- Hyperparameter search complete. ---")

In [11]:
run_hyperparameter_search(
    TARGET_SUBJECT_ID, PSD_DATA_DIR, CONDITIONS, umap_param_grid, pca_param_grid
)

✅ Data loaded for sub-01. Total epochs: 114. Feature size: 12800
Total hyperparameter combinations to test: 2



--- Hyperparameter search complete. ---
