## üì¶ Setup & Environment


In [None]:
# Install required packages
!pip install -q tensorflow>=2.15.0 kaggle opendatasets scikit-learn matplotlib seaborn

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
from pathlib import Path
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## üîë Configuration

### Set your crop type and paths here


In [None]:
# ========== UNIFIED MODEL - ALL CROPS ==========
# Root path to your "Leaf Nutrient Data Sets" folder on Google Drive
NUTRIENT_DATASETS_ROOT = '/content/drive/MyDrive/Leaf Nutrient Data Sets'

# ALL 12 CROPS - Combined into ONE model automatically!
CROP_DATASETS = {
    'rice': 'Rice Nutrients',
    'wheat': 'Wheat Nitrogen',
    'tomato': 'Tomato Nutrients',
    'maize': 'Maize Nutrients',
    'banana': 'Banana leaves Nutrient',
    'coffee': 'Coffee Nutrients',
    'cucumber': 'Cucumber Nutrients',
    'eggplant': 'EggPlant Nutrients',
    'ashgourd': 'Ashgourd Nutrients',
    'bittergourd': 'Bittergourd Nutrients',
    'ridgegourd': 'Ridgegourd',
    'snakegourd': 'Snakegourd Nutrients'
}

# =============================================================
# üöÄ ULTRA MEMORY-SAFE SETTINGS FOR COLAB FREE TIER
# =============================================================
# ‚ö†Ô∏è Colab free tier has ~12GB RAM - we CANNOT cache or shuffle large datasets!
IMG_SIZE = 224
BATCH_SIZE = 32   # Small batches = less RAM
PLANTVILLAGE_EPOCHS = 8   # More epochs since no shuffle buffer
UNIFIED_EPOCHS = 15
LEARNING_RATE_STAGE2 = 3e-4  # Slightly lower for stability
LEARNING_RATE_STAGE3 = 1e-4

# Enable mixed precision training
tf.keras.mixed_precision.set_global_policy('mixed_float16')
print("üöÄ Mixed precision enabled (FP16)")

# Output paths
OUTPUT_DIR = '/content/fasalvaidya_unified_model'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("üåæ UNIFIED MULTI-CROP MODEL TRAINING")
print("="*50)
print(f"Training ONE model for ALL {len(CROP_DATASETS)} crops")
print(f"\nüî• ULTRA MEMORY-SAFE MODE (prevents RAM crash):")
print(f"   - Batch size: {BATCH_SIZE} (small = safe)")
print(f"   - NO caching (prevents RAM explosion)")
print(f"   - NO shuffle buffer (prevents RAM explosion)")
print(f"   - Stage 2: {PLANTVILLAGE_EPOCHS} epochs")
print(f"   - Stage 3: {UNIFIED_EPOCHS} epochs")
print(f"\n‚ö° Training will be slower but WON'T CRASH!")

## üíæ Mount Google Drive


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Verify ALL crop datasets exist
print("üîç Verifying crop datasets...")
missing_crops = []
for crop, folder_name in CROP_DATASETS.items():
    crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
    if os.path.exists(crop_path):
        num_classes = len([d for d in os.listdir(crop_path) if os.path.isdir(os.path.join(crop_path, d))])
        print(f"‚úÖ {crop.upper()}: {num_classes} classes at {crop_path}")
    else:
        print(f"‚ùå {crop.upper()}: NOT FOUND at {crop_path}")
        missing_crops.append(crop)

if missing_crops:
    print(f"\n‚ö†Ô∏è WARNING: {len(missing_crops)} crop(s) not found: {', '.join(missing_crops)}")
    print("Please verify paths in NUTRIENT_DATASETS_ROOT and CROP_DATASETS")
else:
    print(f"\n‚úÖ All {len(CROP_DATASETS)} crop datasets verified!")


## üå± Stage 1: Download PlantVillage Dataset from Kaggle


