In [None]:


# Estos son los imports que se necesitan para poder activar el codido o poder usarlo

import os
import sys
import re
import random

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

import tensorflow as tf
sys.modules['keras'] = tf.keras  # Para que vit_keras use tf.keras internamente

from tensorflow.keras import layers
from tensorflow.keras.models import Model

# Ruta a tu proyecto (PainClassifier)
sys.path.append(
    "C:\\Users\\" + os.getlogin() +
    "\\OneDrive - Instituto Tecnologico y de Estudios Superiores de Monterrey\\PainClassifier"
)
from my_data_generator import *  # O importa solo lo que uses (recomendado)

from vit_keras import vit, utils

from sklearn.model_selection import (
    train_test_split,
    KFold,
    StratifiedKFold,
    StratifiedShuffleSplit
)

from sklearn.metrics import (
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    classification_report,
    balanced_accuracy_score,
    precision_recall_fscore_support,
    accuracy_score
)

import wandb
from wandb.keras import WandbCallback

wandb.login()


SEED = 42
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)


# Aqui definimos algunos parametros, por ejemplo los de project name y run group son para WANDB
PROJECT_NAME = "Male_W7_vs_W1_ViT"
RUN_GROUP    = "ViT-B16-50e-pain-only-M-W7_vs_W1"
# Aqui definimos el tamaño de la imagen para pasar al vision transformer
# Tambien definimos el batch size y los numeros de epochs 
IMG_SIZE = 224
BATCH    = 32
EPOCHS   = 50

# Estos son los sujetos que seran usados para el modelo
male_all = [57, 60, 73, 74, 93, 94, 95, 96, 98, 99, 100]

# Estos son mas parametros en los que indicamos que necesiamtos functional MRI y que sea el de dist y no rest
MRI_type = "func"
functional_type = "dist"
# En este caso especificamos que necesitamos que sean para week 1 y week 7 ya que 1 es para baselina y 2 es week 1 y 3 es week 7
needed_sessions = [2, 3]   

# Estas con las especificaciones que se describen en la tesis para poder agarrar archivos
TR = 1.51069; ON_SEC, OFF_SEC, CYCLES = 45, 15, 15


# En esta funcion basicamente nos ayuda a sacar el substring del archivo como sub-057
def subject_id_from_path(p: str):
    m = re.search(r"sub-\d{3}", p.replace("\\","/"))
    if not m: raise ValueError(f"Could not find subject id in path: {p}")
    return m.group(0)


# En esta funcion nos saca por ejemplo: ses-02 y nos regresa el numero de al final que es el 2 y asi sabemos que es de week 1
def get_session_from_path(p: str) -> int:
    m = re.search(r"ses-(\d{2})", p)
    if not m: raise ValueError(f"Could not find ses-XX in path: {p}")
    return int(m.group(1))


# Este codigo nos dice o da los indices del fmri que pertenecen al dolor, osea nos da los 135 pertenecientes. 
# Nos regresa un arrray con los indices


def compute_on_indices(T, TR=TR, on_s=ON_SEC, off_s=OFF_SEC, cycles=CYCLES):
    task=[]; [task.extend([on_s, off_s]) for _ in range(cycles)]
    schedule=[]; t=0
    for d in task: t+=d; schedule.append(t)
    starts, ends = [], []
    for i, s in enumerate(schedule):
        if i % 2 == 0: starts.append(int(np.ceil(s/TR) + 1))
        else:          ends.append(int(s/TR))
    on=[]
    for a,b in zip(starts, ends):
        a = max(0, min(T-1, a)); b = max(0, min(T-1, b))
        if b >= a: on.extend(range(a, b+1))
    return np.array(on, dtype=int)

# Esta funcion hace para que nos regrese un array de (T,42,65,29)
# Carga el 4d fMRI y brain mask 

def load_masked_cropped(bold_path, mask_path):
    img  = nib.load(bold_path)         
    mask = nib.load(mask_path)         
    bold = np.asarray(img.dataobj,  dtype=np.float32)
    msk  = np.asarray(mask.dataobj, dtype=np.float32)
    bold = bold * msk[..., None]

    bold = bold[3:45, 4:69, 7:36, :]
    return np.transpose(bold, (3,0,1,2))  



