In [1]:
import os
import numpy as np
import pandas as pd
import librosa
import torch
import warnings
import matplotlib.pyplot as plt
import librosa.display
from sklearn.preprocessing import minmax_scale
import IPython.display as ipd
import torchaudio.transforms as T
from sklearn.preprocessing import StandardScaler
import copy
from sklearn.model_selection import train_test_split
import torch.optim as optim
import random
from sklearn.manifold import TSNE
import math
import re
import IPython
from IPython.display import Audio
from IPython.display import Image
from sklearn.model_selection import StratifiedGroupKFold
import torchaudio
import torchaudio.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import Dataset
from sklearn.model_selection import KFold
import seaborn as sn
import sklearn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, roc_auc_score, roc_curve
from torch.autograd import Function
from imblearn.metrics import sensitivity_specificity_support
from sklearn.model_selection import StratifiedShuffleSplit
import shutil
import argparse
import torch.nn as nn
from scipy import signal
import itertools

# data preparation

In [None]:
SPLIT_DIR = "./data/CV_FOLDS/"
SAMPLE_RATE = 16000

def load_audio_data_with_labels(csv_path, sample_rate=16000):
    df = pd.read_csv(csv_path)
    audio_data = []
    for idx, row in df.iterrows():
        file_path = row['audio_path']
        try:
            audio, sr = librosa.load(file_path, sr=sample_rate)
            audio_data.append({
                "audio": torch.tensor(audio, dtype=torch.float32),  
                "speaker_id": row['speaker_id'],
                "status": 1 if row['status'] == 'pd' else 0,  
                "sex": 1 if row['SEX'] == 'M' else 0,        
                "file_name": file_path
            })
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    return audio_data


def preprocess_audio_with_labels(signals, gender_labels, health_labels, sample_rate, file_names=None):
    selected_segments = []
    segment_gender_labels = []
    segment_health_labels = []
    segment_file_names = []
    skipped_quiet = 0

    for audio, gender_label, health_label, file_name in zip(signals, gender_labels, health_labels, file_names):

        max_amp = torch.max(torch.abs(audio))
        if max_amp < 1e-6:  # Skip near-silent audio
            skipped_quiet += 1
            print(f"Skipping near silent audio: {file_name}")
            continue
        audio = audio / max_amp  

        selected_segments.append(audio)
        segment_gender_labels.append(gender_label)
        segment_health_labels.append(health_label)
        segment_file_names.append(file_name)

    return selected_segments, segment_gender_labels, segment_health_labels, segment_file_names



all_folds_X_train = []
all_folds_yg_train = []
all_folds_yh_train = []
all_folds_filenames_train = []

all_folds_X_test = []
all_folds_yg_test = []
all_folds_yh_test = []
all_folds_filenames_test = []


all_folds_X_val = []
all_folds_yg_val = []
all_folds_yh_val = []
all_folds_filenames_val = []



for fold in range(1, 11):
    fold_dir = os.path.join(SPLIT_DIR, f"FOLD_{fold}")
    train_path = os.path.join(fold_dir, "train.csv")
    val_path = os.path.join(fold_dir, "val.csv")  
    test_path = os.path.join(fold_dir, "test.csv")
    
    train_data = load_audio_data_with_labels(train_path, SAMPLE_RATE)
    train_signals = [item["audio"] for item in train_data]
    train_gender_labels = [item["sex"] for item in train_data]
    train_health_labels = [item["status"] for item in train_data]
    train_file_names = [item["file_name"] for item in train_data]
    
    X_train_processed, yg_train, yh_train, filenames_train = preprocess_audio_with_labels(
        train_signals, train_gender_labels, train_health_labels, SAMPLE_RATE, file_names=train_file_names
    )
    
    all_folds_X_train.append(X_train_processed)
    all_folds_yg_train.append(yg_train)
    all_folds_yh_train.append(yh_train)
    all_folds_filenames_train.append(filenames_train)
    
    print(f"Fold {fold} - Processed {len(X_train_processed)} audio segments for training.")
    
    val_data = load_audio_data_with_labels(val_path, SAMPLE_RATE)
    val_signals = [item["audio"] for item in val_data]
    val_gender_labels = [item["sex"] for item in val_data]
    val_health_labels = [item["status"] for item in val_data]
    val_file_names = [item["file_name"] for item in val_data]
    
    X_val_processed, yg_val, yh_val, filenames_val = preprocess_audio_with_labels(
        val_signals, val_gender_labels, val_health_labels, SAMPLE_RATE, file_names=val_file_names
    )
    
    all_folds_X_val.append(X_val_processed)
    all_folds_yg_val.append(yg_val)
    all_folds_yh_val.append(yh_val)
    all_folds_filenames_val.append(filenames_val)

    print(f"Fold {fold} - Processed {len(X_val_processed)} audio segments for validation.")
    
    test_data = load_audio_data_with_labels(test_path, SAMPLE_RATE)
    test_signals = [item["audio"] for item in test_data]
    test_gender_labels = [item["sex"] for item in test_data]
    test_health_labels = [item["status"] for item in test_data]
    test_file_names = [item["file_name"] for item in test_data]
    
    X_test_processed, yg_test, yh_test, filenames_test = preprocess_audio_with_labels(
        test_signals, test_gender_labels, test_health_labels, SAMPLE_RATE, file_names=test_file_names
    )
    
    all_folds_X_test.append(X_test_processed)
    all_folds_yg_test.append(yg_test)
    all_folds_yh_test.append(yh_test)
    all_folds_filenames_test.append(filenames_test)
    
    print(f"Fold {fold} - Processed {len(X_test_processed)} audio segments for testing.")


# X vector extraction

In [None]:

from speechbrain.inference.speaker import EncoderClassifier
classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", 
                                            savedir="pretrained_models/spkrec-xvect-voxceleb")

In [None]:
def X_vector(signals):
    x_vector = []
    
    for i, audio in enumerate(signals):
        if audio.ndim == 1:
            audio = audio.unsqueeze(0)

        embeddings = classifier.encode_batch(audio)
        embedding_shape = embeddings.shape
        squeezed_embedding = embeddings.squeeze(dim=0).squeeze(dim=0)
        if squeezed_embedding.shape[0] != 512:
            print(f"Warning: Unexpected embedding size for sample {i} after squeezing with shape {squeezed_embedding.shape}")

        x_vector.append(squeezed_embedding.detach().cpu().numpy())

    x_vector = np.stack(x_vector, axis=0)
    return x_vector


def standardize_data(train_xvectors, val_xvectors, test_xvectors):
    
    mu = np.mean(train_xvectors, axis=0)
    std_dev = np.std(train_xvectors, axis=0)

    std_dev[std_dev == 0] = 1e-6
    
    standardized_train = (train_xvectors - mu) / std_dev
    standardized_val   = (val_xvectors - mu) / std_dev
    standardized_test  = (test_xvectors - mu) / std_dev

    return standardized_train, standardized_val, standardized_test
    

all_folds_train_xvectors = []
all_folds_val_xvectors = []
all_folds_test_xvectors = []

if __name__ == "__main__":

    for fold in range(10):
        print(f"\nProcessing Fold {fold + 1}")

        X_train_processed = all_folds_X_train[fold]
        X_val_processed   = all_folds_X_val[fold]
        X_test_processed  = all_folds_X_test[fold]

    
        X_train_xvectors = X_vector(X_train_processed)
        X_val_xvectors   = X_vector(X_val_processed)
        X_test_xvectors  = X_vector(X_test_processed)

        print(f"Fold {fold + 1} - Before standardization:")
        print(f"  Train X-Vectors shape: {X_train_xvectors.shape}")
        print(f"  Val X-Vectors shape:   {X_val_xvectors.shape}")
        print(f"  Test X-Vectors shape:  {X_test_xvectors.shape}")

        X_train_xvectors, X_val_xvectors, X_test_xvectors = standardize_data(
            X_train_xvectors, X_val_xvectors, X_test_xvectors
        )

        print(f"Fold {fold + 1} - After standardization:")
        print(f"  Train X-Vectors shape: {X_train_xvectors.shape}")
        print(f"  Val X-Vectors shape:   {X_val_xvectors.shape}")
        print(f"  Test X-Vectors shape:  {X_test_xvectors.shape}")

        all_folds_train_xvectors.append(X_train_xvectors)
        all_folds_val_xvectors.append(X_val_xvectors)
        all_folds_test_xvectors.append(X_test_xvectors)

