# 🌽 Entrenamiento MobileNetV3 - Corn Diseases Detection

**Arquitectura 10/10** - Optimizada para >85% accuracy y >80% recall

---

## 📋 Pasos:
1. ✅ Setup y Verificación
2. ✅ Configuración y Modelo
3. ✅ Entrenamiento Inicial (40 épocas)
4. ✅ Fine-tuning (20 épocas)
5. ✅ Evaluación y Guardado

## 🔧 BLOQUE 1: Setup y Verificación

In [None]:
# 1.1 Montar Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 1.2 Clonar repositorio
!git clone -b main https://github.com/ojgonzalezz/corn-diseases-detection.git
%cd corn-diseases-detection/entrenamiento_modelos

# 1.3 Instalar dependencias
!pip install -q -r requirements.txt

# 1.4 Crear directorios necesarios en Drive
!mkdir -p /content/drive/MyDrive/corn-diseases-detection/models
!mkdir -p /content/drive/MyDrive/corn-diseases-detection/logs
!mkdir -p /content/drive/MyDrive/corn-diseases-detection/mlruns

print("\n✅ Setup completado!")

## 🏗️ BLOQUE 2: Configuración y Creación del Modelo

In [None]:
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV3Large
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from sklearn.utils.class_weight import compute_class_weight

# Importar configuración
from config import *
from utils import *

# Configurar GPU
setup_gpu(GPU_MEMORY_LIMIT)

# Crear generadores de datos
print("\nCreando generadores de datos...")
train_gen, val_gen, test_gen = create_data_generators(
    DATA_DIR, IMAGE_SIZE, BATCH_SIZE, TRAIN_SPLIT, VAL_SPLIT, TEST_SPLIT, RANDOM_SEED, DATA_AUGMENTATION
)

print(f"\n📊 Dataset:")
print(f"  Training:   {train_gen.samples} imágenes")
print(f"  Validation: {val_gen.samples} imágenes")
print(f"  Test:       {test_gen.samples} imágenes")

# Calcular class weights
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_gen.classes),
    y=train_gen.classes
)
class_weight_dict = dict(enumerate(class_weights))
print(f"\n⚖️ Class weights: {class_weight_dict}")

print("\n✅ Configuración completada!")