In [None]:
# Setup Kaggle credentials
# IMPORTANT: You need to manually download kaggle.json FIRST!
# 
# üìù HOW TO GET kaggle.json:
# 1. Go to https://www.kaggle.com/settings
# 2. Scroll down to "API" section
# 3. Click "Create New Token" button
# 4. This will DOWNLOAD a file called "kaggle.json" to your computer
# 5. Find the downloaded file (usually in your Downloads folder)
# 6. Then come back here and upload it when prompted below
#
# ‚ö†Ô∏è NOTE: If you only see the API key on screen but no download happened,
#    click "Create New Token" again - it should download the file

from google.colab import files

print("=" * 70)
print("üì§ UPLOAD YOUR kaggle.json FILE")
print("=" * 70)
print("\nüìù If you haven't downloaded it yet:")
print("   1. Go to: https://www.kaggle.com/settings")
print("   2. Scroll to 'API' section")
print("   3. Click 'Create New Token' (downloads kaggle.json)")
print("   4. Find the file in your Downloads folder")
print("   5. Click 'Choose Files' below and select it")
print("\n‚è≥ Waiting for your kaggle.json file...\n")

uploaded = files.upload()

# Verify the file was uploaded
if 'kaggle.json' not in uploaded:
    print("\n‚ùå ERROR: kaggle.json was not uploaded!")
    print("   Please make sure you selected the correct file.")
    raise FileNotFoundError("kaggle.json not found in uploaded files")

# Move kaggle.json to the correct location
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("\n‚úÖ Kaggle credentials configured successfully!")
print("üìÅ File saved to: ~/.kaggle/kaggle.json")


In [None]:
# Download PlantVillage dataset from Kaggle (SKIP IF ALREADY EXISTS)
import opendatasets as od

PLANTVILLAGE_URL = 'https://www.kaggle.com/datasets/emmarex/plantdisease'
PLANTVILLAGE_PATH = '/content/plantvillage'

# Known possible paths after download
POSSIBLE_PATHS = [
    os.path.join(PLANTVILLAGE_PATH, 'plantdisease', 'PlantVillage'),
    os.path.join(PLANTVILLAGE_PATH, 'PlantVillage'),
    os.path.join(PLANTVILLAGE_PATH, 'plantdisease', 'plantvillage', 'PlantVillage'),
]

def find_plantvillage_dataset():
    """Find PlantVillage dataset if it exists"""
    for path in POSSIBLE_PATHS:
        if os.path.exists(path) and os.path.isdir(path):
            subdirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
            if len(subdirs) >= 15:
                sample_dir = os.path.join(path, subdirs[0])
                sample_files = [f for f in os.listdir(sample_dir)
                              if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                if len(sample_files) > 0:
                    return path
    return None

# Check if dataset already exists
existing_path = find_plantvillage_dataset()

if existing_path:
    print("‚úÖ PlantVillage dataset ALREADY EXISTS! Skipping download...")
    print(f"üìÅ Using cached dataset at: {existing_path}")
    PLANTVILLAGE_PATH = existing_path
else:
    print("üì• Downloading PlantVillage dataset (54,305 images)...")
    print("‚è≥ This will take 3-5 minutes (first time only)...")
    
    od.download(PLANTVILLAGE_URL, data_dir=PLANTVILLAGE_PATH)
    
    print("\nüîç Locating dataset structure...")
    
    # Find the dataset path
    dataset_root = find_plantvillage_dataset()
    
    if not dataset_root:
        # Search recursively as fallback
        for root, dirs, files in os.walk(PLANTVILLAGE_PATH):
            if len(dirs) >= 15:
                has_images = False
                for d in dirs[:3]:
                    dir_path = os.path.join(root, d)
                    if os.path.isdir(dir_path):
                        dir_files = os.listdir(dir_path)
                        if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in dir_files):
                            has_images = True
                            break
                if has_images:
                    dataset_root = root
                    break
    
    if dataset_root:
        PLANTVILLAGE_PATH = dataset_root
    else:
        raise FileNotFoundError("‚ùå PlantVillage dataset not found after download")

# Verify dataset
class_dirs = [d for d in os.listdir(PLANTVILLAGE_PATH) 
              if os.path.isdir(os.path.join(PLANTVILLAGE_PATH, d))]
num_classes = len(class_dirs)