In [None]:
class XVectorDataset(Dataset):
    def __init__(self, x_vectors, gender_labels, health_labels, filenames):
        # Ensure all data is converted to appropriate tensor formats
        self.x_vectors = torch.tensor(x_vectors, dtype=torch.float32) if not isinstance(x_vectors, torch.Tensor) else x_vectors
        self.gender_labels = torch.tensor(gender_labels, dtype=torch.int64) if not isinstance(gender_labels, torch.Tensor) else gender_labels
        self.health_labels = torch.tensor(health_labels, dtype=torch.int64) if not isinstance(health_labels, torch.Tensor) else health_labels
        self.filenames = filenames

    def __len__(self):
        return len(self.x_vectors)

    def __getitem__(self, idx):
        x_vector = self.x_vectors[idx]
        gender_label = self.gender_labels[idx]
        health_label = self.health_labels[idx]
        filename = self.filenames[idx]

        return x_vector, gender_label, health_label, filename


def create_data_loader(x_vectors, gender_labels, health_labels, filenames, batch_size, shuffle=False):
    dataset = XVectorDataset(x_vectors, gender_labels, health_labels, filenames)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True)



# BATCH_SIZE = 4

# # Loop through folds
# for fold in range(10):
#     print(f"\nCreating DataLoaders for Fold {fold + 1}")

#     # --- TRAIN ---
#     X_train_xvectors = all_folds_train_xvectors[fold]
#     yg_train = all_folds_yg_train[fold]
#     yh_train = all_folds_yh_train[fold]
#     filenames_train = all_folds_filenames_train[fold]

#     train_loader = create_data_loader(
#         X_train_xvectors, yg_train, yh_train, filenames_train, BATCH_SIZE, shuffle=True
#     )

#     # --- VALIDATION (NEW) ---
#     X_val_xvectors = all_folds_val_xvectors[fold]
#     yg_val = all_folds_yg_val[fold]
#     yh_val = all_folds_yh_val[fold]
#     filenames_val = all_folds_filenames_val[fold]

#     val_loader = create_data_loader(
#         X_val_xvectors, yg_val, yh_val, filenames_val, BATCH_SIZE, shuffle=False
#     )

#     # --- TEST ---
#     X_test_xvectors = all_folds_test_xvectors[fold]
#     yg_test = all_folds_yg_test[fold]
#     yh_test = all_folds_yh_test[fold]
#     filenames_test = all_folds_filenames_test[fold]

#     test_loader = create_data_loader(
#         X_test_xvectors, yg_test, yh_test, filenames_test, BATCH_SIZE, shuffle=False
#     )

#     print(f"Train DataLoader for Fold {fold + 1} has {len(train_loader)} batches.")
#     print(f"Val DataLoader   for Fold {fold + 1} has {len(val_loader)} batches.")
#     print(f"Test DataLoader  for Fold {fold + 1} has {len(test_loader)} batches.")



# gender classifier

In [40]:
import torch
import torch.nn as nn
import torch.optim as optim

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



class GenderClassifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GenderClassifier, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, output_dim)
        self.sigm = torch.nn.Sigmoid()
    
    def forward(self, x):
        x = self.sigm(self.linear1(x))
        return x



# class GenderClassifier(nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super(GenderClassifier, self).__init__()
#         self.linear1 = nn.Linear(input_dim, 128)
#         self.relu = nn.ReLU()
#         self.linear2 = nn.Linear(128, output_dim)
#         self.sigm = nn.Sigmoid()  
    
#     def forward(self, x):
#         x = self.relu(self.linear1(x))
#         x = self.sigm(self.linear2(x))
#         return x






In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from sklearn.metrics import roc_auc_score
# import matplotlib.pyplot as plt


# input_dim = 512  
# output_dim = 1  
# num_epochs = 20
# BATCH_SIZE = 4
# learning_rate = 0.0001
# patience = 5              


# all_fold_train_losses = []
# all_fold_val_losses = []       
# all_fold_val_roc_aucs = []      
# all_fold_roc_auc_scores = []    

# all_fold_best_val_aucs = []


# all_true_labels = []
# all_predictions = []



# for fold in range(5):
#     print(f"\nStarting fold {fold + 1}/5")
    
   
#     train_loader = create_data_loader(
#         all_folds_train_xvectors[fold], all_folds_yg_train[fold], all_folds_yh_train[fold],
#         all_folds_filenames_train[fold], BATCH_SIZE, shuffle=True
#     )

#     val_loader = create_data_loader(
#         all_folds_val_xvectors[fold],
#         all_folds_yg_val[fold],
#         all_folds_yh_val[fold],
#         all_folds_filenames_val[fold],
#         BATCH_SIZE, shuffle=False
#     )

    
#     test_loader = create_data_loader(
#         all_folds_test_xvectors[fold], all_folds_yg_test[fold], all_folds_yh_test[fold],
#         all_folds_filenames_test[fold], BATCH_SIZE, shuffle=False
#     )

    
#     model = GenderClassifier(input_dim, output_dim).to(device)
#     criterion = nn.BCELoss()
#     optimizer = optim.Adam(model.parameters(), lr=learning_rate)

   
#     fold_train_losses = []
#     fold_val_losses = []
#     fold_val_aucs = []

    
#     best_val_auc = 0.0
#     best_model_weights = None

#     for epoch in range(num_epochs):
#         model.train()
#         running_loss = 0.0

#         for x_vectors, gender_labels, _, _ in train_loader:
#             x_vectors = x_vectors.to(device)
#             gender_labels = gender_labels.unsqueeze(1).float().to(device)

#             optimizer.zero_grad()
#             outputs = model(x_vectors)
#             loss = criterion(outputs, gender_labels)
#             loss.backward()
#             optimizer.step()

#             running_loss += loss.item()

        
#         train_loss = running_loss / len(train_loader)
#         fold_train_losses.append(train_loss)


#         model.eval()
#         val_running_loss = 0.0
#         val_true_labels = []
#         val_predictions = []

#         with torch.no_grad():
#             for x_vectors, gender_labels, _, _ in val_loader:
#                 x_vectors = x_vectors.to(device)
#                 gender_labels = gender_labels.unsqueeze(1).float().to(device)
#                 outputs = model(x_vectors)

#                 val_loss = criterion(outputs, gender_labels)
#                 val_running_loss += val_loss.item()

#                 val_predictions.extend(outputs.cpu().numpy())
#                 val_true_labels.extend(gender_labels.cpu().numpy())

#         val_loss_epoch = val_running_loss / len(val_loader) if len(val_loader) > 0 else 0
#         val_roc_auc = roc_auc_score(val_true_labels, val_predictions)
        
#         fold_val_losses.append(val_loss_epoch)
#         fold_val_aucs.append(val_roc_auc)


#         if val_roc_auc > best_val_auc:
#             best_val_auc = val_roc_auc
#             best_model_weights = copy.deepcopy(model.state_dict())


#         print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs}, "
#               f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss_epoch:.4f}, Val ROC: {val_roc_auc:.4f}")        

#     all_fold_train_losses.append(fold_train_losses)
#     all_fold_val_losses.append(fold_val_losses)
#     all_fold_val_roc_aucs.append(fold_val_aucs)
#     all_fold_best_val_aucs.append(best_val_auc)

