In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import numpy as np
import torch
from pathlib import Path
import mne
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from scipy.signal import detrend

# ---------------------------------------------------
# CONFIG
# ---------------------------------------------------
PROJ = Path("..").resolve().parent
MODEL_BASE   = PROJ / "models" / "04_hybrid" / "ModeloW" / "nb2_h4" / "nb2_h4"
BRAINBIT_DIR = PROJ / "data" / "brainbit"
BRAINBIT_FILE = BRAINBIT_DIR / "brainbit_MI_LR_preproc_final_for_model.npz"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Usando modelos en:", MODEL_BASE)
print("Usando BrainBit npz:", BRAINBIT_FILE)

# ---------------------------------------------------
# SWITCHES GLOBALES
# ---------------------------------------------------
# TTA shifts (los mismos del entrenamiento)
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]

# Notch opcional para BrainBit
BB_APPLY_NOTCH  = True        # pon True si quieres aplicar notch
BB_NOTCH_FREQS  = [60.0]      # frecuencias de notch, p.ej. [50.0] o [50.0, 100.0]

# Z-score opcional (normalizaciÃ³n global por canal)
BB_APPLY_ZSCORE = False        # True â†’ aplica z-score, False â†’ no

# TTA opcional (tanto en global como en FT)
BB_USE_TTA      = True       # True â†’ usa TTA, False â†’ inferencia directa

# ---------------------------------------------------
# 0) Data augmentation (igual que en entrenamiento)
# ---------------------------------------------------
def augment_batch(xb,
                  p_jitter=0.35,
                  p_noise=0.35,
                  p_chdrop=0.15,
                  max_jitter_frac=0.03,
                  noise_std=0.03,
                  max_chdrop=1):
    """
    xb: tensor (B, C, T)
    """
    B, C, T = xb.shape

    # Jitter temporal (shift aleatorio)
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T * max_jitter_frac))
        shifts = torch.randint(
            low=-max_shift,
            high=max_shift + 1,
            size=(B,),
            device=xb.device
        )
        for i in range(B):
            xb[i] = torch.roll(xb[i],
                               shifts=int(shifts[i].item()),
                               dims=-1)

    # Ruido gaussiano
    if np.random.rand() < p_noise:
        xb = xb + noise_std * torch.randn_like(xb)

    # Channel dropout
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0

    return xb

# ---------------------------------------------------
# 1) Cargar 6s (1â€“7 s) desde el NPZ preprocesado
# ---------------------------------------------------
def load_brainbit_6s(path_npz):
    """
    Extrae los 6 s centrales del NPZ preprocesado:
    - Entrada = 8 s (1280 muestras a 160 Hz)
    - Salida  = 6 s (960 muestras)
    Formato final: (N, 8, 960).
    Ventana temporal: 1â€“7 s  (1s antes + 4s MI + 1s despuÃ©s).
    """
    d = np.load(path_npz, allow_pickle=True)

    X = d["X"]            # (N, 1280, 8) o (N, 8, 1280)
    y = d["y"].astype(int)
    fs = int(d["fs"])     # 160 Hz

    print("\n====== CARGA BRAINBIT ======")
    print("X original:", X.shape)
    print("fs =", fs)

    # Asegurar formato (N, C, T)
    if X.shape[1] == 8:          # (N, 8, T)
        pass
    elif X.shape[2] == 8:        # (N, T, 8) â†’ convertir
        X = np.transpose(X, (0, 2, 1))
    else:
        raise ValueError("X debe tener 8 canales en alguna dimensiÃ³n.")

    N, C, T = X.shape
    assert T == 1280, "Se esperaban 1280 muestras (8s a 160Hz)."

    # Ventana 1.0s â†’ 7.0s => 6s => 960 muestras
    idx_start = int(1.0 * fs)     # 160
    idx_end   = int(7.0 * fs)     # 1120
    X6 = X[:, :, idx_start:idx_end]    # (N, 8, 960)
    print("Ventana 6s:", X6.shape)

    # Detrend por canal
    for i in range(N):
        X6[i] = detrend(X6[i], axis=1, type="linear")

    # -----------------------------------------------
    # Notch opcional (ej. 60 Hz)
    # -----------------------------------------------
    if BB_APPLY_NOTCH and BB_NOTCH_FREQS:
        print(f"Aplicando notch en BrainBit: freqs={BB_NOTCH_FREQS} Hz ...")
        # mne.filter.notch_filter exige float64
        X_flat = X6.reshape(N * C, -1).astype(np.float64)

        X_flat = mne.filter.notch_filter(
            X_flat,
            Fs=fs,
            freqs=BB_NOTCH_FREQS,
            verbose="ERROR"
        )

        # volver a float32 y re-ensamblar
        X6 = X_flat.astype(np.float32).reshape(N, C, -1)

    # -----------------------------------------------
    # Z-score opcional global por canal
    # -----------------------------------------------
    if BB_APPLY_ZSCORE:
        print("Aplicando z-score global por canal...")
        mean = X6.mean(axis=(0, 2), keepdims=True)
        std  = X6.std(axis=(0, 2), keepdims=True) + 1e-6
        X6 = (X6 - mean) / std

    print("Etiquetas Ãºnicas:", np.unique(y))
    return X6.astype(np.float32), y, fs