print(f"\n‚úÖ PlantVillage dataset ready!")
print(f"üìÅ Path: {PLANTVILLAGE_PATH}")
print(f"üåø Classes: {num_classes}")

# Quick image count
total_images = sum(len([f for f in os.listdir(os.path.join(PLANTVILLAGE_PATH, cls))
                        if f.lower().endswith(('.jpg', '.jpeg', '.png'))]) 
                   for cls in class_dirs[:5])
print(f"üìä Sample: First 5 classes have {total_images:,} images")

## üìä Data Exploration & Preparation


In [None]:
# Analyze PlantVillage dataset
plantvillage_classes = sorted(os.listdir(PLANTVILLAGE_PATH))
print(f"üå± PlantVillage Dataset:")
print(f"Total classes: {len(plantvillage_classes)}")
print(f"\nSample classes:")
for cls in plantvillage_classes[:5]:
    class_path = os.path.join(PLANTVILLAGE_PATH, cls)
    if os.path.isdir(class_path):
        num_images = len(os.listdir(class_path))
        print(f"  - {cls}: {num_images} images")

# Build unified dataset info
print(f"\nüåæ UNIFIED Nutrient Dataset (ALL {len(CROP_DATASETS)} crops):")
total_classes = 0
total_images = 0

for crop, folder_name in CROP_DATASETS.items():
    crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
    if os.path.exists(crop_path):
        crop_classes = [d for d in os.listdir(crop_path) if os.path.isdir(os.path.join(crop_path, d))]
        crop_images = sum([len([f for f in os.listdir(os.path.join(crop_path, cls)) 
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))]) 
                          for cls in crop_classes])
        total_classes += len(crop_classes)
        total_images += crop_images
        print(f"  {crop.upper()}: {len(crop_classes)} classes, {crop_images} images")

print(f"\nüìä UNIFIED TOTALS:")
print(f"  Total classes: {total_classes}")
print(f"  Total images: {total_images}")
print(f"  Class format: {{crop}}_{{deficiency}} (e.g., rice_N, wheat_healthy)")


## üî® Create Data Pipelines


In [None]:
# =============================================================
# üöÄ ULTRA MEMORY-SAFE DATA PIPELINE (NO CACHING, NO SHUFFLE BUFFER)
# =============================================================
# ‚ö†Ô∏è Colab free tier has ~12GB RAM - caching causes crashes!
# Solution: NO caching, NO shuffle buffer - slower but STABLE!

AUTOTUNE = tf.data.AUTOTUNE

def create_dataset(data_dir, img_size, batch_size, validation_split=0.2, subset=None):
    """Create dataset with built-in shuffling only (memory-safe)"""
    return tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=validation_split,
        subset=subset,
        seed=42,
        image_size=(img_size, img_size),
        batch_size=batch_size,
        label_mode='categorical',
        shuffle=True  # Built-in shuffle is memory-safe (shuffles filenames only)
    )

@tf.function
def augment_light(image, label):
    """Lightweight augmentation"""
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, 0.1)
    return image, label

@tf.function  
def normalize_mobilenet(image, label):
    """Normalize for MobileNetV2"""
    image = tf.cast(image, tf.float32)
    image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
    return image, label

