# BRCA1 supervised classification using ANN 

Date: July 22, 2025

In [None]:
# ANN Training
import sys
import glob
import gzip
import json
import math
import os
import argparse
import numpy as np
import pandas as pd
import random
import copy
import time
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, BatchNormalization
from tensorflow.keras.layers import Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras import layers, models

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve

from Bio.Seq import Seq
from Bio import SeqIO

In [None]:
# REGION = "BRCA1_DATA" # BRCA1_DATA, RovHer_BRCA1 or RovHer_LDLR, "both" (BRCA1 + LDLR RVs)
# LAYER="blocks.28.mlp.l3"
# COMBO="refvar" # delta, refvar
# y_label="clinvar" # clinvar (0, 0.25, 0.5, 0.75,1); class (LOF, FUNC/INT)

parser = argparse.ArgumentParser(description="Evo2 embeddings")
parser.add_argument("--REGION", type=str, required=True, help="BRCA1_DATA, RovHer_BRCA1 or RovHer_LDLR, both (BRCA1 + LDLR RVs)")
parser.add_argument("--LAYER", required=True,type=str, help="embedding layer")
parser.add_argument("--COMBO", required=True,type=str, help="delta, refvar")
parser.add_argument("--Y_LABEL", type=str, help="clinvar (0, 0.25, 0.5, 0.75,1); class (LOF, FUNC/INT)")
parser.add_argument("--SUBSET_METHOD", type=str, help="random, top, bottom, balanced, all")
parser.add_argument("--MODEL_SIZE", type=str, help="7B or 40B")
args = parser.parse_args()

MODEL_SIZE = args.MODEL_SIZE
SUBSET_METHOD = args.SUBSET_METHOD
REGION = args.REGION
LAYER = args.LAYER
COMBO = args.COMBO
y_label = args.Y_LABEL

### Set input paths 

In [None]:
# Input Directories
INPUT_DIR = Path("/mnt/nfs/rigenenfs/shared_resources/biobanks/UKBIOBANK/pangk/evo2/BRCA1_LDLR")
INPUT_DIR.mkdir(parents=True, exist_ok=True)

# Input data: 
delta_file = f"{INPUT_DIR}/{REGION}_{LAYER}_delta.csv"
delta_rev_file = f"{INPUT_DIR}/{REGION}_{LAYER}_delta_rev.csv"
ref_file = f"{INPUT_DIR}/{REGION}_{LAYER}_ref.csv"
var_file = f"{INPUT_DIR}/{REGION}_{LAYER}_var.csv"
ref_rev_file = f"{INPUT_DIR}/{REGION}_{LAYER}_ref_rev.csv"
var_rev_file = f"{INPUT_DIR}/{REGION}_{LAYER}_var_rev.csv"

# Embedding input files
if REGION == "BRCA1_DATA":
    file = "/mnt/nfs/rigenenfs/workspace/pangk/Softwares/evo2/data/BRCA1_DATA.xlsx" # training variants + labels
else:
    DIR="/mnt/nfs/rigenenfs/shared_resources/biobanks/UKBIOBANK/pangk"
    label_file1 = f"{DIR}/RARity_monogenic_benchmark/BRCAexchange/BRCA1_clinvar_cleaned.txt" 
    label_file2 = f"{DIR}/RARity_monogenic_benchmark/LOVD_LDLR/LDLR_clinvar_curated.txt" # British heart foundation-classified variants on LOVD

if REGION == "both":
    REGION = "RovHer_BRCA1" 
    ref_file1 = f"{INPUT_DIR}/{REGION}_{LAYER}_ref.csv"
    var_file1 = f"{INPUT_DIR}/{REGION}_{LAYER}_var.csv"
    ref_rev_file1 = f"{INPUT_DIR}/{REGION}_{LAYER}_ref_rev.csv"
    var_rev_file1 = f"{INPUT_DIR}/{REGION}_{LAYER}_var_rev.csv"
    REGION = "RovHer_LDLR" 
    ref_file2 = f"{INPUT_DIR}/{REGION}_{LAYER}_ref.csv"
    var_file2 = f"{INPUT_DIR}/{REGION}_{LAYER}_var.csv"
    ref_rev_file2 = f"{INPUT_DIR}/{REGION}_{LAYER}_ref_rev.csv"
    var_rev_file2 = f"{INPUT_DIR}/{REGION}_{LAYER}_var_rev.csv"