#     final_model_path = f"model_gender2layers_fold_{fold+1}_finalNVZ.pth"
#     torch.save(model.state_dict(), final_model_path)
#     print(f"Saved FINAL model for Fold {fold+1} as '{final_model_path}'")  

#     if best_model_weights is not None:
#         best_model_path = f"model_gender2layers_fold_{fold+1}_bestValNVZ.pth"
#         torch.save(best_model_weights, best_model_path)
#         print(f"Saved BEST-VAL model for Fold {fold+1} (AUC={best_val_auc:.4f}) as '{best_model_path}'")

    
    
#     # model.load_state_dict(best_model_weights)    
#     model.eval()
#     true_labels, predictions = [], []

#     with torch.no_grad():
#         for x_vectors, gender_labels, _, _ in test_loader:
#             x_vectors = x_vectors.to(device)
#             gender_labels = gender_labels.unsqueeze(1).float().to(device)
#             outputs = model(x_vectors)

#             predictions.extend(outputs.cpu().numpy())
#             true_labels.extend(gender_labels.cpu().numpy())

            
#     all_true_labels.extend(true_labels)
#     all_predictions.extend(predictions)  
   
#     roc_auc = roc_auc_score(true_labels, predictions)
#     all_fold_roc_auc_scores.append(roc_auc)
#     print(f"Fold {fold + 1}, Test ROC AUC: {roc_auc:.4f}")
    
#     fpr, tpr, _ = roc_curve(true_labels, predictions)
#     plt.figure()
#     plt.plot(fpr, tpr, label=f'Fold {fold + 1} ROC curve (area = {roc_auc:.4f})')
#     plt.plot([0, 1], [0, 1], 'k--')  
#     plt.xlabel('False Positive Rate')
#     plt.ylabel('True Positive Rate')
#     plt.title(f'ROC Curve for Fold {fold + 1}')
#     plt.legend(loc='lower right')
#     plt.show()


# avg_roc_auc = np.mean(all_fold_roc_auc_scores)
# std_roc_auc = np.std(all_fold_roc_auc_scores)
# print(f"\nAverage Test ROC AUC across folds: {avg_roc_auc:.4f} ± {std_roc_auc:.4f}")


# avg_best_val_auc = np.mean(all_fold_best_val_aucs)
# std_best_val_auc = np.std(all_fold_best_val_aucs)
# print(f"Average BEST Val AUC Across Folds: {avg_best_val_auc:.4f} ± {std_best_val_auc:.4f}")


# # plt.figure(figsize=(10, 5))
# # for fold_idx, (train_losses, val_losses) in enumerate(zip(all_fold_train_losses, all_fold_val_losses), 1):
# #     plt.plot(range(1, num_epochs + 1), train_losses, label=f'Fold {fold_idx} - Train')
# #     plt.plot(range(1, num_epochs + 1), val_losses, label=f'Fold {fold_idx} - Val')
# # plt.xlabel('Epoch')
# # plt.ylabel('Loss')
# # plt.title('Training & Validation Loss Curves for Each Fold')
# # plt.legend()
# # plt.show()


# all_true_labels = np.array(all_true_labels).flatten()
# all_predictions = np.array(all_predictions).flatten()

# male_scores = all_predictions[all_true_labels == 1]
# female_scores = all_predictions[all_true_labels == 0]

# plt.figure(figsize=(4, 4))
# plt.hist(female_scores, bins=100, alpha=0.9, color='blue', label='Female')
# plt.hist(male_scores, bins=100, alpha=0.9, color='red', label='Male')
# plt.xlabel('Predicted Probability')
# plt.ylabel('Frequency')
# plt.title('Posterior Probabilities Across All Folds')
# plt.legend()
# plt.show()

In [None]:
param_grid = {
    'learning_rate': [0.001, 0.0001, 0.00001],
    'batch_size': [4],
    'optimizer': ['adam', 'sgd']
   
}
param_combinations = list(itertools.product(*param_grid.values()))

all_fold_train_losses = []
all_fold_val_losses = []
all_fold_val_aucs = []
all_fold_test_aucs = []
all_fold_best_val_aucs = []
all_fold_best_params = []

labels = []
predictions = []


