In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import MobileNetV2, ResNet50, VGG16
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import json
from datetime import datetime
from typing import Dict, List, Tuple, Optional

class ModelTrainer:
    """Handles model creation, training, and management"""
    
    def __init__(self, config: Config):
        self.config = config
        self.model = None
        self.class_names = []
        self.training_history = {}
        
    def create_model(self, num_classes: int, model_type: str = 'mobilenet') -> tf.keras.Model:
        """
        Create a CNN model for image classification
        
        Args:
            num_classes: Number of classes to classify
            model_type: Type of model ('mobilenet', 'resnet', 'vgg', 'custom')
            
        Returns:
            Compiled Keras model
        """
        input_shape = (*self.config.IMAGE_SIZE, 3)
        
        if model_type == 'mobilenet':
            base_model = MobileNetV2(
                weights='imagenet',
                include_top=False,
                input_shape=input_shape
            )
        elif model_type == 'resnet':
            base_model = ResNet50(
                weights='imagenet',
                include_top=False,
                input_shape=input_shape
            )
        elif model_type == 'vgg':
            base_model = VGG16(
                weights='imagenet',
                include_top=False,
                input_shape=input_shape
            )
        else:  # custom model
            return self._create_custom_model(input_shape, num_classes)
        
        # Freeze base model initially
        base_model.trainable = False
        
        # Add custom top layers
        model = models.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dropout(0.2),
            layers.Dense(128, activation='relu'),
            layers.Dropout(0.2),
            layers.Dense(num_classes, activation='softmax')
        ])
        
        # Compile model
        model.compile(
            optimizer=optimizers.Adam(learning_rate=self.config.LEARNING_RATE),
            loss='categorical_crossentropy',
            metrics=['accuracy', 'top_k_categorical_accuracy']
        )
        
        self.model = model
        return model
    
    def _create_custom_model(self, input_shape: Tuple, num_classes: int) -> tf.keras.Model:
        """Create a custom CNN model from scratch"""
        model = models.Sequential([
            layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(64, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(128, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(128, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Flatten(),
            layers.Dropout(0.5),
            layers.Dense(512, activation='relu'),
            layers.Dropout(0.5),
            layers.Dense(num_classes, activation='softmax')
        ])
        
        model.compile(
            optimizer=optimizers.Adam(learning_rate=self.config.LEARNING_RATE),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        
        return model
    
    def train_model(self, train_data, validation_data, class_names: List[str]) -> dict:
        """
        Train the model with given data
        
        Args:
            train_data: Training data generator
            validation_data: Validation data generator
            class_names: List of class names
            
        Returns:
            Training history
        """
        self.class_names = class_names
        
        # Callbacks
        callbacks = [
            EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True
            ),
            ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.2,
                patience=3,
                min_lr=0.001
            ),
            ModelCheckpoint(
                filepath=str(self.config.MODEL_FOLDER / f'{self.config.MODEL_NAME}_best.h5'),
                monitor='val_accuracy',
                save_best_only=True
            )
        ]
        
        # Train model
        history = self.model.fit(
            train_data,
            epochs=self.config.EPOCHS,
            validation_data=validation_data,
            callbacks=callbacks,
            verbose=1
        )
        
        self.training_history = history.history
        return history.history
    
    def save_model(self, model_path: Optional[str] = None):
        """Save the trained model and metadata"""
        if model_path is None:
            model_path = self.config.MODEL_FOLDER / f'{self.config.MODEL_NAME}.h5'
        
        # Save model
        self.model.save(model_path)
        
        # Save metadata
        metadata = {
            'class_names': self.class_names,
            'image_size': self.config.IMAGE_SIZE,
            'model_name': self.config.MODEL_NAME,
            'training_date': datetime.now().isoformat(),
            'training_history': self.training_history
        }
        
        metadata_path = str(model_path).replace('.h5', '_metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        logger.info(f"Model saved to {model_path}")
    
    def load_model(self, model_path: str):
        """Load a trained model and its metadata"""
        self.model = tf.keras.models.load_model(model_path)
        
        # Load metadata
        metadata_path = model_path.replace('.h5', '_metadata.json')
        try:
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            self.class_names = metadata.get('class_names', [])
            self.training_history = metadata.get('training_history', {})
        except FileNotFoundError:
            logger.warning(f"Metadata file not found: {metadata_path}")
    
    def predict(self, image_array: np.ndarray) -> List[Dict]:
        """
        Make predictions on preprocessed image
        
        Args:
            image_array: Preprocessed image array
            
        Returns:
            List of predictions with confidence scores
        """
        if self.model is None:
            raise ValueError("Model not loaded")
        
        # Make prediction
        predictions = self.model.predict(image_array, verbose=0)
        
        # Get top predictions
        top_indices = np.argsort(predictions[0])[::-1][:self.config.TOP_K_PREDICTIONS]
        
        results = []
        for idx in top_indices:
            confidence = float(predictions[0][idx])
            if confidence >= self.config.CONFIDENCE_THRESHOLD:
                class_name = self.class_names[idx] if idx < len(self.class_names) else f"Class_{idx}"
                results.append({
                    'class': class_name,
                    'confidence': confidence,
                    'percentage': confidence * 100
                })
        
        return results