In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting google-pasta>=0.1.1 (from tensorflow)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Downloading libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl.metadata (5.2 kB)
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 (from tensorflow)
  Downloading protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Collecting tensorboard~=2.19.0 (from tensorflow)
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorflow-io-gcs-filesystem>=0.23.1 (from tensorf

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn import metrics
import math
import os
import time
import pandas as pd
from sklearn.model_selection import StratifiedKFold


LOG_DIR = '/content/drive/MyDrive/Clathrin-msCNN/log/esm2'
os.makedirs(LOG_DIR, exist_ok=True)

In [None]:
# Load dữ liệu
X_train = np.load('/content/drive/MyDrive/Clathrin-msCNN/data/t33_650M_UR50D/X_train.npy')
y_train = np.load('/content/drive/MyDrive/Clathrin-msCNN/data/t33_650M_UR50D/y_train.npy')
X_test = np.load('/content/drive/MyDrive/Clathrin-msCNN/data/t33_650M_UR50D/X_test.npy')
y_test = np.load('/content/drive/MyDrive/Clathrin-msCNN/data/t33_650M_UR50D/y_test.npy')


print("Kích thước dữ liệu:")
print(f"X_cv: {X_train.shape}")
print(f"y_cv: {y_train.shape}")
print(f"X_ind: {X_test.shape}")
print(f"y_ind: {y_test.shape}")

In [None]:
X_train = X_train[:, np.newaxis, :, :]  # (samples, 1, height, width)
X_test = X_test[:, np.newaxis, :, :]

print("X_train shape:", X_train.shape)
print("X_test shape:", X_test.shape)

X_train shape: (2421, 1, 1022, 1280)
X_test shape: (485, 1, 1022, 1280)


#Cross Validation

In [None]:
import os
import time
import math

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.saving import register_keras_serializable

# === Hyperpararmeters ===
BATCH_SIZE      = 128
NUM_CLASSES     = 1
EPOCHS          = 50
NUM_FILTERS     = 256
NUM_HIDDEN      = 1024
WINDOW_SIZES    = [32]
MAX_SEQ_LENGTH  = 1022
EMBEDDING_WIDTH = 1280
LOG_MODEL = os.path.join(LOG_DIR, f'MODELS_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H')
os.makedirs(LOG_MODEL, exist_ok=True)

@register_keras_serializable()
def DeepScan(input_shape=(1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH),
             window_sizes=WINDOW_SIZES,
             num_filters=NUM_FILTERS,
             num_hidden=NUM_HIDDEN,
             num_classes=NUM_CLASSES):
    inputs = tf.keras.Input(shape=input_shape)

    branches = []
    for ws in window_sizes:
        x = layers.SeparableConv2D(
            filters=num_filters,
            kernel_size=(1, ws),
            strides=(1, 1),
            activation='relu',
            padding='valid',
            depthwise_regularizer=tf.keras.regularizers.l2(1e-4),
            pointwise_regularizer=tf.keras.regularizers.l2(1e-4),
            depthwise_initializer='glorot_uniform',
            pointwise_initializer='glorot_uniform'
        )(inputs)

        x = layers.MaxPooling2D(
            pool_size=(1, MAX_SEQ_LENGTH - ws + 1),
            strides=(1, 1),
            padding='valid'
        )(x)

        x = layers.Flatten()(x)
        branches.append(x)

    x = layers.Concatenate()(branches)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_hidden, activation='relu', name='fc1')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='DeepScan')
    return model

# === Callback on CV folds ===
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val, fold):
        super().__init__()
        self.X_val = X_val
        self.y_val = y_val
        self.fold  = fold
        self.fold_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)
        cm = metrics.confusion_matrix(self.y_val, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()
        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0
        epoch_time = time.time() - self.epoch_start
        results.loc[len(results)] = [
            'CV', self.fold, epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, None,
            None, epoch_time, self.model.count_params()
        ]

class SaveEveryEpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, base_dir, stage='CV', fold=None):
        super().__init__()
        self.base_dir = base_dir
        self.stage = stage
        self.fold = fold
        self.sub_dir = os.path.join(base_dir, stage, f'fold_{fold}' if fold is not None else '')
        os.makedirs(self.sub_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        filename = f"{self.stage}_fold{self.fold}_epoch{epoch+1:02d}.keras" if self.fold is not None else f"{self.stage}_epoch{epoch+1:02d}.keras"
        path = os.path.join(self.sub_dir, filename)
        self.model.save(path)
        print(f"[Saved model] Epoch {epoch+1} saved to {path}")

# === Callback on independent test ===
class FinalMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.epoch_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_test, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)

        cm = metrics.confusion_matrix(self.y_test, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()

        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0

        fpr, tpr, _ = roc_curve(self.y_test, y_pred_probs)
        roc_auc = auc(fpr, tpr)

        epoch_time = time.time() - self.epoch_start_time

        results.loc[len(results)] = [
            'Independent', 'Final', epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, roc_auc,
            None, epoch_time, self.model.count_params()
        ]

# === DataFrame Column ===
results_columns = [
    'Stage', 'Fold', 'Epoch', 'TP', 'FP', 'TN', 'FN',
    'Sens', 'Spec', 'Acc', 'MCC', 'F1', 'AUC',
    'Train_Time', 'Epoch_Time', 'Total_Params'
]
results = pd.DataFrame(columns=results_columns)

# === 5-Fold Cross-Validation + ROC per fold ===
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train, y_train), start=1):
    print(f"\n=== Fold {fold}/5 ===")
    X_train_fold, X_val = X_train[train_idx], X_train[val_idx]
    y_train_fold, y_val = y_train[train_idx], y_train[val_idx]

    model = DeepScan(window_sizes=WINDOW_SIZES,
                     num_filters=NUM_FILTERS,
                     num_hidden=NUM_HIDDEN)
    model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
    model.summary()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    metrics_cb = MetricsCallback(X_val, y_val, fold)
    tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, f'fold_{fold}'))
    early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
    save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='CV', fold=fold)
    checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_CV_{fold}.keras')
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')

    start_fold = time.time()
    model.fit(
        X_train_fold, y_train_fold,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=[metrics_cb, tb_cb, early_stop_cb, save_cb, checkpoint_cb],
        verbose=1,
    )
    train_time = time.time() - start_fold
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'Train_Time'] = train_time

    y_val_probs = model.predict(X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
    fpr, tpr, _ = roc_curve(y_val, y_val_probs)
    roc_auc = auc(fpr, tpr)
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'AUC'] = roc_auc

# Save CV results
results_path = os.path.join(
    LOG_DIR,
    f'training_results_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H.csv'
)
results.to_csv(results_path, index=False)
print(f"\nCV results saved to {results_path}")

# === Final training and evaluation ===
final_cb = FinalMetricsCallback(X_test, y_test)

final_model = DeepScan(window_sizes=WINDOW_SIZES,
                      num_filters=NUM_FILTERS,
                      num_hidden=NUM_HIDDEN)
final_model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
final_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='Independent')
final_checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_Independent.keras')
final_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(final_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')
final_early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
final_tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, 'final_model'))

history = final_model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[final_cb, final_tb_cb, final_checkpoint_cb, final_early_stop_cb, save_cb],
    verbose=1
)

# Save final results
results.to_csv(results_path, index=False)
print(f"\nFinal results saved to {results_path}")

best_auc_idx = results[results['Stage'] == 'Independent']['AUC'].idxmax()
best_auc_row = results.loc[best_auc_idx]
print(f"\nBest Independent Test AUC: {best_auc_row['AUC']:.4f} (Epoch {best_auc_row['Epoch']})")