In [None]:
# Crear modelo con arquitectura 10/10
def create_mobilenetv3_model(num_classes, image_size, learning_rate):
    """Crear modelo MobileNetV3-Large con arquitectura 10/10"""
    
    # Cargar base preentrenada
    base_model = MobileNetV3Large(
        input_shape=(*image_size, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Congelar capas base inicialmente
    base_model.trainable = False
    
    # Arquitectura 10/10: Dense(256) → Dense(128)
    inputs = tf.keras.Input(shape=(*image_size, 3))
    x = base_model(inputs, training=False)
    x = GlobalAveragePooling2D()(x)
    
    x = Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = Dropout(0.35)(x)
    
    x = Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = Dropout(0.3)(x)
    
    outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs, outputs)
    
    # Compilar
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Crear modelo
print("\n🏗️ Creando modelo MobileNetV3-Large...")
model = create_mobilenetv3_model(NUM_CLASSES, IMAGE_SIZE, LEARNING_RATE)
print(f"📐 Total parámetros: {model.count_params():,}")
print("\n✅ Modelo creado!")

## 🚀 BLOQUE 3: Entrenamiento Inicial (40 épocas)

In [None]:
# Callbacks para entrenamiento inicial
callbacks = [
    EarlyStopping(
        monitor='val_accuracy',
        patience=EARLY_STOPPING_PATIENCE,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=REDUCE_LR_PATIENCE,
        min_lr=1e-7,
        verbose=1
    ),
    ModelCheckpoint(
        str(MODELS_DIR / 'mobilenetv3_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

print("\n🚀 Iniciando entrenamiento inicial...\n")
start_time = time.time()

history = model.fit(
    train_gen,
    epochs=EPOCHS,
    validation_data=val_gen,
    callbacks=callbacks,
    class_weight=class_weight_dict,
    verbose=1
)

training_time = time.time() - start_time
print(f"\n✅ Entrenamiento completado en {training_time/60:.2f} minutos")
print(f"📊 Mejor Val Accuracy: {max(history.history['val_accuracy']):.4f}")

## 🎯 BLOQUE 4: Fine-tuning (20 épocas)

In [None]:
print("\n🎯 Iniciando fine-tuning...\n")

# Descongelar solo las últimas 20 capas
base_model = model.layers[1]
base_model.trainable = True

for layer in base_model.layers[:-20]:
    layer.trainable = False

trainable_layers = sum([1 for layer in base_model.layers if layer.trainable])
print(f"🔓 Capas descongeladas: {trainable_layers} de {len(base_model.layers)}\n")

# Recompilar con LR más bajo
model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE * 0.05),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Callbacks para fine-tuning
finetune_callbacks = [
    EarlyStopping(
        monitor='val_accuracy',
        patience=8,
        restore_best_weights=True,
        verbose=1,
        mode='max'
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=4,
        min_lr=1e-7,
        verbose=1,
        mode='min'
    ),
    ModelCheckpoint(
        str(MODELS_DIR / 'mobilenetv3_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1,
        mode='max'
    )
]

# Fine-tuning
history_finetune = model.fit(
    train_gen,
    epochs=20,
    validation_data=val_gen,
    callbacks=finetune_callbacks,
    class_weight=class_weight_dict,
    verbose=1
)

# Combinar historiales
for key in history.history:
    history.history[key].extend(history_finetune.history[key])

finetune_time = time.time() - start_time - training_time
total_time = time.time() - start_time

print(f"\n✅ Fine-tuning completado en {finetune_time/60:.2f} minutos")
print(f"⏱️ Tiempo total: {total_time/60:.2f} minutos")

## 📊 BLOQUE 5: Evaluación y Guardado de Resultados

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import json
from datetime import datetime

print("\n📊 Evaluando modelo en test set...\n")

# Evaluar modelo
evaluation_results = evaluate_model(model, test_gen, CLASSES)

print(f"\n✅ RESULTADOS FINALES:")
print(f"  Test Accuracy: {evaluation_results['test_accuracy']:.4f} ({evaluation_results['test_accuracy']*100:.2f}%)")
print(f"  Test Loss:     {evaluation_results['test_loss']:.4f}")

# Mostrar métricas por clase
print(f"\n📋 Métricas por Clase:")
for class_name in CLASSES:
    metrics = evaluation_results['classification_report'][class_name]
    print(f"\n  {class_name}:")
    print(f"    Precision: {metrics['precision']:.4f}")
    print(f"    Recall:    {metrics['recall']:.4f}")
    print(f"    F1-Score:  {metrics['f1-score']:.4f}")

# Guardar gráficos
plot_path = LOGS_DIR / 'mobilenetv3_training_history.png'
plot_training_history(history, plot_path)
print(f"\n💾 Gráfico guardado: {plot_path}")

# Matriz de confusión
cm_path = LOGS_DIR / 'mobilenetv3_confusion_matrix.png'
cm = plot_confusion_matrix(
    evaluation_results['y_true'],
    evaluation_results['y_pred'],
    CLASSES,
    cm_path
)
print(f"💾 Matriz de confusión guardada: {cm_path}")

# Guardar modelo final
model_path = MODELS_DIR / 'mobilenetv3_final.keras'
model.save(str(model_path))
print(f"💾 Modelo final guardado: {model_path}")

# Guardar log JSON
hyperparameters = {
    'model_name': 'MobileNetV3-Large',
    'image_size': IMAGE_SIZE,
    'batch_size': BATCH_SIZE,
    'epochs': EPOCHS,
    'learning_rate': LEARNING_RATE,
    'architecture': 'Dense(256)->Dense(128) [10/10]'
}

log_path = LOGS_DIR / 'mobilenetv3_training_log.json'
save_training_log(
    log_path,
    'MobileNetV3-Large',
    hyperparameters,
    history,
    evaluation_results,
    cm,
    total_time
)
print(f"💾 Log guardado: {log_path}")

print("\n🎉 ¡ENTRENAMIENTO COMPLETADO EXITOSAMENTE!")