In [None]:
# train_model_a2.py

import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from cnn_lstm_models import build_model

def load_sequence_data(patches_dir="data/sequences"):
    X, Y = [], []
    for region in os.listdir(patches_dir):
        region_dir = os.path.join(patches_dir, region)
        for file in os.listdir(region_dir):
            if file.endswith(".npy") and not file.endswith("_label.npy"):
                label_file = file.replace(".npy", "_label.npy")
                path_seq = os.path.join(region_dir, file)
                path_label = os.path.join(region_dir, label_file)

                if not os.path.exists(path_label):
                    continue
                try:
                    seq = np.load(path_seq)
                    label = np.load(path_label)
                    if seq.shape == (10, 32, 32, 10) and not np.isnan(label):
                        X.append(seq)
                        Y.append(label)
                except:
                    continue

    X = np.array(X, dtype=np.float32)
    Y = np.array(Y, dtype=np.float32)
    print(f"✅ Loaded {len(X)} sequences, shape: {X.shape}")
    return X, Y

def train_model_a2():
    model_type = "A2"
    output_dir = "models"
    os.makedirs(output_dir, exist_ok=True)

    X, Y = load_sequence_data()
    X, Y = shuffle(X, Y, random_state=42)
    X_train, X_val, Y_train, Y_val = train_test_split(X, Y, test_size=0.2, random_state=42)

    model = build_model(model_type)

    callbacks = [
        EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
        ModelCheckpoint(
            os.path.join(output_dir, f"best_model_{model_type}.h5"),
            monitor='val_loss', save_best_only=True
        )
    ]

    class_weights = {0: 1.0, 1: 5.0}  # Adjust based on imbalance

    history = model.fit(
        X_train, Y_train,
        validation_data=(X_val, Y_val),
        batch_size=16,
        epochs=30,
        class_weight=class_weights,
        callbacks=callbacks,
        verbose=1
    )

    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Loss')
    plt.legend()

    plt.subplot(1, 3, 2)
    plt.plot(history.history['accuracy'], label='Train Acc')
    plt.plot(history.history['val_accuracy'], label='Val Acc')
    plt.title('Accuracy')
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(history.history['auc'], label='Train AUC')
    plt.plot(history.history['val_auc'], label='Val AUC')
    plt.title('AUC')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"training_plot_{model_type}.png"))
    plt.close()

    print(f"✅ Model A2 trained and saved in {output_dir}")

if __name__ == "__main__":
    train_model_a2()