# Esta funcion crea las rebanadas de los volumenes por rata 
# Se asegura de tener los training data y tambien hace el array de los labels

def build_slices_for_pairs(pairs, sub_ses_to_files, window=30, return_ids=False):
   
    X2D_list, Y_list, IDs_list = [], [], []
    for sub, ses, y in pairs:
        key = (sub, ses)
        if key not in sub_ses_to_files: 
            continue
        bold_path, mask_path = sub_ses_to_files[key]
        Txyz = load_masked_cropped(bold_path, mask_path)   
        on_idx = compute_on_indices(Txyz.shape[0])        
        pain   = Txyz[on_idx]                            

        for s in range(0, pain.shape[0], window):
            block = pain[s:s+window]                     
            if block.shape[0] == 0:
                continue
            slices_all = np.moveaxis(block, -1, 1).reshape(-1, 42, 65).astype(np.float32)
            labels_vec = np.full(slices_all.shape[0], y, dtype=np.int32)
            X2D_list.append(slices_all); Y_list.append(labels_vec)
            if return_ids:
                IDs_list.append(np.array([f"{sub}_ses-{ses:02d}"]*len(labels_vec), dtype=object))

    X2D = np.concatenate(X2D_list, axis=0) if X2D_list else np.empty((0,42,65), np.float32)
    Y   = np.concatenate(Y_list,  axis=0) if Y_list else np.empty((0,), np.int32)
    if return_ids:
        IDs = np.concatenate(IDs_list, axis=0) if IDs_list else np.empty((0,), dtype=object)
        return X2D, Y, IDs
    return X2D, Y

# Esta funcion hace las ultimas transformaciones para ya pasa al vision transformers
def prep(x,y):
    x = tf.expand_dims(x, -1)
    x = tf.image.resize(x, (IMG_SIZE, IMG_SIZE))
    x = tf.image.grayscale_to_rgb(x)
    x = tf.cast(x, tf.float32)
    x = (x - tf.reduce_mean(x)) / (tf.math.reduce_std(x) + 1e-6)
    return x, y


# En este caso cargamos el modelo del vision transformer 
def make_model(vit_variant="b16"):
    if vit_variant == "b16":
        backbone = vit.vit_b16(image_size=IMG_SIZE, pretrained=True, include_top=False, pretrained_top=False)
    else:
        backbone = vit.vit_b32(image_size=IMG_SIZE, pretrained=True, include_top=False, pretrained_top=False)
    inp = tf.keras.Input((IMG_SIZE, IMG_SIZE, 3))
    feat = backbone(inp)
    out  = tf.keras.layers.Dense(2, activation="softmax")(feat)
    model = tf.keras.Model(inp, out)
    model.compile(optimizer=tf.keras.optimizers.Adam(3e-5),
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])
    return model


from my_data_generator import FILES_and_LABELS

files = FILES_and_LABELS(male_all, needed_sessions, MRI_type, functional_type).get_mask_and_bold()
print(f"Loaded {len(files)} (subject,session) file pairs for DIST ses-02/03.")

# Aqui mapeamos de los valores que tenemos a los archivos
sub_ses_to_files = {}
for (bold_path, mask_path) in files:
    sub = subject_id_from_path(bold_path)
    ses = get_session_from_path(bold_path)
    sub_ses_to_files[(sub, ses)] = (bold_path, mask_path)

# En este caso solo nos quedamos cona los subjects que tenemos W1 and W7
subjects_all = sorted({sub for (sub, ses) in sub_ses_to_files.keys()})
usable_subjects = np.array([s for s in subjects_all if (s,2) in sub_ses_to_files and (s,3) in sub_ses_to_files])
assert len(usable_subjects) > 0, "No male subjects with both W1 and W7 found."

print(f"Usable male subjects with W1 & W7: {list(usable_subjects)} (n={len(usable_subjects)})")


pairs = []
for sub in usable_subjects:
    pairs.append((sub, 2, 0))   # W1 → 0
    pairs.append((sub, 3, 1))   # W7 → 1

def subset_pairs_by_subjects(pairs, subj_set):
    return [p for p in pairs if p[0] in subj_set]