=== Fold 1/5 ===


Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 395ms/step - accuracy: 0.5140 - auc: 0.5284 - loss: 0.7723[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_1/CV_fold1_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 726ms/step - accuracy: 0.5222 - auc: 0.5382 - loss: 0.7662 - val_accuracy: 0.6289 - val_auc: 0.9462 - val_loss: 0.6157
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 384ms/step - accuracy: 0.6919 - auc: 0.7736 - loss: 0.6219[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_1/CV_fold1_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 550ms/step - accuracy: 0.6922 - auc: 0.7737 - loss: 0.6213 - val_accuracy: 0.8742 - val_auc: 0.9569 - val_loss: 0.5068
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 380ms/step - accuracy: 0

Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 357ms/step - accuracy: 0.5499 - auc: 0.5607 - loss: 0.7489[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_2/CV_fold2_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 701ms/step - accuracy: 0.5521 - auc: 0.5641 - loss: 0.7469 - val_accuracy: 0.8802 - val_auc: 0.9394 - val_loss: 0.6141
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 379ms/step - accuracy: 0.6667 - auc: 0.7326 - loss: 0.6491[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_2/CV_fold2_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 548ms/step - accuracy: 0.6699 - auc: 0.7368 - loss: 0.6462 - val_accuracy: 0.8864 - val_auc: 0.9580 - val_loss: 0.5266
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 355ms/step - accuracy: 0

Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 361ms/step - accuracy: 0.5768 - auc: 0.5949 - loss: 0.7324[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_3/CV_fold3_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 704ms/step - accuracy: 0.5776 - auc: 0.5968 - loss: 0.7312 - val_accuracy: 0.8079 - val_auc: 0.9046 - val_loss: 0.6232
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 382ms/step - accuracy: 0.7047 - auc: 0.7702 - loss: 0.6237[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_3/CV_fold3_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 549ms/step - accuracy: 0.7052 - auc: 0.7715 - loss: 0.6222 - val_accuracy: 0.7293 - val_auc: 0.9308 - val_loss: 0.5708
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 377ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 387ms/step - accuracy: 0.5421 - auc: 0.5361 - loss: 0.7612[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_4/CV_fold4_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 707ms/step - accuracy: 0.5480 - auc: 0.5450 - loss: 0.7564 - val_accuracy: 0.8471 - val_auc: 0.9214 - val_loss: 0.6027
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 360ms/step - accuracy: 0.6912 - auc: 0.7737 - loss: 0.6238[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_4/CV_fold4_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 550ms/step - accuracy: 0.6926 - auc: 0.7752 - loss: 0.6226 - val_accuracy: 0.8719 - val_auc: 0.9348 - val_loss: 0.5079
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 381ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 384ms/step - accuracy: 0.5430 - auc: 0.5488 - loss: 0.7478[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_5/CV_fold5_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 702ms/step - accuracy: 0.5471 - auc: 0.5559 - loss: 0.7443 - val_accuracy: 0.6880 - val_auc: 0.9034 - val_loss: 0.6273
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 376ms/step - accuracy: 0.7039 - auc: 0.7735 - loss: 0.6202[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[32]_256F_1024H/CV/fold_5/CV_fold5_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 540ms/step - accuracy: 0.7066 - auc: 0.7767 - loss: 0.6177 - val_accuracy: 0.6715 - val_auc: 0.9262 - val_loss: 0.5798
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 371ms/step - accuracy: 0

In [None]:
import os
import time
import math

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.saving import register_keras_serializable

# === Hyperpararmeters ===
BATCH_SIZE      = 128
NUM_CLASSES     = 1
EPOCHS          = 50
NUM_FILTERS     = 256
NUM_HIDDEN      = 1024
WINDOW_SIZES    = [28]
MAX_SEQ_LENGTH  = 1022
EMBEDDING_WIDTH = 1280
LOG_MODEL = os.path.join(LOG_DIR, f'MODELS_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H')
os.makedirs(LOG_MODEL, exist_ok=True)

@register_keras_serializable()
def DeepScan(input_shape=(1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH),
             window_sizes=WINDOW_SIZES,
             num_filters=NUM_FILTERS,
             num_hidden=NUM_HIDDEN,
             num_classes=NUM_CLASSES):
    inputs = tf.keras.Input(shape=input_shape)

    branches = []
    for ws in window_sizes:
        x = layers.SeparableConv2D(
            filters=num_filters,
            kernel_size=(1, ws),
            strides=(1, 1),
            activation='relu',
            padding='valid',
            depthwise_regularizer=tf.keras.regularizers.l2(1e-4),
            pointwise_regularizer=tf.keras.regularizers.l2(1e-4),
            depthwise_initializer='glorot_uniform',
            pointwise_initializer='glorot_uniform'
        )(inputs)

        x = layers.MaxPooling2D(
            pool_size=(1, MAX_SEQ_LENGTH - ws + 1),
            strides=(1, 1),
            padding='valid'
        )(x)

        x = layers.Flatten()(x)
        branches.append(x)

    x = layers.Concatenate()(branches)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_hidden, activation='relu', name='fc1')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='DeepScan')
    return model

# === Callback on CV folds ===
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val, fold):
        super().__init__()
        self.X_val = X_val
        self.y_val = y_val
        self.fold  = fold
        self.fold_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)
        cm = metrics.confusion_matrix(self.y_val, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()
        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0
        epoch_time = time.time() - self.epoch_start
        results.loc[len(results)] = [
            'CV', self.fold, epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, None,
            None, epoch_time, self.model.count_params()
        ]

class SaveEveryEpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, base_dir, stage='CV', fold=None):
        super().__init__()
        self.base_dir = base_dir
        self.stage = stage
        self.fold = fold
        self.sub_dir = os.path.join(base_dir, stage, f'fold_{fold}' if fold is not None else '')
        os.makedirs(self.sub_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        filename = f"{self.stage}_fold{self.fold}_epoch{epoch+1:02d}.keras" if self.fold is not None else f"{self.stage}_epoch{epoch+1:02d}.keras"
        path = os.path.join(self.sub_dir, filename)
        self.model.save(path)
        print(f"[Saved model] Epoch {epoch+1} saved to {path}")

# === Callback on independent test ===
class FinalMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.epoch_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_test, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)

        cm = metrics.confusion_matrix(self.y_test, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()

        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0

        fpr, tpr, _ = roc_curve(self.y_test, y_pred_probs)
        roc_auc = auc(fpr, tpr)

        epoch_time = time.time() - self.epoch_start_time

        results.loc[len(results)] = [
            'Independent', 'Final', epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, roc_auc,
            None, epoch_time, self.model.count_params()
        ]

# === DataFrame Column ===
results_columns = [
    'Stage', 'Fold', 'Epoch', 'TP', 'FP', 'TN', 'FN',
    'Sens', 'Spec', 'Acc', 'MCC', 'F1', 'AUC',
    'Train_Time', 'Epoch_Time', 'Total_Params'
]
results = pd.DataFrame(columns=results_columns)

# === 5-Fold Cross-Validation + ROC per fold ===
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train, y_train), start=1):
    print(f"\n=== Fold {fold}/5 ===")
    X_train_fold, X_val = X_train[train_idx], X_train[val_idx]
    y_train_fold, y_val = y_train[train_idx], y_train[val_idx]

    model = DeepScan(window_sizes=WINDOW_SIZES,
                     num_filters=NUM_FILTERS,
                     num_hidden=NUM_HIDDEN)
    model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
    model.summary()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    metrics_cb = MetricsCallback(X_val, y_val, fold)
    tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, f'fold_{fold}'))
    early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
    save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='CV', fold=fold)
    checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_CV_{fold}.keras')
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')

    start_fold = time.time()
    model.fit(
        X_train_fold, y_train_fold,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=[metrics_cb, tb_cb, early_stop_cb, save_cb, checkpoint_cb],
        verbose=1,
    )
    train_time = time.time() - start_fold
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'Train_Time'] = train_time

    y_val_probs = model.predict(X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
    fpr, tpr, _ = roc_curve(y_val, y_val_probs)
    roc_auc = auc(fpr, tpr)
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'AUC'] = roc_auc

# Save CV results
results_path = os.path.join(
    LOG_DIR,
    f'training_results_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H.csv'
)
results.to_csv(results_path, index=False)
print(f"\nCV results saved to {results_path}")

# === Final training and evaluation ===
final_cb = FinalMetricsCallback(X_test, y_test)

final_model = DeepScan(window_sizes=WINDOW_SIZES,
                      num_filters=NUM_FILTERS,
                      num_hidden=NUM_HIDDEN)
final_model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
final_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='Independent')
final_checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_Independent.keras')
final_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(final_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')
final_early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
final_tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, 'final_model'))

history = final_model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[final_cb, final_tb_cb, final_checkpoint_cb, final_early_stop_cb, save_cb],
    verbose=1
)

# Save final results
results.to_csv(results_path, index=False)
print(f"\nFinal results saved to {results_path}")

best_auc_idx = results[results['Stage'] == 'Independent']['AUC'].idxmax()
best_auc_row = results.loc[best_auc_idx]
print(f"\nBest Independent Test AUC: {best_auc_row['AUC']:.4f} (Epoch {best_auc_row['Epoch']})")



