In [1]:
# cnn_lstm_models.py

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout,
    BatchNormalization, TimeDistributed, LSTM, Bidirectional,
    GlobalAveragePooling1D, LayerNormalization, MultiHeadAttention, Concatenate
)

def build_base_cnn(input_shape):
    cnn_input = Input(shape=input_shape)
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(cnn_input)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((2, 2))(x)
    x = Flatten()(x)
    return Model(cnn_input, x, name="base_cnn")

def build_model(model_type="A2"):
    sequence_input = Input(shape=(10, 32, 32, 10))
    cnn_model = build_base_cnn((32, 32, 10))
    x = TimeDistributed(cnn_model)(sequence_input)  # (batch, T, F)

    if model_type == "A1":  # CNN + LSTM
        x = LSTM(64)(x)

    elif model_type == "A2":  # CNN + BiLSTM
        x = Bidirectional(LSTM(64))(x)

    elif model_type == "A3":  # CNN + BiLSTM + dot attention
        x = Bidirectional(LSTM(64, return_sequences=True))(x)
        attention = tf.keras.layers.Attention()([x, x])
        x = tf.reduce_mean(attention, axis=1)

    elif model_type == "A4":  # CNN + BiLSTM + MultiHeadAttention
        x = Bidirectional(LSTM(64, return_sequences=True))(x)
        x = MultiHeadAttention(num_heads=4, key_dim=32)(x, x)
        x = GlobalAveragePooling1D()(x)

    elif model_type == "A5":  # CNN + Transformer Encoder
        x = LayerNormalization()(x)
        x = MultiHeadAttention(num_heads=4, key_dim=32)(x, x)
        x = tf.keras.layers.Add()([x, cnn_model.output])
        x = GlobalAveragePooling1D()(x)

    elif model_type == "A6":  # CNN only (avg pooling over T)
        x = tf.keras.layers.GlobalAveragePooling1D()(x)

    else:
        raise ValueError(f"Unknown model type: {model_type}")

    x = Dropout(0.5)(x)
    x = Dense(128, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)

    model = Model(sequence_input, output, name=f"model_{model_type}")
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC()])
    return model


In [None]:
# train_models.py

import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
# from cnn_lstm_models import build_model

# === Load sequence data ===
def load_sequence_data(patches_dir="data/sequences"):
    X, Y = [], []
    for region in os.listdir(patches_dir):
        region_dir = os.path.join(patches_dir, region)
        for file in os.listdir(region_dir):
            if file.endswith(".npy") and not file.endswith("_label.npy"):
                label_file = file.replace(".npy", "_label.npy")
                path_seq = os.path.join(region_dir, file)
                path_label = os.path.join(region_dir, label_file)

                if not os.path.exists(path_label):
                    continue
                try:
                    seq = np.load(path_seq)
                    label = np.load(path_label)
                    if seq.shape == (10, 32, 32, 10) and not np.isnan(label):
                        X.append(seq)
                        Y.append(label)
                except:
                    continue

    X, Y = np.array(X, dtype=np.float32), np.array(Y, dtype=np.float32)
    print(f"✅ Loaded {len(X)} sequences, shape: {X.shape}")
    return X, Y

# === Train a model variant ===
def train_model(model_type, output_dir="models"):
    os.makedirs(output_dir, exist_ok=True)
    X, Y = load_sequence_data()
    X, Y = shuffle(X, Y, random_state=42)
    X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.2, random_state=42)

    model = build_model(model_type)

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
        ModelCheckpoint(
            os.path.join(output_dir, f"best_model_{model_type}.h5"),
            monitor='val_loss', save_best_only=True
        )
    ]

    # Optional: Adjust class weights if needed
    class_weights = {0: 1.0, 1: 5.0}  # You can modify based on label balance

    history = model.fit(
        X_train, Y_train,
        validation_data=(X_val, Y_val),
        batch_size=16,
        epochs=30,
        class_weight=class_weights,
        callbacks=callbacks,
        verbose=1
    )

    # Plot metrics
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title(f'Loss - {model_type}')
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(history.history['accuracy'], label='Train Acc')
    plt.plot(history.history['val_accuracy'], label='Val Acc')
    plt.title('Accuracy')
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(history.history['auc'], label='Train AUC')
    plt.plot(history.history['val_auc'], label='Val AUC')
    plt.title('AUC')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"training_plot_{model_type}.png"))
    plt.close()

    print(f"✅ Training complete for {model_type}. Model and plot saved.")

if __name__ == "__main__":
    for model_id in ["A1", "A2", "A4", "A6"]:
        print(f"\n===== Training model: {model_id} =====")
        train_model(model_id)



===== Training model: A1 =====