#Aqui usamos el 4 para usar cross validation de 4
K = 4
kf = KFold(n_splits=K, shuffle=True, random_state=SEED)

fold_idx = 0

# En este caso empezamos a entrenar el modelo
for train_idx_all, test_idx in kf.split(usable_subjects):
    fold_idx += 1

    TEST_SUBS   = usable_subjects[test_idx]
    TRAIN_SUBS  = usable_subjects[train_idx_all]

    
    inner = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=SEED+fold_idx)
    strat_y = np.array([i % 2 for i in range(len(TRAIN_SUBS))])
    tr_inner_idx, val_inner_idx = next(inner.split(TRAIN_SUBS, strat_y))
    TR_SUBS = TRAIN_SUBS[tr_inner_idx]
    VAL_SUBS= TRAIN_SUBS[val_inner_idx]


    TR_PAIRS  = subset_pairs_by_subjects(pairs, set(TR_SUBS))
    VAL_PAIRS = subset_pairs_by_subjects(pairs, set(VAL_SUBS))
    TE_PAIRS  = subset_pairs_by_subjects(pairs, set(TEST_SUBS))

    print(f"\n=== Fold {fold_idx}/{K} ===")
    print("TEST subs :", list(TEST_SUBS))
    print("TRAIN subs:", list(TR_SUBS))
    print("VAL subs  :", list(VAL_SUBS))


    X_train, y_train = build_slices_for_pairs(TR_PAIRS, sub_ses_to_files, return_ids=False)
    X_val,   y_val   = build_slices_for_pairs(VAL_PAIRS, sub_ses_to_files, return_ids=False)
    X_test,  y_test, ids_test = build_slices_for_pairs(TE_PAIRS, sub_ses_to_files, return_ids=True)

    print(f"Fold {fold_idx}: train slices={len(X_train)}, val slices={len(X_val)}, test slices={len(X_test)}")


    ds_train = (tf.data.Dataset.from_tensor_slices((X_train, y_train))
                .map(prep, num_parallel_calls=tf.data.AUTOTUNE)
                .shuffle(min(8192, len(X_train))).batch(BATCH).prefetch(tf.data.AUTOTUNE))
    ds_val   = (tf.data.Dataset.from_tensor_slices((X_val, y_val))
                .map(prep, num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH).prefetch(tf.data.AUTOTUNE))
    ds_test  = (tf.data.Dataset.from_tensor_slices((X_test, y_test))
                .map(prep, num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH).prefetch(tf.data.AUTOTUNE))

    # Aqui pasamos los datos del modelo al wandB
    run = wandb.init(project=PROJECT_NAME,
                     group=RUN_GROUP,
                     name=f"M-W7_vs_W1-ViT-B16-fold{fold_idx}",
                     config={
                       "epochs": EPOCHS,
                       "batch_size": BATCH,
                       "img_size": IMG_SIZE,
                       "backbone": "vit_b16",
                       "split": "Male W7 (1) vs Male W1 (0), pain-only (DIST)",
                       "subjects_train": list(TR_SUBS),
                       "subjects_val":   list(VAL_SUBS),
                       "subjects_test":  list(TEST_SUBS),
                       "seed": SEED
                     })

    # Aqui llamamos a la funcion de hasta arriba de make model 
    model = make_model(vit_variant="b16")

    # Esto es usando keras-tensorflow que basicamente el modelo corre/entrena
    history = model.fit(
        ds_train,
        validation_data=ds_val,
        epochs=EPOCHS,
        verbose=1,
        callbacks=[WandbCallback(save_model=False)]
    )

  
    # Aqui ya evaluamos 
    v_loss, v_acc = model.evaluate(ds_val, verbose=0)
    print(f"[Fold {fold_idx}] VAL  → acc={v_acc:.4f}  loss={v_loss:.6f}")
    wandb.log({"val/final_acc": v_acc, "val/final_loss": v_loss})

 
    y_prob = model.predict(ds_test, verbose=0)        
    y_pred = np.argmax(y_prob, axis=1)               
    y_score = y_prob[:, 1]                           

    acc_slice  = accuracy_score(y_test, y_pred)
    bacc_slice = balanced_accuracy_score(y_test, y_pred)
    prec_m, rec_m, f1_m, _ = precision_recall_fscore_support(
        y_test, y_pred, average='macro', zero_division=0
    )
    try:
        auc_slice = roc_auc_score(y_test, y_score)
    except Exception:
        auc_slice = float("nan")

    print(f"[Fold {fold_idx}] TEST (slice-level) → acc={acc_slice:.4f}  bal_acc={bacc_slice:.4f}  "
          f"F1_macro={f1_m:.4f}  AUC={auc_slice:.4f}")

    wandb.log({
        "test_slice/acc": acc_slice,
        "test_slice/bal_acc": bacc_slice,
        "test_slice/f1_macro": f1_m,
        "test_slice/precision_macro": prec_m,
        "test_slice/recall_macro": rec_m,
        "test_slice/roc_auc": auc_slice,
    })

    # Esto de aqui es para hacer el confusion matrix 
    cm_slice = confusion_matrix(y_test, y_pred, labels=[0,1])  
    print(f"[Fold {fold_idx}] Confusion matrix (TEST, slice-level):\n{cm_slice}")

    fig = plt.figure(figsize=(4,4))
    plt.imshow(cm_slice, interpolation='nearest'); plt.title('Confusion matrix (TEST, slice-level)')
    plt.colorbar(); tick_marks = np.arange(2)
    plt.xticks(tick_marks, ['W1 (0)','W7 (1)'], rotation=45)
    plt.yticks(tick_marks, ['W1 (0)','W7 (1)'])
    for i in range(cm_slice.shape[0]):
        for j in range(cm_slice.shape[1]):
            val = cm_slice[i, j]
            plt.text(j, i, val, ha="center", va="center",
                     color="white" if val > cm_slice.max()/2 else "black")
    plt.ylabel('True label'); plt.xlabel('Predicted label'); plt.tight_layout()
    wandb.log({"test_slice/confusion_matrix": wandb.Image(fig)})
    plt.close(fig)

    run.finish()


