# =============================================================================
# NOTEBOOK D'ENTRA√éNEMENT - DRONE AGRI AI
# =============================================================================

# üå± Entra√Ænement du Mod√®le de Classification de Plantes
# 
# Ce notebook entra√Æne le mod√®le multi-sorties pour:
# 1. D√©tection plante/non-plante
# 2. Classification esp√®ce et maladie
# 3. Score de sant√©
# 4. Stade de croissance

In [None]:
# Installation des d√©pendances
!pip install -q tensorflow tensorflow-model-optimization albumentations loguru

# Imports
import os
import sys
import json
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime

# V√©rifier GPU
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU disponible: {tf.config.list_physical_devices('GPU')}")

# Monter Google Drive
from google.colab import drive
drive.mount('/content/drive')

# D√©finir les chemins
PROJECT_DIR = Path('/content/drone-agri-ai')
DATA_DIR = Path('/content/data')
MODELS_DIR = PROJECT_DIR / 'models'
MODELS_DIR.mkdir(exist_ok=True)

In [None]:
# T√©l√©charger les datasets

# Configuration Kaggle
!pip install -q kaggle
from google.colab import files

print("Uploadez votre fichier kaggle.json")
# files.upload()  # D√©commenter pour upload interactif

# Ou utiliser les credentials existants
!mkdir -p ~/.kaggle
# !cp /content/drive/MyDrive/kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json

# T√©l√©charger PlantVillage
!kaggle datasets download -d abdallahalidev/plantvillage-dataset -p /content/data
!unzip -q /content/data/plantvillage-dataset.zip -d /content/data/

In [None]:
# Explorer les donn√©es
import glob

data_path = Path('/content/data/plantvillage dataset/color')
classes = sorted([d.name for d in data_path.iterdir() if d.is_dir()])

print(f"Nombre de classes: {len(classes)}")
print("\nExemples de classes:")
for c in classes[:10]:
    count = len(list((data_path / c).glob('*')))
    print(f"  - {c}: {count} images")

In [None]:
# Pr√©parer les donn√©es
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Param√®tres
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

# G√©n√©rateur avec augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest'
)

# G√©n√©rateurs
train_generator = train_datagen.flow_from_directory(
    data_path,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    subset='training',
    shuffle=True
)

val_generator = train_datagen.flow_from_directory(
    data_path,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    subset='validation',
    shuffle=False
)

print(f"\nImages d'entra√Ænement: {train_generator.samples}")
print(f"Images de validation: {val_generator.samples}")
print(f"Nombre de classes: {train_generator.num_classes}")

# Sauvegarder le mapping des classes
class_mapping = {
    "class_names": list(train_generator.class_indices.keys()),
    "class_to_idx": train_generator.class_indices,
    "class_info": {}
}

# Parser les infos des classes
for class_name in class_mapping["class_names"]:
    parts = class_name.split("___")
    if len(parts) == 2:
        class_mapping["class_info"][class_name] = {
            "plant": parts[0].replace("_", " "),
            "condition": parts[1].replace("_", " "),
            "is_healthy": "healthy" in parts[1].lower()
        }

with open(MODELS_DIR / 'class_mapping.json', 'w') as f:
    json.dump(class_mapping, f, indent=2)

print("Mapping sauvegard√©!")

In [None]:
# Construire le mod√®le
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, Model

