In [None]:
import os
import gc
import tensorflow as tf

# Import the refactored components
from src.config import Config, setup_environment
from src.data_handler import DatasetProvider
from src.trainer import ModelTrainer

# Initialize configuration and environment
CONFIG = Config()
setup_environment(CONFIG)

In [None]:
# Main loop to train a specialist model for each method
for method in CONFIG.METHODS_TO_TRAIN:
    
    # 1. Prepare Datasets for the current method
    data_provider = DatasetProvider(config=CONFIG, method=method)
    train_ds, val_ds, test_ds = data_provider.train_ds, data_provider.val_ds, data_provider.test_ds

    if not all([train_ds, val_ds]):
        print(f"Skipping {method} due to data loading issues.")
        continue

    # 2. Build and Train the Model
    tf.keras.backend.clear_session()
    trainer = ModelTrainer(config=CONFIG, method_name=method)
    
    # Stage 1: Train the head
    trainer.train_head(train_ds, val_ds)
    
    # Stage 2: Fine-tune the model
    trainer.fine_tune(train_ds, val_ds)
    
    # 3. Evaluate the final model
    trainer.evaluate(test_ds)
    
    # 4. Clean up memory before the next loop
    del data_provider, trainer, train_ds, val_ds, test_ds
    gc.collect()
    print(f"\n--- COMPLETED TRAINING AND EVALUATION FOR {method} ---")

print("\n--- ALL SPECIALIST MODELS HAVE BEEN TRAINED. ---")