=== Fold 1/5 ===


Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 362ms/step - accuracy: 0.5317 - auc: 0.5505 - loss: 0.7474[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_1/CV_fold1_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 682ms/step - accuracy: 0.5347 - auc: 0.5558 - loss: 0.7448 - val_accuracy: 0.8021 - val_auc: 0.9265 - val_loss: 0.6161
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 363ms/step - accuracy: 0.6937 - auc: 0.7651 - loss: 0.6310[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_1/CV_fold1_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 535ms/step - accuracy: 0.6944 - auc: 0.7661 - loss: 0.6298 - val_accuracy: 0.8763 - val_auc: 0.9527 - val_loss: 0.5289
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 358ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0.5625 - auc: 0.5656 - loss: 0.7496[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_2/CV_fold2_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 680ms/step - accuracy: 0.5655 - auc: 0.5719 - loss: 0.7458 - val_accuracy: 0.6756 - val_auc: 0.8169 - val_loss: 0.6175
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.6747 - auc: 0.7502 - loss: 0.6354[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_2/CV_fold2_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 525ms/step - accuracy: 0.6758 - auc: 0.7513 - loss: 0.6341 - val_accuracy: 0.7748 - val_auc: 0.8861 - val_loss: 0.5459
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 359ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 363ms/step - accuracy: 0.5594 - auc: 0.5715 - loss: 0.7474[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_3/CV_fold3_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 685ms/step - accuracy: 0.5634 - auc: 0.5774 - loss: 0.7437 - val_accuracy: 0.8120 - val_auc: 0.8921 - val_loss: 0.6164
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 354ms/step - accuracy: 0.6900 - auc: 0.7594 - loss: 0.6319[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_3/CV_fold3_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 523ms/step - accuracy: 0.6920 - auc: 0.7622 - loss: 0.6297 - val_accuracy: 0.7810 - val_auc: 0.9244 - val_loss: 0.5506
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 353ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 354ms/step - accuracy: 0.5624 - auc: 0.5786 - loss: 0.7422[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_4/CV_fold4_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 677ms/step - accuracy: 0.5652 - auc: 0.5832 - loss: 0.7394 - val_accuracy: 0.7851 - val_auc: 0.9178 - val_loss: 0.5978
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.6896 - auc: 0.7785 - loss: 0.6190[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_4/CV_fold4_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 523ms/step - accuracy: 0.6909 - auc: 0.7793 - loss: 0.6180 - val_accuracy: 0.7913 - val_auc: 0.9350 - val_loss: 0.5166
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 359ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 354ms/step - accuracy: 0.5348 - auc: 0.5511 - loss: 0.8721[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_5/CV_fold5_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 676ms/step - accuracy: 0.5384 - auc: 0.5544 - loss: 0.8624 - val_accuracy: 0.7686 - val_auc: 0.8619 - val_loss: 0.6485
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 358ms/step - accuracy: 0.6471 - auc: 0.7363 - loss: 0.6636[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[28]_256F_1024H/CV/fold_5/CV_fold5_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 521ms/step - accuracy: 0.6505 - auc: 0.7374 - loss: 0.6607 - val_accuracy: 0.6736 - val_auc: 0.9223 - val_loss: 0.5940
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 356ms/step - accuracy: 0

In [None]:
import os
import time
import math

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.saving import register_keras_serializable

# === Hyperpararmeters ===
BATCH_SIZE      = 128
NUM_CLASSES     = 1
EPOCHS          = 50
NUM_FILTERS     = 256
NUM_HIDDEN      = 1024
WINDOW_SIZES    = [24]
MAX_SEQ_LENGTH  = 1022
EMBEDDING_WIDTH = 1280
LOG_MODEL = os.path.join(LOG_DIR, f'MODELS_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H')
os.makedirs(LOG_MODEL, exist_ok=True)

@register_keras_serializable()
def DeepScan(input_shape=(1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH),
             window_sizes=WINDOW_SIZES,
             num_filters=NUM_FILTERS,
             num_hidden=NUM_HIDDEN,
             num_classes=NUM_CLASSES):
    inputs = tf.keras.Input(shape=input_shape)

    branches = []
    for ws in window_sizes:
        x = layers.SeparableConv2D(
            filters=num_filters,
            kernel_size=(1, ws),
            strides=(1, 1),
            activation='relu',
            padding='valid',
            depthwise_regularizer=tf.keras.regularizers.l2(1e-4),
            pointwise_regularizer=tf.keras.regularizers.l2(1e-4),
            depthwise_initializer='glorot_uniform',
            pointwise_initializer='glorot_uniform'
        )(inputs)

        x = layers.MaxPooling2D(
            pool_size=(1, MAX_SEQ_LENGTH - ws + 1),
            strides=(1, 1),
            padding='valid'
        )(x)

        x = layers.Flatten()(x)
        branches.append(x)

    x = layers.Concatenate()(branches)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_hidden, activation='relu', name='fc1')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='DeepScan')
    return model

# === Callback on CV folds ===
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val, fold):
        super().__init__()
        self.X_val = X_val
        self.y_val = y_val
        self.fold  = fold
        self.fold_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)
        cm = metrics.confusion_matrix(self.y_val, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()
        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0
        epoch_time = time.time() - self.epoch_start
        results.loc[len(results)] = [
            'CV', self.fold, epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, None,
            None, epoch_time, self.model.count_params()
        ]

class SaveEveryEpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, base_dir, stage='CV', fold=None):
        super().__init__()
        self.base_dir = base_dir
        self.stage = stage
        self.fold = fold
        self.sub_dir = os.path.join(base_dir, stage, f'fold_{fold}' if fold is not None else '')
        os.makedirs(self.sub_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        filename = f"{self.stage}_fold{self.fold}_epoch{epoch+1:02d}.keras" if self.fold is not None else f"{self.stage}_epoch{epoch+1:02d}.keras"
        path = os.path.join(self.sub_dir, filename)
        self.model.save(path)
        print(f"[Saved model] Epoch {epoch+1} saved to {path}")

# === Callback on independent test ===
class FinalMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.epoch_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_test, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)

        cm = metrics.confusion_matrix(self.y_test, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()

        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0

        fpr, tpr, _ = roc_curve(self.y_test, y_pred_probs)
        roc_auc = auc(fpr, tpr)

        epoch_time = time.time() - self.epoch_start_time

        results.loc[len(results)] = [
            'Independent', 'Final', epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, roc_auc,
            None, epoch_time, self.model.count_params()
        ]

# === DataFrame Column ===
results_columns = [
    'Stage', 'Fold', 'Epoch', 'TP', 'FP', 'TN', 'FN',
    'Sens', 'Spec', 'Acc', 'MCC', 'F1', 'AUC',
    'Train_Time', 'Epoch_Time', 'Total_Params'
]
results = pd.DataFrame(columns=results_columns)

# === 5-Fold Cross-Validation + ROC per fold ===
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train, y_train), start=1):
    print(f"\n=== Fold {fold}/5 ===")
    X_train_fold, X_val = X_train[train_idx], X_train[val_idx]
    y_train_fold, y_val = y_train[train_idx], y_train[val_idx]

    model = DeepScan(window_sizes=WINDOW_SIZES,
                     num_filters=NUM_FILTERS,
                     num_hidden=NUM_HIDDEN)
    model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
    model.summary()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    metrics_cb = MetricsCallback(X_val, y_val, fold)
    tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, f'fold_{fold}'))
    early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
    save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='CV', fold=fold)
    checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_CV_{fold}.keras')
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')

    start_fold = time.time()
    model.fit(
        X_train_fold, y_train_fold,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=[metrics_cb, tb_cb, early_stop_cb, save_cb, checkpoint_cb],
        verbose=1,
    )
    train_time = time.time() - start_fold
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'Train_Time'] = train_time

    y_val_probs = model.predict(X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
    fpr, tpr, _ = roc_curve(y_val, y_val_probs)
    roc_auc = auc(fpr, tpr)
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'AUC'] = roc_auc

# Save CV results
results_path = os.path.join(
    LOG_DIR,
    f'training_results_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H.csv'
)
results.to_csv(results_path, index=False)
print(f"\nCV results saved to {results_path}")

# === Final training and evaluation ===
final_cb = FinalMetricsCallback(X_test, y_test)

final_model = DeepScan(window_sizes=WINDOW_SIZES,
                      num_filters=NUM_FILTERS,
                      num_hidden=NUM_HIDDEN)
final_model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
final_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='Independent')
final_checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_Independent.keras')
final_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(final_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')
final_early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
final_tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, 'final_model'))

history = final_model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[final_cb, final_tb_cb, final_checkpoint_cb, final_early_stop_cb, save_cb],
    verbose=1
)