def build_pipeline(dataset, is_training=True):
    """Memory-safe pipeline - NO caching, NO shuffle buffer"""
    if is_training:
        dataset = dataset.map(augment_light, num_parallel_calls=AUTOTUNE)
    
    dataset = dataset.map(normalize_mobilenet, num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(1)  # Prefetch just 1 batch to save RAM
    
    return dataset

# Create PlantVillage datasets
print("üì¶ Creating PlantVillage datasets (MEMORY-SAFE mode)...")
print("‚ö†Ô∏è NO caching = stable but slower (prevents RAM crash)")

train_plantvillage_raw = create_dataset(
    PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE, 
    validation_split=0.2, subset='training'
)
val_plantvillage_raw = create_dataset(
    PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='validation'
)

# Apply memory-safe pipeline (NO caching!)
train_plantvillage = build_pipeline(train_plantvillage_raw, is_training=True)
val_plantvillage = build_pipeline(val_plantvillage_raw, is_training=False)

train_batches = tf.data.experimental.cardinality(train_plantvillage_raw).numpy()
val_batches = tf.data.experimental.cardinality(val_plantvillage_raw).numpy()

print(f"\n‚úÖ PlantVillage datasets ready (MEMORY-SAFE)")
print(f"   Training: {train_batches} batches √ó {BATCH_SIZE} = ~{train_batches * BATCH_SIZE:,} images")
print(f"   Validation: {val_batches} batches")
print(f"   ‚úÖ No caching = No RAM crash!")
print(f"   ‚è±Ô∏è Each epoch reads from disk (slower but stable)")

## ‚úÖ Pre-Training Validation

Run this cell to verify everything is set up correctly before training.


In [None]:
# ‚úÖ PRE-TRAINING VALIDATION (Memory-Safe)
print("=" * 60)
print("üîç PRE-TRAINING VALIDATION")
print("=" * 60)

errors = []

# 1. GPU Check
print("\n1Ô∏è‚É£ GPU Check...")
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"   ‚úÖ GPU: {gpus[0].name}")
else:
    errors.append("No GPU detected!")
    print("   ‚ö†Ô∏è No GPU - training will be slow")

# 2. Quick data test (just 1 batch)
print("\n2Ô∏è‚É£ Data Pipeline Check...")
try:
    for batch_images, batch_labels in train_plantvillage.take(1):
        print(f"   ‚úÖ Batch shape: {batch_images.shape}")
        print(f"   ‚úÖ Labels shape: {batch_labels.shape}")
        print(f"   ‚úÖ Image dtype: {batch_images.dtype}")
        break
except Exception as e:
    errors.append(f"Data pipeline error: {e}")
    print(f"   ‚ùå Error: {e}")

# 3. MobileNetV2 base check (quick)
print("\n3Ô∏è‚É£ MobileNetV2 Base Check...")
try:
    test_base = tf.keras.applications.MobileNetV2(
        input_shape=(224, 224, 3),
        include_top=False,
        weights='imagenet'
    )
    print(f"   ‚úÖ MobileNetV2 loaded ({len(test_base.layers)} layers)")
    del test_base  # Free memory
except Exception as e:
    errors.append(f"MobileNetV2 error: {e}")
    print(f"   ‚ùå Error: {e}")

# 4. Memory check
print("\n4Ô∏è‚É£ Memory Status...")
try:
    import psutil
    ram_gb = psutil.virtual_memory().available / (1024**3)
    print(f"   ‚úÖ Available RAM: {ram_gb:.1f} GB")
    if ram_gb < 2:
        print("   ‚ö†Ô∏è Low RAM - may crash during training")
except:
    print("   ‚ÑπÔ∏è Could not check RAM")

# Summary
print("\n" + "=" * 60)
if errors:
    print("‚ùå ISSUES FOUND:")
    for e in errors:
        print(f"   ‚Ä¢ {e}")
else:
    print("‚úÖ ALL CHECKS PASSED!")
    print(f"\nüöÄ Ready to train with:")
    print(f"   ‚Ä¢ Batch size: {BATCH_SIZE}")
    print(f"   ‚Ä¢ Memory-safe mode (no caching)")
    print(f"   ‚Ä¢ Mixed precision: FP16")
print("=" * 60)

## üèóÔ∏è Stage 2: Build Model with MobileNetV2 Base


In [None]:
def create_model(num_classes, input_shape=(224, 224, 3), freeze_base=True):
    """Create MobileNetV2-based model optimized for T4 GPU"""
    
    # Load MobileNetV2 with ImageNet weights
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    
    base_model.trainable = not freeze_base
    
    # Streamlined classification head (faster training)
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.25),
        # Float32 output for numerical stability with mixed precision
        tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')
    ])
    
    return model, base_model

# Get number of PlantVillage classes
num_plantvillage_classes = len(plantvillage_classes)

print(f"üèóÔ∏è Creating model for PlantVillage ({num_plantvillage_classes} classes)...")

model_stage2, base_model = create_model(num_plantvillage_classes, freeze_base=True)

