# Female Naive Vs Female CPH
1. Female CPH (baseline)
2. Female CPH (Week 7)

In [None]:

# Environment and GPU configuration


import os
import sys

# Suppress verbose TensorFlow/C++ logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"    # 0 = all logs, 3 = only errors
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"  # avoid grabbing all GPU memory
os.environ["WANDB_SILENT"] = "true"   # make Weights & Biases less verbose
# os.environ["WANDB_DISABLED"] = "true"  #  to fully disable W&B

# Limit thread usage (optiona)
os.environ["OMP_NUM_THREADS"] = "2"
os.environ["TF_NUM_INTRAOP_THREADS"] = "2"
os.environ["TF_NUM_INTEROP_THREADS"] = "1"

# Add project root to Python path 
sys.path.append(r"C:/Users/PC-EIAD209/Desktop/AnaKei/NIPD-AI")  #adjust to the actual project folder)

import tensorflow as tf

# Reduce TensorFlow Python-level logging
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(0)
tf.debugging.set_log_device_placement(False)

# Ensure GPU memory growth is enabled
gpus = tf.config.list_physical_devices('GPU')
for g in gpus:
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception:
        
        pass


In [None]:
#import  modules

from tfdata_generator import*
from gradcam_utils import *
from  callbacks import *


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import scipy.io as sio
import os
from tensorflow.keras.utils import plot_model
from sklearn.model_selection import train_test_split
import h5py
import warnings
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import precision_recall_fscore_support as score, precision_score, recall_score, f1_score
from sklearn.metrics import confusion_matrix
from datetime import datetime
import scipy as sp
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import auc, roc_curve
from itertools import cycle
from sklearn.metrics import RocCurveDisplay
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Conv3D, MaxPooling3D, Flatten, Dropout, GlobalAveragePooling3D, concatenate, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import StratifiedKFold
from sklearn.utils import resample
import seaborn as sns
from tensorflow.keras.applications.vgg16 import VGG16
from keras.regularizers import l2
import cv2
from keras import initializers
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
import tensorflow as tf
from keras import losses
from tensorflow.keras.optimizers import SGD
import wandb
from wandb.integration.keras import WandbCallback
import gc
#from numba import cuda
from tensorflow.keras.optimizers.schedules import ExponentialDecay
import nibabel as nib 
from tensorflow.data import AUTOTUNE
from sklearn.metrics import roc_auc_score, average_precision_score, balanced_accuracy_score



In [None]:
# reference volume for overlaying heatmaps

# Option 1: use the SIGMA / template volume 

# rabies_ref_path = r"F:/New data/sigma_files/SIGMA_resam_InVivo_Brain_Template_Masked.nii"
# rabies_ref = nib.load(rabies_ref_path).get_fdata()
# rabies_vol = np.mean(rabies_ref, axis=3)
# print("Template rabies_ref shape:", rabies_ref.shape)
# print("Template rabies_vol shape (mean over time):", rabies_vol.shape)

# Option 2: use one preprocessed RABIES functional run as background volume.
rabies_ref_path = (
    r"F:/rabies/preprocess_batch-001/commonspace_bold/"
    r"_scan_info_subject_id003.session01_split_name_sub-003_ses-01_desc-o_T2w/_run_None/"
    r"sub-003_ses-01_task-dist_desc-oa_bold_autobox_combined.nii.gz"
)
rabies_ref = nib.load(rabies_ref_path).get_fdata()
rabies_vol = np.mean(rabies_ref, axis=3)

print("rabies_ref shape:", rabies_ref.shape)
print("rabies_vol shape (mean over time):", rabies_vol.shape)


In [None]:

physical_devices = tf.config.list_physical_devices('GPU')
try:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  # Invalid device or cannot modify virtual devices once initialized.
  pass

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("Devices:", tf.config.list_physical_devices())
tf.debugging.set_log_device_placement(False)

In [None]:

# fucntion to filter heatmap (only keep high activations)