# Save final results
results.to_csv(results_path, index=False)
print(f"\nFinal results saved to {results_path}")

best_auc_idx = results[results['Stage'] == 'Independent']['AUC'].idxmax()
best_auc_row = results.loc[best_auc_idx]
print(f"\nBest Independent Test AUC: {best_auc_row['AUC']:.4f} (Epoch {best_auc_row['Epoch']})")



=== Fold 1/5 ===


Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 359ms/step - accuracy: 0.5200 - auc: 0.5485 - loss: 0.7482[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_1/CV_fold1_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 684ms/step - accuracy: 0.5248 - auc: 0.5535 - loss: 0.7456 - val_accuracy: 0.7918 - val_auc: 0.9016 - val_loss: 0.6191
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 353ms/step - accuracy: 0.6731 - auc: 0.7407 - loss: 0.6485[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_1/CV_fold1_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 522ms/step - accuracy: 0.6739 - auc: 0.7422 - loss: 0.6469 - val_accuracy: 0.8082 - val_auc: 0.9355 - val_loss: 0.5467
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 350ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 350ms/step - accuracy: 0.5250 - auc: 0.5329 - loss: 0.7717[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_2/CV_fold2_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 675ms/step - accuracy: 0.5313 - auc: 0.5409 - loss: 0.7669 - val_accuracy: 0.7004 - val_auc: 0.9176 - val_loss: 0.6241
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 356ms/step - accuracy: 0.6747 - auc: 0.7379 - loss: 0.6472[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_2/CV_fold2_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 525ms/step - accuracy: 0.6757 - auc: 0.7402 - loss: 0.6453 - val_accuracy: 0.7831 - val_auc: 0.9552 - val_loss: 0.5450
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 351ms/step - accuracy: 0.5258 - auc: 0.5388 - loss: 0.7995[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_3/CV_fold3_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 673ms/step - accuracy: 0.5284 - auc: 0.5415 - loss: 0.7955 - val_accuracy: 0.6116 - val_auc: 0.8888 - val_loss: 0.6643
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0.6562 - auc: 0.7203 - loss: 0.6605[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_3/CV_fold3_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 524ms/step - accuracy: 0.6581 - auc: 0.7231 - loss: 0.6586 - val_accuracy: 0.8512 - val_auc: 0.9438 - val_loss: 0.5776
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 350ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 358ms/step - accuracy: 0.5271 - auc: 0.5306 - loss: 0.7660[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_4/CV_fold4_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 679ms/step - accuracy: 0.5312 - auc: 0.5366 - loss: 0.7625 - val_accuracy: 0.7149 - val_auc: 0.9448 - val_loss: 0.6164
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 352ms/step - accuracy: 0.6602 - auc: 0.7250 - loss: 0.6556[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_4/CV_fold4_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 518ms/step - accuracy: 0.6620 - auc: 0.7271 - loss: 0.6540 - val_accuracy: 0.6860 - val_auc: 0.9622 - val_loss: 0.5623
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 356ms/step - accuracy: 0.5499 - auc: 0.5634 - loss: 0.7392[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_5/CV_fold5_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 678ms/step - accuracy: 0.5519 - auc: 0.5674 - loss: 0.7376 - val_accuracy: 0.5640 - val_auc: 0.8680 - val_loss: 0.6736
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 353ms/step - accuracy: 0.6582 - auc: 0.7343 - loss: 0.6513[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[24]_256F_1024H/CV/fold_5/CV_fold5_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 522ms/step - accuracy: 0.6610 - auc: 0.7364 - loss: 0.6492 - val_accuracy: 0.7872 - val_auc: 0.9147 - val_loss: 0.5874
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0

In [None]:
import os
import time
import math

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.saving import register_keras_serializable

# === Hyperpararmeters ===
BATCH_SIZE      = 128
NUM_CLASSES     = 1
EPOCHS          = 50
NUM_FILTERS     = 256
NUM_HIDDEN      = 1024
WINDOW_SIZES    = [20]
MAX_SEQ_LENGTH  = 1022
EMBEDDING_WIDTH = 1280
LOG_MODEL = os.path.join(LOG_DIR, f'MODELS_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H')
os.makedirs(LOG_MODEL, exist_ok=True)

@register_keras_serializable()
def DeepScan(input_shape=(1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH),
             window_sizes=WINDOW_SIZES,
             num_filters=NUM_FILTERS,
             num_hidden=NUM_HIDDEN,
             num_classes=NUM_CLASSES):
    inputs = tf.keras.Input(shape=input_shape)

    branches = []
    for ws in window_sizes:
        x = layers.SeparableConv2D(
            filters=num_filters,
            kernel_size=(1, ws),
            strides=(1, 1),
            activation='relu',
            padding='valid',
            depthwise_regularizer=tf.keras.regularizers.l2(1e-4),
            pointwise_regularizer=tf.keras.regularizers.l2(1e-4),
            depthwise_initializer='glorot_uniform',
            pointwise_initializer='glorot_uniform'
        )(inputs)

        x = layers.MaxPooling2D(
            pool_size=(1, MAX_SEQ_LENGTH - ws + 1),
            strides=(1, 1),
            padding='valid'
        )(x)

        x = layers.Flatten()(x)
        branches.append(x)

    x = layers.Concatenate()(branches)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_hidden, activation='relu', name='fc1')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='DeepScan')
    return model

# === Callback on CV folds ===
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val, fold):
        super().__init__()
        self.X_val = X_val
        self.y_val = y_val
        self.fold  = fold
        self.fold_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)
        cm = metrics.confusion_matrix(self.y_val, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()
        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0
        epoch_time = time.time() - self.epoch_start
        results.loc[len(results)] = [
            'CV', self.fold, epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, None,
            None, epoch_time, self.model.count_params()
        ]

class SaveEveryEpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, base_dir, stage='CV', fold=None):
        super().__init__()
        self.base_dir = base_dir
        self.stage = stage
        self.fold = fold
        self.sub_dir = os.path.join(base_dir, stage, f'fold_{fold}' if fold is not None else '')
        os.makedirs(self.sub_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        filename = f"{self.stage}_fold{self.fold}_epoch{epoch+1:02d}.keras" if self.fold is not None else f"{self.stage}_epoch{epoch+1:02d}.keras"
        path = os.path.join(self.sub_dir, filename)
        self.model.save(path)
        print(f"[Saved model] Epoch {epoch+1} saved to {path}")

# === Callback on independent test ===
class FinalMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.epoch_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_test, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)

        cm = metrics.confusion_matrix(self.y_test, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()

        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0

        fpr, tpr, _ = roc_curve(self.y_test, y_pred_probs)
        roc_auc = auc(fpr, tpr)

        epoch_time = time.time() - self.epoch_start_time

        results.loc[len(results)] = [
            'Independent', 'Final', epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, roc_auc,
            None, epoch_time, self.model.count_params()
        ]

# === DataFrame Column ===
results_columns = [
    'Stage', 'Fold', 'Epoch', 'TP', 'FP', 'TN', 'FN',
    'Sens', 'Spec', 'Acc', 'MCC', 'F1', 'AUC',
    'Train_Time', 'Epoch_Time', 'Total_Params'
]
results = pd.DataFrame(columns=results_columns)

# === 5-Fold Cross-Validation + ROC per fold ===
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train, y_train), start=1):
    print(f"\n=== Fold {fold}/5 ===")
    X_train_fold, X_val = X_train[train_idx], X_train[val_idx]
    y_train_fold, y_val = y_train[train_idx], y_train[val_idx]

    model = DeepScan(window_sizes=WINDOW_SIZES,
                     num_filters=NUM_FILTERS,
                     num_hidden=NUM_HIDDEN)
    model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
    model.summary()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    metrics_cb = MetricsCallback(X_val, y_val, fold)
    tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, f'fold_{fold}'))
    early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
    save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='CV', fold=fold)
    checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_CV_{fold}.keras')
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')

    start_fold = time.time()
    model.fit(
        X_train_fold, y_train_fold,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=[metrics_cb, tb_cb, early_stop_cb, save_cb, checkpoint_cb],
        verbose=1,
    )
    train_time = time.time() - start_fold
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'Train_Time'] = train_time

    y_val_probs = model.predict(X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
    fpr, tpr, _ = roc_curve(y_val, y_val_probs)
    roc_auc = auc(fpr, tpr)
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'AUC'] = roc_auc

# Save CV results
results_path = os.path.join(
    LOG_DIR,
    f'training_results_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H.csv'
)
results.to_csv(results_path, index=False)
print(f"\nCV results saved to {results_path}")

# === Final training and evaluation ===
final_cb = FinalMetricsCallback(X_test, y_test)

final_model = DeepScan(window_sizes=WINDOW_SIZES,
                      num_filters=NUM_FILTERS,
                      num_hidden=NUM_HIDDEN)
final_model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
final_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='Independent')
final_checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_Independent.keras')
final_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(final_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')
final_early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
final_tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, 'final_model'))