for fold in range(10):
    print(f"\n{'='*60}")
    print(f"Starting fold {fold+1}/10")
    print(f"{'='*60}")

    best_fold_auc = 0.0
    best_fold_params = None
    best_model_weights_for_fold = None  

  
    for (lr, bs, opt_name) in param_combinations:
        print(f"\n  Searching - (LR={lr}, BatchSize={bs}, Optim={opt_name})")

        train_loader = create_data_loader(
            all_folds_train_xvectors[fold],
            all_folds_yg_train[fold],
            all_folds_yh_train[fold],
            all_folds_filenames_train[fold],
            batch_size=bs, shuffle=True
        )
        val_loader = create_data_loader(
            all_folds_val_xvectors[fold],
            all_folds_yg_val[fold],
            all_folds_yh_val[fold],
            all_folds_filenames_val[fold],
            batch_size=bs, shuffle=False
        )

      
        model = GenderClassifier(input_dim, output_dim).to(device)  
        criterion = nn.BCELoss()
        if opt_name == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=lr)
        else:
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

        
        best_val_auc_for_combo = 0.0
        best_model_weights_for_combo = None

       
        for epoch in range(num_epochs):
            model.train()
            for x_vectors, gender_labels, health_labels, _ in train_loader:
                x_vectors = x_vectors.to(device)
                gender_labels = gender_labels.unsqueeze(1).float().to(device)

                optimizer.zero_grad()
                outputs = model(x_vectors)
                loss = criterion(outputs, gender_labels)
                loss.backward()
                optimizer.step()

            model.eval()
            val_preds = []
            val_true = []
            with torch.no_grad():
                for x_vectors, gender_labels, health_labels, _ in val_loader:
                    x_vectors = x_vectors.to(device)
                    gender_labels = gender_labels.unsqueeze(1).float().to(device)
                    outputs = model(x_vectors)
                    val_preds.extend(outputs.cpu().numpy())
                    val_true.extend(gender_labels.cpu().numpy())

            val_auc = roc_auc_score(val_true, val_preds)

           
            if val_auc > best_val_auc_for_combo:
                best_val_auc_for_combo = val_auc
                best_model_weights_for_combo = copy.deepcopy(model.state_dict())

        if best_val_auc_for_combo > best_fold_auc:
            best_fold_auc = best_val_auc_for_combo
            best_fold_params = (lr, bs, opt_name)
            best_model_weights_for_fold = best_model_weights_for_combo

    print(f"\nBest hyperparameters for fold {fold+1}:")
    print(f"   LR={best_fold_params[0]}, BatchSize={best_fold_params[1]}, Optim={best_fold_params[2]}")
    print(f"   Best Val AUC from grid search = {best_fold_auc:.4f}")

    all_fold_best_params.append(best_fold_params)

   
    final_bs = best_fold_params[1]
    final_train_loader = create_data_loader(
        all_folds_train_xvectors[fold],
        all_folds_yg_train[fold],
        all_folds_yh_train[fold],
        all_folds_filenames_train[fold],
        batch_size=final_bs, shuffle=True
    )
    final_val_loader = create_data_loader(
        all_folds_val_xvectors[fold],
        all_folds_yg_val[fold],
        all_folds_yh_val[fold],
        all_folds_filenames_val[fold],
        batch_size=final_bs, shuffle=False
    )
    test_loader = create_data_loader(
        all_folds_test_xvectors[fold],
        all_folds_yg_test[fold],
        all_folds_yh_test[fold],
        all_folds_filenames_test[fold],
        batch_size=final_bs, shuffle=False
    )

    
    model = GenderClassifier(input_dim, output_dim).to(device)

    .
    if best_fold_params[2] == 'adam':
        optimizer_h = optim.Adam(model.parameters(), lr=best_fold_params[0]) 
    else:
        optimizer_h = optim.SGD(model.parameters(), lr=best_fold_params[0], momentum=0.9)

    criterion_h = nn.BCELoss()

    
    fold_train_losses = []
    fold_val_losses = []
    fold_val_aucs = []

    best_val_auc_this_fold = 0.0
    best_model_weights_this_fold = None
    epochs_without_improvement = 0  

    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for x_vectors, gender_labels, health_labels, _ in final_train_loader:
            x_vectors = x_vectors.to(device)
            gender_labels = gender_labels.unsqueeze(1).float().to(device)

            optimizer_h.zero_grad()
            outputs = model(x_vectors)
            loss = criterion_h(outputs, gender_labels)
            loss.backward()
            optimizer_h.step()

            running_loss += loss.item()

     
        train_loss = running_loss / len(final_train_loader)
        fold_train_losses.append(train_loss)

       
        model.eval()
        val_loss_accum = 0.0
        val_preds = []
        val_true = []

        with torch.no_grad():
            for x_vectors, gender_labels, health_labels, _ in final_val_loader:
                x_vectors = x_vectors.to(device)
                gender_labels = gender_labels.unsqueeze(1).float().to(device)

                outputs = model(x_vectors)
                val_loss = criterion_h(outputs, gender_labels)
                val_loss_accum += val_loss.item()

                val_preds.extend(outputs.cpu().numpy())
                val_true.extend(gender_labels.cpu().numpy())

        val_loss_avg = val_loss_accum / len(final_val_loader)
        val_auc = roc_auc_score(val_true, val_preds)

        fold_val_losses.append(val_loss_avg)
        fold_val_aucs.append(val_auc)

        
        if val_auc > best_val_auc_this_fold:
            best_val_auc_this_fold = val_auc
            best_model_weights_this_fold = copy.deepcopy(model.state_dict())
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [Early Stopping] No improvement for {patience} epochs. Stopping at epoch {epoch+1}.")
            break

        print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss_avg:.4f}, Val AUC: {val_auc:.4f}")

    
    final_model_path = f"model_gender_fold_{fold+1}_finalNVZ.pth"
    torch.save(model.state_dict(), final_model_path)
    print(f"Saved FINAL model for Fold {fold+1} as '{final_model_path}'")

    
    if best_model_weights_this_fold is not None:
        best_model_path = f"model_gender_fold_{fold+1}_bestValNVZ.pth"
        torch.save(best_model_weights_this_fold, best_model_path)
        print(f"Saved BEST-VAL model for Fold {fold+1} (AUC={best_val_auc_this_fold:.4f}) as '{best_model_path}'")
    else:
        print("No improvement was tracked, best_model_weights_this_fold is None (check logic).")

    
    if best_model_weights_this_fold is not None:
        model.load_state_dict(best_model_weights_this_fold)

    model.eval()
    test_true, test_preds = [], []

    with torch.no_grad():
        for x_vectors, gender_labels, health_labels, _ in test_loader:
            x_vectors = x_vectors.to(device)
            gender_labels = gender_labels.unsqueeze(1).float().to(device)
            outputs = model(x_vectors)

            test_preds.extend(outputs.cpu().numpy())
            test_true.extend(gender_labels.cpu().numpy())

    labels.extend(test_true)
    predictions.extend(test_preds)    

    test_auc = roc_auc_score(test_true, test_preds)
    all_fold_test_aucs.append(test_auc)
    print(f"\nFold {fold+1}, Test ROC AUC: {test_auc:.4f}")

   
    fpr, tpr, _ = roc_curve(test_true, test_preds)
    plt.figure()
    plt.plot(fpr, tpr, label=f'Fold {fold+1} (AUC={test_auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f' Test ROC - Fold {fold+1}')
    plt.legend()
    plt.show()


    all_fold_train_losses.append(fold_train_losses)
    all_fold_val_losses.append(fold_val_losses)
    all_fold_val_aucs.append(fold_val_aucs)
    all_fold_best_val_aucs.append(best_val_auc_this_fold)

avg_test_auc = np.mean(all_fold_test_aucs)
std_test_auc = np.std(all_fold_test_aucs, ddof=1)
print(f"\nAverage Test ROC AUC: {avg_test_auc:.4f} ± {std_test_auc:.4f}")

avg_best_val_auc = np.mean(all_fold_best_val_aucs)
std_best_val_auc = np.std(all_fold_best_val_aucs, ddof=1)
print(f"Average BEST Val AUC Across Folds: {avg_best_val_auc:.4f} ± {std_best_val_auc:.4f}")

print("\nBest Parameters per Fold:")
for fold_idx, params in enumerate(all_fold_best_params):
    print(f"  Fold {fold_idx+1}: LR={params[0]}, BS={params[1]}, Optim={params[2]}")


# plt.figure(figsize=(10,5))
# for fold_idx, (train_losses, val_losses) in enumerate(zip(all_fold_train_losses, all_fold_val_losses), 1):
#     # Each fold might have a different number of epochs if early stopping triggered
#     epochs_range = range(1, len(train_losses) + 1)
#     plt.plot(epochs_range, train_losses, label=f'Fold {fold_idx} - Train')
#     plt.plot(epochs_range, val_losses, label=f'Fold {fold_idx} - Val')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Train & Validation Loss Curves')
# plt.legend()
# plt.show()

# plt.figure(figsize=(10,5))
# for fold_idx, val_aucs in enumerate(all_fold_val_aucs, 1):
#     epochs_range = range(1, len(val_aucs) + 1)
#     plt.plot(epochs_range, val_aucs, label=f'Fold {fold_idx}')
# plt.xlabel('Epoch')
# plt.ylabel('Val AUC')
# plt.title('Validation AUC Curves per Fold')
# plt.legend()
# plt.show()


# health classifier

In [16]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# class HealthClassifier(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim):
#         super(HealthClassifier, self).__init__()
#         self.linear1 = nn.Linear(input_dim, hidden_dim)
#         self.relu = nn.ReLU()
#         self.linear2 = nn.Linear(hidden_dim, output_dim)
#         self.sigm = nn.Sigmoid()
    
#     def forward(self, x):
#         x = self.relu(self.linear1(x))
#         x = self.sigm(self.linear2(x))
#         return x

# class HealthClassifier(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim):
#         super(HealthClassifier, self).__init__()
#         self.linear1 = nn.Linear(input_dim, hidden_dim)
#         self.relu = nn.ReLU()
#         self.linear2 = nn.Linear(hidden_dim, output_dim)
    
#     def forward(self, x):
#         x = self.relu(self.linear1(x))
#         x = self.linear2(x)  # No sigmoid
#         return x


In [None]:

input_dim = 512
hidden_dim = 128
output_dim = 1 
learning_rate = 0.00001   
num_epochs = 20           
BATCH_SIZE = 4            
patience = 5              

param_grid = {
    'learning_rate': [0.001, 0.0001, 0.00001],
    'batch_size': [4],
    'optimizer': ['adam', 'sgd']
    # 'optimizer': ['adam']
}
param_combinations = list(itertools.product(*param_grid.values()))


class HealthClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(HealthClassifier, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, output_dim)
        self.sigm = nn.Sigmoid()
    
    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.sigm(self.linear2(x))
        return x


all_fold_train_losses = []
all_fold_val_losses = []
all_fold_val_aucs = []
all_fold_test_aucs = []
all_fold_best_val_aucs = []
all_fold_best_params = []


for fold in range(10):
    print(f"\n{'='*60}")
    print(f"Starting fold {fold+1}/10")
    print(f"{'='*60}")

    best_fold_auc = 0.0
    best_fold_params = None
    best_model_weights_for_fold = None  

  
    for (lr, bs, opt_name) in param_combinations:
        print(f"\n  Searching - (LR={lr}, BatchSize={bs}, Optim={opt_name})")

        train_loader = create_data_loader(
            all_folds_train_xvectors[fold],
            all_folds_yg_train[fold],
            all_folds_yh_train[fold],
            all_folds_filenames_train[fold],
            batch_size=bs, shuffle=True
        )
        val_loader = create_data_loader(
            all_folds_val_xvectors[fold],
            all_folds_yg_val[fold],
            all_folds_yh_val[fold],
            all_folds_filenames_val[fold],
            batch_size=bs, shuffle=False
        )

        
        model = HealthClassifier(input_dim, hidden_dim, output_dim).to(device)  
        criterion = nn.BCELoss()
        if opt_name == 'adam':
            optimizer = optim.Adam(model.parameters(), lr=lr)
        else:
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

        
        best_val_auc_for_combo = 0.0
        best_model_weights_for_combo = None

        -
        for epoch in range(num_epochs):
            model.train()
            for x_vectors, _, health_labels, _ in train_loader:
                x_vectors = x_vectors.to(device)
                health_labels = health_labels.unsqueeze(1).float().to(device)

                optimizer.zero_grad()
                outputs = model(x_vectors)
                loss = criterion(outputs, health_labels)
                loss.backward()
                optimizer.step()

            
            model.eval()
            val_preds = []
            val_true = []
            with torch.no_grad():
                for x_vectors, _, health_labels, _ in val_loader:
                    x_vectors = x_vectors.to(device)
                    health_labels = health_labels.unsqueeze(1).float().to(device)
                    outputs = model(x_vectors)
                    val_preds.extend(outputs.cpu().numpy())
                    val_true.extend(health_labels.cpu().numpy())

            val_auc = roc_auc_score(val_true, val_preds)

            
            if val_auc > best_val_auc_for_combo:
                best_val_auc_for_combo = val_auc
                best_model_weights_for_combo = copy.deepcopy(model.state_dict())

     
        if best_val_auc_for_combo > best_fold_auc:
            best_fold_auc = best_val_auc_for_combo
            best_fold_params = (lr, bs, opt_name)
            best_model_weights_for_fold = best_model_weights_for_combo

 
    print(f"\nBest hyperparameters for fold {fold+1}:")
    print(f"   LR={best_fold_params[0]}, BatchSize={best_fold_params[1]}, Optim={best_fold_params[2]}")
    print(f"   Best Val AUC from grid search = {best_fold_auc:.4f}")

    all_fold_best_params.append(best_fold_params)


    final_bs = best_fold_params[1] 
    final_train_loader = create_data_loader(
        all_folds_train_xvectors[fold],
        all_folds_yg_train[fold],
        all_folds_yh_train[fold],
        all_folds_filenames_train[fold],
        batch_size=final_bs, shuffle=True
    )
    final_val_loader = create_data_loader(
        all_folds_val_xvectors[fold],
        all_folds_yg_val[fold],
        all_folds_yh_val[fold],
        all_folds_filenames_val[fold],
        batch_size=final_bs, shuffle=False
    )
    test_loader = create_data_loader(
        all_folds_test_xvectors[fold],
        all_folds_yg_test[fold],
        all_folds_yh_test[fold],
        all_folds_filenames_test[fold],
        batch_size=final_bs, shuffle=False
    )

    
    model_health = HealthClassifier(input_dim, hidden_dim, output_dim).to(device)

    
    if best_fold_params[2] == 'adam':
        optimizer_h = optim.Adam(model_health.parameters(), lr=best_fold_params[0]) 
    else:
        optimizer_h = optim.SGD(model_health.parameters(), lr=best_fold_params[0], momentum=0.9)

    criterion_h = nn.BCELoss()

   
    fold_train_losses = []
    fold_val_losses = []
    fold_val_aucs = []

    best_val_auc_this_fold = 0.0
    best_model_weights_this_fold = None
    epochs_without_improvement = 0  

    
    for epoch in range(num_epochs):
        model_health.train()
        running_loss = 0.0

        for x_vectors, _, health_labels, _ in final_train_loader:
            x_vectors = x_vectors.to(device)
            health_labels = health_labels.unsqueeze(1).float().to(device)

            optimizer_h.zero_grad()
            outputs = model_health(x_vectors)
            loss = criterion_h(outputs, health_labels)
            loss.backward()
            optimizer_h.step()

            running_loss += loss.item()

        
        train_loss = running_loss / len(final_train_loader)
        fold_train_losses.append(train_loss)

      
        model_health.eval()
        val_loss_accum = 0.0
        val_preds = []
        val_true = []

        with torch.no_grad():
            for x_vectors, _, health_labels, _ in final_val_loader:
                x_vectors = x_vectors.to(device)
                health_labels = health_labels.unsqueeze(1).float().to(device)

                outputs = model_health(x_vectors)
                val_loss = criterion_h(outputs, health_labels)
                val_loss_accum += val_loss.item()

                val_preds.extend(outputs.cpu().numpy())
                val_true.extend(health_labels.cpu().numpy())

        val_loss_avg = val_loss_accum / len(final_val_loader)
        val_auc = roc_auc_score(val_true, val_preds)

        fold_val_losses.append(val_loss_avg)
        fold_val_aucs.append(val_auc)

    
        if val_auc > best_val_auc_this_fold:
            best_val_auc_this_fold = val_auc
            best_model_weights_this_fold = copy.deepcopy(model_health.state_dict())
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"  [Early Stopping] No improvement for {patience} epochs. Stopping at epoch {epoch+1}.")
            break

        print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss_avg:.4f}, Val AUC: {val_auc:.4f}")

    final_model_path = f"model_health_fold_{fold+1}_finalNVZ.pth"
    torch.save(model_health.state_dict(), final_model_path)
    print(f"Saved FINAL model for Fold {fold+1} as '{final_model_path}'")

  
    if best_model_weights_this_fold is not None:
        best_model_path = f"model_health_fold_{fold+1}_bestValNVZ.pth"
        torch.save(best_model_weights_this_fold, best_model_path)
        print(f"Saved BEST-VAL model for Fold {fold+1} (AUC={best_val_auc_this_fold:.4f}) as '{best_model_path}'")
    else:
        print("No improvement was tracked, best_model_weights_this_fold is None (check logic).")


    if best_model_weights_this_fold is not None:
        model_health.load_state_dict(best_model_weights_this_fold)

    model_health.eval()
    test_true, test_preds = [], []

    with torch.no_grad():
        for x_vectors, _, health_labels, _ in test_loader:
            x_vectors = x_vectors.to(device)
            health_labels = health_labels.unsqueeze(1).float().to(device)
            outputs = model_health(x_vectors)

            test_preds.extend(outputs.cpu().numpy())
            test_true.extend(health_labels.cpu().numpy())

    test_auc = roc_auc_score(test_true, test_preds)
    all_fold_test_aucs.append(test_auc)
    print(f"\nFold {fold+1}, Test ROC AUC: {test_auc:.4f}")


    fpr, tpr, _ = roc_curve(test_true, test_preds)
    plt.figure()
    plt.plot(fpr, tpr, label=f'Fold {fold+1} (AUC={test_auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Health Test ROC - Fold {fold+1}')
    plt.legend()
    plt.show()


    all_fold_train_losses.append(fold_train_losses)
    all_fold_val_losses.append(fold_val_losses)
    all_fold_val_aucs.append(fold_val_aucs)
    all_fold_best_val_aucs.append(best_val_auc_this_fold)