def build_model(num_classes):
    # Base model
    base_model = EfficientNetB0(
        input_shape=(224, 224, 3),
        include_top=False,
        weights='imagenet'
    )
    
    # Geler 70% des couches
    for layer in base_model.layers[:int(len(base_model.layers) * 0.7)]:
        layer.trainable = False
    
    # Input
    inputs = keras.Input(shape=(224, 224, 3))
    
    # Backbone
    x = base_model(inputs)
    x = layers.GlobalAveragePooling2D()(x)
    
    # Couches partag√©es
    x = layers.Dense(512, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    
    # Sortie 1: Est-ce une plante?
    is_plant = layers.Dense(64, activation='relu')(x)
    is_plant = layers.Dense(1, activation='sigmoid', name='is_plant')(is_plant)
    
    # Sortie 2: Classification
    classification = layers.Dense(128, activation='relu')(x)
    classification = layers.Dense(num_classes, activation='softmax', name='classification')(classification)
    
    # Sortie 3: Score de sant√©
    health = layers.Dense(64, activation='relu')(x)
    health = layers.Dense(1, activation='sigmoid', name='health_score')(health)
    
    # Sortie 4: Stade de croissance
    growth = layers.Dense(64, activation='relu')(x)
    growth = layers.Dense(4, activation='softmax', name='growth_stage')(growth)
    
    model = Model(
        inputs=inputs,
        outputs={
            'is_plant': is_plant,
            'classification': classification,
            'health_score': health,
            'growth_stage': growth
        }
    )
    
    return model

# Cr√©er le mod√®le
model = build_model(train_generator.num_classes)
model.summary()

In [None]:
# Pr√©parer les labels multi-sorties
class MultiOutputGenerator(keras.utils.Sequence):
    """G√©n√©rateur personnalis√© pour multi-sorties"""
    
    def __init__(self, generator, class_mapping):
        self.generator = generator
        self.class_mapping = class_mapping
        self.class_names = list(class_mapping['class_to_idx'].keys())
    
    def __len__(self):
        return len(self.generator)
    
    def __getitem__(self, idx):
        X, y = self.generator[idx]
        batch_size = len(y)
        
        # Pr√©parer les sorties
        outputs = {
            'is_plant': np.ones((batch_size, 1)),  # Toutes sont des plantes
            'classification': y,
            'health_score': np.zeros((batch_size, 1)),
            'growth_stage': np.zeros((batch_size,))
        }
        
        # Calculer health_score bas√© sur la classe
        for i, label in enumerate(y):
            class_name = self.class_names[int(label)]
            class_info = self.class_mapping['class_info'].get(class_name, {})
            
            # Score de sant√©: 1.0 si sain, 0.3 si malade
            outputs['health_score'][i] = 1.0 if class_info.get('is_healthy', True) else 0.3
            
            # Stade de croissance (simul√© pour l'exemple)
            outputs['growth_stage'][i] = np.random.randint(0, 4)
        
        return X, outputs
    
    def on_epoch_end(self):
        self.generator.on_epoch_end()

# Cr√©er les g√©n√©rateurs multi-sorties
train_gen = MultiOutputGenerator(train_generator, class_mapping)
val_gen = MultiOutputGenerator(val_generator, class_mapping)

In [None]:
# Compiler le mod√®le
model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=0.001, weight_decay=1e-5),
    loss={
        'is_plant': 'binary_crossentropy',
        'classification': 'sparse_categorical_crossentropy',
        'health_score': 'mse',
        'growth_stage': 'sparse_categorical_crossentropy'
    },
    loss_weights={
        'is_plant': 0.5,
        'classification': 2.0,
        'health_score': 0.5,
        'growth_stage': 0.5
    },
    metrics={
        'is_plant': ['accuracy'],
        'classification': ['accuracy'],
        'health_score': ['mae'],
        'growth_stage': ['accuracy']
    }
)

In [None]:
# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(
        str(MODELS_DIR / 'best_model.keras'),
        monitor='val_classification_accuracy',
        mode='max',
        save_best_only=True,
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    keras.callbacks.TensorBoard(
        log_dir=str(MODELS_DIR / 'logs'),
        histogram_freq=1
    )
]

In [None]:
# Entra√Ænement
EPOCHS = 50

history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

# Visualiser l'entra√Ænement
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Classification accuracy
axes[0, 0].plot(history.history['classification_accuracy'], label='Train')
axes[0, 0].plot(history.history['val_classification_accuracy'], label='Val')
axes[0, 0].set_title('Classification Accuracy')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Loss
axes[0, 1].plot(history.history['loss'], label='Train')
axes[0, 1].plot(history.history['val_loss'], label='Val')
axes[0, 1].set_title('Total Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Health MAE
axes[1, 0].plot(history.history['health_score_mae'], label='Train')
axes[1, 0].plot(history.history['val_health_score_mae'], label='Val')
axes[1, 0].set_title('Health Score MAE')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Growth accuracy
axes[1, 1].plot(history.history['growth_stage_accuracy'], label='Train')
axes[1, 1].plot(history.history['val_growth_stage_accuracy'], label='Val')
axes[1, 1].set_title('Growth Stage Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(str(MODELS_DIR / 'training_history.png'))
plt.show()

# Sauvegarder le mod√®le final
model.save(str(MODELS_DIR / 'plant_model.keras'))
print(f"Mod√®le sauvegard√©: {MODELS_DIR / 'plant_model.keras'}")

In [None]:
# Convertir en TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

tflite_model = converter.convert()

tflite_path = MODELS_DIR / 'plant_model.tflite'
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

print(f"Mod√®le TFLite: {tflite_path}")
print(f"Taille: {len(tflite_model) / 1024 / 1024:.2f} MB")

In [None]:
# Tester le mod√®le TFLite
interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input:", input_details[0]['shape'], input_details[0]['dtype'])
print("\nOutputs:")
for out in output_details:
    print(f"  - {out['name']}: {out['shape']}")

# Copier sur Google Drive
!cp -r {MODELS_DIR}/* /content/drive/MyDrive/drone-agri-ai/models/
print("Fichiers copi√©s sur Google Drive!")

# T√©l√©charger les fichiers
from google.colab import files
files.download(str(MODELS_DIR / 'plant_model.tflite'))
files.download(str(MODELS_DIR / 'class_mapping.json'))