history = final_model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[final_cb, final_tb_cb, final_checkpoint_cb, final_early_stop_cb, save_cb],
    verbose=1
)

# Save final results
results.to_csv(results_path, index=False)
print(f"\nFinal results saved to {results_path}")

best_auc_idx = results[results['Stage'] == 'Independent']['AUC'].idxmax()
best_auc_row = results.loc[best_auc_idx]
print(f"\nBest Independent Test AUC: {best_auc_row['AUC']:.4f} (Epoch {best_auc_row['Epoch']})")



=== Fold 1/5 ===


Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 363ms/step - accuracy: 0.5455 - auc: 0.5675 - loss: 0.7440[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_1/CV_fold1_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 684ms/step - accuracy: 0.5492 - auc: 0.5719 - loss: 0.7414 - val_accuracy: 0.8062 - val_auc: 0.9360 - val_loss: 0.6117
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 358ms/step - accuracy: 0.6755 - auc: 0.7241 - loss: 0.6577[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_1/CV_fold1_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 556ms/step - accuracy: 0.6783 - auc: 0.7278 - loss: 0.6550 - val_accuracy: 0.8660 - val_auc: 0.9579 - val_loss: 0.5258
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 359ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.5437 - auc: 0.5508 - loss: 0.7533[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_2/CV_fold2_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 680ms/step - accuracy: 0.5466 - auc: 0.5557 - loss: 0.7506 - val_accuracy: 0.6715 - val_auc: 0.9141 - val_loss: 0.6467
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0.6616 - auc: 0.7162 - loss: 0.6649[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_2/CV_fold2_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 525ms/step - accuracy: 0.6625 - auc: 0.7179 - loss: 0.6630 - val_accuracy: 0.8450 - val_auc: 0.9472 - val_loss: 0.5714
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 353ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.5516 - auc: 0.5552 - loss: 0.7649[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_3/CV_fold3_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 682ms/step - accuracy: 0.5549 - auc: 0.5601 - loss: 0.7611 - val_accuracy: 0.6033 - val_auc: 0.9093 - val_loss: 0.6514
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 359ms/step - accuracy: 0.6374 - auc: 0.7042 - loss: 0.6672[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_3/CV_fold3_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 527ms/step - accuracy: 0.6399 - auc: 0.7066 - loss: 0.6656 - val_accuracy: 0.6178 - val_auc: 0.9374 - val_loss: 0.6124
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 362ms/step - accuracy: 0.5448 - auc: 0.5613 - loss: 0.7604[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_4/CV_fold4_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 682ms/step - accuracy: 0.5498 - auc: 0.5674 - loss: 0.7561 - val_accuracy: 0.8017 - val_auc: 0.9323 - val_loss: 0.6066
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.6546 - auc: 0.7253 - loss: 0.6527[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_4/CV_fold4_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 527ms/step - accuracy: 0.6564 - auc: 0.7281 - loss: 0.6508 - val_accuracy: 0.8967 - val_auc: 0.9591 - val_loss: 0.5281
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 354ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 353ms/step - accuracy: 0.5374 - auc: 0.5292 - loss: 0.8001[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_5/CV_fold5_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 677ms/step - accuracy: 0.5395 - auc: 0.5344 - loss: 0.7944 - val_accuracy: 0.6178 - val_auc: 0.8644 - val_loss: 0.6612
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 364ms/step - accuracy: 0.6527 - auc: 0.7109 - loss: 0.6645[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[20]_256F_1024H/CV/fold_5/CV_fold5_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 530ms/step - accuracy: 0.6557 - auc: 0.7151 - loss: 0.6621 - val_accuracy: 0.8120 - val_auc: 0.8986 - val_loss: 0.5954
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 353ms/step - accuracy: 0

In [None]:
import os
import time
import math

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.saving import register_keras_serializable

# === Hyperpararmeters ===
BATCH_SIZE      = 128
NUM_CLASSES     = 1
EPOCHS          = 50
NUM_FILTERS     = 256
NUM_HIDDEN      = 1024
WINDOW_SIZES    = [16]
MAX_SEQ_LENGTH  = 1022
EMBEDDING_WIDTH = 1280
LOG_MODEL = os.path.join(LOG_DIR, f'MODELS_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H')
os.makedirs(LOG_MODEL, exist_ok=True)

@register_keras_serializable()
def DeepScan(input_shape=(1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH),
             window_sizes=WINDOW_SIZES,
             num_filters=NUM_FILTERS,
             num_hidden=NUM_HIDDEN,
             num_classes=NUM_CLASSES):
    inputs = tf.keras.Input(shape=input_shape)

    branches = []
    for ws in window_sizes:
        x = layers.SeparableConv2D(
            filters=num_filters,
            kernel_size=(1, ws),
            strides=(1, 1),
            activation='relu',
            padding='valid',
            depthwise_regularizer=tf.keras.regularizers.l2(1e-4),
            pointwise_regularizer=tf.keras.regularizers.l2(1e-4),
            depthwise_initializer='glorot_uniform',
            pointwise_initializer='glorot_uniform'
        )(inputs)

        x = layers.MaxPooling2D(
            pool_size=(1, MAX_SEQ_LENGTH - ws + 1),
            strides=(1, 1),
            padding='valid'
        )(x)

        x = layers.Flatten()(x)
        branches.append(x)

    x = layers.Concatenate()(branches)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_hidden, activation='relu', name='fc1')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='DeepScan')
    return model

# === Callback on CV folds ===
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val, fold):
        super().__init__()
        self.X_val = X_val
        self.y_val = y_val
        self.fold  = fold
        self.fold_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)
        cm = metrics.confusion_matrix(self.y_val, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()
        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0
        epoch_time = time.time() - self.epoch_start
        results.loc[len(results)] = [
            'CV', self.fold, epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, None,
            None, epoch_time, self.model.count_params()
        ]

class SaveEveryEpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, base_dir, stage='CV', fold=None):
        super().__init__()
        self.base_dir = base_dir
        self.stage = stage
        self.fold = fold
        self.sub_dir = os.path.join(base_dir, stage, f'fold_{fold}' if fold is not None else '')
        os.makedirs(self.sub_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        filename = f"{self.stage}_fold{self.fold}_epoch{epoch+1:02d}.keras" if self.fold is not None else f"{self.stage}_epoch{epoch+1:02d}.keras"
        path = os.path.join(self.sub_dir, filename)
        self.model.save(path)
        print(f"[Saved model] Epoch {epoch+1} saved to {path}")

# === Callback on independent test ===
class FinalMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.epoch_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_test, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)

        cm = metrics.confusion_matrix(self.y_test, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()

        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0

        fpr, tpr, _ = roc_curve(self.y_test, y_pred_probs)
        roc_auc = auc(fpr, tpr)

        epoch_time = time.time() - self.epoch_start_time

        results.loc[len(results)] = [
            'Independent', 'Final', epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, roc_auc,
            None, epoch_time, self.model.count_params()
        ]

# === DataFrame Column ===
results_columns = [
    'Stage', 'Fold', 'Epoch', 'TP', 'FP', 'TN', 'FN',
    'Sens', 'Spec', 'Acc', 'MCC', 'F1', 'AUC',
    'Train_Time', 'Epoch_Time', 'Total_Params'
]
results = pd.DataFrame(columns=results_columns)

# === 5-Fold Cross-Validation + ROC per fold ===
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train, y_train), start=1):
    print(f"\n=== Fold {fold}/5 ===")
    X_train_fold, X_val = X_train[train_idx], X_train[val_idx]
    y_train_fold, y_val = y_train[train_idx], y_train[val_idx]

    model = DeepScan(window_sizes=WINDOW_SIZES,
                     num_filters=NUM_FILTERS,
                     num_hidden=NUM_HIDDEN)
    model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
    model.summary()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    metrics_cb = MetricsCallback(X_val, y_val, fold)
    tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, f'fold_{fold}'))
    early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
    save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='CV', fold=fold)
    checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_CV_{fold}.keras')
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')

    start_fold = time.time()
    model.fit(
        X_train_fold, y_train_fold,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=[metrics_cb, tb_cb, early_stop_cb, save_cb, checkpoint_cb],
        verbose=1,
    )
    train_time = time.time() - start_fold
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'Train_Time'] = train_time

    y_val_probs = model.predict(X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
    fpr, tpr, _ = roc_curve(y_val, y_val_probs)
    roc_auc = auc(fpr, tpr)
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'AUC'] = roc_auc

# Save CV results
results_path = os.path.join(
    LOG_DIR,
    f'training_results_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H.csv'
)
results.to_csv(results_path, index=False)
print(f"\nCV results saved to {results_path}")

# === Final training and evaluation ===
final_cb = FinalMetricsCallback(X_test, y_test)

final_model = DeepScan(window_sizes=WINDOW_SIZES,
                      num_filters=NUM_FILTERS,
                      num_hidden=NUM_HIDDEN)
final_model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
final_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='Independent')
final_checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_Independent.keras')
final_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(final_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')
final_early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
final_tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, 'final_model'))