In [None]:
# ================================================
# W7 Male (1) vs W1 Female (0) — pain-only (DIST)
# 4-fold CV, 50 epochs, ViT-B/16, Stratified by class
# ================================================


# Estos son los imports que se necesitan para poder activar el codido o poder usarlo

import os
import sys
import re
import random

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

import tensorflow as tf
sys.modules['keras'] = tf.keras  # Para que vit_keras use tf.keras internamente

from tensorflow.keras import layers
from tensorflow.keras.models import Model

# Ruta a tu proyecto (PainClassifier)
sys.path.append(
    "C:\\Users\\" + os.getlogin() +
    "\\OneDrive - Instituto Tecnologico y de Estudios Superiores de Monterrey\\PainClassifier"
)
from my_data_generator import *  # O importa solo lo que uses (recomendado)

from vit_keras import vit, utils

from sklearn.model_selection import (
    train_test_split,
    KFold,
    StratifiedKFold,
    StratifiedShuffleSplit
)

from sklearn.metrics import (
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    classification_report,
    balanced_accuracy_score,
    precision_recall_fscore_support,
    accuracy_score
)

import wandb
from wandb.keras import WandbCallback

wandb.login()



SEED = 42
random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)


# Aqui definimos algunos parametros, por ejemplo los de project name y run group son para WANDB
PROJECT_NAME = "W7Male_vs_W1Female_ViT"
RUN_GROUP    = "ViT-B16-50e-pain-only-W7M_vs_W1F"
# Aqui definimos el tamaño de la imagen para pasar al vision transformer
# Tambien definimos el batch size y los numeros de epochs 
IMG_SIZE = 224
BATCH    = 32
EPOCHS   = 50

# Estos son los sujetos que seran usados para el modelo
male   = [57, 59, 60, 73, 74, 93, 94, 95, 96, 98, 99, 100]  
female = [49, 50, 51, 52, 65, 66, 77, 78, 79, 80, 81, 82]   

# En este caso decimos que el uno es para los males y los 0 para los females 
# Es un diccionario que nos mapea dependiondo del subject a donde deberia de ir 
label_map_subject = {**{f"sub-{m:03d}": 1 for m in male},  
                     **{f"sub-{f_:03d}": 0 for f_ in female}}  