# Compile with mixed precision optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE2)

model_stage2.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy']  # Simplified metrics for faster training
)

# Count trainable params
trainable_params = sum([tf.keras.backend.count_params(w) for w in model_stage2.trainable_weights])
print(f"\nüîí Base model frozen: {not base_model.trainable}")
print(f"üìä Trainable parameters: {trainable_params:,}")
print(f"üíæ Mixed precision: FP16 enabled")

## üéØ Stage 2: Train on PlantVillage Dataset


In [None]:
print("üöÄ Starting Stage 2: PlantVillage Fine-tuning")
print(f"‚è±Ô∏è Epochs: {PLANTVILLAGE_EPOCHS} | LR: {LEARNING_RATE_STAGE2}")
print("="*60)

# Minimal callbacks for speed
callbacks_stage2 = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=2,  # Aggressive early stopping
        restore_best_weights=True,
        verbose=1
    ),
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'stage2_plantvillage_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0  # Silent for speed
    )
]

# Train with verbose=2 for cleaner output
history_stage2 = model_stage2.fit(
    train_plantvillage,
    validation_data=val_plantvillage,
    epochs=PLANTVILLAGE_EPOCHS,
    callbacks=callbacks_stage2,
    verbose=2  # One line per epoch (faster)
)

print(f"\n‚úÖ Stage 2 completed!")
print(f"üìà Best val accuracy: {max(history_stage2.history['val_accuracy']):.4f}")

## üìà Stage 2 Results Visualization


In [None]:
# Quick training visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history_stage2.history['accuracy'], 'b-', label='Train')
axes[0].plot(history_stage2.history['val_accuracy'], 'r-', label='Val')
axes[0].set_title('Stage 2: Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_stage2.history['loss'], 'b-', label='Train')
axes[1].plot(history_stage2.history['val_loss'], 'r-', label='Val')
axes[1].set_title('Stage 2: Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'stage2_training_history.png'), dpi=150)
plt.show()

# Clear memory before Stage 3
import gc
gc.collect()
tf.keras.backend.clear_session()
print("üßπ Memory cleared for Stage 3")

## üîÑ Stage 3: Build UNIFIED Dataset


In [None]:
# Build unified dataset by combining all crops
print("üèóÔ∏è Building UNIFIED dataset...")

UNIFIED_DATASET_PATH = '/content/unified_nutrient_dataset'

# Check if unified dataset already exists
if os.path.exists(UNIFIED_DATASET_PATH):
    existing_classes = [d for d in os.listdir(UNIFIED_DATASET_PATH) 
                        if os.path.isdir(os.path.join(UNIFIED_DATASET_PATH, d))]
    if len(existing_classes) > 10:
        print(f"‚úÖ Already exists with {len(existing_classes)} classes!")
        unified_classes = existing_classes
    else:
        import shutil
        shutil.rmtree(UNIFIED_DATASET_PATH)
        os.makedirs(UNIFIED_DATASET_PATH)
        unified_classes = []
else:
    os.makedirs(UNIFIED_DATASET_PATH)
    unified_classes = []

if len(unified_classes) == 0:
    skipped_crops = []
    
    for crop, folder_name in CROP_DATASETS.items():
        crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
        
        if not os.path.exists(crop_path):
            skipped_crops.append(crop)
            continue
        
        try:
            crop_classes = [d for d in os.listdir(crop_path) 
                            if os.path.isdir(os.path.join(crop_path, d))]
        except:
            skipped_crops.append(crop)
            continue
        
        for cls in crop_classes:
            try:
                clean_cls = cls.replace(f"{crop}_", "").replace(f"{crop}__", "")
                unified_class_name = f"{crop}_{clean_cls}"
                
                src_dir = os.path.join(crop_path, cls)
                dst_dir = os.path.join(UNIFIED_DATASET_PATH, unified_class_name)
                
                if not os.path.exists(dst_dir):
                    os.symlink(src_dir, dst_dir)
                    unified_classes.append(unified_class_name)
            except:
                continue
        
        print(f"  ‚úÖ {crop.upper()}: {len([c for c in unified_classes if c.startswith(crop)])} classes")
    
    if skipped_crops:
        print(f"‚ö†Ô∏è Skipped: {', '.join(skipped_crops)}")

if len(unified_classes) == 0:
    raise RuntimeError("‚ùå No classes! Check Google Drive paths.")

class_names = sorted(unified_classes)
num_unified_classes = len(class_names)

print(f"\n‚úÖ Unified dataset: {num_unified_classes} classes")

# Create MEMORY-SAFE datasets
print("üì¶ Creating datasets (MEMORY-SAFE)...")

train_nutrient_raw = create_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='training'
)
val_nutrient_raw = create_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='validation'
)