history = final_model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[final_cb, final_tb_cb, final_checkpoint_cb, final_early_stop_cb, save_cb],
    verbose=1
)

# Save final results
results.to_csv(results_path, index=False)
print(f"\nFinal results saved to {results_path}")

best_auc_idx = results[results['Stage'] == 'Independent']['AUC'].idxmax()
best_auc_row = results.loc[best_auc_idx]
print(f"\nBest Independent Test AUC: {best_auc_row['AUC']:.4f} (Epoch {best_auc_row['Epoch']})")



=== Fold 1/5 ===


Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.5523 - auc: 0.5708 - loss: 0.7553[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_1/CV_fold1_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 683ms/step - accuracy: 0.5554 - auc: 0.5740 - loss: 0.7527 - val_accuracy: 0.8371 - val_auc: 0.9307 - val_loss: 0.6361
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 362ms/step - accuracy: 0.6650 - auc: 0.7189 - loss: 0.6616[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_1/CV_fold1_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 528ms/step - accuracy: 0.6659 - auc: 0.7203 - loss: 0.6605 - val_accuracy: 0.8701 - val_auc: 0.9556 - val_loss: 0.5683
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 363ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.5503 - auc: 0.5562 - loss: 0.7548[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_2/CV_fold2_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 680ms/step - accuracy: 0.5525 - auc: 0.5607 - loss: 0.7520 - val_accuracy: 0.5826 - val_auc: 0.8678 - val_loss: 0.6691
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 359ms/step - accuracy: 0.6465 - auc: 0.7119 - loss: 0.6671[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_2/CV_fold2_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 528ms/step - accuracy: 0.6494 - auc: 0.7148 - loss: 0.6649 - val_accuracy: 0.6653 - val_auc: 0.9068 - val_loss: 0.6038
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 358ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 356ms/step - accuracy: 0.5337 - auc: 0.5392 - loss: 0.7660[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_3/CV_fold3_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 715ms/step - accuracy: 0.5355 - auc: 0.5419 - loss: 0.7638 - val_accuracy: 0.6260 - val_auc: 0.9012 - val_loss: 0.6483
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 356ms/step - accuracy: 0.6766 - auc: 0.7400 - loss: 0.6480[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_3/CV_fold3_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 525ms/step - accuracy: 0.6768 - auc: 0.7412 - loss: 0.6468 - val_accuracy: 0.7872 - val_auc: 0.9152 - val_loss: 0.5701
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 358ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 360ms/step - accuracy: 0.5771 - auc: 0.5863 - loss: 0.7325[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_4/CV_fold4_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 679ms/step - accuracy: 0.5759 - auc: 0.5861 - loss: 0.7324 - val_accuracy: 0.7128 - val_auc: 0.8990 - val_loss: 0.6415
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 357ms/step - accuracy: 0.6641 - auc: 0.7233 - loss: 0.6579[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_4/CV_fold4_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 528ms/step - accuracy: 0.6641 - auc: 0.7233 - loss: 0.6574 - val_accuracy: 0.8905 - val_auc: 0.9481 - val_loss: 0.5620
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 359ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 355ms/step - accuracy: 0.5200 - auc: 0.5351 - loss: 0.8052[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_5/CV_fold5_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 676ms/step - accuracy: 0.5238 - auc: 0.5382 - loss: 0.8008 - val_accuracy: 0.5971 - val_auc: 0.8742 - val_loss: 0.6772
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 358ms/step - accuracy: 0.6424 - auc: 0.7010 - loss: 0.6723[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[16]_256F_1024H/CV/fold_5/CV_fold5_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 525ms/step - accuracy: 0.6421 - auc: 0.7022 - loss: 0.6712 - val_accuracy: 0.7624 - val_auc: 0.9136 - val_loss: 0.6129
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 356ms/step - accuracy: 0

In [None]:
import os
import time
import math

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.saving import register_keras_serializable

# === Hyperpararmeters ===
BATCH_SIZE      = 128
NUM_CLASSES     = 1
EPOCHS          = 50
NUM_FILTERS     = 256
NUM_HIDDEN      = 1024
WINDOW_SIZES    = [12]
MAX_SEQ_LENGTH  = 1022
EMBEDDING_WIDTH = 1280
LOG_MODEL = os.path.join(LOG_DIR, f'MODELS_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H')
os.makedirs(LOG_MODEL, exist_ok=True)

@register_keras_serializable()
def DeepScan(input_shape=(1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH),
             window_sizes=WINDOW_SIZES,
             num_filters=NUM_FILTERS,
             num_hidden=NUM_HIDDEN,
             num_classes=NUM_CLASSES):
    inputs = tf.keras.Input(shape=input_shape)

    branches = []
    for ws in window_sizes:
        x = layers.SeparableConv2D(
            filters=num_filters,
            kernel_size=(1, ws),
            strides=(1, 1),
            activation='relu',
            padding='valid',
            depthwise_regularizer=tf.keras.regularizers.l2(1e-4),
            pointwise_regularizer=tf.keras.regularizers.l2(1e-4),
            depthwise_initializer='glorot_uniform',
            pointwise_initializer='glorot_uniform'
        )(inputs)

        x = layers.MaxPooling2D(
            pool_size=(1, MAX_SEQ_LENGTH - ws + 1),
            strides=(1, 1),
            padding='valid'
        )(x)

        x = layers.Flatten()(x)
        branches.append(x)

    x = layers.Concatenate()(branches)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_hidden, activation='relu', name='fc1')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='DeepScan')
    return model

# === Callback on CV folds ===
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val, fold):
        super().__init__()
        self.X_val = X_val
        self.y_val = y_val
        self.fold  = fold
        self.fold_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)
        cm = metrics.confusion_matrix(self.y_val, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()
        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0
        epoch_time = time.time() - self.epoch_start
        results.loc[len(results)] = [
            'CV', self.fold, epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, None,
            None, epoch_time, self.model.count_params()
        ]

class SaveEveryEpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, base_dir, stage='CV', fold=None):
        super().__init__()
        self.base_dir = base_dir
        self.stage = stage
        self.fold = fold
        self.sub_dir = os.path.join(base_dir, stage, f'fold_{fold}' if fold is not None else '')
        os.makedirs(self.sub_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        filename = f"{self.stage}_fold{self.fold}_epoch{epoch+1:02d}.keras" if self.fold is not None else f"{self.stage}_epoch{epoch+1:02d}.keras"
        path = os.path.join(self.sub_dir, filename)
        self.model.save(path)
        print(f"[Saved model] Epoch {epoch+1} saved to {path}")

# === Callback on independent test ===
class FinalMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.epoch_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_test, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)

        cm = metrics.confusion_matrix(self.y_test, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()

        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0

        fpr, tpr, _ = roc_curve(self.y_test, y_pred_probs)
        roc_auc = auc(fpr, tpr)

        epoch_time = time.time() - self.epoch_start_time

        results.loc[len(results)] = [
            'Independent', 'Final', epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, roc_auc,
            None, epoch_time, self.model.count_params()
        ]

# === DataFrame Column ===
results_columns = [
    'Stage', 'Fold', 'Epoch', 'TP', 'FP', 'TN', 'FN',
    'Sens', 'Spec', 'Acc', 'MCC', 'F1', 'AUC',
    'Train_Time', 'Epoch_Time', 'Total_Params'
]
results = pd.DataFrame(columns=results_columns)

# === 5-Fold Cross-Validation + ROC per fold ===
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train, y_train), start=1):
    print(f"\n=== Fold {fold}/5 ===")
    X_train_fold, X_val = X_train[train_idx], X_train[val_idx]
    y_train_fold, y_val = y_train[train_idx], y_train[val_idx]

    model = DeepScan(window_sizes=WINDOW_SIZES,
                     num_filters=NUM_FILTERS,
                     num_hidden=NUM_HIDDEN)
    model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
    model.summary()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    metrics_cb = MetricsCallback(X_val, y_val, fold)
    tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, f'fold_{fold}'))
    early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
    save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='CV', fold=fold)
    checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_CV_{fold}.keras')
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')

    start_fold = time.time()
    model.fit(
        X_train_fold, y_train_fold,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=[metrics_cb, tb_cb, early_stop_cb, save_cb, checkpoint_cb],
        verbose=1,
    )
    train_time = time.time() - start_fold
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'Train_Time'] = train_time

    y_val_probs = model.predict(X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
    fpr, tpr, _ = roc_curve(y_val, y_val_probs)
    roc_auc = auc(fpr, tpr)
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'AUC'] = roc_auc

# Save CV results
results_path = os.path.join(
    LOG_DIR,
    f'training_results_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H.csv'
)
results.to_csv(results_path, index=False)
print(f"\nCV results saved to {results_path}")

# === Final training and evaluation ===
final_cb = FinalMetricsCallback(X_test, y_test)

final_model = DeepScan(window_sizes=WINDOW_SIZES,
                      num_filters=NUM_FILTERS,
                      num_hidden=NUM_HIDDEN)
final_model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
final_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='Independent')
final_checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_Independent.keras')
final_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(final_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')
final_early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
final_tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, 'final_model'))