avg_test_auc = np.mean(all_fold_test_aucs)
std_test_auc = np.std(all_fold_test_aucs, ddof=1)
print(f"\nAverage Test ROC AUC: {avg_test_auc:.4f} ± {std_test_auc:.4f}")

avg_best_val_auc = np.mean(all_fold_best_val_aucs)
std_best_val_auc = np.std(all_fold_best_val_aucs, ddof=1)
print(f"Average BEST Val AUC Across Folds: {avg_best_val_auc:.4f} ± {std_best_val_auc:.4f}")

print("\nBest Parameters per Fold:")
for fold_idx, params in enumerate(all_fold_best_params):
    print(f"  Fold {fold_idx+1}: LR={params[0]}, BS={params[1]}, Optim={params[2]}")


# plt.figure(figsize=(10,5))
# for fold_idx, (train_losses, val_losses) in enumerate(zip(all_fold_train_losses, all_fold_val_losses), 1):
#     # Each fold might have a different number of epochs if early stopping triggered
#     epochs_range = range(1, len(train_losses) + 1)
#     plt.plot(epochs_range, train_losses, label=f'Fold {fold_idx} - Train')
#     plt.plot(epochs_range, val_losses, label=f'Fold {fold_idx} - Val')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Train & Validation Loss Curves')
# plt.legend()
# plt.show()