# Carga final
X_bb, y_bb, fs_bb = load_brainbit_6s(BRAINBIT_FILE)
print("Shape final BrainBit para el modelo:", X_bb.shape, y_bb.shape)

# ---------------------------------------------------
# 2) TTA por time-shift (igual filosofÃ­a que en entrenamiento)
# ---------------------------------------------------
def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    """
    Aplica TTA por desplazamientos temporales sobre TODA la matriz X:
    X: (N, C, T)
    Devuelve logits promedio: (N, n_classes)
    """
    model.eval()
    T = X.shape[-1]
    out = []

    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]      # (C, T)
            acc = []

            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    # shift hacia la derecha: recorta al inicio y rellena al final
                    x = np.pad(x0[:, shift:], ((0, 0), (0, shift)), mode="edge")[:, :T]
                else:
                    # shift < 0: hacia la izquierda, recorta al final y rellena al inicio
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0, 0), (shift, 0)), mode="edge")[:, :T]

                xb = torch.tensor(x[None, ...],
                                  dtype=torch.float32,
                                  device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)

            acc = np.mean(np.stack(acc, axis=0), axis=0)
            out.append(acc)

    return np.stack(out, axis=0)

# ---------------------------------------------------
# 3) Cargar modelo global de un fold
# ---------------------------------------------------
def load_global_model(fold: int):
    model_path = MODEL_BASE / f"fold{fold}" / f"model_global_fold{fold}_nb2_h4.pth"
    if not model_path.exists():
        raise FileNotFoundError(f"No se encontrÃ³ el modelo: {model_path}")

    # AsegÃºrate de tener EEGCNNTransformer definido / importado
    model = EEGCNNTransformer(
        n_ch=8, n_cls=2,
        d_model=128, n_heads=4, n_layers=1,
        p_drop=0.2, p_drop_encoder=0.1,
        n_dw_blocks=2, capture_attn=False
    ).to(DEVICE)

    state = torch.load(model_path, map_location=DEVICE)
    model.load_state_dict(state)
    model.eval()
    print(f"âœ” Modelo fold{fold} cargado desde:\n{model_path}")
    return model

# ---------------------------------------------------
# 4) EvaluaciÃ³n GLOBAL (TTA opcional, sin fine-tuning)
# ---------------------------------------------------
def eval_global(model, X, y, fs):
    """
    Eval GLOBAL:
      - Si BB_USE_TTA=True â†’ TTA por time-shift
      - Si BB_USE_TTA=False â†’ inferencia directa
    """
    if BB_USE_TTA:
        logits = time_shift_tta_logits(model, X, fs, TTA_SHIFTS_S, DEVICE)
        print(">> GLOBAL (con TTA)")
    else:
        model.eval()
        with torch.no_grad():
            X_t = torch.tensor(X, dtype=torch.float32, device=DEVICE)
            logits = model(X_t).cpu().numpy()
        print(">> GLOBAL (sin TTA)")

    y_pred = logits.argmax(axis=1)

    acc = accuracy_score(y, y_pred)
    f1m = f1_score(y, y_pred, average="macro")

    print(f"ACC = {acc:.3f} | F1_macro = {f1m:.3f}")
    print(classification_report(y, y_pred,
                                target_names=["left", "right"],
                                digits=3))
    print("Matriz de confusiÃ³n:\n", confusion_matrix(y, y_pred, labels=[0, 1]))
    return acc, f1m, y_pred

