In [None]:
# Premonitor: train_models.py
# This script is responsible for training the core AI models for the Premonitor project.
# It should be run on a powerful development PC with a GPU.

import os
import argparse
import numpy as np
import tensorflow as tf
from tensorflow.keras import optimizers, losses, models, layers, callbacks

# Import our custom project files
import config
import model_blueprints
import utils # This will be created next

# --- Custom Loss Function for SimSiam Pre-training ---
def sim_siam_loss(p, z):
    """Calculates the negative cosine similarity loss for the SimSiam model."""
    z = tf.stop_gradient(z) # Crucial step to prevent model collapse
    p = tf.math.l2_normalize(p, axis=1)
    z = tf.math.l2_normalize(z, axis=1)
    return -tf.reduce_mean(tf.reduce_sum((p * z), axis=1))

# --- Main Training Functions ---

def train_thermal_model(epochs=50, batch_size=32):
    """
    Orchestrates the full, two-stage training process for the thermal model.
    Stage 1: Self-supervised pre-training on unlabeled data.
    Stage 2: Supervised fine-tuning on labeled data.
    """
    print("--- Starting Full Thermal Model Training Pipeline ---")

    # --- STAGE 1: SELF-SUPERVISED PRE-TRAINING ---
    print("\n--- STAGE 1: Self-Supervised Pre-training ---")

    # 1. Load Data for Pre-training (unlabeled images)
    # This function in utils.py should load paths from AAU VAP and FLIR datasets.
    unlabeled_image_paths = utils.load_all_thermal_image_paths()
    train_dataset = utils.create_thermal_dataset_generator(unlabeled_image_paths, batch_size)

    # 2. Get Model Blueprints
    siamese_model, encoder, _ = model_blueprints.get_thermal_anomaly_model()
    siamese_model.compile(optimizer=optimizers.Adam(0.001))

    # 3. Custom Training Loop for SimSiam
    print(f"Starting SimSiam pre-training for {epochs} epochs...")
    for epoch in range(epochs):
        total_loss = 0
        for step, (view1, view2) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                p_a, z_a, p_b, z_b = siamese_model([view1, view2], training=True)
                loss = (sim_siam_loss(p_a, z_b) + sim_siam_loss(p_b, z_a)) / 2

            gradients = tape.gradient(loss, siamese_model.trainable_variables)
            siamese_model.optimizer.apply_gradients(zip(gradients, siamese_model.trainable_variables))
            total_loss += loss

        avg_loss = total_loss / (step + 1)
        print(f"Epoch {epoch+1}/{epochs}, Pre-training Loss: {avg_loss:.4f}")

    # 4. Save the Pre-trained Encoder
    # This encoder now has a powerful understanding of thermal images.
    encoder_save_path = os.path.join(config.MODEL_DIR, "thermal_encoder_pretrained.h5")
    if not os.path.exists(config.MODEL_DIR): os.makedirs(config.MODEL_DIR)
    encoder.save(encoder_save_path)
    print(f"Pre-trained encoder saved to {encoder_save_path}")

    # --- STAGE 2: SUPERVISED FINE-TUNING ---
    print("\n--- STAGE 2: Supervised Fine-tuning ---")

    # 1. Load Labeled Data for Fine-tuning
    # This function in utils.py should load your small, custom-labeled dataset.
    # It should return a dataset of (image, label) pairs.
    labeled_train_ds, labeled_val_ds = utils.load_labeled_thermal_data(batch_size)

    # 2. Build the Classifier Model
    # Load the pre-trained encoder and freeze its backbone layers.
    pretrained_encoder = models.load_model(encoder_save_path)
    pretrained_encoder.trainable = False # Start with the backbone frozen

    # Add a new classification head
    classifier_input = layers.Input(shape=config.THERMAL_MODEL_INPUT_SHAPE)
    x = pretrained_encoder(classifier_input, training=False)
    # Add a dropout layer for regularization to prevent overfitting
    x = layers.Dropout(0.5)(x)
    # The final output layer for binary classification (normal vs anomaly)
    classifier_output = layers.Dense(1, activation='sigmoid')(x)
    classifier_model = models.Model(classifier_input, classifier_output, name="thermal_classifier")

    # 3. Compile and Train the Classifier
    classifier_model.compile(
        optimizer=optimizers.Adam(learning_rate=0.0001), # Use a smaller learning rate for fine-tuning
        loss=losses.BinaryCrossentropy(),
        metrics=['accuracy']
    )

    # 4. Add Checkpointing to Save the Best Model
    # This saves the model only when validation accuracy improves.
    checkpoint_path = os.path.join(config.MODEL_DIR, "thermal_classifier_best.h5")
    model_checkpoint_callback = callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        save_weights_only=False,
        monitor='val_accuracy',
        mode='max',
        save_best_only=True)

    print(f"Starting classifier fine-tuning for {epochs // 2} epochs...")
    classifier_model.fit(
        labeled_train_ds,
        epochs=epochs // 2, # Fine-tuning usually requires fewer epochs
        validation_data=labeled_val_ds,
        callbacks=[model_checkpoint_callback]
    )

    print(f"Fine-tuning complete. Best model saved to {checkpoint_path}")
    # This 'thermal_classifier_best.h5' is the final model you will convert to .tflite

def train_acoustic_model(epochs=30, batch_size=64):
    """Orchestrates the training for the acoustic anomaly model."""
    print("--- Starting Acoustic Model Training ---")

    # 1. Load Data (from MIMII, Urbansound8K, etc.)
    # This function in utils.py will handle loading audio, creating spectrograms,
    # and making the pseudo-anomaly pairs.
    spectrograms, labels = utils.create_acoustic_dataset(data_dir='path/to/your/mimii/0_dB_fan/')

    # 2. Get Model Blueprint
    acoustic_model = model_blueprints.get_acoustic_anomaly_model()

    # 3. Compile the Model
    acoustic_model.compile(
        optimizer=optimizers.Adam(learning_rate=0.001),
        loss=losses.BinaryCrossentropy(),
        metrics=['accuracy']
    )

    # 4. Add Checkpointing
    checkpoint_path = os.path.join(config.MODEL_DIR, "acoustic_anomaly_model_best.h5")
    model_checkpoint_callback = callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        save_weights_only=False,
        monitor='val_accuracy',
        mode='max',
        save_best_only=True)

    # 5. Training
    print(f"Starting acoustic model training for {epochs} epochs...")
    acoustic_model.fit(
        spectrograms,
        labels,
        epochs=epochs,
        batch_size=batch_size,
        validation_split=0.2,
        callbacks=[model_checkpoint_callback]
    )
    print(f"Training complete. Best model saved to {checkpoint_path}")

# --- Command-Line Interface ---
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Premonitor AI Model Training Script")
    parser.add_argument(
        "--model", type=str, required=True, choices=["thermal", "acoustic"],
        help="The type of model to train ('thermal' or 'acoustic')."
    )
    args = parser.parse_args()

    # Create model directory if it doesn't exist
    if not os.path.exists(config.MODEL_DIR):
        os.makedirs(config.MODEL_DIR)

    if args.model == "thermal":
        # NOTE: For this to run, you must implement the data loading functions
        # in utils.py: load_all_thermal_image_paths() and load_labeled_thermal_data()
        print("Running in placeholder mode for thermal training.")
        # train_thermal_model()
    elif args.model == "acoustic":
        # NOTE: For this to run, you must implement create_acoustic_dataset() in utils.py
        print("Running in placeholder mode for acoustic training.")
        # train_acoustic_model()
    else:
        print("Invalid model type specified.")