history = final_model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[final_cb, final_tb_cb, final_checkpoint_cb, final_early_stop_cb, save_cb],
    verbose=1
)

# Save final results
results.to_csv(results_path, index=False)
print(f"\nFinal results saved to {results_path}")

best_auc_idx = results[results['Stage'] == 'Independent']['AUC'].idxmax()
best_auc_row = results.loc[best_auc_idx]
print(f"\nBest Independent Test AUC: {best_auc_row['AUC']:.4f} (Epoch {best_auc_row['Epoch']})")



=== Fold 1/5 ===


Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 244ms/step - accuracy: 0.4908 - auc: 0.4930 - loss: 0.8373[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_1/CV_fold1_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 574ms/step - accuracy: 0.4949 - auc: 0.4976 - loss: 0.8315 - val_accuracy: 0.7505 - val_auc: 0.8540 - val_loss: 0.6469
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 243ms/step - accuracy: 0.6466 - auc: 0.7000 - loss: 0.6696[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_1/CV_fold1_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 419ms/step - accuracy: 0.6471 - auc: 0.7016 - loss: 0.6685 - val_accuracy: 0.8021 - val_auc: 0.8896 - val_loss: 0.5874
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 247ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 245ms/step - accuracy: 0.5148 - auc: 0.5321 - loss: 0.8811[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_2/CV_fold2_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 579ms/step - accuracy: 0.5174 - auc: 0.5338 - loss: 0.8726 - val_accuracy: 0.7355 - val_auc: 0.8551 - val_loss: 0.6671
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 242ms/step - accuracy: 0.6101 - auc: 0.6749 - loss: 0.6990[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_2/CV_fold2_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 417ms/step - accuracy: 0.6119 - auc: 0.6744 - loss: 0.6977 - val_accuracy: 0.6674 - val_auc: 0.9236 - val_loss: 0.6162
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 247ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 246ms/step - accuracy: 0.5103 - auc: 0.5265 - loss: 0.8595[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_3/CV_fold3_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 577ms/step - accuracy: 0.5146 - auc: 0.5304 - loss: 0.8505 - val_accuracy: 0.7934 - val_auc: 0.8842 - val_loss: 0.6560
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 250ms/step - accuracy: 0.6240 - auc: 0.6837 - loss: 0.6900[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_3/CV_fold3_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 424ms/step - accuracy: 0.6276 - auc: 0.6864 - loss: 0.6874 - val_accuracy: 0.7603 - val_auc: 0.9091 - val_loss: 0.5915
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 244ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 246ms/step - accuracy: 0.5612 - auc: 0.5597 - loss: 0.7542[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_4/CV_fold4_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 575ms/step - accuracy: 0.5610 - auc: 0.5620 - loss: 0.7526 - val_accuracy: 0.8140 - val_auc: 0.9204 - val_loss: 0.6411
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 246ms/step - accuracy: 0.6343 - auc: 0.6825 - loss: 0.6798[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_4/CV_fold4_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 422ms/step - accuracy: 0.6354 - auc: 0.6846 - loss: 0.6787 - val_accuracy: 0.8140 - val_auc: 0.9444 - val_loss: 0.5822
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 246ms/step - accuracy: 0

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 249ms/step - accuracy: 0.5556 - auc: 0.5635 - loss: 0.7463[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_5/CV_fold5_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 574ms/step - accuracy: 0.5583 - auc: 0.5678 - loss: 0.7439 - val_accuracy: 0.6818 - val_auc: 0.8520 - val_loss: 0.6586
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 248ms/step - accuracy: 0.6488 - auc: 0.6942 - loss: 0.6746[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[12]_256F_1024H/CV/fold_5/CV_fold5_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 425ms/step - accuracy: 0.6502 - auc: 0.6962 - loss: 0.6731 - val_accuracy: 0.7231 - val_auc: 0.8951 - val_loss: 0.5999
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 249ms/step - accuracy: 0

In [None]:
import os
import time
import math

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.metrics import roc_curve, auc
from tensorflow.keras.saving import register_keras_serializable

# === Hyperpararmeters ===
BATCH_SIZE      = 128
NUM_CLASSES     = 1
EPOCHS          = 50
NUM_FILTERS     = 256
NUM_HIDDEN      = 1024
WINDOW_SIZES    = [8]
MAX_SEQ_LENGTH  = 1022
EMBEDDING_WIDTH = 1280
LOG_MODEL = os.path.join(LOG_DIR, f'MODELS_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H')
os.makedirs(LOG_MODEL, exist_ok=True)

@register_keras_serializable()
def DeepScan(input_shape=(1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH),
             window_sizes=WINDOW_SIZES,
             num_filters=NUM_FILTERS,
             num_hidden=NUM_HIDDEN,
             num_classes=NUM_CLASSES):
    inputs = tf.keras.Input(shape=input_shape)

    branches = []
    for ws in window_sizes:
        x = layers.SeparableConv2D(
            filters=num_filters,
            kernel_size=(1, ws),
            strides=(1, 1),
            activation='relu',
            padding='valid',
            depthwise_regularizer=tf.keras.regularizers.l2(1e-4),
            pointwise_regularizer=tf.keras.regularizers.l2(1e-4),
            depthwise_initializer='glorot_uniform',
            pointwise_initializer='glorot_uniform'
        )(inputs)

        x = layers.MaxPooling2D(
            pool_size=(1, MAX_SEQ_LENGTH - ws + 1),
            strides=(1, 1),
            padding='valid'
        )(x)

        x = layers.Flatten()(x)
        branches.append(x)

    x = layers.Concatenate()(branches)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(num_hidden, activation='relu', name='fc1')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs, name='DeepScan')
    return model

# === Callback on CV folds ===
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_val, y_val, fold):
        super().__init__()
        self.X_val = X_val
        self.y_val = y_val
        self.fold  = fold
        self.fold_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)
        cm = metrics.confusion_matrix(self.y_val, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()
        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0
        epoch_time = time.time() - self.epoch_start
        results.loc[len(results)] = [
            'CV', self.fold, epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, None,
            None, epoch_time, self.model.count_params()
        ]

class SaveEveryEpochCallback(tf.keras.callbacks.Callback):
    def __init__(self, base_dir, stage='CV', fold=None):
        super().__init__()
        self.base_dir = base_dir
        self.stage = stage
        self.fold = fold
        self.sub_dir = os.path.join(base_dir, stage, f'fold_{fold}' if fold is not None else '')
        os.makedirs(self.sub_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        filename = f"{self.stage}_fold{self.fold}_epoch{epoch+1:02d}.keras" if self.fold is not None else f"{self.stage}_epoch{epoch+1:02d}.keras"
        path = os.path.join(self.sub_dir, filename)
        self.model.save(path)
        print(f"[Saved model] Epoch {epoch+1} saved to {path}")

# === Callback on independent test ===
class FinalMetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, X_test, y_test):
        super().__init__()
        self.X_test = X_test
        self.y_test = y_test
        self.epoch_start_time = time.time()

    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs=None):
        y_pred_probs  = self.model.predict(self.X_test, batch_size=BATCH_SIZE, verbose=0).ravel()
        y_pred_labels = (y_pred_probs >= 0.5).astype(int)

        cm = metrics.confusion_matrix(self.y_test, y_pred_labels)
        if cm.size == 1:
            if y_pred_labels[0] == 1:
                TN, FP, FN, TP = 0, 0, cm[0,0], 0
            else:
                TN, FP, FN, TP = cm[0,0], 0, 0, 0
        else:
            TN, FP, FN, TP = cm.ravel()

        Sens = TP/(TP+FN) if TP+FN>0 else 0
        Spec = TN/(TN+FP) if TN+FP>0 else 0
        Acc  = (TP+TN)/(TP+FP+TN+FN) if TP+FP+TN+FN>0 else 0
        denom = (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)
        MCC = (TP*TN - FP*FN)/math.sqrt(denom) if denom>0 else 0
        F1  = 2*TP/(2*TP + FP + FN) if 2*TP+FP+FN>0 else 0

        fpr, tpr, _ = roc_curve(self.y_test, y_pred_probs)
        roc_auc = auc(fpr, tpr)

        epoch_time = time.time() - self.epoch_start_time

        results.loc[len(results)] = [
            'Independent', 'Final', epoch+1, TP, FP, TN, FN,
            Sens, Spec, Acc, MCC, F1, roc_auc,
            None, epoch_time, self.model.count_params()
        ]

# === DataFrame Column ===
results_columns = [
    'Stage', 'Fold', 'Epoch', 'TP', 'FP', 'TN', 'FN',
    'Sens', 'Spec', 'Acc', 'MCC', 'F1', 'AUC',
    'Train_Time', 'Epoch_Time', 'Total_Params'
]
results = pd.DataFrame(columns=results_columns)

# === 5-Fold Cross-Validation + ROC per fold ===
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(X_train, y_train), start=1):
    print(f"\n=== Fold {fold}/5 ===")
    X_train_fold, X_val = X_train[train_idx], X_train[val_idx]
    y_train_fold, y_val = y_train[train_idx], y_train[val_idx]

    model = DeepScan(window_sizes=WINDOW_SIZES,
                     num_filters=NUM_FILTERS,
                     num_hidden=NUM_HIDDEN)
    model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
    model.summary()
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
    )

    metrics_cb = MetricsCallback(X_val, y_val, fold)
    tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, f'fold_{fold}'))
    early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
    save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='CV', fold=fold)
    checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_CV_{fold}.keras')
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')

    start_fold = time.time()
    model.fit(
        X_train_fold, y_train_fold,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=(X_val, y_val),
        callbacks=[metrics_cb, tb_cb, early_stop_cb, save_cb, checkpoint_cb],
        verbose=1,
    )
    train_time = time.time() - start_fold
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'Train_Time'] = train_time

    y_val_probs = model.predict(X_val, batch_size=BATCH_SIZE, verbose=0).ravel()
    fpr, tpr, _ = roc_curve(y_val, y_val_probs)
    roc_auc = auc(fpr, tpr)
    results.loc[(results.Stage=='CV') & (results.Fold==fold), 'AUC'] = roc_auc

# Save CV results
results_path = os.path.join(
    LOG_DIR,
    f'training_results_{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H.csv'
)
results.to_csv(results_path, index=False)
print(f"\nCV results saved to {results_path}")

# === Final training and evaluation ===
final_cb = FinalMetricsCallback(X_test, y_test)

final_model = DeepScan(window_sizes=WINDOW_SIZES,
                      num_filters=NUM_FILTERS,
                      num_hidden=NUM_HIDDEN)
final_model.build((None, 1, MAX_SEQ_LENGTH, EMBEDDING_WIDTH))
final_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

save_cb = SaveEveryEpochCallback(base_dir=LOG_MODEL, stage='Independent')
final_checkpoint_path = os.path.join(LOG_MODEL, f'best_model__{WINDOW_SIZES}_{NUM_FILTERS}F_{NUM_HIDDEN}H_Independent.keras')
final_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(final_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max')
final_early_stop_cb = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=5, restore_best_weights=True)
final_tb_cb = tf.keras.callbacks.TensorBoard(log_dir=os.path.join(LOG_DIR, 'final_model'))

history = final_model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[final_cb, final_tb_cb, final_checkpoint_cb, final_early_stop_cb, save_cb],
    verbose=1
)