train_nutrient = build_pipeline(train_nutrient_raw, is_training=True)
val_nutrient = build_pipeline(val_nutrient_raw, is_training=False)

print(f"‚úÖ Datasets ready (MEMORY-SAFE)")
print(f"   Training: {tf.data.experimental.cardinality(train_nutrient_raw).numpy()} batches")
print(f"   Validation: {tf.data.experimental.cardinality(val_nutrient_raw).numpy()} batches")

## üîß Stage 3: Adapt Model for Unified Classes


In [None]:
# Adapt model head for unified multi-crop detection
if 'num_unified_classes' not in locals() or num_unified_classes == 0:
    raise RuntimeError("‚ö†Ô∏è Run 'Build UNIFIED Dataset' cell first!")

print(f"üîß Adapting model for {num_unified_classes} unified classes...")

# Get the base model from Stage 2
base_model_stage2 = model_stage2.layers[0]
base_model_stage2.trainable = False

# Streamlined classification head for speed
model_stage3 = tf.keras.Sequential([
    base_model_stage2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(384, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Dense(num_unified_classes, activation='softmax', dtype='float32')
], name='unified_nutrient_model')

# Compile
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE3)

model_stage3.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')]
)

trainable_params = sum([tf.keras.backend.count_params(w) for w in model_stage3.trainable_weights])
print(f"üìä Trainable params: {trainable_params:,}")
print(f"üéØ Output classes: {num_unified_classes}")

## üéØ Stage 3: Train on UNIFIED Nutrient Dataset


In [None]:
print("üöÄ Starting Stage 3: UNIFIED Nutrient Detection")
print(f"üåæ Training ALL {len(CROP_DATASETS)} crops | Epochs: {UNIFIED_EPOCHS} | LR: {LEARNING_RATE_STAGE3}")
print("="*60)

# Optimized callbacks
callbacks_stage3 = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=4,
        restore_best_weights=True,
        verbose=1
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=2,
        min_lr=1e-7,
        verbose=1
    ),
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    )
]

# Train
history_stage3 = model_stage3.fit(
    train_nutrient,
    validation_data=val_nutrient,
    epochs=UNIFIED_EPOCHS,
    callbacks=callbacks_stage3,
    verbose=2
)

print(f"\n‚úÖ Stage 3 completed!")
print(f"üìà Best val accuracy: {max(history_stage3.history['val_accuracy']):.4f}")
print(f"üéØ Top-3 accuracy: {max(history_stage3.history['val_top3_acc']):.4f}")

## üìà Stage 3 Results Visualization


