# =============================================================================
# NOTEBOOK 03 - OPTIMISATION DU MOD√àLE POUR RASPBERRY PI
# =============================================================================


## ‚ö° Optimisation du Mod√®le pour D√©ploiement Edge
# 
# Ce notebook optimise le mod√®le pour:
# 1. Quantification (INT8, FP16)
# 2. Pruning (√©lagage des poids)
# 3. Conversion TensorFlow Lite
# 4. Optimisation pour Coral Edge TPU
# 5. Benchmark de performance

In [None]:
# Imports
!pip install -q tensorflow tensorflow-model-optimization

import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot
from pathlib import Path
import time
import json

print(f"TensorFlow: {tf.__version__}")

In [None]:
# Charger le mod√®le entra√Æn√©
from google.colab import drive
drive.mount('/content/drive')

MODEL_DIR = Path('/content/drive/MyDrive/drone-agri-ai/models')
OUTPUT_DIR = Path('/content/optimized_models')
OUTPUT_DIR.mkdir(exist_ok=True)

# Charger le mod√®le Keras
model = keras.models.load_model(MODEL_DIR / 'plant_model.keras')
model.summary()

# Cr√©er un dataset repr√©sentatif pour calibration
DATA_DIR = Path('/content/drive/MyDrive/drone-agri-ai/data/train')

def create_representative_dataset(data_dir, num_samples=200):
    """Cr√©e un dataset pour la calibration de quantification"""
    images = []
    
    for class_dir in list(data_dir.iterdir())[:20]:  # 20 classes
        if not class_dir.is_dir():
            continue
        for img_path in list(class_dir.glob('*.jpg'))[:10]:  # 10 images/classe
            img = tf.io.read_file(str(img_path))
            img = tf.image.decode_jpeg(img, channels=3)
            img = tf.image.resize(img, [224, 224])
            img = tf.cast(img, tf.float32) / 255.0
            # Normalisation ImageNet
            mean = tf.constant([0.485, 0.456, 0.406])
            std = tf.constant([0.229, 0.224, 0.225])
            img = (img - mean) / std
            images.append(img.numpy())
            
            if len(images) >= num_samples:
                break
        if len(images) >= num_samples:
            break
    
    return np.array(images)

print("Cr√©ation du dataset repr√©sentatif...")
representative_data = create_representative_dataset(DATA_DIR)
print(f"Dataset: {representative_data.shape}")

# Conversion TFLite de base (FP32)
print("üì¶ Conversion TFLite FP32...")

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_fp32 = converter.convert()

# Sauvegarder
fp32_path = OUTPUT_DIR / 'plant_model_fp32.tflite'
with open(fp32_path, 'wb') as f:
    f.write(tflite_fp32)

print(f"‚úÖ FP32: {len(tflite_fp32) / 1024 / 1024:.2f} MB")

# Conversion TFLite FP16 (demi-pr√©cision)
print("üì¶ Conversion TFLite FP16...")

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

tflite_fp16 = converter.convert()

fp16_path = OUTPUT_DIR / 'plant_model_fp16.tflite'
with open(fp16_path, 'wb') as f:
    f.write(tflite_fp16)

print(f"‚úÖ FP16: {len(tflite_fp16) / 1024 / 1024:.2f} MB")

# Conversion TFLite INT8 (quantification compl√®te)
print("üì¶ Conversion TFLite INT8...")

def representative_dataset_gen():
    """G√©n√©rateur pour la calibration"""
    for i in range(min(100, len(representative_data))):
        yield [representative_data[i:i+1].astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.float32

try:
    tflite_int8 = converter.convert()
    
    int8_path = OUTPUT_DIR / 'plant_model_int8.tflite'
    with open(int8_path, 'wb') as f:
        f.write(tflite_int8)
    
    print(f"‚úÖ INT8: {len(tflite_int8) / 1024 / 1024:.2f} MB")
except Exception as e:
    print(f"‚ö†Ô∏è Erreur INT8: {e}")
    tflite_int8 = None

# Quantification dynamique
print("üì¶ Conversion avec quantification dynamique...")

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_dynamic = converter.convert()

dynamic_path = OUTPUT_DIR / 'plant_model_dynamic.tflite'
with open(dynamic_path, 'wb') as f:
    f.write(tflite_dynamic)

print(f"‚úÖ Dynamic: {len(tflite_dynamic) / 1024 / 1024:.2f} MB")

In [None]:
# Appliquer le Pruning au mod√®le
print("‚úÇÔ∏è Application du Pruning...")

# Recr√©er le mod√®le avec pruning
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.20,
        final_sparsity=0.70,
        begin_step=0,
        end_step=1000
    )
}