MRI_type = "func"
functional_type = "dist"
needed_sessions = [2, 3]   


TR = 1.51069; ON_SEC, OFF_SEC, CYCLES = 45, 15, 15

def subject_id_from_path(p: str):
    m = re.search(r"sub-\d{3}", p.replace("\\","/"))
    if not m: raise ValueError(f"Could not find subject id in path: {p}")
    return m.group(0)

def get_session_from_path(p: str) -> int:
    m = re.search(r"ses-(\d{2})", p)
    if not m: raise ValueError(f"Could not find ses-XX in path: {p}")
    return int(m.group(1))

def compute_on_indices(T, TR=TR, on_s=ON_SEC, off_s=OFF_SEC, cycles=CYCLES):
    task=[]; [task.extend([on_s, off_s]) for _ in range(cycles)]
    schedule=[]; t=0
    for d in task: t+=d; schedule.append(t)
    starts, ends = [], []
    for i, s in enumerate(schedule):
        if i % 2 == 0: starts.append(int(np.ceil(s/TR) + 1))
        else:          ends.append(int(s/TR))
    on=[]
    for a,b in zip(starts, ends):
        a = max(0, min(T-1, a)); b = max(0, min(T-1, b))
        if b >= a: on.extend(range(a, b+1))
    return np.array(on, dtype=int)

def load_masked_cropped(bold_path, mask_path):
    img  = nib.load(bold_path)         
    mask = nib.load(mask_path)         
    bold = np.asarray(img.dataobj,  dtype=np.float32)
    msk  = np.asarray(mask.dataobj, dtype=np.float32)
    bold = bold * msk[..., None]

    bold = bold[3:45, 4:69, 7:36, :]
    return np.transpose(bold, (3,0,1,2)) 

def build_slices_for_pairs(pairs, sub_ses_to_files, window=30, return_ids=False):
    """
    pairs = [(sub_id, ses_int, class_label), ...]
    For each (sub, ses), take pain-only windows (~135 ON frames) split into 30-vol blocks.
    """
    X2D_list, Y_list, IDs_list = [], [], []
    for sub, ses, y in pairs:
        key = (sub, ses)
        if key not in sub_ses_to_files:  
            continue
        bold_path, mask_path = sub_ses_to_files[key]
        Txyz = load_masked_cropped(bold_path, mask_path)  
        on_idx = compute_on_indices(Txyz.shape[0])      
        pain   = Txyz[on_idx]                              

      
        for s in range(0, pain.shape[0], window):
            block = pain[s:s+window]                     
            if block.shape[0] == 0: 
                continue
            slices_all = np.moveaxis(block, -1, 1).reshape(-1, 42, 65).astype(np.float32)
            labels_vec = np.full(slices_all.shape[0], y, dtype=np.int32)
            X2D_list.append(slices_all); Y_list.append(labels_vec)
            if return_ids:
                IDs_list.append(np.array([f"{sub}_ses-{ses:02d}"]*len(labels_vec), dtype=object))

    X2D = np.concatenate(X2D_list, axis=0) if X2D_list else np.empty((0,42,65), np.float32)
    Y   = np.concatenate(Y_list,  axis=0) if Y_list else np.empty((0,), np.int32)
    if return_ids:
        IDs = np.concatenate(IDs_list, axis=0) if IDs_list else np.empty((0,), dtype=object)
        return X2D, Y, IDs
    return X2D, Y

def prep(x,y):
    x = tf.expand_dims(x, -1)
    x = tf.image.resize(x, (IMG_SIZE, IMG_SIZE))
    x = tf.image.grayscale_to_rgb(x)
    x = tf.cast(x, tf.float32)
    x = (x - tf.reduce_mean(x)) / (tf.math.reduce_std(x) + 1e-6)
    return x, y

def make_model(vit_variant="b16"):
    if vit_variant == "b16":
        backbone = vit.vit_b16(image_size=IMG_SIZE, pretrained=True, include_top=False, pretrained_top=False)
    else:
        backbone = vit.vit_b32(image_size=IMG_SIZE, pretrained=True, include_top=False, pretrained_top=False)
    inp = tf.keras.Input((IMG_SIZE, IMG_SIZE, 3))
    feat = backbone(inp)
    out  = tf.keras.layers.Dense(2, activation="softmax")(feat)
    model = tf.keras.Model(inp, out)
    model.compile(optimizer=tf.keras.optimizers.Adam(3e-5),
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])
    return model