# ---------------------------------------------------
# 5) Fine-tuning ligero (solo head) + augmentation + TTA opcional
# ---------------------------------------------------
def eval_with_ft(model, X, y, fs, epochs=10, batch=8, lr=1e-3):
    """
    Fine-tuning sobre TODO BrainBit:
    - Data augmentation en cada batch (augment_batch)
    - EvaluaciÃ³n final:
        * Si BB_USE_TTA=True â†’ con TTA
        * Si BB_USE_TTA=False â†’ directa
    """
    assert y.min() >= 0 and y.max() <= 1, "Etiquetas fuera de rango [0,1]"

    X_t = torch.tensor(X, dtype=torch.float32)
    y_t = torch.tensor(y, dtype=torch.long)
    ds = torch.utils.data.TensorDataset(X_t, y_t)
    ld = torch.utils.data.DataLoader(ds, batch_size=batch, shuffle=True)

    # Congelar backbone
    for name, p in model.named_parameters():
        if not name.startswith("head"):
            p.requires_grad = False

    opt = torch.optim.Adam(model.head.parameters(), lr=lr, weight_decay=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()

    model.train()
    for ep in range(1, epochs + 1):
        tot = 0.0
        for xb, yb in ld:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)

            # ðŸ”¥ AUGMENTATION:
            xb = augment_batch(xb)

            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
            tot += float(loss.item()) * len(yb)

        print(f"  FT epoch {ep:02d} | loss={tot/len(ds):.4f}")

    # Evaluar despuÃ©s de FT
    model.eval()

    if BB_USE_TTA:
        logits_ft = time_shift_tta_logits(model, X, fs, TTA_SHIFTS_S, DEVICE)
        print(">> FINE-TUNED (con augmentation + TTA)")
    else:
        with torch.no_grad():
            X_t = torch.tensor(X, dtype=torch.float32, device=DEVICE)
            logits_ft = model(X_t).cpu().numpy()
        print(">> FINE-TUNED (con augmentation, sin TTA)")

    y_pred = logits_ft.argmax(axis=1)

    acc = accuracy_score(y, y_pred)
    f1m = f1_score(y, y_pred, average="macro")

    print(f"ACC = {acc:.3f} | F1_macro = {f1m:.3f}")
    print(classification_report(y, y_pred,
                                target_names=["left", "right"],
                                digits=3))
    print("Matriz de confusiÃ³n:\n", confusion_matrix(y, y_pred, labels=[0, 1]))
    return acc, f1m, y_pred

# ---------------------------------------------------
# 6) Probar los 5 folds sobre BrainBit
# ---------------------------------------------------
results = []

for fold in range(1, 6):
    print("\n====================================")
    print(f" Fold {fold}")
    print("====================================")

    model = load_global_model(fold)

    # GLOBAL (con/sin TTA segÃºn flag)
    acc_g, f1_g, pred_g = eval_global(model, X_bb, y_bb, fs_bb)

    # FINE-TUNING (con augmentation + TTA opcional)
    acc_ft, f1_ft, pred_ft = eval_with_ft(model, X_bb, y_bb,
                                          fs_bb,
                                          epochs=10,
                                          batch=8,
                                          lr=1e-3)

    results.append(dict(
        fold=fold,
        acc_global=acc_g, f1_global=f1_g,
        acc_ft=acc_ft,   f1_ft=f1_ft
    ))

print("\n======== RESUMEN ========")
for r in results:
    print(f"Fold {r['fold']}: Global={r['acc_global']:.3f} | FT={r['acc_ft']:.3f}")