### Set output paths 

In [None]:
# Output Directories
OUTPUT_DIR = Path(f"/mnt/nfs/rigenenfs/shared_resources/biobanks/UKBIOBANK/pangk/evo2/NN/BRCA1_LDLR_{COMBO}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# OUTPUT FILES
plot1 = f"{OUTPUT_DIR}/{REGION}_{LAYER}_{y_label}_AUC_loss.png"
plot2 = f"{OUTPUT_DIR}/{REGION}_{LAYER}_{y_label}_ROC.png"

### Define functions

In [None]:
from keras.losses import binary_crossentropy

def sample_data(df, sample_frac=1.0, balanced=True, disable=True, random_state=42):
    """Sample dataframe, optionally with balanced classes.
    """
    if disable:
        return df
    if balanced: # Get the number of rows in the dataframe
        num_rows_minor_class = math.ceil(len(df[df["class"] == "LOF"]) * sample_frac)
        return (
            pd.concat(
                [
                    df[df["class"] == "LOF"].sample(n=num_rows_minor_class, random_state=random_state),
                    df[df["class"] == "FUNC/INT"].sample(n=num_rows_minor_class, random_state=random_state),
                ]
            )
            .sample(frac=1.0, random_state=random_state)
            .reset_index(drop=True)
        )
    else: # Calculate the number of rows to sample
        return df.sample(frac=sample_frac, random_state=random_state).reset_index(drop=True)

def subset_dataframe(df, seq):
    """
    Randomly subsets the dataframe to SEQ_LENGTH number of rows.
    Returns: pandas.DataFrame - A subset of the dataframe with SEQ_LENGTH rows.
    """
    print("Number of rows to extract:", seq) 
    if seq > len(df):
        raise ValueError(f"SEQ_LENGTH ({seq}) is greater than the number of rows in the DataFrame ({len(df)}).")
    subset_df = df.sample(n=seq, random_state=42)
    print("New subset:", subset_df.shape) 
    return subset_df

# Compute Binary Cross-Entropy using NumPy
def binary_cross_entropy_np(y_true, y_pred):
    """
    Calculates Binary Cross-Entropy loss for multiple samples using NumPy.
    y_true: NumPy array of actual labels (0s and 1s)
    y_pred: NumPy array of predicted probabilities (between 0 and 1)
    """
    epsilon = 1e-15  # Small value to prevent log(0)
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)  # Clip probabilities
    loss = -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
    return loss

def recode_clinvar(value):
    mapping = {
        "P": 1,
        "B": 0,
        "LB": 0.25,
        "LP": 0.75,
        "LP,P": 0.75,
        "B/LB": 0.25
    }
    return mapping.get(value, 0.5)


# Load training data

### a) Labels (either `clinvar` or `class`)