In [None]:
# Quick training visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history_stage3.history['accuracy'], 'b-', label='Train')
axes[0].plot(history_stage3.history['val_accuracy'], 'r-', label='Val')
axes[0].set_title('Stage 3: Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_stage3.history['loss'], 'b-', label='Train')
axes[1].plot(history_stage3.history['val_loss'], 'r-', label='Val')
axes[1].set_title('Stage 3: Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'stage3_training_history.png'), dpi=150)
plt.show()

## üîç Model Evaluation & Confusion Matrix


In [None]:
# Quick evaluation (skip heavy confusion matrix for speed)
print("üîç Evaluating UNIFIED model...")
results = model_stage3.evaluate(val_nutrient, verbose=0)

print(f"\nüìä Validation Metrics:")
print(f"   Loss: {results[0]:.4f}")
print(f"   Accuracy: {results[1]:.4f}")
print(f"   Top-3 Accuracy: {results[2]:.4f}")

# Quick per-crop accuracy (sample-based for speed)
print(f"\nüåæ Per-Crop Performance (quick check):")
y_true, y_pred = [], []
for images, labels in val_nutrient.take(20):  # Sample only
    predictions = model_stage3.predict(images, verbose=0)
    y_true.extend(np.argmax(labels.numpy(), axis=1))
    y_pred.extend(np.argmax(predictions, axis=1))

for crop in list(CROP_DATASETS.keys())[:6]:  # First 6 crops
    crop_classes = [cls for cls in class_names if cls.startswith(f"{crop}_")]
    if not crop_classes:
        continue
    crop_indices = [class_names.index(cls) for cls in crop_classes]
    crop_mask = np.isin(y_true, crop_indices)
    if crop_mask.sum() > 0:
        crop_acc = (np.array(y_true)[crop_mask] == np.array(y_pred)[crop_mask]).mean()
        print(f"   {crop.upper():12s}: {crop_acc:.1%}")

# Save classification report
report = classification_report(y_true, y_pred, target_names=[class_names[i] for i in sorted(set(y_true))], output_dict=True, zero_division=0)
with open(os.path.join(OUTPUT_DIR, 'unified_classification_report.json'), 'w') as f:
    json.dump(report, f, indent=2)

print(f"\n‚úÖ Evaluation complete")

## üíæ Export to TensorFlow Lite for Mobile Deployment


In [None]:
print("üì¶ Converting to TensorFlow Lite...")

# Load best model
best_model_path = os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras')
best_model = tf.keras.models.load_model(best_model_path)

# Convert to TFLite with FP16 quantization
converter = tf.lite.TFLiteConverter.from_keras_model(best_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

print("‚öôÔ∏è Converting with FP16 quantization...")
tflite_model = converter.convert()

# Save
tflite_path = os.path.join(OUTPUT_DIR, 'fasalvaidya_unified.tflite')
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

keras_size = os.path.getsize(best_model_path) / (1024 * 1024)
tflite_size = os.path.getsize(tflite_path) / (1024 * 1024)

print(f"\n‚úÖ Conversion complete!")
print(f"üìä Keras: {keras_size:.1f}MB ‚Üí TFLite: {tflite_size:.1f}MB ({(1-tflite_size/keras_size)*100:.0f}% smaller)")
print(f"üöÄ Single model for {len(CROP_DATASETS)} crops!")

## üß™ Test TFLite Model Inference


In [None]:
# Quick TFLite verification
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("üîç TFLite Model:")
print(f"   Input: {input_details[0]['shape']} ({input_details[0]['dtype']})")
print(f"   Output: {output_details[0]['shape']} ({num_unified_classes} classes)")

# Quick test
for images, labels in val_nutrient.take(1):
    test_image = images[0].numpy()
    input_data = np.expand_dims(test_image, axis=0).astype(np.float32)
    
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    
    pred_idx = np.argmax(output[0])
    true_idx = np.argmax(labels[0].numpy())
    
    print(f"\nüß™ Quick test:")
    print(f"   True: {class_names[true_idx]}")
    print(f"   Pred: {class_names[pred_idx]} ({output[0][pred_idx]:.1%})")
    print(f"   {'‚úÖ CORRECT' if pred_idx == true_idx else '‚ùå INCORRECT'}")
    break

print("\n‚úÖ TFLite model verified!")

## üì§ Save Model Metadata & Class Labels


In [None]:
# Save metadata and labels
print("üìù Saving metadata...")

crop_class_mapping = {crop: [c for c in class_names if c.startswith(f"{crop}_")] 
                      for crop in CROP_DATASETS.keys()}

metadata = {
    'model_type': 'unified_multi_crop',
    'model_version': '2.0',
    'training_date': datetime.now().isoformat(),
    'architecture': 'MobileNetV2',
    'supported_crops': list(CROP_DATASETS.keys()),
    'num_crops': len(CROP_DATASETS),
    'input_shape': [IMG_SIZE, IMG_SIZE, 3],
    'num_classes': num_unified_classes,
    'class_names': class_names,
    'crop_class_mapping': crop_class_mapping,
    'metrics': {'accuracy': float(results[1]), 'top3_accuracy': float(results[2])},
    'preprocessing': {'method': 'MobileNetV2', 'normalization': '[-1, 1]'}
}

with open(os.path.join(OUTPUT_DIR, 'unified_model_metadata.json'), 'w') as f:
    json.dump(metadata, f, indent=2)

with open(os.path.join(OUTPUT_DIR, 'labels.txt'), 'w') as f:
    f.write('\n'.join(class_names))

print(f"‚úÖ Saved: metadata.json, labels.txt")
print(f"üìä {len(CROP_DATASETS)} crops, {num_unified_classes} classes")

## üì¶ Download Models to Local Machine


In [None]:
# Create and download zip
import shutil

zip_filename = 'fasalvaidya_unified_model'
shutil.make_archive(f'/content/{zip_filename}', 'zip', OUTPUT_DIR)

print(f"üì¶ Created: {zip_filename}.zip")
print(f"\nüìÇ Contents:")
print(f"   üì± fasalvaidya_unified.tflite ({tflite_size:.1f}MB)")
print(f"   üíæ unified_nutrient_best.keras")
print(f"   üìÑ unified_model_metadata.json")
print(f"   üè∑Ô∏è labels.txt ({num_unified_classes} classes)")
print(f"\nüåæ Supports: {', '.join(list(CROP_DATASETS.keys())[:6])}...")

from google.colab import files
files.download(f'/content/{zip_filename}.zip')
print(f"\n‚¨áÔ∏è Download started!")

## üéâ UNIFIED Model Training Complete!

### üöÄ What You Got:

**ONE powerful model for ALL crops** instead of 12 separate models:

- üì± Single TFLite file: `fasalvaidya_unified.tflite` (~10-15MB)
- üåæ Handles 12 crops automatically: rice, wheat, tomato, maize, banana, coffee, cucumber, eggplant, ashgourd, bittergourd, ridgegourd, snakegourd
- üéØ 40-60+ deficiency classes (varies by available data)
- üî• Class format: `{crop}_{deficiency}` (e.g., `rice_N`, `wheat_healthy`, `tomato_K`)

### üìä Files Included:

- `fasalvaidya_unified.tflite` - Main model
- `unified_model_metadata.json` - Complete model info & crop mappings
- `labels.txt` - All class labels
- `unified_classification_report.json` - Performance metrics
- `unified_confusion_matrix.png` - Visualization
- Training history plots
- TensorBoard logs

### üéØ Model Architecture:

- **Stage 1**: ImageNet pretrained weights (general vision)
- **Stage 2**: PlantVillage fine-tuning (plant disease patterns)
- **Stage 3**: Unified nutrient training (ALL crops combined)

### üí° Benefits:

- ‚úÖ **12√ó smaller**: ~15MB vs 120MB (12 separate models)
- ‚úÖ **Simpler deployment**: One model to manage
- ‚úÖ **Consistent performance**: Same base features for all crops
- ‚úÖ **Easy updates**: Retrain once, update all crops
- ‚úÖ **Mobile-optimized**: FP16 quantization
- ‚úÖ **Grad-CAM ready**: Visualize deficiency regions

### üì¶ Integration Steps:

1. **Extract the ZIP file** you downloaded

2. **Copy the unified TFLite model** to your app:

   ```
   frontend/assets/models/fasalvaidya_unified.tflite
   ```

3. **Copy the labels file**:

   ```
   frontend/assets/models/labels.txt
   ```

4. **Update your inference code** to:

   - Load the single unified model
   - Parse predictions: split class name on `_` to get crop and deficiency
   - Example: `rice_N` ‚Üí crop=`rice`, deficiency=`N`

5. **Grad-CAM visualization** (optional):
   - Use `Conv_1` layer from MobileNetV2 base
   - See metadata for implementation notes

### üîÑ To Retrain:

1. Update datasets in Google Drive
2. Verify paths in configuration cell
3. Run all cells again
4. ONE training session updates ALL crops!