from my_data_generator import FILES_and_LABELS

# En este caso ahora cargamos los machos y hembras 
files = FILES_and_LABELS(male + female, needed_sessions, MRI_type, functional_type).get_mask_and_bold()
print(f"Loaded {len(files)} (subject,session) file pairs for DIST ses-02/03.")


sub_ses_to_files = {}
for (bold_path, mask_path) in files:
    sub = subject_id_from_path(bold_path)
    ses = get_session_from_path(bold_path)
    sub_ses_to_files[(sub, ses)] = (bold_path, mask_path)


# Aqui construimos la lista de de pares en el cual si es de machos semana 7 o numero 3 entonces es 1
# Si es hembra semana 1 pues es 0

male_pairs   = [(f"sub-{m:03d}", 3, 1) for m in male   if (f"sub-{m:03d}", 3) in sub_ses_to_files]
female_pairs = [(f"sub-{f_:03d}", 2, 0) for f_ in female if (f"sub-{f_:03d}", 2) in sub_ses_to_files]

pairs = male_pairs + female_pairs
assert len(male_pairs) > 0 and len(female_pairs) > 0, "Missing W7 males or W1 females."



# Aqui extraemos los sujetos unicos con sus arreglos de IDs y su etiqueta 
# Esto para ayudar con el stratified k fold

units   = np.array([p[0] for p in pairs])       
y_units = np.array([p[2] for p in pairs])       


uniq, idx = np.unique(units, return_index=True)
units = units[idx]; y_units = y_units[idx]

print(f"Usable subjects: {list(units)}")
print(f"Counts → class1(♂W7): {int(y_units.sum())}, class0(♀W1): {len(y_units)-int(y_units.sum())}")


def subset_pairs_by_subjects(pairs, subj_set):
    return [p for p in pairs if p[0] in subj_set]


K = 4
skf = StratifiedKFold(n_splits=K, shuffle=True, random_state=SEED)