# Appliquer aux couches Dense uniquement
def apply_pruning_to_dense(layer):
    if isinstance(layer, keras.layers.Dense):
        return prune_low_magnitude(layer, **pruning_params)
    return layer

# Cloner le mod√®le avec pruning
model_for_pruning = keras.models.clone_model(
    model,
    clone_function=apply_pruning_to_dense
)

print("‚úÖ Pruning configur√©")

In [None]:
# Benchmark des mod√®les
print("‚è±Ô∏è Benchmark des mod√®les...")

def benchmark_tflite(model_path, num_runs=50):
    """Benchmark un mod√®le TFLite"""
    interpreter = tf.lite.Interpreter(model_path=str(model_path))
    interpreter.allocate_tensors()
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    # Pr√©parer l'entr√©e
    input_shape = input_details[0]['shape']
    input_dtype = input_details[0]['dtype']
    
    if input_dtype == np.uint8:
        input_data = np.random.randint(0, 255, input_shape).astype(np.uint8)
    else:
        input_data = np.random.rand(*input_shape).astype(np.float32)
    
    # Warmup
    for _ in range(10):
        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
    
    # Benchmark
    times = []
    for _ in range(num_runs):
        start = time.perf_counter()
        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()
        times.append((time.perf_counter() - start) * 1000)
    
    return {
        'mean_ms': np.mean(times),
        'std_ms': np.std(times),
        'min_ms': np.min(times),
        'max_ms': np.max(times),
        'fps': 1000 / np.mean(times)
    }

# Benchmarker tous les mod√®les
results = {}
for name, path in [
    ('FP32', fp32_path),
    ('FP16', fp16_path),
    ('Dynamic', dynamic_path),
]:
    if path.exists():
        print(f"Benchmark {name}...")
        results[name] = benchmark_tflite(path)
        results[name]['size_mb'] = path.stat().st_size / 1024 / 1024

if int8_path.exists():
    print("Benchmark INT8...")
    results['INT8'] = benchmark_tflite(int8_path)
    results['INT8']['size_mb'] = int8_path.stat().st_size / 1024 / 1024

# Afficher les r√©sultats
print("\n" + "="*60)
print("üìä R√âSULTATS DU BENCHMARK")
print("="*60)

for name, stats in results.items():
    print(f"\n{name}:")
    print(f"  Taille: {stats['size_mb']:.2f} MB")
    print(f"  Temps: {stats['mean_ms']:.2f} ¬± {stats['std_ms']:.2f} ms")
    print(f"  FPS: {stats['fps']:.1f}")

# Graphique comparatif
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

names = list(results.keys())
sizes = [results[n]['size_mb'] for n in names]
times = [results[n]['mean_ms'] for n in names]
fps = [results[n]['fps'] for n in names]

# Taille
axes[0].bar(names, sizes, color='steelblue')
axes[0].set_ylabel('Taille (MB)')
axes[0].set_title('Taille du mod√®le')

# Temps
axes[1].bar(names, times, color='coral')
axes[1].set_ylabel('Temps (ms)')
axes[1].set_title('Temps d\'inf√©rence')

# FPS
axes[2].bar(names, fps, color='forestgreen')
axes[2].set_ylabel('FPS')
axes[2].set_title('Images par seconde')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'benchmark_comparison.png', dpi=150)
plt.show()

In [None]:
# Choisir le meilleur mod√®le pour Raspberry Pi
print("\nüéØ RECOMMANDATION POUR RASPBERRY PI")
print("="*50)

# Le mod√®le FP16 offre g√©n√©ralement le meilleur compromis
best_model = 'FP16'
best_path = fp16_path

print(f"""
Mod√®le recommand√©: {best_model}
- Taille: {results[best_model]['size_mb']:.2f} MB
- Temps: {results[best_model]['mean_ms']:.2f} ms
- FPS: {results[best_model]['fps']:.1f}

‚úÖ Ce mod√®le offre le meilleur compromis taille/performance
   pour un Raspberry Pi 4.

Pour Coral Edge TPU, utilisez le mod√®le INT8.
""")

In [None]:
# Copier le mod√®le final
!cp {best_path} /content/drive/MyDrive/drone-agri-ai/models/plant_model.tflite
!cp {OUTPUT_DIR}/* /content/drive/MyDrive/drone-agri-ai/models/

# Sauvegarder les r√©sultats du benchmark
with open(OUTPUT_DIR / 'benchmark_results.json', 'w') as f:
    json.dump(results, f, indent=2)

!cp {OUTPUT_DIR}/benchmark_results.json /content/drive/MyDrive/drone-agri-ai/models/

print("‚úÖ Mod√®les et r√©sultats copi√©s sur Google Drive!")

# T√©l√©charger le mod√®le final
from google.colab import files
files.download(str(best_path))
print(f"‚úÖ Mod√®le {best_model} t√©l√©charg√©!")