# plt.figure(figsize=(10,5))
# for fold_idx, val_aucs in enumerate(all_fold_val_aucs, 1):
#     epochs_range = range(1, len(val_aucs) + 1)
#     plt.plot(epochs_range, val_aucs, label=f'Fold {fold_idx}')
# plt.xlabel('Epoch')
# plt.ylabel('Val AUC')
# plt.title('Validation AUC Curves per Fold')
# plt.legend()
# plt.show()


# Gradient Reversal Layer

In [10]:
from torch.autograd import Function

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class GradReverse(Function):
    @staticmethod
    def forward(ctx, x):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()

def grad_reverse(x):
    return GradReverse.apply(x)


class Discrim(nn.Module):                                                 
    def __init__(self,input_dim,hidden_dim):
        super(Discrim, self).__init__()
        self.input_dim  = input_dim
        self.linear1 = torch.nn.Linear(input_dim,1)
        # self.linear2 = torch.nn.Linear(hidden_dim,1)

    def forward(self, x):
        att_pred = self.linear1(x)
        # att_pred = F.relu(att_pred)
        # att_pred = self.linear2(att_pred)
        att_pred = torch.sigmoid(att_pred)
        return att_pred




class Autoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Autoencoder, self).__init__()
        self.input_dim = input_dim
        self.linear1 = torch.nn.Linear(input_dim, latent_dim)
        self.bn1 = torch.nn.BatchNorm1d(num_features=latent_dim)
        self.linear2 = torch.nn.Linear(latent_dim, input_dim)
        self.discriminator = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, 1),
            torch.nn.Sigmoid()
        )

    def encode(self, x):
        x = x / torch.norm(x, dim=1, keepdim=True)  
        z = F.relu(self.linear1(x))
        z = self.bn1(z)
        return z

    def decode(self, z):
        x = torch.tanh(self.linear2(z))
        x = x / torch.norm(x, dim=1, keepdim=True)  
        return x

    def forward(self, x):
        z = self.encode(x)
        rev_z = grad_reverse(z)
        att_pred = self.discriminator(rev_z)  
        outputs = self.decode(z)
        return outputs, z, att_pred


def recons_loss_function(out_x, x):
    recons_loss = torch.mean(1-F.cosine_similarity(out_x, x.view(-1,INPUT_SIZE),dim=1))
    return recons_loss


def discrim_loss_function(pred, lbl):
    bce_loss        = torch.nn.BCELoss()
    discrim_loss    = bce_loss(pred, lbl)
    return discrim_loss


# diagnosis_loss_function = nn.BCEWithLogitsLoss()


INPUT_SIZE      = 512   
input_dim  = 512
latent_dim = 128
input_dim_discrim  = latent_dim
hidden_dim_discrim = 128



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# model_ae  = Autoencoder(input_dim, latent_dim)
# optimizer_ae  = torch.optim.SGD(model_ae.parameters(), lr = 0.0001, momentum=0.9)
# model_ae.to(device)

# model_discrim   = Discrim(input_dim_discrim, hidden_dim_discrim)
# optimizer_discrim   = torch.optim.SGD(model_discrim.parameters(), lr = 0.0001, momentum=0.9)
# model_discrim.to(device)



# output_dim_health = 1  

# modelD = HealthClass(latent_dim, output_dim_health).to(device)
# optimizer_modelD = optim.Adam(modelD.parameters(), lr=learning_rate)


# Adversarial Disentanglement

In [None]:
all_folds_roc_curves = []
all_fold_auc_scores = []



all_fold_accuracies = []
all_fold_f1_scores = []
all_fold_recalls = []
all_fold_precisions = []


all_fold_val_auc_scores = []
all_fold_val_accuracies = []
all_fold_val_f1_scores = [] 
all_fold_val_recalls = []
all_fold_val_precisions = []


####for health classifier validation

all_fold_val_health_auc_scores = []
all_fold_val_health_accuracies = []
all_fold_val_health_f1_scores = []
all_fold_val_health_recalls = []
all_fold_val_health_precisions = []


####for health classifier test
all_fold_health_auc_scores = []
all_folds_health_roc_curves = []
all_fold_health_accuracies = []
all_fold_health_f1_scores = []
all_fold_health_recalls = []
all_fold_health_precisions = []


EPOCHS = 20
param_grid = {
    'learning_rate': [0.0001],
    'batch_size': [4],
    'optimizer': ['adam', 'sgd'],
    'lmda': [0.1, 0.9, 1, 2]
}
param_combinations = list(itertools.product(*param_grid.values()))