# Save final results
results.to_csv(results_path, index=False)
print(f"\nFinal results saved to {results_path}")

best_auc_idx = results[results['Stage'] == 'Independent']['AUC'].idxmax()
best_auc_row = results.loc[best_auc_idx]
print(f"\nBest Independent Test AUC: {best_auc_row['AUC']:.4f} (Epoch {best_auc_row['Epoch']})")



=== Fold 1/5 ===


Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 201ms/step - accuracy: 0.5302 - auc: 0.5280 - loss: 0.7644[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_1/CV_fold1_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 537ms/step - accuracy: 0.5323 - auc: 0.5317 - loss: 0.7624 - val_accuracy: 0.6412 - val_auc: 0.8978 - val_loss: 0.6619
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 200ms/step - accuracy: 0.6420 - auc: 0.6886 - loss: 0.6793[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_1/CV_fold1_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 379ms/step - accuracy: 0.6433 - auc: 0.6912 - loss: 0.6778 - val_accuracy: 0.7320 - val_auc: 0.9435 - val_loss: 0.6013
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 197ms/step - accuracy: 0.67

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 201ms/step - accuracy: 0.5175 - auc: 0.5215 - loss: 0.7685[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_2/CV_fold2_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 533ms/step - accuracy: 0.5225 - auc: 0.5276 - loss: 0.7647 - val_accuracy: 0.6467 - val_auc: 0.8843 - val_loss: 0.6579
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 204ms/step - accuracy: 0.6372 - auc: 0.6778 - loss: 0.6822[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_2/CV_fold2_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 383ms/step - accuracy: 0.6386 - auc: 0.6805 - loss: 0.6807 - val_accuracy: 0.8554 - val_auc: 0.9333 - val_loss: 0.5953
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 196ms/step - accuracy: 0.68

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 197ms/step - accuracy: 0.5097 - auc: 0.5276 - loss: 0.7572[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_3/CV_fold3_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 532ms/step - accuracy: 0.5158 - auc: 0.5354 - loss: 0.7534 - val_accuracy: 0.6653 - val_auc: 0.8316 - val_loss: 0.6483
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 196ms/step - accuracy: 0.6633 - auc: 0.7165 - loss: 0.6644[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_3/CV_fold3_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 376ms/step - accuracy: 0.6655 - auc: 0.7191 - loss: 0.6625 - val_accuracy: 0.7831 - val_auc: 0.8723 - val_loss: 0.5861
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 201ms/step - accuracy: 0.71

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 197ms/step - accuracy: 0.5316 - auc: 0.5315 - loss: 0.7633[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_4/CV_fold4_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 526ms/step - accuracy: 0.5337 - auc: 0.5345 - loss: 0.7614 - val_accuracy: 0.7273 - val_auc: 0.8729 - val_loss: 0.6582
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 202ms/step - accuracy: 0.6594 - auc: 0.7165 - loss: 0.6637[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_4/CV_fold4_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 382ms/step - accuracy: 0.6587 - auc: 0.7169 - loss: 0.6631 - val_accuracy: 0.7459 - val_auc: 0.9338 - val_loss: 0.5892
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 194ms/step - accuracy: 0.68

Epoch 1/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 201ms/step - accuracy: 0.5147 - auc: 0.5181 - loss: 0.7788[Saved model] Epoch 1 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_5/CV_fold5_epoch01.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 533ms/step - accuracy: 0.5192 - auc: 0.5242 - loss: 0.7745 - val_accuracy: 0.7562 - val_auc: 0.8623 - val_loss: 0.6632
Epoch 2/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 200ms/step - accuracy: 0.6569 - auc: 0.7131 - loss: 0.6669[Saved model] Epoch 2 saved to /content/drive/MyDrive/Clathrin-msCNN/log/esm2/MODELS_[8]_256F_1024H/CV/fold_5/CV_fold5_epoch02.keras
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 379ms/step - accuracy: 0.6593 - auc: 0.7159 - loss: 0.6652 - val_accuracy: 0.8140 - val_auc: 0.8962 - val_loss: 0.6076
Epoch 3/50
[1m15/16[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 206ms/step - accuracy: 0.70