The *BRCA1* SNV dataset was obtained from [Findlay et al. (2018)](https://www.nature.com/articles/s41586-018-0461-z), which contains 3,893 SNVs. Among them, 631 SNVs have `clinvar` classification [0, 0.25, 0.5, 0.75, 1]. This dataset also contains functional `class` annotations (LOF or FUNC/INT).

The *BRCA1* RV dataset was obtained from the RovHer study using UKB WES data, which contains contains over 1,000 RVs, of which 717 RVs have `clinvar` classification [0, 0.25, 0.5, 0.75, 1]. 

In [None]:
recode_map = {
    "Pathogenic": 1,
    'Pathogenic/Likely pathogenic': 1,
    'Likely pathogenic': 0.75,
    "Uncertain significance": 0.5,
    "Likely benign": 0.25,
    "Benign": 0,
    "absent": "NA",
    'Conflicting interpretations of pathogenicity': "NA",
}

# 1. Variant data + ClinVar labels 
if REGION == "BRCA1_DATA":
    data = pd.read_excel(file, header=2)
    data = data[['chromosome', 'position (hg19)', 'reference', 'alt', 'function.score.mean', 'func.class', 'clinvar',]]
    data.rename(columns={
            'chromosome': 'chrom','position (hg19)': 'pos',
            'reference': 'ref','alt': 'alt',
            'function.score.mean': 'score','func.class': 'class', 'clinvar': 'clinvar',
        }, inplace=True)
    # Re-code values 
    data['class'] = data['class'].replace(['FUNC', 'INT'], 'FUNC/INT')
    data["class"] = data["class"].replace({0: "FUNC/INT", 1: "LOF"})
    # Create new column 
    data['PLINK_SNP_NAME'] = data.apply(
            lambda row: f"{row['chrom']}:{row['pos']}:{row['ref']}:{row['alt']}", axis=1
    )
    # Recode `clinvar` column
    unique_clinvar_values = data['clinvar'].unique()
    print("Unique values in clinvar column:", unique_clinvar_values)
    data['clinvar'] = data['clinvar'].replace(recode_map)
else:
    # BRCA1
    ACMG_col1 = pd.read_csv(label_file1, sep="\t", usecols=["PLINK_SNP_NAME", "ACMG_final"])
    ACMG_col1 = ACMG_col1.rename(columns={"ACMG_final": "clinvar"})
    # LDLR
    ACMG_col2 = pd.read_csv(label_file2, sep="\t", usecols=["PLINK_SNP_NAME", "clinvar_clnsig"])
    ACMG_col2 = ACMG_col2.rename(columns={"clinvar_clnsig": "clinvar"})
    # Combine 
    data = pd.concat([ACMG_col1, ACMG_col2], ignore_index=True)
    print(f"BRCA1 and LDLR merged: {data.shape}")
    # (883, 2)
    data = data[~data["clinvar"].isin(["", "NA", "CCP"])]
    data["clinvar"] = data["clinvar"].apply(recode_clinvar)
    print(data["clinvar"].value_counts(dropna=False))

# Remove rows with missing clinvar anno
data = data[data['clinvar'] != "NA"]
print("After removing NA in clinvar:", data.shape)


### b) Evo2 7B emeddings

* `delta`: delta + delta reverse complement embeddings concatenated (8192-dimensional)

* `refvar`: ref + ref reverse complement + var + var reverse complement concatenated (16384-dimensional)

In [None]:
if COMBO == "delta":
    # Variant + reverse complement embeddings 
    delta = pd.read_csv(delta_file)
    delta_reverse = pd.read_csv(delta_rev_file)

if COMBO == "refvar":
    if REGION == "BRCA1_DATA" or REGION == "RovHer_BRCA1" or REGION == "RovHer_LDLR":
        # 1. Variant + reverse complement 
        var = pd.read_csv(var_file)
        var_reverse = pd.read_csv(var_rev_file)
        # 2. Reference + reverse complement
        ref = pd.read_csv(ref_file)
        ref_reverse = pd.read_csv(ref_rev_file)

    if REGION == "both":
        var1 = pd.read_csv(var_file1)
        var_reverse1 = pd.read_csv(var_rev_file1)
        ref1 = pd.read_csv(ref_file1)
        ref_reverse1 = pd.read_csv(ref_rev_file1)
        var2 = pd.read_csv(var_file2)
        var_reverse2 = pd.read_csv(var_rev_file2)
        ref2 = pd.read_csv(ref_file2)
        ref_reverse2 = pd.read_csv(ref_rev_file2)

        var = pd.concat([var1, var2], ignore_index=True)
        var_reverse = pd.concat([var_reverse1, var_reverse2], ignore_index=True)
        ref = pd.concat([ref1, ref2], ignore_index=True)
        ref_reverse = pd.concat([ref_reverse1, ref_reverse2], ignore_index=True)

## Data pre-processing

### a) Subset rows from training data

In [None]:
# Check for duplicate rows based on the PLINK_SNP_NAME column
data = data[~data['PLINK_SNP_NAME'].duplicated(keep='first')]

if COMBO == "delta":
    # Step 1: Compute the strict intersection of PLINK_SNP_NAME across all dfs
    final_common_snp_names = list(
        set(data['PLINK_SNP_NAME'])
        .intersection(delta['PLINK_SNP_NAME'])
        .intersection(delta_reverse['PLINK_SNP_NAME'])
    )
    # Step 2: Filter all dfs simultaneously based on the common SNP names
    data = data[data['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)
    delta = delta[delta['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)
    delta_reverse = delta_reverse[delta_reverse['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)
    # Tallies
    print("Filtered labels file (data):", data.shape)
    print("delta:", delta.shape, "delta_reverse:", delta_reverse.shape)
    # Check if the number of rows match
    if not (delta.shape[0] == data.shape[0] and
            delta_reverse.shape[0] == data.shape[0]):
        raise ValueError("Number of rows in embeddings do not match number of rows in data.")

if COMBO == "refvar":
    # Step 1: Compute the strict intersection of PLINK_SNP_NAME across all dfs
    final_common_snp_names = list(
        set(data['PLINK_SNP_NAME'])
        .intersection(var['PLINK_SNP_NAME'])
        .intersection(var_reverse['PLINK_SNP_NAME'])
        .intersection(ref['PLINK_SNP_NAME'])
        .intersection(ref_reverse['PLINK_SNP_NAME'])
    )
    # Step 2: Filter all dfs based on the common SNP names
    data = data[data['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)
    var = var[var['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)
    var_reverse = var_reverse[var_reverse['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)
    ref = ref[ref['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)
    ref_reverse = ref_reverse[ref_reverse['PLINK_SNP_NAME'].isin(final_common_snp_names)].reset_index(drop=True)

    # Tallies
    print("Filtered labels file (data):", data.shape)
    print("var:", var.shape, "var_reverse:", var_reverse.shape, "ref:", ref.shape, "ref_reverse:", ref_reverse.shape)

    # Check if the number of rows match
    if not (var.shape[0] == data.shape[0] and
            var_reverse.shape[0] == data.shape[0] and
            ref.shape[0] == data.shape[0] and
            ref_reverse.shape[0] == data.shape[0]):
        raise ValueError("Number of rows in embeddings do not match number of rows in data.")

print(f"---------- Values in {y_label} column -------------\n")
print(data[y_label].value_counts())
numeric_rows = data[y_label].apply(lambda x: isinstance(x, (int, float))).sum()
print("Number of non-missing labels:", numeric_rows, "of", NROWS, "rows") #  631 of 3893 rows

### b) Remove columns

In [None]:
# Drop the 'input_file' and 'layer' columns

if COMBO == "delta":
    delta = delta.drop(columns=['PLINK_SNP_NAME','input_file', 'layer'])
    delta_reverse = delta_reverse.drop(columns=['PLINK_SNP_NAME','input_file', 'layer'])
    print(f"Variant embeddings: {delta.shape}")
    print(f"Variant reverse comp. embeddings: {delta_reverse.shape}")

if COMBO == "refvar":
    var = var.drop(columns=['PLINK_SNP_NAME','input_file', 'layer'])
    var_reverse = var_reverse.drop(columns=['PLINK_SNP_NAME','input_file', 'layer'])
    ref = ref.drop(columns=['PLINK_SNP_NAME','input_file', 'layer'])
    ref_reverse = ref_reverse.drop(columns=['PLINK_SNP_NAME','input_file', 'layer'])
    print(f"Variant embeddings: {var.shape}") # (631, 4096)
    print(f"Variant reverse comp. embeddings: {var_reverse.shape}") # (631, 4096)
    print(f"Reference embeddings: {ref.shape}") # (631, 4096)
    print(f"Reference reverse comp. embeddings: {ref_reverse.shape}") # (631, 4096)

## Build feature vector by concatenation

In [None]:
if COMBO == "delta":
    # feature vector for each SNV  (631, 8192)
    feature_vec = np.hstack([
        delta.values,         # delta embeddings
        delta_reverse.values, # Reverse complement
    ])
    
if COMBO == "refvar":
    # feature vector for each SNV (3893, 16384) | 16384 features per SNV
    feature_vec = np.hstack([
        ref.values,         # Reference embeddings
        ref_reverse.values, # Reverse complement of reference
        var.values,         # Variant embeddings
        var_reverse.values  # Reverse complement of variant
    ])

print(f"feature_vec embeddings: {feature_vec.shape}") #  (812, 8194)

# Extract y-label vectors
print(f"y labels to extract: {y_label}\n")
train_y = data[y_label].values
if "class" in data.columns:
    lof_count = data[data["class"] == "LOF"].shape[0]
    print(f"Test 'LOF': {lof_count}\n")
    other_count = data[data["class"] == "FUNC/INT"].shape[0]
    print(f"Test 'FUNC/INT': {other_count}\n")

### Split dataset
1. Training Set: 80% of the training data
2. Test Set: 20% of the data (withheld entirely from training)
3. Validation Set: 20% of the remaining training data

In [None]:
# Use all of X_train for training (more training data), w/ internal validation_split during training
internal_validation_split="no"

if internal_validation_split == "yes":
    # Split dataset into test (20%) and remaining training data (80%)
    X_train, X_test, y_train, y_test = train_test_split(
        feature_vec, train_y, test_size=0.2, random_state=42, stratify=train_y
    )
    print(f"Training set size: {X_train.shape}") 
else: 
    # MANUALLY split into test (20%) and remaining training data (80%)
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        feature_vec, train_y, test_size=0.2, random_state=42, stratify=train_y
    )
    # Split remaining training data into train (80%) and validation (20%)
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.2, random_state=42, stratify=y_train_val
    )
    X_val = X_val.astype('float32')
    y_val = y_val.astype('float32')
    X_train_val = X_train_val.astype('float32')
    y_train_val = y_train_val.astype('float32')
    print(f"Validation set size: {X_val.shape}")

X_train = X_train.astype('float32')
y_train = y_train.astype('float32')
X_test = X_test.astype('float32')
y_test = y_test.astype('float32')

print(f"Training set size: {X_train.shape}") # (403, 16384)
print(f"Test set size: {X_test.shape}") # (127, 16384)

### Train ANN
* Input Layer = 32,768 features.
* Hidden Layers: 512 → 128 → 32 neurons.
* Output Layer: Binary classification (pathogenic probability).
* Activation: ReLU for hidden layers, Sigmoid for the output layer.
* Batch Normalization and Dropout (𝑝=0.3) after each hidden layer.

In [None]:
# Create ANN model, with output layer for binary classification
input_dim = feature_vec.shape[1]
def build_model():
    model = Sequential()
    model.add(Dense(512, activation='relu',input_dim=input_dim))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    model.add(Dense(128, activation='relu'))
    model.add(BatchNormalization())
    model.add(Dropout(0.3))
    model.add(Dense(32, activation='relu'))
    model.add(BatchNormalization())
    model.add(Dense(1, activation='sigmoid'))
    model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['AUC'])
    return model

ANN_model = build_model()
ANN_model.summary()

In [None]:
start_time = time.time()

# reserves 15% of the training data (X_train and y_train) for validation during training
if internal_validation_split == "yes":
    history = ANN_model.fit(
        X_train, y_train, 
        epochs=100, 
        batch_size=64, 
        validation_split=0.15, 
    )
    print("Internal validation split enabled during training\n")
else:
    history = ANN_model.fit(
        X_train, y_train, 
        epochs=100, 
        batch_size=64, 
        verbose=2
    )
    print("No internal validation split occured.\n")

end_time = time.time()
exe_time = end_time - start_time
print("Training time: ", exe_time)

### Evaluate on test and validation sets
If the labels are `clinvar` (0, 0.25, 0.5, 0.75, 1):
* AUC evaluates P/LP versus B/LB, with VUS (labels with 0.5) removed

If labels are `class` (LOF as 0; FUNC/INT as 1):
* AUC evaluates LOF versus FUNC/INT

In [None]:
if y_label == "clinvar":
    print(f"\n------------- AUC (P/LP vs B/LB) ------------\n")
    
    # Remove VUS; which are rows where y_test == 0.5
    test_mask = y_test != 0.5
    X_test_filtered = X_test[test_mask]
    y_test_filtered = y_test[test_mask]
    val_mask = y_val != 0.5
    X_val_filtered = X_val[val_mask]
    y_val_filtered = y_val[val_mask]
    
    # Recode labels: 0.75 (LP) is recoded to 1, 0.25 (LB) is recoded to 0 
    y_test_filtered = np.where(y_test_filtered == 0.75, 1, y_test_filtered) 
    y_test_filtered = np.where(y_test_filtered == 0.25, 0, y_test_filtered)
    print(f"y_test: {y_test.shape}")
    print(f"y_test_filtered: {y_test_filtered.shape}\n")
    y_val_filtered = np.where(y_val_filtered == 0.75, 1, y_val_filtered)
    y_val_filtered = np.where(y_val_filtered == 0.25, 0, y_val_filtered)
    
    # Predict probabilities on the test and validation sets
    y_test_pred_prob = ANN_model.predict(X_test_filtered).ravel()
    y_val_pred_prob = ANN_model.predict(X_val_filtered).ravel()
    
    # Calculate AUROC for the test/validation set
    auc_test = roc_auc_score(y_test_filtered, y_test_pred_prob)
    if internal_validation_split == "no":  
        auc_val = roc_auc_score(y_val_filtered, y_val_pred_prob)
    
    # For plotting ROC curve
    fpr_test, tpr_test, thresholds_test = roc_curve(y_test_filtered, y_test_pred_prob)
    fpr_val, tpr_val, thresholds_val = roc_curve(y_val_filtered, y_val_pred_prob)

else:
    print(f"\n------------- AUC (LOF vs FUNC/INT) ------------\n")

    y_test_pred_prob = ANN_model.predict(X_test).ravel()
    y_val_pred_prob = ANN_model.predict(X_val).ravel()

    # Calculate AUROC for the test/validation set
    auc_test = roc_auc_score(y_test, y_test_pred_prob)
    if internal_validation_split == "no":  
        auc_val  = roc_auc_score(y_val, y_val_pred_prob)

    # For plotting ROC curve
    fpr_test, tpr_test, thresholds_test = roc_curve(y_test, y_test_pred_prob)
    fpr_val, tpr_val, thresholds_val = roc_curve(y_val, y_val_pred_prob)

## Plot training loss & AUC
* Loss vs Epochs and AUC vs Epochs (Training)
* ROC curve (Test and Validation set)

In [None]:
# Plot training and validation loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Plot training and validation AUC
plt.subplot(1, 2, 2)
plt.plot(history.history['auc'], label='Training AUC')
plt.plot(history.history['val_auc'], label='Validation AUC')
plt.title('Training AUC')
plt.xlabel('Epochs')
plt.ylabel('AUC')
plt.legend()
plt.tight_layout()
plt.show()
plt.savefig(f"{OUTPUT_DIR}/{REGION}_{LAYER}_train_AUC_loss.png")
print("Loss plot:", f"{OUTPUT_DIR}/{REGION}_{LAYER}_train_AUC_loss.png")

## Plot test/validation ROC curve 

In [None]:
plt.figure(figsize=(8, 6))

# Test set
plt.plot(fpr_test, tpr_test, label=f"Test Set AUC = {auc_test:.4f}", color="blue", linewidth=2)

# Validation set
plt.plot(fpr_val, tpr_val, label=f"Validation Set AUC = {auc_val:.4f}", color="green", linewidth=2)

# Random guess line
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess", color="gray", linewidth=1.5)

plt.xlabel("False Positive Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)
plt.title(f"ROC Curve for {y_label}", fontsize=14)
plt.legend(loc="lower right", fontsize=10)

# Grid and layout adjustments
plt.grid(alpha=0.3)
plt.tight_layout()

plt.savefig(f"{OUTPUT_DIR}/{REGION}_{LAYER}_ROC.png")
plt.show()
print(f"AUC (Test): {auc_test:.4f}  (Validation): {auc_val:.4f}")
print("Results in:", f"{OUTPUT_DIR}\n")