fold_idx = 0
for train_idx_all, test_idx in skf.split(units, y_units):
    fold_idx += 1

    TEST_SUBS   = units[test_idx]
    TEST_LABELS = y_units[test_idx]   
    TRAIN_SUBS  = units[train_idx_all]

    
    train_labels = y_units[train_idx_all]
    inner = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=SEED+fold_idx)
    tr_inner_idx, val_inner_idx = next(inner.split(TRAIN_SUBS, train_labels))
    TR_SUBS = TRAIN_SUBS[tr_inner_idx]
    VAL_SUBS= TRAIN_SUBS[val_inner_idx]

    TR_PAIRS  = subset_pairs_by_subjects(pairs, set(TR_SUBS))
    VAL_PAIRS = subset_pairs_by_subjects(pairs, set(VAL_SUBS))
    TE_PAIRS  = subset_pairs_by_subjects(pairs, set(TEST_SUBS))

    print(f"\n=== Fold {fold_idx}/{K} ===")
    print("TEST subs :", list(TEST_SUBS), " (♂W7:", int(TEST_LABELS.sum()), " / ♀W1:", len(TEST_LABELS)-int(TEST_LABELS.sum()), ")")
    print("TRAIN subs:", list(TR_SUBS))
    print("VAL subs  :", list(VAL_SUBS))


    X_train, y_train = build_slices_for_pairs(TR_PAIRS, sub_ses_to_files, return_ids=False)
    X_val,   y_val   = build_slices_for_pairs(VAL_PAIRS, sub_ses_to_files, return_ids=False)
    X_test,  y_test, ids_test = build_slices_for_pairs(TE_PAIRS, sub_ses_to_files, return_ids=True)

    print(f"Fold {fold_idx}: train slices={len(X_train)}, val slices={len(X_val)}, test slices={len(X_test)}")


    def prep(x,y):
        x = tf.expand_dims(x, -1)
        x = tf.image.resize(x, (IMG_SIZE, IMG_SIZE))
        x = tf.image.grayscale_to_rgb(x)
        x = tf.cast(x, tf.float32)
        x = (x - tf.reduce_mean(x)) / (tf.math.reduce_std(x) + 1e-6)
        return x, y

    ds_train = (tf.data.Dataset.from_tensor_slices((X_train, y_train))
                .map(prep, num_parallel_calls=tf.data.AUTOTUNE)
                .shuffle(min(8192, len(X_train))).batch(BATCH).prefetch(tf.data.AUTOTUNE))
    ds_val   = (tf.data.Dataset.from_tensor_slices((X_val, y_val))
                .map(prep, num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH).prefetch(tf.data.AUTOTUNE))
    ds_test  = (tf.data.Dataset.from_tensor_slices((X_test, y_test))
                .map(prep, num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH).prefetch(tf.data.AUTOTUNE))

   
    run = wandb.init(project=PROJECT_NAME,
                     group=RUN_GROUP,
                     name=f"W7M_vs_W1F-ViT-B16-fold{fold_idx}",
                     config={
                       "epochs": EPOCHS,
                       "batch_size": BATCH,
                       "img_size": IMG_SIZE,
                       "backbone": "vit_b16",
                       "split": "W7 Male (1) vs W1 Female (0), pain-only (DIST)",
                       "subjects_train": list(TR_SUBS),
                       "subjects_val":   list(VAL_SUBS),
                       "subjects_test":  list(TEST_SUBS),
                       "seed": SEED
                     })


    model = make_model(vit_variant="b16")


    history = model.fit(
        ds_train,
        validation_data=ds_val,
        epochs=EPOCHS,
        verbose=1,
        callbacks=[WandbCallback(save_model=False)]
    )

   
    v_loss, v_acc = model.evaluate(ds_val, verbose=0)
    print(f"[Fold {fold_idx}] VAL  → acc={v_acc:.4f}  loss={v_loss:.6f}")
    wandb.log({"val/final_acc": v_acc, "val/final_loss": v_loss})

    y_prob = model.predict(ds_test, verbose=0)     
    y_pred = np.argmax(y_prob, axis=1)               
    y_score = y_prob[:, 1]                           

    acc_slice  = accuracy_score(y_test, y_pred)
    bacc_slice = balanced_accuracy_score(y_test, y_pred)
    prec_m, rec_m, f1_m, _ = precision_recall_fscore_support(
        y_test, y_pred, average='macro', zero_division=0
    )
    try:
        auc_slice = roc_auc_score(y_test, y_score)
    except Exception:
        auc_slice = float("nan")

    print(f"[Fold {fold_idx}] TEST (slice-level) → acc={acc_slice:.4f}  bal_acc={bacc_slice:.4f}  "
          f"F1_macro={f1_m:.4f}  AUC={auc_slice:.4f}")

    wandb.log({
        "test_slice/acc": acc_slice,
        "test_slice/bal_acc": bacc_slice,
        "test_slice/f1_macro": f1_m,
        "test_slice/precision_macro": prec_m,
        "test_slice/recall_macro": rec_m,
        "test_slice/roc_auc": auc_slice,
    })

    
    cm_slice = confusion_matrix(y_test, y_pred, labels=[0,1])  
    print(f"[Fold {fold_idx}] Confusion matrix (TEST, slice-level):\n{cm_slice}")

    fig = plt.figure(figsize=(4,4))
    plt.imshow(cm_slice, interpolation='nearest'); plt.title('Confusion matrix (TEST, slice-level)')
    plt.colorbar(); tick_marks = np.arange(2)
    plt.xticks(tick_marks, ['Female W1 (0)','Male W7 (1)'], rotation=45)
    plt.yticks(tick_marks, ['Female W1 (0)','Male W7 (1)'])
    for i in range(cm_slice.shape[0]):
        for j in range(cm_slice.shape[1]):
            val = cm_slice[i, j]
            plt.text(j, i, val, ha="center", va="center",
                     color="white" if val > cm_slice.max()/2 else "black")
    plt.ylabel('True label'); plt.xlabel('Predicted label'); plt.tight_layout()
    wandb.log({"test_slice/confusion_matrix": wandb.Image(fig)})
    plt.close(fig)

    run.finish()
