In [None]:
# main_train.py

import os
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import f1_score
import numpy as np

# Import configurations and helper functions
from config import (
    FILE_PATH, TEXT_COLUMN, LABEL_COLUMNS, VALIDATION_SPLIT, MODEL_TYPE,
    MAX_WORDS, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM, GRU_UNITS, DENSE_UNITS,
    DROPOUT_RATE, L2_REG_FACTOR, LEARNING_RATE,
    BERT_MAX_LEN, XLSTM_UNITS, BERT_DENSE_UNITS, BERT_DROPOUT_RATE, BERT_L2_REG_FACTOR,
    BATCH_SIZE, EPOCHS,
    EARLY_STOPPING_PATIENCE, REDUCE_LR_FACTOR, REDUCE_LR_PATIENCE, MIN_LR,
    SAVE_DIR, MODEL_SAVE_PATH, TOKENIZER_SAVE_PATH
)
from data_handler import load_and_preprocess_data, tokenize_and_pad_sequences, get_bert_embeddings, split_data, save_tokenizer
from model_architectures import build_bigru_attention_model, build_bert_xlstm_model
from utils import plot_training_history

def main():
    """
    Main function to load data, build, train, and evaluate the selected deep learning model.
    """
    # Create save directory if it doesn't exist
    if not os.path.exists(SAVE_DIR):
        os.makedirs(SAVE_DIR)
        print(f"Created directory: {SAVE_DIR}")

    # --- 1. Data Loading and Preprocessing ---
    try:
        df = load_and_preprocess_data(FILE_PATH, TEXT_COLUMN, LABEL_COLUMNS)
    except (FileNotFoundError, ValueError) as e:
        print(f"Critical error during data loading: {e}")
        return

    # Prepare labels (convert DataFrame to NumPy array)
    y = df[LABEL_COLUMNS].values.astype(np.float32)
    print(f"Labels shape: {y.shape}")
    print(f"Number of labels: {len(LABEL_COLUMNS)}")

    X = None # Initialize X
    tokenizer_for_save = None # Initialize tokenizer for saving

    if MODEL_TYPE == 'BiGRU_Attention':
        # --- 2. Tokenization and Padding for BiGRU ---
        X, tokenizer_for_save = tokenize_and_pad_sequences(
            df[TEXT_COLUMN], MAX_WORDS, MAX_SEQUENCE_LENGTH
        )
    elif MODEL_TYPE == 'BERT_XLSTM':
        # --- 2. Generate BERT Embeddings ---
        X = get_bert_embeddings(df[TEXT_COLUMN], max_len=BERT_MAX_LEN, batch_size=BATCH_SIZE)
    else:
        print(f"Error: Unknown MODEL_TYPE '{MODEL_TYPE}'. Please choose 'BiGRU_Attention' or 'BERT_XLSTM'.")
        return

    if X is None:
        print("Error: Features (X) could not be generated. Exiting.")
        return

    # --- 3. Data Splitting ---
    X_train, X_val, y_train, y_val = split_data(X, y, VALIDATION_SPLIT)

    # --- 4. Build the Model ---
    model = None
    if MODEL_TYPE == 'BiGRU_Attention':
        model = build_bigru_attention_model(
            max_words=MAX_WORDS,
            max_sequence_length=MAX_SEQUENCE_LENGTH,
            embedding_dim=EMBEDDING_DIM,
            gru_units=GRU_UNITS,
            dense_units=DENSE_UNITS,
            dropout_rate=DROPOUT_RATE,
            num_labels=len(LABEL_COLUMNS),
            l2_reg_factor=L2_REG_FACTOR
        )
    elif MODEL_TYPE == 'BERT_XLSTM':
        model = build_bert_xlstm_model(
            num_labels=len(LABEL_COLUMNS),
            lstm_units=XLSTM_UNITS,
            dense_units=BERT_DENSE_UNITS,
            dropout_rate=BERT_DROPOUT_RATE,
            l2_reg_factor=BERT_L2_REG_FACTOR
        )

    if model is None:
        print("Error: Model could not be built. Exiting.")
        return

    # --- 5. Compile the Model ---
    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy',
                  metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])
    print("Model Summary:")
    model.summary()

    # --- 6. Train the Model ---
    print("\nTraining the model...")
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=EARLY_STOPPING_PATIENCE,
        restore_best_weights=True
    )
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=REDUCE_LR_FACTOR,
        patience=REDUCE_LR_PATIENCE,
        min_lr=MIN_LR
    )

    history = model.fit(
        X_train, y_train,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(X_val, y_val),
        callbacks=[early_stopping, reduce_lr]
    )
    print("\nModel training complete.")

    # --- 7. Evaluate the Model and Calculate Macro-averaged F1-score ---
    print("\nEvaluating the model on validation data...")
    y_pred_proba = model.predict(X_val)
    y_pred_binary = (y_pred_proba > 0.5).astype(int)

    macro_f1 = f1_score(y_val, y_pred_binary, average='macro')
    print(f"\nMacro-averaged F1-score on validation set: {macro_f1:.4f}")

    print("\nTraining History Metrics:")
    print(f"Final Training Loss: {history.history['loss'][-1]:.4f}")
    print(f"Final Validation Loss: {history.history['val_loss'][-1]:.4f}")
    if 'precision' in history.history:
        print(f"Final Training Precision (threshold 0.5): {history.history['precision'][-1]:.4f}")
    if 'val_precision' in history.history:
        print(f"Final Validation Precision (threshold 0.5): {history.history['val_precision'][-1]:.4f}")
    if 'recall' in history.history:
        print(f"Final Training Recall (threshold 0.5): {history.history['recall'][-1]:.4f}")
    if 'val_recall' in history.history:
        print(f"Final Validation Recall (threshold 0.5): {history.history['val_recall'][-1]:.4f}")

    # --- 8. Save the Model and Tokenizer ---
    print(f"\nSaving the model to {MODEL_SAVE_PATH}...")
    model.save(MODEL_SAVE_PATH)
    print("Model saved successfully.")

    if MODEL_TYPE == 'BiGRU_Attention' and tokenizer_for_save:
        save_tokenizer(tokenizer_for_save, TOKENIZER_SAVE_PATH)

    # --- 9. Plot Training History ---
    plot_training_history(history)


if __name__ == "__main__":
    main()