for fold in range(10):
    print(f"\nStarting training for fold {fold + 1}/10")

    best_val_auc = 0.0
    best_params = None
    best_model_ae_state = None

    for (lr, bs, opt_name, lmda) in param_combinations:
        print(f"\n  Searching - (LR={lr}, BatchSize={bs}, Optim={opt_name}, Lambda={lmda}))")
    
        
        
        train_loader = create_data_loader(
            all_folds_train_xvectors[fold], all_folds_yg_train[fold], all_folds_yh_train[fold],
            all_folds_filenames_train[fold], bs, shuffle=True
        )
    
    
        val_loader = create_data_loader(
            all_folds_val_xvectors[fold], all_folds_yg_val[fold], all_folds_yh_val[fold],
            all_folds_filenames_val[fold],
            bs, shuffle=False
        )    
    
    
        test_loader = create_data_loader(
            all_folds_test_xvectors[fold], all_folds_yg_test[fold], all_folds_yh_test[fold],
            all_folds_filenames_test[fold], bs, shuffle=False
        )

        
        model_ae  = Autoencoder(input_dim, latent_dim)
        # optimizer_ae  = torch.optim.SGD(model_ae.parameters(), lr = 0.001, momentum=0.9)
        # optimizer_ae  = optim.Adam(model_ae.parameters(), lr=0.001)
        
        if opt_name == 'adam':
            optimizer_ae = optim.Adam(model_ae.parameters(), lr=lr)
        else:
            optimizer_ae = optim.SGD(model_ae.parameters(), lr=lr, momentum=0.9)
        
        model_ae.to(device)
        
        # model_discrim   = Discrim(input_dim_discrim, hidden_dim_discrim)
        # optimizer_discrim   = torch.optim.SGD(model_discrim.parameters(), lr = 0.0001, momentum=0.9)
        # model_discrim.to(device)
        
        # Training loop for the current fold
        for epoch in range(EPOCHS):
            model_ae.train()
            running_loss = 0.0
            running_recons_loss = 0.0
            running_ad_loss = 0.0
            running_discrim_loss = 0.0
            # running_diagnosis_loss = 0.0
            num_batches = 0
    
            print(f"_____FOLD {fold + 1} - EPOCH: {epoch + 1}/{EPOCHS}_____")
            for x_vectors, gender_labels, health_labels, filenames in train_loader:
                local_batch = x_vectors.to(device)
                local_labels = gender_labels.unsqueeze(1).float().to(device)
                
    
                optimizer_ae.zero_grad()
                # optimizer_discrim.zero_grad()
    
                
                outputs, z, att_pred = model_ae(local_batch)
    
              
                loss_discrim = discrim_loss_function(att_pred, local_labels)
                # loss_discrim.backward(retain_graph=True)
                # optimizer_discrim.step()
    
    
                # outputs, z, att_pred = model_ae(local_batch)
    
                # Adversarial loss
                # ad_loss = discrim_loss_function(att_pred, local_labels)
    
                
                recons_loss = recons_loss_function(outputs, local_batch)
    
                # Total loss
                loss = recons_loss + loss_discrim*lmda
                loss.backward()
                # optimizer_discrim.step()
                optimizer_ae.step()
    
                # Accumulate losses
                running_loss += loss.item()
                running_recons_loss += recons_loss.item()
                # running_ad_loss += ad_loss.item()
                running_discrim_loss += loss_discrim.item()
                # running_diagnosis_loss += diagnosis_loss.item()
                num_batches += 1
    
            h
            print("..............")
            print(f"Training Loss: {running_loss / num_batches:.4f}")
            print(f"Reconstruction Loss: {running_recons_loss / num_batches:.4f}")
            # print(f"Adversarial Loss: {running_ad_loss / num_batches:.4f}")
            print(f"Discriminator Loss: {running_discrim_loss / num_batches:.4f}")
            # print(f"Diagnosis Loss: {running_diagnosis_loss / num_batches:.4f}")
    
    
           
        
        model_ae.eval()
        val_running_loss = 0.0
        val_running_recons_loss = 0.0
        val_running_discrim_loss = 0.0
        val_num_batches = 0
    
        val_true_labels = []
        val_probs = []
    
        print(f"_____(Validation)_____")
        with torch.no_grad():
            for x_vectors, gender_labels, health_labels, filenames in val_loader:
                local_batch = x_vectors.to(device)
                local_labels = gender_labels.unsqueeze(1).float().to(device)
                    
                outputs, z, att_pred = model_ae(local_batch)
                    
                val_loss_discrim = discrim_loss_function(att_pred, local_labels)
                val_recons_loss = recons_loss_function(outputs, local_batch)
                val_loss_total = val_recons_loss + val_loss_discrim
          
                val_running_loss += val_loss_total.item()
                val_running_recons_loss += val_recons_loss.item()
                val_running_discrim_loss += val_loss_discrim.item()
                val_num_batches += 1
                    
                val_probs.extend(att_pred.cpu().numpy().flatten())
                val_true_labels.extend(local_labels.cpu().numpy().flatten())
    
        if val_num_batches > 0:
            avg_val_loss = val_running_loss / val_num_batches
            avg_val_recons = val_running_recons_loss / val_num_batches
            avg_val_discrim = val_running_discrim_loss / val_num_batches
        else:
            avg_val_loss = 0.0
            avg_val_recons = 0.0
            avg_val_discrim = 0.0
    
          
    
        threshold = 0.5
        val_pred_labels = [1 if prob >= threshold else 0 for prob in val_probs]
        
        # val_accuracy = accuracy_score(val_true_labels, val_pred_labels)
        # val_f1 = f1_score(val_true_labels, val_pred_labels)
        # val_recall = recall_score(val_true_labels, val_pred_labels)
        # val_precision = precision_score(val_true_labels, val_pred_labels)
        val_auc = roc_auc_score(val_true_labels, val_probs)
    
        # all_fold_val_accuracies.append(val_accuracy)
        # all_fold_val_f1_scores.append(val_f1)
        # all_fold_val_recalls.append(val_recall)
        # all_fold_val_precisions.append(val_precision)
        all_fold_val_auc_scores.append(val_auc)
        
        # print(f"Validation - Fold {fold+1}, Epoch {epoch+1}/{EPOCHS}")
        # print(f"  Loss: {avg_val_loss:.4f} | Recons: {avg_val_recons:.4f} | Discrim: {avg_val_discrim:.4f}")
        # print(f"  Accuracy: {val_accuracy:.4f}, F1: {val_f1:.4f}, Recall: {val_recall:.4f}, Precision: {val_precision:.4f}, AUC: {val_auc:.4f}")
    
    
    
    
    # torch.save(model_ae.state_dict(), f"model_ae_fold_{fold + 1}NVZ.pth")

        
        model_health = HealthClassifier(input_dim, hidden_dim, output_dim).to(device)
        model_health.load_state_dict(torch.load(f"model_health_fold_{fold+1}_bestValNVZ.pth"))
        model_health.eval()
        
        model_ae.eval()
        val_health_probs = []
        val_health_labels = []
            
        with torch.no_grad():
            for x_vectors, _, health_labels, _ in val_loader:
                local_batch = x_vectors.to(device)
                local_health_labels = health_labels.view(-1, 1).float().to(device)
            
                outputs, _, _ = model_ae(local_batch)
                val_health_pred = model_health(outputs)
            
                val_health_probs.extend(val_health_pred.cpu().numpy().flatten())
                val_health_labels.extend(local_health_labels.cpu().numpy().flatten())
            
        val_health_probs = np.array(val_health_probs)
        val_health_labels = np.array(val_health_labels)
        threshold = 0.5
        val_health_pred_labels = (val_health_probs >= threshold).astype(int)
            
        # val_health_accuracy = accuracy_score(val_health_labels, val_health_pred_labels)
        # val_health_f1       = f1_score(val_health_labels, val_health_pred_labels)
        # val_health_recall   = recall_score(val_health_labels, val_health_pred_labels)
        # val_health_precision= precision_score(val_health_labels, val_health_pred_labels)
        val_health_auc      = roc_auc_score(val_health_labels, val_health_probs)
            
        # all_fold_val_health_accuracies.append(val_health_accuracy)
        # all_fold_val_health_f1_scores.append(val_health_f1)
        # all_fold_val_health_recalls.append(val_health_recall)
        # all_fold_val_health_precisions.append(val_health_precision)
        all_fold_val_health_auc_scores.append(val_health_auc)
            
        # print(f"Fold {fold + 1} - Health (VAL) Metrics:")
        # print(f"  Accuracy_health: {val_health_accuracy:.4f}")
        # print(f"  F1_health:       {val_health_f1:.4f}")
        # print(f"  Recall_health:   {val_health_recall:.4f}")
        # print(f"  Precision_health:{val_health_precision:.4f}")
        # print(f"  AUC_health:      {val_health_auc:.4f}")
        if 0.45 <= val_auc <= 0.55:
            if val_health_auc > best_val_auc:
                best_val_auc = val_health_auc
                best_params = (lr, bs, opt_name, lmda)
                best_model_ae_state = copy.deepcopy(model_ae.state_dict())
        

    if best_params is not None:
        print(f"\nBest hyperparameters found: LR={best_params[0]}, BS={best_params[1]}, Optim={best_params[2]}, Lambda={best_params[3]}")
        print(f"Best validation AUC for health classifier: {best_val_auc:.4f}")
    else:
        print("\nNo optimal parameters found for this fold.")
        
    # print(f"\nBest hyperparameters found: LR={best_params[0]}, BS={best_params[1]}, Optim={best_params[2]}, Lambda={best_params[3]}")
    # print(f"Best validation AUC for health classifier: {best_val_auc:.4f}")
        
            
    torch.save(best_model_ae_state, f"best_ae_for_health_fold_{fold+1}.pth")
           
        # model_discrim.eval()  # Switch discriminator to evaluation mode

    print(f"\nEvaluating Gender Prediction for Fold {fold + 1}/5")
    
    model_gender = GenderClassifier(input_dim, output_dim).to(device)
    model_gender.load_state_dict(torch.load(f"model_gender_fold_{fold+1}_bestValNVZ.pth"))
    model_gender.eval()

    best_ae_for_fold = Autoencoder(input_dim, latent_dim).to(device)
    best_ae_for_fold.load_state_dict(torch.load(f"best_ae_for_health_fold_{fold+1}.pth"))
    best_ae_for_fold.eval()
        
    true_labels, att_probs = [], []
        
    with torch.no_grad(): 
        for x_vectors, gender_labels, _, _ in test_loader:
            local_batch, local_labels = x_vectors.to(device), gender_labels.view(-1, 1).float().to(device)
        
            outputs, z, att_pred = best_ae_for_fold(local_batch) 
            gender_pred = model_gender(outputs)
            att_probs.extend(gender_pred.cpu().numpy().flatten())
            true_labels.extend(local_labels.cpu().numpy().flatten())
    
    
    threshold = 0.5
    pred_labels = [1 if prob >= threshold else 0 for prob in att_probs]
    
        
    # accuracy = accuracy_score(true_labels, pred_labels)
    # f1 = f1_score(true_labels, pred_labels)
    # recall = recall_score(true_labels, pred_labels)
    # precision = precision_score(true_labels, pred_labels)
        
    # all_fold_accuracies.append(accuracy)
    # all_fold_f1_scores.append(f1)
    # all_fold_recalls.append(recall)
    # all_fold_precisions.append(precision)
    
    # print(f"Fold {fold + 1}:")
        
    # print(f"  Accuracy: {accuracy:.4f}")
    # print(f"  F1 Score: {f1:.4f}")
    # print(f"  Recall: {recall:.4f}")
    # print(f"  Precision: {precision:.4f}\n")
        
            
    auc_score = roc_auc_score(true_labels, att_probs)
    all_fold_auc_scores.append(auc_score)
    print(f"Fold {fold + 1} - AUC Score: {auc_score:.4f}")
        
    fpr, tpr, _ = roc_curve(true_labels, att_probs)
    all_folds_roc_curves.append((fpr, tpr))
    
        
    plt.figure(figsize=(4, 4))
    plt.plot(fpr, tpr, label=f'Fold {fold + 1} ROC curve (AUC = {auc_score:.4f})', color='darkorange', lw=2)
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--', lw=2)  # Diagonal line for random guessing
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve for Fold {fold + 1}')
    plt.legend(loc="lower right")
    plt.show()

    print(f"\nEvaluating Health Prediction for Fold {fold + 1}/10")
    
    model_health = HealthClassifier(input_dim, hidden_dim, output_dim).to(device)
    model_health.load_state_dict(torch.load(f"model_health_fold_{fold+1}_bestValNVZ.pth"))
    model_health.eval()

    best_ae_for_fold = Autoencoder(input_dim, latent_dim).to(device)
    best_ae_for_fold.load_state_dict(torch.load(f"best_ae_for_health_fold_{fold+1}.pth"))
    best_ae_for_fold.eval()

    all_health_probs = []
    all_health_labels = []    

    with torch.no_grad():
        for x_vectors, _, health_labels, _ in test_loader:
            local_batch, local_health_labels = x_vectors.to(device), health_labels.view(-1, 1).float().to(device)

            
            outputs, _, _ = best_ae_for_fold(local_batch)

            
            health_pred = model_health(outputs)

            
            all_health_probs.extend(health_pred.cpu().numpy().flatten())
            all_health_labels.extend(local_health_labels.cpu().numpy().flatten())

    
    all_health_probs = np.array(all_health_probs)
    all_health_labels = np.array(all_health_labels)

    threshold = 0.5
    pred_labels = (all_health_probs >= threshold).astype(int)

    # accuracy = accuracy_score(all_health_labels, pred_labels)
    # f1 = f1_score(all_health_labels, pred_labels)
    # recall = recall_score(all_health_labels, pred_labels)
    # precision = precision_score(all_health_labels, pred_labels)

    # all_fold_health_accuracies.append(accuracy)
    # all_fold_health_f1_scores.append(f1)
    # all_fold_health_recalls.append(recall)
    # all_fold_health_precisions.append(precision)

    # print(f"Fold {fold + 1} - Health Metrics:")
    
    # print(f"  Accuracy_health: {accuracy:.4f}")
    # print(f"  F1 Score_health: {f1:.4f}")
    # print(f"  Recall_health: {recall:.4f}")
    # print(f"  Precision_health: {precision:.4f}")

    auc_score = roc_auc_score(all_health_labels, all_health_probs)
    all_fold_health_auc_scores.append(auc_score)
    print(f"Fold {fold + 1} - Health AUC: {auc_score:.4f}")