def filter_heatmap(heatmap, threshold=0.6):
  
    # copy to avoid modifying original
    filtered = np.copy(heatmap)

    # range [0, 1]
    if np.max(filtered) > 1:
        filtered = filtered / 255.0

    # Apply adaptive threshold
    dynamic_threshold = threshold * np.max(filtered)
    filtered[filtered < dynamic_threshold] = 0

    # normalize again to 0–255
    filtered = 255 * filtered / np.max(filtered) if np.max(filtered) > 0 else filtered

    filtered = np.uint8(filtered)

    return filtered


# Model

In [None]:
def VGG16_3D(blocks):
        
    inputs = Input(shape=(42, 65, 29), name='input_layer')
    x = Reshape(target_shape=[42, 65, 29, 1], name='input_x_3d_volumes')(inputs)

    if blocks == 1:
        print("entra al 1")
        #batch_norm
        x = BatchNormalization()(x)
        # 1st Conv Block
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x) #kernel_regularizer='l2')(x)
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.GlobalAveragePooling3D()(x)
        x = tf.keras.layers.Dropout(0.5)(x)
        
    elif blocks == 2:
        print("entra al 2")
        #batch_norm
        x = BatchNormalization()(x)
        # 1st Conv Block
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
            
        #batch_norm
        x = BatchNormalization()(x)
        # 2nd Conv Block
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.GlobalAveragePooling3D()(x)
        x = tf.keras.layers.Dropout(0.5)(x)
        
    elif blocks == 3:
        print("entra al 3")
        #batch_norm
        x = BatchNormalization()(x)
        # 1st Conv Block
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
            
        #batch_norm
        x = BatchNormalization()(x)
        # 2nd Conv Block
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
        
        #batch_norm
        x = BatchNormalization()(x)
        # 3rd Conv block  
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.GlobalAveragePooling3D()(x)
        x = tf.keras.layers.Dropout(0.5)(x)
        
    elif blocks == 4:
        print("entra al 4")
        #batch_norm
        x = BatchNormalization()(x)
        # 1st Conv Block
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
            
        #batch_norm
        x = BatchNormalization()(x)
        # 2nd Conv Block
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        #x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
        
        #batch_norm
        x = BatchNormalization()(x)
        # 3rd Conv block  
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        #x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.2)(x)
        
        #batch_norm
        x = BatchNormalization()(x)
        # 4th Conv block
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
        x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        #x = tf.keras.layers.GlobalAveragePooling3D()(x)
        x = tf.keras.layers.Dropout(0.5)(x)

    elif blocks == 5:
        print("entra al 5")
        #batch_norm
        x = BatchNormalization()(x)
        # 1st Conv Block
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x) #l2=0.05
        x = Conv3D(filters =64, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        #x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.3)(x)
            
        #batch_norm
        x = BatchNormalization()(x)
        # 2nd Conv Block
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        x = Conv3D (filters =128, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        #x = MaxPooling3D(pool_size =2, strides =1, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.3)(x)
        
        #batch_norm
        x = BatchNormalization()(x)
        # 3rd Conv block  
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x) 
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x) 
        x = Conv3D (filters =256, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x) 
        #x = MaxPooling3D(pool_size =2, strides =1, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.3)(x)
        
        #batch_norm
        x = BatchNormalization()(x)
        # 4th Conv block
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        #x = MaxPooling3D(pool_size =2, strides =1, padding ='same')(x)
        x = tf.keras.layers.Dropout(0.3)(x)
    
        #batch_norm
        x = BatchNormalization()(x)
        # 5th Conv block
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        x = Conv3D (filters =512, kernel_size =3, padding ='same', activation='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
        #x = MaxPooling3D(pool_size =2, strides =2, padding ='same')(x)
        x = tf.keras.layers.GlobalAveragePooling3D()(x)
        x = tf.keras.layers.Dropout(0.5)(x)
    
    # Fully connected layers  
    #x = Flatten()(x)  
    x = Dense(units = 4096, activation ='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
    x = tf.keras.layers.Dropout(0.5)(x) 
    x = Dense(units = 4096, activation ='relu',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    output = Dense(units = 2,activation ='softmax',kernel_regularizer=tf.keras.regularizers.L2(1e-4))(x)
    # creating the model
    VGG_3d_model = Model (inputs=inputs, outputs =output)
    #model.summary()

    return VGG_3d_model

def set_pretrained_weigths(VGG_3d_model):
    #VGG 16 with weights from Imagenet
    pretrained_model = tf.keras.applications.VGG16(
        include_top=False,
        weights="imagenet",
        pooling='avg',
        input_shape = (42, 65, 3)
    )
    
    #conv layers on VGG_3d_model
    layers_conv = []
    for j in range(len(VGG_3d_model.layers)):
        if "conv3d" in str(VGG_3d_model.layers[j]):
            layers_conv.append(j)
    layers_conv_pretrained = []
    for j in range(len(pretrained_model.layers)):
        if "Conv2D" in str(pretrained_model.layers[j]):
            layers_conv_pretrained.append(j)
    
    for i in range(len(layers_conv)):
        if "Conv2D" in str(pretrained_model.layers[layers_conv_pretrained[i]]):
            if i == 0:
                w = pretrained_model.layers[layers_conv_pretrained[i]].get_weights()[0].sum(axis=2, keepdims=True)
            else:
                w = pretrained_model.layers[layers_conv_pretrained[i]].get_weights()[0]
                
            w3d=[]
            
            w = np.reshape(w,(3,3,-1),order='F')
            for j in range(len(w[0,0,:])):
                for k in range(3):
                    w3d.append(w[:,:,j])
            w3d = np.transpose(w3d, (1,2,0))
            
            new_weights = np.reshape(w3d, np.array(VGG_3d_model.layers[layers_conv[i]].get_weights()[0]).shape,order='F')
            new_bias = pretrained_model.layers[layers_conv_pretrained[i]].get_weights()[1]
            
            WnB = []
            WnB.append(new_weights)
            WnB.append(new_bias)
    
            VGG_3d_model.layers[layers_conv[i]].set_weights(WnB)

    del pretrained_model, w, WnB, new_weights, new_bias, w3d

# Metrics

In [None]:
def confusionmatrix_multiclass(y_test,pred):
    cm = confusion_matrix(y_test, (np.rint(preds)).astype(int) )
    group_names = ['True baseline','False Baseline','False Baseline',   
                   'False week 1','Truec','False Week 1',
                  'False week 7','False week 7','True week 7']
    group_counts = ["{0:0.0f}".format(value) for value in
                    cm.flatten()]
    group_percentages = ["{0:.2%}".format(value) for value in
                         np.ndarray.flatten(cm/(np.sum(cm,axis=1).reshape(3,1)))]
    labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in
              zip(group_names,group_counts,group_percentages)]
    labels = np.asarray(labels).reshape(3,3)
    sns.heatmap(cm, annot=labels, fmt='', cmap='Blues', xticklabels = ['Baseline','Week 1','Week 7'] ,yticklabels = ['Baseline','Week 1','Week 7'])
    plt.show()


def confusionmatrix(y_test, preds, xticklabels=('BL','W7'), yticklabels=('BL','W7'), **kwargs):
    import numpy as np, matplotlib.pyplot as plt, seaborn as sns
    from sklearn.metrics import confusion_matrix

    y_true = np.asarray(y_test).astype(int).ravel()
    y_pred = np.asarray(preds).astype(int).ravel()

    # Reenvía cualquier kwarg, p.ej. labels=(0,1)
    cm = confusion_matrix(y_true, y_pred, **kwargs)

    # Evitar divisiones por cero al calcular porcentajes por fila
    row_sum = cm.sum(axis=1, keepdims=True)
    safe_row_sum = np.where(row_sum == 0, 1, row_sum)
    perc = np.divide(cm, safe_row_sum, where=(safe_row_sum != 0))

    labels_txt = [f"{c}\n{p:.2%}" for c, p in zip(cm.flatten(), perc.flatten())]
    labels_txt = np.asarray(labels_txt).reshape(cm.shape)

    fig, ax = plt.subplots(figsize=(4.5,4))
    sns.heatmap(cm, annot=labels_txt, fmt='', cmap='Blues',
                xticklabels=xticklabels, yticklabels=yticklabels, ax=ax)
    ax.set_xlabel("Predicted"); ax.set_ylabel("True"); ax.set_title("Confusion Matrix")
    plt.tight_layout()
    return fig, cm


    
def confusionmatrix_binary(y_test, preds):
    cm = confusion_matrix(y_test, preds)
    group_names = ['True baseline','False baseline','False Week 1','True Week 1']
    group_counts = ["{0:0.0f}".format(value) for value in
                    cm.flatten()]
    group_percentages = ["{0:.2%}".format(value) for value in
                         np.ndarray.flatten(cm/(np.sum(cm,axis=1).reshape(2,1)))]
    labels = [f"{v1}\n{v2}\n{v3}" for v1, v2, v3 in
              zip(group_names,group_counts,group_percentages)]
    labels = np.asarray(labels).reshape(2,2)
    sns.heatmap(cm, annot=labels, fmt='', cmap='Blues', xticklabels = ['BL','W7'] ,yticklabels = ['BL','W7'])
    plt.show()

def ROC(probs,y_test): #binary
    #Classification Area under curve
     warnings.filterwarnings('ignore')
             
     auc = roc_auc_score(y_test, probs)
     print('AUC - Test Set: %.2f%%' % (auc*100))
    
     # calculate roc curve
     fpr, tpr, thresholds = roc_curve(y_test, probs)
     # plot no skill
     plt.plot([0, 1], [0, 1], linestyle='--')
     # plot the roc curve for the model
     plt.plot(fpr, tpr, marker='.')
     plt.xlabel('False positive rate')
     plt.ylabel('Sensitivity/ Recall')
     # show the plot
     plt.show()
    
     probs = (np.rint(probs)).astype(int)   
        
     precision = precision_score(y_test, probs)
     print('Precision: %f' % precision)
     # recall: tp / (tp + fn)
     recall = recall_score(y_test, probs)
     print('Recall: %f' % recall)
     # f1: tp / (tp + fp + fn)
     f1 = f1_score(y_test, probs)
     print('F1 score: %f' % f1)
        
def ROC_multiclass(model, y_test, n_class):
    #y_test: array size (# of subjects, ) with classes 
    #pretrained model to be evaluated 
    
    label_binarizer = LabelBinarizer().fit(y_test)
    y_onehot_test = label_binarizer.transform(y_test)
    y_onehot_test.shape  # (n_samples, n_classes)

    y_score = model.predict(X_test) # y_score is onehot
    
    # store the fpr, tpr, and roc_auc for all averaging strategies
    fpr, tpr, roc_auc = dict(), dict(), dict()
    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_onehot_test.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    print(f"Micro-averaged One-vs-Rest ROC AUC score:\n{roc_auc['micro']:.2f}")
    
    n_classes = n_class
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_onehot_test[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    fpr_grid = np.linspace(0.0, 1.0, 1000)

    # Interpolate all ROC curves at these points
    mean_tpr = np.zeros_like(fpr_grid)

    for i in range(n_classes):
        mean_tpr += np.interp(fpr_grid, fpr[i], tpr[i])  # linear interpolation

    # Average it and compute AUC
    mean_tpr /= n_classes

    fpr["macro"] = fpr_grid
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    print(f"Macro-averaged One-vs-Rest ROC AUC score:\n{roc_auc['macro']:.2f}")
    
    target_names = ['Naive','Week1','Week7']

    fig, ax = plt.subplots(figsize=(6, 6))

    plt.plot(
        fpr["micro"],
        tpr["micro"],
        label=f"micro-average ROC curve (AUC = {roc_auc['micro']:.2f})",
        color="deeppink",
        linestyle=":",
        linewidth=4,
    )

    plt.plot(
        fpr["macro"],
        tpr["macro"],
        label=f"macro-average ROC curve (AUC = {roc_auc['macro']:.2f})",
        color="navy",
        linestyle=":",
        linewidth=4,
    )

    colors = cycle(["aqua", "darkorange", "cornflowerblue"])
    for class_id, color in zip(range(n_classes), colors):
        RocCurveDisplay.from_predictions(
            y_onehot_test[:, class_id],
            y_score[:, class_id],
            name=f"ROC curve for {target_names[class_id]}",
            color=color,
            ax=ax,
            plot_chance_level=(class_id == 2),
        )

    _ = ax.set(
        xlabel="False Positive Rate",
        ylabel="True Positive Rate",
        title="Extension of Receiver Operating Characteristic\nto One-vs-Rest multiclass",
    )
    
# plot diagnostic learning curves
def summarize_diagnostics(histories):
    c = ['b','g','r','c','m','y','k','w']
    ltr = ['fold 1(train)','fold 2(train)','fold 3(train)','fold 4(train)','fold 5(train)']
    lts = ['fold 1(val)','fold 2(val)','fold 3(val)','fold 4(val)','fold 5(val)']
    for i in range(len(histories)):
        # plot loss
        plt.subplot(2, 1, 1)
        plt.title('Cross Entropy Loss')
        plt.plot(histories[i].history['loss'], color=c[i], label=ltr[i], linestyle="-")
        plt.plot(histories[i].history['val_loss'], color=c[i], label=lts[i], linestyle="--")
        # plot accuracy
        plt.subplot(2, 1, 2)
        plt.title('Classification Accuracy')
        plt.plot(histories[i].history['accuracy'], color=c[i], label=ltr[i], linestyle="-")
        plt.plot(histories[i].history['val_accuracy'], color=c[i], label=lts[i], linestyle="--")
    plt.legend()
    plt.show()

# summarize model performance
def summarize_performance(scores):
    # print summary
    print('Accuracy: mean=%.3f std=%.3f, n=%d' % (np.mean(scores)*100, np.std(scores)*100, len(scores)))
    # box and whisker plots of results
    plt.boxplot(scores)
    plt.show() 

# Just brain. Female. Naive vs CPH
1) Naive (CPH_BL)
2) CPH (CPH_W7)

In [None]:
female = [49,50,51,52,65,66,77,78,79,80,81,82,83]

y_female = np.ones(len(female))

subjects = np.array(female)
labels = np.array(list(y_female))
sessions = [1,3]
MRI_type = "func"
functional_type = "rest"

In [None]:
# make sure to login to your account first 
wandb.login()



In [None]:
#Metrics, dataset bencchmark and ETA callbacks:

import time
import gc
import os
from datetime import datetime

import tensorflow as tf
import numpy as np

# Binary AUC that takes the positive column from a 2-class softmax
class AUCPos(tf.keras.metrics.Metric):
    
   # Wrapper around tf.keras.metrics.AUC for binary problems 

    def __init__(self, name="auc", curve="ROC", **kwargs):
        super().__init__(name=name, **kwargs)
        self._auc = tf.keras.metrics.AUC(curve=curve, name=name)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # y_true: integer labels 0/1 → float vector (B,)
        y_true = tf.cast(tf.reshape(y_true, (-1,)), tf.float32)

        # y_pred: if shape is (B, 2), take column 1 (positive class)
        if tf.rank(y_pred) > 1:
            y_pred = y_pred[..., 1]
        y_pred = tf.cast(tf.reshape(y_pred, (-1,)), tf.float32)

        return self._auc.update_state(y_true, y_pred, sample_weight)

    def result(self):
        return self._auc.result()

    def reset_states(self):
        self._auc.reset_states()


class SparseCatAccFixed(tf.keras.metrics.SparseCategoricalAccuracy):
    '''
    SparseCategoricalAccuracy that always flattens y_true to shape (B,).
    This avoids shape mismatches when labels come as (B, 1).
    ''' 
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.reshape(tf.cast(y_true, tf.int64), (-1,))
        return super().update_state(y_true, y_pred, sample_weight)


def benchmark_ds(ds, n_steps=50, label="train"):
    '''
    it measures the seconds per batch for a tf.data.Dataset, without running
    the model. It simply iterates over 'n_steps' batches 
    '''
    try:
        it = iter(ds)
        # warm-up batch (forces the pipeline to start)
        _ = next(it)
    except StopIteration:
        print(f"[bench/{label}] empty dataset")
        return None

    t0 = time.perf_counter()
    got = 0
    for _ in range(n_steps):
        try:
            xb, yb = next(it)
            got += 1
        except StopIteration:
            break

    dt = time.perf_counter() - t0
    if got == 0:
        print(f"[bench/{label}] could not iterate")
        return None

    sec_per_batch = dt / got
    print(f"[bench/{label}] {got} batches in {dt:.2f}s  ⇒  {sec_per_batch:.3f} s/batch")
    return sec_per_batch


class BatchTime(tf.keras.callbacks.Callback):
    '''
    smple callback that prints an approximate ETA for each epoch based
    on the observed average time per batch.
    '''
    def __init__(self, steps_per_epoch):
        super().__init__()
        self.steps_per_epoch = steps_per_epoch
        self.t0 = None

    def on_train_begin(self, logs=None):
        self.t0 = time.perf_counter()

    def on_train_batch_end(self, batch, logs=None):
        passed = time.perf_counter() - self.t0
        done = batch + 1
        sec_per_batch = passed / max(1, done)
        remaining = (self.steps_per_epoch - done) * sec_per_batch

        # Print only at the beginning and then about 10 times per epoch
        if done <= 5 or done % max(1, self.steps_per_epoch // 10) == 0:
            print(
                f"[ETA] step {done}/{self.steps_per_epoch} | "
                f"{sec_per_batch:.2f}s/step | ETA {remaining/3600:.2f}h"
            )


Training 

In [None]:

USE_BOOTSTRAP = False
N_BOOTSTRAPS  = 1
START_BOOT_AT = 1  # bootstrap would start at run 2 if enabled

# tf.data parallelism (tuned for Windows / external disk)
PAR = 2
PREFETCH_BUF = PAR

# Volumes per session used per epoch (here we keep a fixed value)
EPOCH_VOL_SCHEDULE = [120]
VOLS_TEST = 120

for boot in range(N_BOOTSTRAPS):
    do_bootstrap = (USE_BOOTSTRAP and boot >= START_BOOT_AT)
    print(f"\n=== Run {boot+1}/{N_BOOTSTRAPS} | bootstrap={do_bootstrap} ===")

    if do_bootstrap:
        boot_subjects, boot_labels = resample(
            subjects, labels, replace=True, random_state=42 + boot
        )
    else:
        boot_subjects, boot_labels = subjects, labels

    # Split subject-level data into train+val vs test
    sub_trainval, sub_test, y_trainval, y_test = train_test_split(
        boot_subjects,
        boot_labels,
        test_size=0.2,
        random_state=42,
        stratify=boot_labels,
    )

    kfold = StratifiedKFold(n_splits=2, shuffle=True, random_state=1)

    scores = []
    histories = []
    run = 1

    for train_ix, val_ix in kfold.split(sub_trainval, y_trainval):
        print("Run #", run)

        sub_train, sub_val = sub_trainval[train_ix], sub_trainval[val_ix]
        y_train,  y_val   = y_trainval[train_ix],  y_trainval[val_ix]

        # Weights & Biases configuration (tracking hyperparameters and metrics)
        config_defaults = {"batch": 30}
        wandb.init(
            project="FEMALE_Naive_vs_CPH(BLvsW7)",
            notes="tf.data 3D mean; (42,65,29,1); SparseCE probs; flips 3D; AUCPos; ETA visible",
            config=config_defaults,
        )
        wandb.config.epochs = 10
        wandb.config.sub_batch = 8
        wandb.config.sub_batch_group = 1
        wandb.config.sub_batch_ts = 8
        wandb.config.subjects = subjects
        wandb.config.architecture_name = "VGG16_3D"
        wandb.config.dataset_name = "NAIVE vs CPH (BL vs W7)"
        wandb.config.CNN_blocks = 5
        wandb.config.sessions = sessions
        wandb.config.vols_per_session_tr = EPOCH_VOL_SCHEDULE[-1]
        wandb.config.vols_per_session_ts = VOLS_TEST
        wandb.config.initial_learning_rate = 1e-5
        wandb.config.optimizer = "Adam"

        # Cropping coordinates and label task
        CROP6 = (3, 4, 7, 45, 69, 36)
        TASK_NAME = "bl_vs_w7"  # BL=0, W7=1
        SEED = 42

        SUBBATCH        = int(wandb.config.sub_batch)
        SUBBATCH_GROUP  = int(wandb.config.sub_batch_group)
        EFFECTIVE_BATCH = SUBBATCH * SUBBATCH_GROUP

        # Build file lists for each split (REST functional runs)
        CPHclassTrain = FILES_and_LABELS(sub_train, sessions, MRI_type, "rest")
        CPHclassVal   = FILES_and_LABELS(sub_val,   sessions, MRI_type, "rest")
        CPHclassTest  = FILES_and_LABELS(sub_test,  sessions, MRI_type, "rest")

        X_train = CPHclassTrain.get_mask_and_bold()
        X_val   = CPHclassVal.get_mask_and_bold()
        X_test  = CPHclassTest.get_mask_and_bold()

        train_pairs = X_train
        val_pairs   = X_val
        test_pairs  = X_test

        print("Train sessions:", np.array(X_train)[:, 0])
        print("Test sessions:",  np.array(X_test)[:, 0])
        print("Val sessions:",   np.array(X_val)[:, 0])
        print("# Train sessions:", len(X_train))
        print("# Test sessions:",  len(X_test))
        print("# Val sessions:",   len(X_val))

        # Build and compile the 3D VGG16 model
        print("Starting VGG16_3D -----------------------------------------------------")
        tf.keras.backend.clear_session()
        gc.collect()

        CNN = VGG16_3D(5)
        print("CNN input shape:", CNN.input_shape)
        set_pretrained_weigths(CNN)

        CNN.compile(
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
            optimizer=tf.keras.optimizers.Adam(
                learning_rate=wandb.config.initial_learning_rate
            ),
            metrics=[
                SparseCatAccFixed(name="acc"),
                AUCPos(name="auc", curve="ROC"),
            ],
        )

        checkpoint_filepath = os.path.join(os.getcwd(), wandb.run.name)

        # Callbacks: combined metric, early stopping, model checkpointing
        acc_loss_rate = CombineCallback()
        early_stopping = tf.keras.callbacks.EarlyStopping(
            monitor="val_auc",
            mode="max",
            patience=8,
            restore_best_weights=True,
            verbose=1,
        )
        model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_filepath,
            monitor="val_auc",
            mode="max",
            save_best_only=True,
            save_weights_only=True,
            verbose=1,
        )

        start_time = datetime.now()
        history = None

        # Epoch-by-epoch training loop (rebuilds tf.data datasets every epoch)
        for e in range(wandb.config.epochs):
            print(f"\n=== Epoch {e+1}/{wandb.config.epochs} ===")

            # Volumes per session for this epoch (can be scheduled if needed)
            VOLS_PER_SESSION_EPOCH = EPOCH_VOL_SCHEDULE[
                min(e, len(EPOCH_VOL_SCHEDULE) - 1)
            ]

            # Steps per epoch for train and validation
            batches_per_session = max(1, VOLS_PER_SESSION_EPOCH // SUBBATCH)
            steps_per_epoch     = len(train_pairs) * batches_per_session
            validation_steps    = len(val_pairs)   * batches_per_session

            print(
                f"[config] SUBBATCH={SUBBATCH} | GROUP={SUBBATCH_GROUP} | "
                f"EFFECTIVE_BATCH={EFFECTIVE_BATCH} | "
                f"VOLS_PER_SESSION_EPOCH={VOLS_PER_SESSION_EPOCH}"
            )
            print(
                "steps_per_epoch:",
                steps_per_epoch,
                "| validation_steps:",
                validation_steps,
            )

            # Build epoch-specific datasets with tf.data
            train_ds = make_epoch_ds(
                train_pairs,
                training=True,
                epoch=e,
                par=PAR,
                prefetch_buf=PREFETCH_BUF,
                seed=SEED,
                subbatch=SUBBATCH,
                vols_per_session_epoch=VOLS_PER_SESSION_EPOCH,
                crop_idx6=CROP6,
                task_name=TASK_NAME,
            )
            val_ds = make_epoch_ds(
                val_pairs,
                training=False,
                epoch=e,
                par=PAR,
                prefetch_buf=PREFETCH_BUF,
                seed=SEED,
                subbatch=SUBBATCH,
                vols_per_session_epoch=VOLS_PER_SESSION_EPOCH,
                crop_idx6=CROP6,
                task_name=TASK_NAME,
            )

            # Full-loader benchmark only on the first epoch
            if e == 0:
                t0 = time.perf_counter()
                it = iter(train_ds)
                n_batches = 0
                for _ in range(steps_per_epoch):
                    try:
                        xb, yb = next(it)
                        # Force some computation so the pipeline really executes
                        _ = tf.reduce_mean(xb).numpy()
                        n_batches += 1
                    except StopIteration:
                        break
                dt = time.perf_counter() - t0
                if n_batches > 0:
                    print(
                        f"[FULL-LOADER tf.data] {n_batches} batches in {dt:.2f}s "
                        f"⇒ {dt / n_batches:.3f} s/batch"
                    )

            # Optional: tf.data options (non-deterministic for speed)
            opts = tf.data.Options()
            opts.experimental_deterministic = False
            opts.experimental_slack = True
            train_ds = train_ds.with_options(opts)
            val_ds   = val_ds.with_options(opts)

            # Cardinality (may be -1 if unknown)
            try:
                card_tr = tf.data.experimental.cardinality(train_ds).numpy()
            except Exception:
                card_tr = -1
            try:
                card_va = tf.data.experimental.cardinality(val_ds).numpy()
            except Exception:
                card_va = -1
            print(f"[cardinality] train: {card_tr} | val: {card_va}")

            # Lightweight benchmark of the loader
            spb_data = benchmark_ds(
                train_ds,
                n_steps=min(50, max(10, steps_per_epoch // 20)),
                label="train",
            )
            if spb_data is not None:
                print(
                    f"[est_ETA loader] ~{(spb_data * steps_per_epoch) / 3600:.2f} h/epoch"
                )

            # Sanity check: one batch from the dataset
            xb, yb = next(iter(train_ds))
            print(f"[sanity] X {xb.shape} {xb.dtype} | y {yb.shape} {yb.dtype}")
            # Expected: X (B, 42, 65, 29, 1), y (B,)

            # ETA callback (uses steps_per_epoch)
            eta_cb = BatchTime(steps_per_epoch)

            callbacks = [
                WandbCallback(monitor="val_auc", mode="max", save_model=False),
                acc_loss_rate,
                early_stopping,
                model_checkpoint_callback,
                eta_cb,
            ]

            history = CNN.fit(
                train_ds,
                epochs=e + 1,
                initial_epoch=e,
                validation_data=val_ds,
                steps_per_epoch=steps_per_epoch,
                validation_steps=validation_steps,
                shuffle=False,
                callbacks=callbacks,
                verbose=1,
            )

        # Build full test dataset (all volumes per session)
        test_ds = make_full_ds(
            test_pairs,
            subbatch=int(wandb.config.sub_batch_ts),
            crop_idx6=CROP6,
            task_name=TASK_NAME,
            par=PAR,
            prefetch_buf=PREFETCH_BUF,
        )

        # For consistency, you could also flatten labels here if needed
        test_ds = test_ds.map(
            lambda x, y: (x, tf.reshape(tf.cast(y, tf.int32), ())),
            num_parallel_calls=PAR,
        )

        test_opts = tf.data.Options()
        test_opts.experimental_deterministic = True
        test_opts.experimental_slack = True
        test_ds = test_ds.with_options(test_opts)

        loss, acc, auc_val = CNN.evaluate(test_ds, verbose=1)
        print(
            "Test — Loss: {:.4f} | Acc: {:.2f}% | AUC: {:.4f}".format(
                loss, acc * 100, auc_val
            )
        )

        #confusion matrices,
        # Grad-CAM visualizations and W&B logging.
        
       

        scores.append(acc)
        histories.append(history)
        run += 1

    print("histories and scores from VGG16_3D")
    summarize_diagnostics(histories)
    summarize_performance(scores)
    wandb.finish()