mean_val_auc  = np.mean(all_fold_val_auc_scores)
std_val_auc   = np.std(all_fold_val_auc_scores)

mean_val_acc  = np.mean(all_fold_val_accuracies)
std_val_acc   = np.std(all_fold_val_accuracies)

mean_auc = np.mean(all_fold_auc_scores)
std_auc = np.std(all_fold_auc_scores)

print("\n AVERAGE VALIDATION METRICS ACROSS FOLDS (validation)")
print(f"Val AUC:        {mean_val_auc:.4f} ± {std_val_auc:.4f}")
print(f"Val Accuracy:   {mean_val_acc:.4f} ± {std_val_acc:.4f}")


print(f"\nMean AUC across folds(TEST): {mean_auc:.4f} ± {std_auc:.4f}")

# mean_accuracy = np.mean(all_fold_accuracies)
# std_accuracy = np.std(all_fold_accuracies)
# mean_f1 = np.mean(all_fold_f1_scores)
# std_f1 = np.std(all_fold_f1_scores)
# mean_recall = np.mean(all_fold_recalls)
# std_recall = np.std(all_fold_recalls)
# mean_precision = np.mean(all_fold_precisions)
# std_precision = np.std(all_fold_precisions)


# print(f"  Accuracy: {mean_accuracy:.4f} ± {std_accuracy:.4f}")
# print(f"  F1 Score: {mean_f1:.4f} ± {std_f1:.4f}")
# print(f"  Recall: {mean_recall:.4f} ± {std_recall:.4f}")
# print(f"  Precision: {mean_precision:.4f} ± {std_precision:.4f}")


mean_fpr = np.linspace(0, 1, 100)
mean_tpr = np.mean([np.interp(mean_fpr, fpr, tpr) for fpr, tpr in all_folds_roc_curves], axis=0)

plt.figure(figsize=(4, 4))
plt.plot(mean_fpr, mean_tpr, label=f'Mean ROC Curve (AUC = {mean_auc:.4f})', color='blue', lw=2)
plt.plot([0, 1], [0, 1], color='navy', linestyle='--', lw=2)  
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Mean ROC Curve Across All Folds')
plt.legend(loc="lower right")
plt.show() 


mean_health_auc = np.mean(all_fold_health_auc_scores)
std_health_auc = np.std(all_fold_health_auc_scores)

# mean_health_accuracy = np.mean(all_fold_health_accuracies)
# std_health_accuracy = np.std(all_fold_health_accuracies)
# mean_health_f1 = np.mean(all_fold_health_f1_scores)
# std_health_f1 = np.std(all_fold_health_f1_scores)
# mean_health_recall = np.mean(all_fold_health_recalls)
# std_health_recall = np.std(all_fold_health_recalls)
# mean_health_precision = np.mean(all_fold_health_precisions)
# std_health_precision = np.std(all_fold_health_precisions)


# print(f"  Accuracy_health: {mean_health_accuracy:.4f} ± {std_health_accuracy:.4f}")
# print(f"  F1 Score_health: {mean_health_f1:.4f} ± {std_health_f1:.4f}")
# print(f"  Recall_health: {mean_health_recall:.4f} ± {std_health_recall:.4f}")
# print(f"  Precision_health: {mean_health_precision:.4f} ± {std_health_precision:.4f}")

print(f"\nMean Health AUC across folds: {mean_health_auc:.4f} ± {std_health_auc:.4f}")