## üì¶ Setup & Environment


# üöÄ FasalVaidya: Unified Multi-Crop Nutrient Deficiency Detection

## ‚ö†Ô∏è CRITICAL: First-Time Setup

**BEFORE RUNNING ANY CELLS:**

1. **üîÑ Restart Runtime** (REQUIRED if you've run this notebook before)
   - Click: **Runtime ‚Üí Restart Runtime** (or press `Ctrl+M` then `.`)
   - This clears any cached settings from previous sessions
   - Ensures float32 policy is properly set

2. **üì± Run cells sequentially** from top to bottom
   - Each cell depends on previous cells
   - Don't skip cells or run out of order

3. **‚è±Ô∏è Expected Training Time:**
   - Data copy: ~3-5 minutes (one-time)
   - Stage 2 (PlantVillage): ~15-20 minutes
   - Stage 3 (Nutrient): ~30-40 minutes
   - **Total: ~50-65 minutes** on T4 GPU

## üìã What This Notebook Does

**Stage 1:** Download PlantVillage dataset (optional transfer learning base)
**Stage 2:** Train on PlantVillage for general leaf disease recognition
**Stage 3:** Fine-tune on unified 4-crop nutrient deficiency dataset

## üéØ Optimizations Included

- ‚úÖ **Float32 precision** (no mixed precision issues)
- ‚úÖ **Local SSD data** (10-50x faster than Drive I/O)
- ‚úÖ **XLA/JIT compilation** (10-20% speedup)
- ‚úÖ **AUTOTUNE prefetch** (maximizes GPU utilization)
- ‚úÖ **Optimized batch size** (32 for better GPU usage)
- ‚úÖ **Smart augmentation** (flip, brightness, contrast, saturation, hue)
- ‚úÖ **Checkpoint resume** (can continue if interrupted)

## üåæ Supported Crops

- üåæ **Rice** (Nitrogen, Phosphorus, Potassium deficiencies)
- üåæ **Wheat** (Nitrogen deficiency)
- üçÖ **Tomato** (Multi-nutrient deficiencies)
- üåΩ **Maize** (Nitrogen-deficient and healthy)

---

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Install required packages
!pip install -q tensorflow>=2.15.0 kaggle opendatasets scikit-learn matplotlib seaborn tqdm

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from pathlib import Path
from datetime import datetime, timedelta
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tqdm.auto import tqdm
import time

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")
print(f"CPU cores available: {multiprocessing.cpu_count()}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# =============================================================
# ‚è±Ô∏è SESSION TIME TRACKER (Important for 3-hour limit!)
# =============================================================
SESSION_START_TIME = datetime.now()
TRAINING_START_TIME = None

def get_session_time():
    """Get elapsed session time"""
    elapsed = datetime.now() - SESSION_START_TIME
    hours = elapsed.seconds // 3600
    minutes = (elapsed.seconds % 3600) // 60
    return f"{hours}h {minutes}m"

def get_eta(current_epoch, total_epochs, epoch_time):
    """Calculate ETA for training completion"""
    remaining_epochs = total_epochs - current_epoch
    eta_seconds = remaining_epochs * epoch_time
    eta = timedelta(seconds=int(eta_seconds))
    return str(eta)

def check_time_limit(warn_minutes=150):
    """Warn if approaching 3-hour limit (180 min)"""
    elapsed = (datetime.now() - SESSION_START_TIME).seconds // 60
    remaining = 180 - elapsed
    if elapsed >= warn_minutes:
        print(f"‚ö†Ô∏è WARNING: {remaining} minutes remaining before typical Colab disconnect!")
        print(f"   Consider saving checkpoints and downloading results now.")
        return True
    return False

print(f"\n‚è±Ô∏è Session started at: {SESSION_START_TIME.strftime('%H:%M:%S')}")
print(f"   Target: Complete training within 1-1.5 hours")
print(f"   Checkpoints auto-save to Drive (training resumes from checkpoint)")

TensorFlow version: 2.19.0
GPU available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
CPU cores available: 2

‚è±Ô∏è Session started at: 10:45:28
   Target: Complete training within 1-1.5 hours
   Checkpoints auto-save to Drive (training resumes from checkpoint)


## üîë Configuration

### Set your crop type and paths here


In [None]:
# =============================================================
# üöÄ OPTIMAL CONFIGURATION FOR FASTEST TRAINING
# =============================================================

# Root path to your "Leaf Nutrient Data Sets" folder on Google Drive
NUTRIENT_DATASETS_ROOT = '/content/drive/MyDrive/Leaf Nutrient Data Sets'

# üöÄ FAST MVP: Only 4 crops for quick training!
CROP_DATASETS = {
    'rice': 'Rice Nutrients',
    'wheat': 'Wheat Nitrogen',
    'tomato': 'Tomato Nutrients',
    'maize': 'Maize Nutrients',
}

# =============================================================
# üéØ OPTIMAL TRAINING HYPERPARAMETERS
# =============================================================
IMG_SIZE = 224  # MobileNetV2 native resolution
BATCH_SIZE = 32  # Increased from 16 for better GPU utilization

# Training epochs (Aggressive but effective)
PLANTVILLAGE_EPOCHS = 8    # Stage 2: Quick transfer learning
UNIFIED_EPOCHS = 15        # Stage 3: Nutrient detection (increased for better convergence)

# Learning rates (Tuned for fast convergence)
LEARNING_RATE_STAGE2 = 1e-3  # Aggressive learning
LEARNING_RATE_STAGE3 = 5e-4  # Fine-tuning rate

# Regularization
DROPOUT_RATE = 0.3  # Increased to prevent overfitting

# =============================================================
# üöÄ PERFORMANCE OPTIMIZATIONS
# =============================================================
# Enable XLA compilation for up to 2x speedup
tf.config.optimizer.set_jit(True)

# GPU memory growth (prevents OOM errors)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"‚úÖ Enabled memory growth for {len(gpus)} GPU(s)")
    except RuntimeError as e:
        print(f"‚ö†Ô∏è GPU config warning: {e}")

# Multiprocessing settings (use all available cores)
NUM_WORKERS = multiprocessing.cpu_count()
PREFETCH_BUFFER = tf.data.AUTOTUNE  # Let TF auto-tune prefetch size

# CRITICAL: Use float32 for full precision (no mixed precision issues)
tf.keras.mixed_precision.set_global_policy('float32')
print("‚úÖ Using float32 policy for training (no mixed precision)")

# Output paths (persistent on Drive for checkpoint resume)
OUTPUT_DIR = '/content/fasalvaidya_unified_model'
DRIVE_CHECKPOINT_DIR = '/content/drive/MyDrive/FasalVaidya_Checkpoints'
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DRIVE_CHECKPOINT_DIR, exist_ok=True)

# =============================================================
# üìä CONFIGURATION SUMMARY
# =============================================================
print("‚ö° OPTIMIZED TRAINING CONFIGURATION")
print("="*60)
print(f"üåæ Crops: {len(CROP_DATASETS)} ({', '.join(CROP_DATASETS.keys())})")
print(f"\nüéØ Training Settings:")
print(f"   ‚Ä¢ Image size: {IMG_SIZE}√ó{IMG_SIZE}")
print(f"   ‚Ä¢ Batch size: {BATCH_SIZE} (optimized for GPU)")
print(f"   ‚Ä¢ Precision: float32 (no mixed precision)")
print(f"   ‚Ä¢ XLA compilation: {'‚úÖ Enabled' if tf.config.optimizer.get_jit() else '‚ùå Disabled'}")
print(f"\n‚öôÔ∏è Performance:")
print(f"   ‚Ä¢ CPU workers: {NUM_WORKERS}")
print(f"   ‚Ä¢ Prefetch: AUTOTUNE")
print(f"   ‚Ä¢ Memory growth: {'‚úÖ Enabled' if gpus else '‚ö†Ô∏è No GPU detected'}")
print(f"\nüìà Epochs:")
print(f"   ‚Ä¢ Stage 2 (PlantVillage): {PLANTVILLAGE_EPOCHS}")
print(f"   ‚Ä¢ Stage 3 (Nutrient): {UNIFIED_EPOCHS}")
print(f"\n‚ö° Expected training time: ~45-70 minutes total")

üöÄ Mixed precision enabled (FP16)
‚ö° FAST MVP PROTOTYPING MODE (1-1.5 hour target)
Training ONE model for ALL 12 crops

üöÄ RAPID PROTOTYPING SETTINGS:
   - Batch size: 16
   - Mixed precision: FP16 (2x faster)
   - XLA/JIT compilation: Enabled
   - Multiprocessing workers: 2
   - Prefetch buffer: 4
   - Stage 2: 5 epochs
   - Stage 3: 10 epochs
   - Grad-CAM: Enabled (visualize during training)

‚ö° Expected total time: ~1-1.5 hours for full training


## üíæ Mount Google Drive


In [None]:
from google.colab import drive
import os
import glob

# Robust check for already-mounted Drive
drive_path = '/content/drive'
is_mounted = False

if os.path.exists(drive_path):
    # Check if directory has content (indicates already mounted)
    try:
        if os.listdir(drive_path):
            is_mounted = True
            print("‚úÖ Google Drive already mounted!")
    except:
        pass

if not is_mounted:
    print("üìÅ Mounting Google Drive...")
    # Clean up if directory exists but is empty
    if os.path.exists(drive_path) and not os.listdir(drive_path):
        os.rmdir(drive_path)
    drive.mount(drive_path)
    print("‚úÖ Google Drive mounted successfully!")

# =============================================================
# üîç SMART PATH DETECTION: Search All Possible Locations
# =============================================================
print("\nüîç Searching for 'Leaf Nutrient Data Sets' folder...")

# List of possible locations to check
search_paths = [
    '/content/drive/MyDrive/Leaf Nutrient Data Sets',
    '/content/drive/Shareddrives/Leaf Nutrient Data Sets',
    '/content/drive/Shared drives/Leaf Nutrient Data Sets',
]

# Also search for shortcuts and nested locations
mydrive_base = '/content/drive/MyDrive'
if os.path.exists(mydrive_base):
    # Search in .shortcut-targets-by-id (where "Shared with me" shortcuts appear)
    shortcut_dir = os.path.join(mydrive_base, '.shortcut-targets-by-id')
    if os.path.exists(shortcut_dir):
        try:
            for folder_id in os.listdir(shortcut_dir):
                target_path = os.path.join(shortcut_dir, folder_id, 'Leaf Nutrient Data Sets')
                if os.path.exists(target_path):
                    search_paths.append(target_path)
        except:
            pass

# Try each location
found_location = None

for search_path in search_paths:
    if os.path.exists(search_path):
        # Verify it has crop folders
        try:
            contents = os.listdir(search_path)
            crop_folders = [f for f in contents if os.path.isdir(os.path.join(search_path, f))]
            if len(crop_folders) >= 5:  # Should have at least 5 crop folders
                print(f"‚úÖ Found at: {search_path}")
                print(f"   Contains {len(crop_folders)} folders")
                found_location = search_path
                break
        except:
            pass

if found_location:
    NUTRIENT_DATASETS_ROOT = found_location
    print(f"\n‚úÖ Using dataset location: {NUTRIENT_DATASETS_ROOT}")
else:
    print(f"\n‚ùå 'Leaf Nutrient Data Sets' folder NOT FOUND!")
    print(f"\nüìÇ What's in your Drive:")
    try:
        mydrive_items = os.listdir(mydrive_base)[:10]  # Show first 10 items
        for item in mydrive_items:
            item_path = os.path.join(mydrive_base, item)
            if os.path.isdir(item_path):
                print(f"   üìÅ {item}")
            else:
                print(f"   üìÑ {item}")
    except:
        print("   (Could not list Drive contents)")

    print(f"\n‚ö†Ô∏è FOLDER IS IN 'SHARED WITH ME' - NOT ACCESSIBLE!")
    print(f"\n‚úÖ SOLUTION: Add shortcut to My Drive")
    print(f"   1. Open Google Drive in browser: https://drive.google.com")
    print(f"   2. Click 'Shared with me' in left sidebar")
    print(f"   3. Right-click 'Leaf Nutrient Data Sets' folder")
    print(f"   4. Select 'Add shortcut to Drive' or 'Organize' > 'Add shortcut'")
    print(f"   5. Choose 'My Drive' root (don't put it in a subfolder)")
    print(f"   6. Click 'Add' or 'Add shortcut'")
    print(f"   7. Come back here and re-run this cell")
    print(f"\nüí° After adding shortcut, the folder will appear at:")
    print(f"   /content/drive/MyDrive/Leaf Nutrient Data Sets")

# Verify ALL crop datasets exist (only if folder found)
if found_location:
    print("\nüîç Verifying crop datasets...")
    missing_crops = []
    for crop, folder_name in CROP_DATASETS.items():
        crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
        if os.path.exists(crop_path):
            num_classes = len([d for d in os.listdir(crop_path) if os.path.isdir(os.path.join(crop_path, d))])
            print(f"‚úÖ {crop.upper()}: {num_classes} classes")
        else:
            print(f"‚ùå {crop.upper()}: NOT FOUND")
            missing_crops.append(crop)

    if missing_crops:
        print(f"\n‚ö†Ô∏è WARNING: {len(missing_crops)} crop(s) not found: {', '.join(missing_crops)}")
    else:
        print(f"\n‚úÖ All {len(CROP_DATASETS)} crop datasets verified!")


‚úÖ Google Drive already mounted!

üîç Searching for 'Leaf Nutrient Data Sets' folder...
‚úÖ Found at: /content/drive/MyDrive/Leaf Nutrient Data Sets
   Contains 12 folders

‚úÖ Using dataset location: /content/drive/MyDrive/Leaf Nutrient Data Sets

üîç Verifying crop datasets...
‚úÖ RICE: 3 classes
‚úÖ WHEAT: 3 classes
‚úÖ TOMATO: 2 classes
‚úÖ MAIZE: 2 classes
‚úÖ BANANA: 3 classes
‚úÖ COFFEE: 4 classes
‚úÖ CUCUMBER: 4 classes
‚úÖ EGGPLANT: 4 classes
‚úÖ ASHGOURD: 7 classes
‚úÖ BITTERGOURD: 9 classes
‚úÖ RIDGEGOURD: 4 classes
‚úÖ SNAKEGOURD: 5 classes

‚úÖ All 12 crop datasets verified!


## üöÄ Speed Optimization: Copy Data to Local Disk

**This is the BIGGEST speed boost!** Copying from Drive to local SSD speeds up training 10-50x.

In [None]:
# =============================================================
# üöÄ OPTIMIZED LOCAL SSD COPY (10-50x FASTER I/O)
# =============================================================
# Reading from Google Drive is SLOW (network I/O)
# Copying to /content/ uses Colab's fast local SSD
# This one-time copy saves HOURS during training!

import shutil
import time
from pathlib import Path

def copy_to_local_ssd(src_path, dest_name):
    """Optimized copy to local SSD with progress tracking"""
    local_path = f"/content/{dest_name}"
    
    # Skip if already exists and populated
    if os.path.exists(local_path):
        try:
            num_items = sum(1 for _ in Path(local_path).rglob('*') if _.is_file())
            if num_items > 100:  # Sanity check
                size_mb = sum(f.stat().st_size for f in Path(local_path).rglob('*') if f.is_file()) / (1024 * 1024)
                print(f"‚úÖ {dest_name}: Already on SSD ({num_items:,} files, {size_mb:.0f}MB)")
                return local_path
        except:
            pass
    
    if not os.path.exists(src_path):
        print(f"‚ö†Ô∏è {dest_name}: Source not found at {src_path}")
        return src_path
    
    print(f"üöÄ Copying {dest_name} to local SSD...", end=" ", flush=True)
    start = time.time()
    
    # Remove existing if corrupted
    if os.path.exists(local_path):
        shutil.rmtree(local_path)
    
    # Fast copy with symlink preservation
    shutil.copytree(src_path, local_path, symlinks=True)
    
    # Calculate stats
    elapsed = time.time() - start
    num_files = sum(1 for _ in Path(local_path).rglob('*') if _.is_file())
    size_mb = sum(f.stat().st_size for f in Path(local_path).rglob('*') if f.is_file()) / (1024 * 1024)
    
    print(f"‚úÖ {num_files:,} files, {size_mb:.0f}MB in {elapsed:.1f}s")
    return local_path

# =============================================================
# üì¶ COPY ALL NUTRIENT DATASETS TO LOCAL SSD
# =============================================================
print("=" * 70)
print("üöÄ COPYING DATASETS TO LOCAL SSD")
print("=" * 70)
print("‚è≥ One-time setup (2-5 min) - saves HOURS during training!\n")

LOCAL_NUTRIENT_ROOT = '/content/local_nutrient_datasets'
os.makedirs(LOCAL_NUTRIENT_ROOT, exist_ok=True)

copy_success = []
copy_failed = []

for crop, folder_name in CROP_DATASETS.items():
    src = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
    dst = os.path.join(LOCAL_NUTRIENT_ROOT, folder_name)
    
    try:
        if os.path.exists(src):
            # Check if already copied
            if os.path.exists(dst):
                num_files = sum(1 for _ in Path(dst).rglob('*.jpg')) + sum(1 for _ in Path(dst).rglob('*.png'))
                if num_files > 50:  # Sanity check
                    print(f"‚úÖ {crop.upper()}: Already on SSD ({num_files:,} images)")
                    copy_success.append(crop)
                    continue
            
            # Copy to local
            print(f"üöÄ {crop.upper()}: Copying...", end=" ", flush=True)
            start = time.time()
            
            if os.path.exists(dst):
                shutil.rmtree(dst)
            
            shutil.copytree(src, dst)
            
            num_files = sum(1 for _ in Path(dst).rglob('*.jpg')) + sum(1 for _ in Path(dst).rglob('*.png'))
            elapsed = time.time() - start
            
            print(f"‚úÖ {num_files:,} images in {elapsed:.1f}s")
            copy_success.append(crop)
        else:
            print(f"‚ö†Ô∏è {crop.upper()}: Not found on Drive")
            copy_failed.append(crop)
    except Exception as e:
        print(f"‚ùå {crop.upper()}: Failed - {e}")
        copy_failed.append(crop)

# Update root path to local SSD
NUTRIENT_DATASETS_ROOT = LOCAL_NUTRIENT_ROOT

print(f"\n{'='*70}")
print(f"‚úÖ {len(copy_success)}/{len(CROP_DATASETS)} crops ready on local SSD")
if copy_failed:
    print(f"‚ö†Ô∏è Failed: {', '.join(copy_failed)}")
print(f"üöÄ Training will now be 10-50x FASTER!")
print(f"{'='*70}\n")

üöÄ COPYING DATASETS TO LOCAL SSD (One-time per session)
‚è≥ This takes 2-5 minutes but saves HOURS of training time!



KeyboardInterrupt: 

## üå± Stage 1: Download PlantVillage Dataset from Kaggle


In [None]:
# Setup Kaggle credentials
# IMPORTANT: You need to manually download kaggle.json FIRST!
#
# üìù HOW TO GET kaggle.json:
# 1. Go to https://www.kaggle.com/settings
# 2. Scroll down to "API" section
# 3. Click "Create New Token" button
# 4. This will DOWNLOAD a file called "kaggle.json" to your computer
# 5. Find the downloaded file (usually in your Downloads folder)
# 6. Then come back here and upload it when prompted below
#
# ‚ö†Ô∏è NOTE: If you only see the API key on screen but no download happened,
#    click "Create New Token" again - it should download the file

from google.colab import files

print("=" * 70)
print("üì§ UPLOAD YOUR kaggle.json FILE")
print("=" * 70)
print("\nüìù If you haven't downloaded it yet:")
print("   1. Go to: https://www.kaggle.com/settings")
print("   2. Scroll to 'API' section")
print("   3. Click 'Create New Token' (downloads kaggle.json)")
print("   4. Find the file in your Downloads folder")
print("   5. Click 'Choose Files' below and select it")
print("\n‚è≥ Waiting for your kaggle.json file...\n")

uploaded = files.upload()

# Verify the file was uploaded
if 'kaggle.json' not in uploaded:
    print("\n‚ùå ERROR: kaggle.json was not uploaded!")
    print("   Please make sure you selected the correct file.")
    raise FileNotFoundError("kaggle.json not found in uploaded files")

# Move kaggle.json to the correct location
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("\n‚úÖ Kaggle credentials configured successfully!")
print("üìÅ File saved to: ~/.kaggle/kaggle.json")


In [None]:
# Download PlantVillage dataset from Kaggle (SKIP IF ALREADY EXISTS)
import opendatasets as od

PLANTVILLAGE_URL = 'https://www.kaggle.com/datasets/emmarex/plantdisease'
PLANTVILLAGE_PATH = '/content/plantvillage'

# Known possible paths after download
POSSIBLE_PATHS = [
    os.path.join(PLANTVILLAGE_PATH, 'plantdisease', 'PlantVillage'),
    os.path.join(PLANTVILLAGE_PATH, 'PlantVillage'),
    os.path.join(PLANTVILLAGE_PATH, 'plantdisease', 'plantvillage', 'PlantVillage'),
]

def find_plantvillage_dataset():
    """Find PlantVillage dataset if it exists"""
    for path in POSSIBLE_PATHS:
        if os.path.exists(path) and os.path.isdir(path):
            subdirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
            if len(subdirs) >= 15:
                sample_dir = os.path.join(path, subdirs[0])
                sample_files = [f for f in os.listdir(sample_dir)
                              if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                if len(sample_files) > 0:
                    return path
    return None

# Check if dataset already exists
existing_path = find_plantvillage_dataset()

if existing_path:
    print("‚úÖ PlantVillage dataset ALREADY EXISTS! Skipping download...")
    print(f"üìÅ Using cached dataset at: {existing_path}")
    PLANTVILLAGE_PATH = existing_path
else:
    print("üì• Downloading PlantVillage dataset (54,305 images)...")
    print("‚è≥ This will take 3-5 minutes (first time only)...")

    od.download(PLANTVILLAGE_URL, data_dir=PLANTVILLAGE_PATH)

    print("\nüîç Locating dataset structure...")

    # Find the dataset path
    dataset_root = find_plantvillage_dataset()

    if not dataset_root:
        # Search recursively as fallback
        for root, dirs, files in os.walk(PLANTVILLAGE_PATH):
            if len(dirs) >= 15:
                has_images = False
                for d in dirs[:3]:
                    dir_path = os.path.join(root, d)
                    if os.path.isdir(dir_path):
                        dir_files = os.listdir(dir_path)
                        if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in dir_files):
                            has_images = True
                            break
                if has_images:
                    dataset_root = root
                    break

    if dataset_root:
        PLANTVILLAGE_PATH = dataset_root
    else:
        raise FileNotFoundError("‚ùå PlantVillage dataset not found after download")

# Verify dataset
class_dirs = [d for d in os.listdir(PLANTVILLAGE_PATH)
              if os.path.isdir(os.path.join(PLANTVILLAGE_PATH, d))]
num_classes = len(class_dirs)

print(f"\n‚úÖ PlantVillage dataset ready!")
print(f"üìÅ Path: {PLANTVILLAGE_PATH}")
print(f"üåø Classes: {num_classes}")

# Quick image count
total_images = sum(len([f for f in os.listdir(os.path.join(PLANTVILLAGE_PATH, cls))
                        if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                   for cls in class_dirs[:5])
print(f"üìä Sample: First 5 classes have {total_images:,} images")

## üìä Data Exploration & Preparation


In [None]:
# Analyze PlantVillage dataset
plantvillage_classes = sorted(os.listdir(PLANTVILLAGE_PATH))
print(f"üå± PlantVillage Dataset:")
print(f"Total classes: {len(plantvillage_classes)}")
print(f"\nSample classes:")
for cls in plantvillage_classes[:5]:
    class_path = os.path.join(PLANTVILLAGE_PATH, cls)
    if os.path.isdir(class_path):
        num_images = len(os.listdir(class_path))
        print(f"  - {cls}: {num_images} images")

# Build unified dataset info
print(f"\nüåæ UNIFIED Nutrient Dataset (ALL {len(CROP_DATASETS)} crops):")
total_classes = 0
total_images = 0

for crop, folder_name in CROP_DATASETS.items():
    crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
    if os.path.exists(crop_path):
        crop_classes = [d for d in os.listdir(crop_path) if os.path.isdir(os.path.join(crop_path, d))]
        crop_images = sum([len([f for f in os.listdir(os.path.join(crop_path, cls))
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                          for cls in crop_classes])
        total_classes += len(crop_classes)
        total_images += crop_images
        print(f"  {crop.upper()}: {len(crop_classes)} classes, {crop_images} images")

print(f"\nüìä UNIFIED TOTALS:")
print(f"  Total classes: {total_classes}")
print(f"  Total images: {total_images}")
print(f"  Class format: {{crop}}_{{deficiency}} (e.g., rice_N, wheat_healthy)")


## üî® Create Data Pipelines


In [None]:
# =============================================================
# üöÄ OPTIMIZED DATA PIPELINE (Maximum Performance)
# =============================================================
# Key optimizations:
# 1. AUTOTUNE for all parallel operations
# 2. cache() after initial load (keeps data in memory)
# 3. Proper prefetch with AUTOTUNE
# 4. Efficient augmentation pipeline
# 5. Consistent normalization

AUTOTUNE = tf.data.AUTOTUNE

def create_optimized_dataset(data_dir, img_size, batch_size, validation_split=0.2, subset=None):
    """
    Create dataset with optimal settings for fastest training
    """
    print(f"üì¶ Loading {subset} data from {os.path.basename(data_dir)}...")
    
    dataset = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=validation_split,
        subset=subset,
        seed=42,
        image_size=(img_size, img_size),
        batch_size=batch_size,
        label_mode='categorical',
        shuffle=True
    )
    
    return dataset

@tf.function(jit_compile=True)  # XLA compilation for speedup
def augment_image(image, label):
    """Fast augmentation pipeline with XLA compilation"""
    # Random flip (horizontal only - vertical doesn't make sense for leaves)
    image = tf.image.random_flip_left_right(image)
    
    # Brightness and contrast (simulates different lighting)
    image = tf.image.random_brightness(image, 0.2)
    image = tf.image.random_contrast(image, 0.8, 1.2)
    
    # Random saturation (simulates different leaf health)
    image = tf.image.random_saturation(image, 0.8, 1.2)
    
    # Random hue shift (slight color variation)
    image = tf.image.random_hue(image, 0.05)
    
    # Ensure values stay in valid range
    image = tf.clip_by_value(image, 0.0, 255.0)
    
    return image, label

@tf.function(jit_compile=True)  # XLA compilation
def normalize_for_mobilenet(image, label):
    """Normalize to MobileNetV2 input range [-1, 1]"""
    image = tf.cast(image, tf.float32)
    image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
    return image, label

def build_optimized_pipeline(dataset, is_training=True, use_cache=True):
    """
    Build high-performance data pipeline
    
    Architecture:
    1. Parallel map operations (AUTOTUNE)
    2. Cache after initial load (optional)
    3. Augmentation (training only)
    4. Normalization
    5. Prefetch (AUTOTUNE)
    """
    # Configure threading for max parallelism
    options = tf.data.Options()
    options.threading.private_threadpool_size = NUM_WORKERS
    options.threading.max_intra_op_parallelism = 1
    options.deterministic = False  # Allow non-deterministic for speed
    dataset = dataset.with_options(options)
    
    # Cache after initial load (keeps preprocessed data in memory)
    # Only use for smaller datasets to avoid OOM
    if use_cache and not is_training:  # Cache validation set only
        dataset = dataset.cache()
    
    # Apply augmentation (training only)
    if is_training:
        dataset = dataset.map(augment_image, num_parallel_calls=AUTOTUNE)
    
    # Normalize for MobileNetV2
    dataset = dataset.map(normalize_for_mobilenet, num_parallel_calls=AUTOTUNE)
    
    # Prefetch batches for GPU (critical for performance)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    
    return dataset


# =============================================================
# üìä TQDM PROGRESS CALLBACK (Real-time tracking with ETA)
# =============================================================
class TQDMProgressCallback(tf.keras.callbacks.Callback):
    """Enhanced callback with progress bars and ETA"""
    
    def __init__(self, total_epochs, stage_name="Training"):
        super().__init__()
        self.total_epochs = total_epochs
        self.stage_name = stage_name
        self.epoch_pbar = None
        self.batch_pbar = None
        self.epoch_times = []
        self.stage_start_time = None
    
    def on_train_begin(self, logs=None):
        self.stage_start_time = time.time()
        print(f"\nüöÄ {self.stage_name} Started")
        self.epoch_pbar = tqdm(
            total=self.total_epochs,
            desc=f"üìà {self.stage_name}",
            unit="epoch",
            position=0,
            leave=True,
            bar_format='{l_bar}{bar:30}{r_bar}'
        )
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()
        total_batches = self.params.get('steps', 0)
        
        self.batch_pbar = tqdm(
            total=total_batches,
            desc=f"  Epoch {epoch+1}/{self.total_epochs}",
            unit="batch",
            position=1,
            leave=False,
            bar_format='{l_bar}{bar:25}{r_bar}'
        )
    
    def on_batch_end(self, batch, logs=None):
        if self.batch_pbar:
            self.batch_pbar.update(1)
            self.batch_pbar.set_postfix({
                'loss': f"{logs.get('loss', 0):.4f}",
                'acc': f"{logs.get('accuracy', 0):.3f}"
            })
    
    def on_epoch_end(self, epoch, logs=None):
        if self.batch_pbar:
            self.batch_pbar.close()
        
        epoch_time = time.time() - self.epoch_start_time
        self.epoch_times.append(epoch_time)
        
        # Calculate ETA
        avg_time = np.mean(self.epoch_times)
        remaining = self.total_epochs - (epoch + 1)
        eta_seconds = remaining * avg_time
        eta = str(timedelta(seconds=int(eta_seconds)))
        
        # Update progress
        self.epoch_pbar.update(1)
        self.epoch_pbar.set_postfix({
            'val_acc': f"{logs.get('val_accuracy', 0):.3f}",
            'val_loss': f"{logs.get('val_loss', 0):.4f}",
            'ETA': eta
        })
        
        print(f"\n   ‚úÖ Epoch {epoch+1}: val_acc={logs.get('val_accuracy', 0):.4f}, "
              f"val_loss={logs.get('val_loss', 0):.4f}, time={epoch_time:.1f}s")
    
    def on_train_end(self, logs=None):
        if self.epoch_pbar:
            self.epoch_pbar.close()
        total_time = time.time() - self.stage_start_time
        print(f"\n‚úÖ {self.stage_name} Complete in {str(timedelta(seconds=int(total_time)))}")
        print(f"   Avg epoch: {np.mean(self.epoch_times):.1f}s")


# =============================================================
# üì¶ CREATE PLANTVILLAGE DATASETS
# =============================================================
print("\n" + "="*70)
print("üì¶ CREATING PLANTVILLAGE DATASETS")
print("="*70)

train_plantvillage_raw = create_optimized_dataset(
    PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='training'
)
val_plantvillage_raw = create_optimized_dataset(
    PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='validation'
)

# Apply optimized pipeline
print("üîß Building optimized pipelines...")
train_plantvillage = build_optimized_pipeline(train_plantvillage_raw, is_training=True, use_cache=False)
val_plantvillage = build_optimized_pipeline(val_plantvillage_raw, is_training=False, use_cache=True)

# Get dataset info
train_batches = tf.data.experimental.cardinality(train_plantvillage_raw).numpy()
val_batches = tf.data.experimental.cardinality(val_plantvillage_raw).numpy()
num_classes = len(train_plantvillage_raw.class_names)

print(f"\n‚úÖ PlantVillage Datasets Ready")
print(f"   Classes: {num_classes}")
print(f"   Training: {train_batches} batches √ó {BATCH_SIZE} = ~{train_batches * BATCH_SIZE:,} images")
print(f"   Validation: {val_batches} batches √ó {BATCH_SIZE} = ~{val_batches * BATCH_SIZE:,} images")
print(f"   ‚ö° Optimizations: AUTOTUNE, XLA, {NUM_WORKERS} workers, cache (val)")
print(f"   üé® Augmentations: flip, brightness, contrast, saturation, hue")
print("="*70 + "\n")

## ‚úÖ Pre-Training Validation

Run this cell to verify everything is set up correctly before training.


In [None]:
# ‚úÖ PRE-TRAINING VALIDATION (T4 Optimized)
print("=" * 60)
print("üîç PRE-TRAINING VALIDATION")
print(f"‚è±Ô∏è Session time: {get_session_time()}")
print("=" * 60)

errors = []

# 1. GPU Check
print("\n1Ô∏è‚É£ GPU Check...")
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    gpu_name = gpus[0].name
    print(f"   ‚úÖ GPU: {gpu_name}")
    # Check if T4
    try:
        !nvidia-smi --query-gpu=name --format=csv,noheader
    except:
        pass
else:
    errors.append("No GPU detected!")
    print("   ‚ö†Ô∏è No GPU - training will be MUCH slower")

# 2. Quick data test
print("\n2Ô∏è‚É£ Data Pipeline Check...")
try:
    import time
    start = time.time()
    for batch_images, batch_labels in train_plantvillage.take(1):
        load_time = time.time() - start
        print(f"   ‚úÖ Batch shape: {batch_images.shape}")
        print(f"   ‚úÖ Labels shape: {batch_labels.shape}")
        print(f"   ‚úÖ Image dtype: {batch_images.dtype}")
        print(f"   ‚ö° First batch load: {load_time:.2f}s")
        break
except Exception as e:
    errors.append(f"Data pipeline error: {e}")
    print(f"   ‚ùå Error: {e}")

# 3. Memory check
print("\n3Ô∏è‚É£ Memory Status...")
try:
    import psutil
    ram_total = psutil.virtual_memory().total / (1024**3)
    ram_avail = psutil.virtual_memory().available / (1024**3)
    print(f"   ‚úÖ RAM: {ram_avail:.1f}GB available / {ram_total:.1f}GB total")
    if ram_avail < 3:
        print("   ‚ö†Ô∏è Low RAM - data copied to local SSD helps prevent crashes")
except:
    print("   ‚ÑπÔ∏è Could not check RAM")

# 4. GPU Memory check
print("\n4Ô∏è‚É£ GPU Memory Check...")
try:
    !nvidia-smi --query-gpu=memory.total,memory.free --format=csv,noheader
except:
    print("   ‚ÑπÔ∏è Could not query GPU memory")

# 5. Existing checkpoints
print("\n5Ô∏è‚É£ Checkpoint Status...")
stage2_exists = os.path.exists(os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_plantvillage_best.keras'))
stage3_exists = os.path.exists(os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras'))
print(f"   Stage 2 checkpoint: {'‚úÖ Found (will resume)' if stage2_exists else '‚ùå None (fresh start)'}")
print(f"   Stage 3 checkpoint: {'‚úÖ Found (will resume)' if stage3_exists else '‚ùå None (fresh start)'}")

# Summary
print("\n" + "=" * 60)
if errors:
    print("‚ùå ISSUES FOUND:")
    for e in errors:
        print(f"   ‚Ä¢ {e}")
else:
    print("‚úÖ ALL CHECKS PASSED!")
    print(f"\nüöÄ T4 GPU OPTIMIZED SETTINGS:")
    print(f"   ‚Ä¢ Batch size: {BATCH_SIZE} (fills 16GB VRAM)")
    print(f"   ‚Ä¢ Mixed precision: FP16")
    print(f"   ‚Ä¢ JIT/XLA compilation: Enabled")
    print(f"   ‚Ä¢ AUTOTUNE prefetch: Enabled")
    print(f"   ‚Ä¢ Data location: Local SSD (fast I/O)")
    print(f"\n‚ö° Expected: ~1-2 min/epoch (10x faster than Drive I/O)")
print("=" * 60)

## üèóÔ∏è Stage 2: Build Model with MobileNetV2 Base


In [None]:
# =============================================================
# üèóÔ∏è OPTIMIZED MODEL ARCHITECTURE
# =============================================================
def create_model(num_classes, input_shape=(224, 224, 3), freeze_base=True):
    """
    Create MobileNetV2-based model optimized for fast training
    
    Architecture:
    - MobileNetV2 base (pretrained on ImageNet)
    - GlobalAveragePooling2D (reduce parameters)
    - Dense(256) with L2 regularization
    - BatchNormalization
    - Dropout layers
    - Dense(num_classes) with softmax
    """
    
    # Load MobileNetV2 with ImageNet weights
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet',
        pooling=None  # We'll add our own pooling
    )
    
    base_model.trainable = not freeze_base
    
    # Build classification head
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(DROPOUT_RATE),
        tf.keras.layers.Dense(
            256, 
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(1e-4)
        ),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(DROPOUT_RATE * 0.8),
        # Float32 output for numerical stability
        tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')
    ])
    
    return model, base_model


# =============================================================
# üéØ STAGE 2: CREATE PLANTVILLAGE MODEL
# =============================================================
print("\n" + "="*70)
print("üèóÔ∏è BUILDING STAGE 2 MODEL (PlantVillage)")
print("="*70)

# Get number of PlantVillage classes
num_plantvillage_classes = len(train_plantvillage_raw.class_names)
print(f"   Classes: {num_plantvillage_classes}")

# Check for existing checkpoint
STAGE2_CHECKPOINT = os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_plantvillage_best.keras')
STAGE2_LOCAL = os.path.join(OUTPUT_DIR, 'stage2_plantvillage_best.keras')

model_stage2 = None
resume_stage2 = False

# Try to load existing checkpoint
for checkpoint_path in [STAGE2_CHECKPOINT, STAGE2_LOCAL]:
    if os.path.exists(checkpoint_path):
        try:
            print(f"\nüîÑ Found existing Stage 2 checkpoint!")
            print(f"   Loading from: {os.path.basename(checkpoint_path)}")
            model_stage2 = tf.keras.models.load_model(checkpoint_path)
            
            # Verify correct output shape
            if model_stage2.output_shape[-1] == num_plantvillage_classes:
                resume_stage2 = True
                print(f"‚úÖ Checkpoint valid ({num_plantvillage_classes} classes)")
                
                # Evaluate current performance
                print("üìä Evaluating checkpoint...")
                results = model_stage2.evaluate(val_plantvillage, verbose=0)
                print(f"   Current - Loss: {results[0]:.4f}, Accuracy: {results[1]:.4f}")
                
                if results[1] >= 0.85:
                    print("‚úÖ Checkpoint performance is good - will use for Stage 3")
                break
            else:
                print(f"‚ö†Ô∏è Class mismatch - creating new model")
                model_stage2 = None
        except Exception as e:
            print(f"‚ö†Ô∏è Could not load: {e}")
            model_stage2 = None

# Create new model if needed
if model_stage2 is None:
    print(f"\nüèóÔ∏è Creating NEW Stage 2 model...")
    model_stage2, base_model_stage2 = create_model(
        num_classes=num_plantvillage_classes,
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        freeze_base=True
    )
    
    # Compile with optimizer
    model_stage2.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE2),
        loss='categorical_crossentropy',
        metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')],
        jit_compile=True  # XLA compilation for speedup
    )
    
    print(f"‚úÖ Model created")

# Show model summary
trainable_params = sum([tf.keras.backend.count_params(w) for w in model_stage2.trainable_weights])
total_params = sum([tf.keras.backend.count_params(w) for w in model_stage2.weights])

print(f"\nüìä Model Architecture:")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Frozen parameters: {total_params - trainable_params:,}")
print(f"   Dropout rate: {DROPOUT_RATE}")
print(f"   Base model: MobileNetV2 (ImageNet)")
print(f"   Output classes: {num_plantvillage_classes}")
print(f"   Precision: float32")
print(f"   XLA/JIT: Enabled")
print(f"\nüîÑ Resume training: {'Yes' if resume_stage2 else 'No (fresh start)'}")
print("="*70 + "\n")

## üéØ Stage 2: Train on PlantVillage Dataset

In [None]:
print("üöÄ Starting Stage 2: PlantVillage Fine-tuning")
print(f"‚è±Ô∏è Epochs: {PLANTVILLAGE_EPOCHS} | LR: {LEARNING_RATE_STAGE2} | Batch: {BATCH_SIZE}")
if resume_stage2:
    print("üîÑ RESUMING from previous checkpoint")
print("="*60)

# Callbacks with TQDM progress
callbacks_stage2 = [
    # TQDM Progress with ETA
    TQDMProgressCallback(PLANTVILLAGE_EPOCHS, stage_name="Stage 2: PlantVillage"),

    # Early stopping
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True,
        verbose=1,
        min_delta=0.005
    ),

    # Reduce LR on plateau
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.3,
        patience=2,
        min_lr=1e-6,
        verbose=1
    ),

    # Save best to local (fast)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'stage2_plantvillage_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    ),

    # Save best to Drive (persistent)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_plantvillage_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    )
]

# Train with progress tracking
TRAINING_START_TIME = datetime.now()

history_stage2 = model_stage2.fit(
    train_plantvillage,
    validation_data=val_plantvillage,
    epochs=PLANTVILLAGE_EPOCHS,
    callbacks=callbacks_stage2,
    verbose=0  # Disable default output, use TQDM instead
)

# Calculate training stats
best_val_acc = max(history_stage2.history['val_accuracy'])
best_val_loss = min(history_stage2.history['val_loss'])
final_train_acc = history_stage2.history['accuracy'][-1]
training_time = datetime.now() - TRAINING_START_TIME

print(f"\n" + "="*60)
print(f"‚úÖ Stage 2 completed in {training_time}")
print(f"üìà Best val accuracy: {best_val_acc:.4f}")
print(f"üìâ Best val loss: {best_val_loss:.4f}")
print(f"üìä Final train accuracy: {final_train_acc:.4f}")

# Check for overfitting/underfitting
gap = final_train_acc - best_val_acc
if gap > 0.15:
    print(f"‚ö†Ô∏è Overfitting detected (train-val gap: {gap:.2%})")
elif best_val_acc < 0.7:
    print(f"‚ö†Ô∏è Possible underfitting (val_acc: {best_val_acc:.2%})")
else:
    print(f"‚úÖ Good generalization (train-val gap: {gap:.2%})")

print(f"\nüíæ Checkpoints saved to:")
print(f"   Local: {OUTPUT_DIR}")
print(f"   Drive: {DRIVE_CHECKPOINT_DIR}")

## üìà Stage 2 Results Visualization


In [None]:
# Quick training visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history_stage2.history['accuracy'], 'b-', label='Train')
axes[0].plot(history_stage2.history['val_accuracy'], 'r-', label='Val')
axes[0].set_title('Stage 2: Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_stage2.history['loss'], 'b-', label='Train')
axes[1].plot(history_stage2.history['val_loss'], 'r-', label='Val')
axes[1].set_title('Stage 2: Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'stage2_training_history.png'), dpi=150)
plt.savefig(os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_training_history.png'), dpi=150)
plt.show()

# Memory cleanup (preserve model references)
import gc
gc.collect()
print("üßπ Memory cleaned (models preserved for Stage 3)")

## üîÑ Stage 3: Build UNIFIED Dataset


In [None]:
# =============================================================
# üîÑ STAGE 3: BUILD UNIFIED NUTRIENT DATASET
# =============================================================
# Combine all crop datasets into one unified structure:
# crop_deficiency format (e.g., rice_N, wheat_P, tomato_healthy)

print("\n" + "="*70)
print("üèóÔ∏è BUILDING UNIFIED NUTRIENT DATASET")
print("="*70)

UNIFIED_DATASET_PATH = '/content/unified_nutrient_dataset'

# =============================================================
# üîß SMART FOLDER DETECTION (Handles multiple structures)
# =============================================================
def detect_nutrient_classes(crop_path):
    """
    Intelligently detect nutrient classes from various folder structures:
    
    Structure 1 (Flat):
        crop_folder/N/, crop_folder/P/, crop_folder/K/
        
    Structure 2 (Split):
        crop_folder/train/N/, crop_folder/test/N/, crop_folder/val/N/
        
    Returns: {class_name: [list_of_source_folders]}
    """
    nutrient_classes = {}
    
    if not os.path.exists(crop_path):
        return nutrient_classes
    
    subfolders = [d for d in os.listdir(crop_path) 
                  if os.path.isdir(os.path.join(crop_path, d))]
    
    # Detect if using train/test/val splits
    split_keywords = {'train', 'test', 'val', 'validation'}
    has_splits = any(f.lower() in split_keywords for f in subfolders)
    
    if has_splits:
        # Dataset has train/test/val structure
        for split_folder in subfolders:
            if split_folder.lower() in split_keywords:
                split_path = os.path.join(crop_path, split_folder)
                
                for class_name in os.listdir(split_path):
                    class_path = os.path.join(split_path, class_name)
                    
                    if not os.path.isdir(class_path):
                        continue
                    
                    # Check if folder contains images
                    files = os.listdir(class_path)[:20]  # Sample first 20
                    has_images = any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files)
                    
                    if has_images:
                        if class_name not in nutrient_classes:
                            nutrient_classes[class_name] = []
                        nutrient_classes[class_name].append(class_path)
    else:
        # Flat structure - subfolders are the classes
        for class_name in subfolders:
            class_path = os.path.join(crop_path, class_name)
            
            # Verify it contains images
            files = os.listdir(class_path)[:20]
            has_images = any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files)
            
            if has_images:
                nutrient_classes[class_name] = [class_path]
    
    return nutrient_classes


# =============================================================
# üöÄ BUILD UNIFIED DATASET (With validation)
# =============================================================
# Check if already built
if os.path.exists(UNIFIED_DATASET_PATH):
    existing_classes = [d for d in os.listdir(UNIFIED_DATASET_PATH)
                        if os.path.isdir(os.path.join(UNIFIED_DATASET_PATH, d))]
    
    # Validate class names (should be crop_deficiency format)
    valid_classes = [c for c in existing_classes if '_' in c and len(c.split('_')) == 2]
    
    if len(valid_classes) >= len(CROP_DATASETS) * 2:  # At least 2 classes per crop
        print(f"‚úÖ Unified dataset already exists with {len(valid_classes)} classes")
        print(f"   Classes: {', '.join(sorted(valid_classes)[:10])}{'...' if len(valid_classes) > 10 else ''}")
        unified_classes = valid_classes
        needs_rebuild = False
    else:
        print(f"‚ö†Ô∏è Existing dataset incomplete ({len(valid_classes)} classes) - rebuilding...")
        shutil.rmtree(UNIFIED_DATASET_PATH)
        needs_rebuild = True
else:
    needs_rebuild = True

# Build if needed
if needs_rebuild:
    os.makedirs(UNIFIED_DATASET_PATH, exist_ok=True)
    unified_classes = []
    skipped_crops = []
    
    print("üìÇ Combining crop datasets into unified structure...\n")
    
    for crop, folder_name in CROP_DATASETS.items():
        crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
        
        if not os.path.exists(crop_path):
            print(f"   ‚ö†Ô∏è {crop.upper()}: Not found")
            skipped_crops.append(crop)
            continue
        
        print(f"   üåæ {crop.upper()}: Processing...")
        
        try:
            # Detect nutrient classes
            nutrient_classes = detect_nutrient_classes(crop_path)
            
            if not nutrient_classes:
                print(f"      ‚ö†Ô∏è No valid classes found")
                skipped_crops.append(crop)
                continue
            
            # Process each nutrient deficiency class
            for class_name, source_paths in nutrient_classes.items():
                # Clean class name (remove crop prefix if exists)
                clean_name = class_name.replace(f"{crop}_", "").replace(f"{crop}__", "")
                unified_class = f"{crop}_{clean_name}"
                
                # Create destination folder
                dst_dir = os.path.join(UNIFIED_DATASET_PATH, unified_class)
                os.makedirs(dst_dir, exist_ok=True)
                
                # Copy images from all source paths (train + test + val)
                total_copied = 0
                for src_dir in source_paths:
                    src_name = os.path.basename(os.path.dirname(src_dir))  # train/test/val
                    
                    for img_file in os.listdir(src_dir):
                        if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                            src_file = os.path.join(src_dir, img_file)
                            # Prevent filename collisions with prefix
                            dst_file = os.path.join(dst_dir, f"{src_name}_{img_file}")
                            
                            if not os.path.exists(dst_file):
                                shutil.copy2(src_file, dst_file)
                                total_copied += 1
                
                if total_copied > 0 and unified_class not in unified_classes:
                    unified_classes.append(unified_class)
            
            # Show crop summary
            crop_classes = [c for c in unified_classes if c.startswith(f"{crop}_")]
            class_names_short = [c.replace(f'{crop}_', '') for c in crop_classes]
            print(f"      ‚úÖ {len(crop_classes)} classes: {', '.join(class_names_short)}")
            
        except Exception as e:
            print(f"      ‚ùå Error: {e}")
            skipped_crops.append(crop)
    
    if skipped_crops:
        print(f"\n‚ö†Ô∏è Skipped crops: {', '.join(skipped_crops)}")

# Validate final dataset
if len(unified_classes) == 0:
    # Try reading from existing
    unified_classes = [d for d in os.listdir(UNIFIED_DATASET_PATH)
                       if os.path.isdir(os.path.join(UNIFIED_DATASET_PATH, d)) and '_' in d]

if len(unified_classes) == 0:
    raise RuntimeError("‚ùå No valid classes found! Check dataset structure.")

# Sort and validate
class_names = sorted(unified_classes)
num_classes = len(class_names)

# Final validation - no train/test/val in class names
bad_classes = [c for c in class_names if any(x in c.lower() for x in ['_train', '_test', '_val'])]
if bad_classes:
    raise RuntimeError(f"‚ùå Invalid class names detected: {bad_classes[:3]}")

print(f"\n{'='*70}")
print(f"‚úÖ UNIFIED DATASET READY")
print(f"   Total classes: {num_classes}")
print(f"   Format: crop_deficiency (e.g., rice_N, wheat_healthy)")
print(f"   Classes: {', '.join(class_names[:8])}{'...' if num_classes > 8 else ''}")
print(f"   Location: {UNIFIED_DATASET_PATH} (LOCAL SSD)")
print(f"{'='*70}\n")

# =============================================================
# üì¶ CREATE OPTIMIZED UNIFIED DATASETS
# =============================================================
print("üì¶ Creating unified nutrient datasets...")

train_nutrient_raw = create_optimized_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='training'
)
val_nutrient_raw = create_optimized_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='validation'
)

# Apply optimized pipeline (no cache for larger dataset)
print("üîß Building optimized pipelines...")
train_nutrient = build_optimized_pipeline(train_nutrient_raw, is_training=True, use_cache=False)
val_nutrient = build_optimized_pipeline(val_nutrient_raw, is_training=False, use_cache=False)

train_batches = tf.data.experimental.cardinality(train_nutrient_raw).numpy()
val_batches = tf.data.experimental.cardinality(val_nutrient_raw).numpy()

print(f"\n‚úÖ Unified Nutrient Datasets Ready")
print(f"   Classes: {num_classes}")
print(f"   Training: {train_batches} batches √ó {BATCH_SIZE} = ~{train_batches * BATCH_SIZE:,} images")
print(f"   Validation: {val_batches} batches √ó {BATCH_SIZE} = ~{val_batches * BATCH_SIZE:,} images")
print(f"   ‚ö° Optimizations: AUTOTUNE, XLA, {NUM_WORKERS} workers")
print(f"   üé® Augmentations: flip, brightness, contrast, saturation, hue")
print("="*70 + "\n")

## üìä Dataset Verification & Statistics

## ‚öñÔ∏è Balance Dataset (Fix Model Bias)

**Problem:** If some classes have many more images than others (e.g., maize has 1000 images while wheat has 200), the model becomes biased towards the majority class.

**Solution:** Balance classes by either:
- **Undersampling:** Reduce majority classes to match minority
- **Oversampling:** Duplicate/augment minority classes to match majority

We'll use a hybrid approach for optimal results.

In [None]:
# =============================================================
# ‚öñÔ∏è BALANCE DATASET - FIX MODEL BIAS
# =============================================================
# Ensures equal representation of all classes to prevent bias

import random
from PIL import Image, ImageEnhance, ImageOps
from collections import Counter

def augment_image_pil(img_path, save_path, augmentation_idx):
    """
    Create augmented version of image using PIL
    Lighter augmentation than training pipeline (preserves image identity)
    """
    try:
        img = Image.open(img_path)
        
        # Different augmentation based on index
        if augmentation_idx % 5 == 0:
            # Horizontal flip
            img = ImageOps.mirror(img)
        elif augmentation_idx % 5 == 1:
            # Brightness adjustment
            enhancer = ImageEnhance.Brightness(img)
            img = enhancer.enhance(random.uniform(0.85, 1.15))
        elif augmentation_idx % 5 == 2:
            # Contrast adjustment
            enhancer = ImageEnhance.Contrast(img)
            img = enhancer.enhance(random.uniform(0.85, 1.15))
        elif augmentation_idx % 5 == 3:
            # Color adjustment
            enhancer = ImageEnhance.Color(img)
            img = enhancer.enhance(random.uniform(0.9, 1.1))
        else:
            # Sharpness adjustment
            enhancer = ImageEnhance.Sharpness(img)
            img = enhancer.enhance(random.uniform(0.9, 1.1))
        
        img.save(save_path, quality=95)
        return True
    except Exception as e:
        print(f"      ‚ö†Ô∏è Augmentation failed: {e}")
        return False

def balance_dataset(dataset_path, target_images_per_class=None, strategy='hybrid'):
    """
    Balance dataset to prevent model bias
    
    Args:
        dataset_path: Path to unified dataset
        target_images_per_class: Target number of images per class (None = use median)
        strategy: 'undersample', 'oversample', or 'hybrid'
    
    Returns:
        Dictionary with balancing statistics
    """
    print("\n" + "="*70)
    print("‚öñÔ∏è BALANCING DATASET TO PREVENT BIAS")
    print("="*70)
    
    # Count images per class
    class_counts = {}
    for class_name in os.listdir(dataset_path):
        class_path = os.path.join(dataset_path, class_name)
        if not os.path.isdir(class_path):
            continue
        
        images = [f for f in os.listdir(class_path) 
                  if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        class_counts[class_name] = len(images)
    
    if not class_counts:
        print("‚ùå No classes found!")
        return {}
    
    # Calculate statistics
    min_count = min(class_counts.values())
    max_count = max(class_counts.values())
    median_count = sorted(class_counts.values())[len(class_counts) // 2]
    mean_count = sum(class_counts.values()) // len(class_counts)
    
    print(f"\nüìä Current Distribution:")
    print(f"   Min: {min_count} images")
    print(f"   Max: {max_count} images")
    print(f"   Median: {median_count} images")
    print(f"   Mean: {mean_count} images")
    print(f"   Imbalance ratio: {min_count/max_count:.2f}")
    
    # Determine target
    if target_images_per_class is None:
        if strategy == 'undersample':
            target = min_count
        elif strategy == 'oversample':
            target = max_count
        else:  # hybrid
            # Use median or mean, whichever is more balanced
            target = median_count if median_count > mean_count * 0.8 else mean_count
    else:
        target = target_images_per_class
    
    print(f"\nüéØ Target: {target} images per class")
    print(f"   Strategy: {strategy}")
    
    # Balance each class
    stats = {
        'original': {},
        'final': {},
        'undersampled': [],
        'oversampled': [],
        'augmented_images': 0
    }
    
    for class_name, current_count in class_counts.items():
        class_path = os.path.join(dataset_path, class_name)
        images = [f for f in os.listdir(class_path) 
                  if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        
        stats['original'][class_name] = current_count
        
        if current_count > target:
            # Undersample: randomly remove excess images
            excess = current_count - target
            images_to_remove = random.sample(images, excess)
            
            for img_file in images_to_remove:
                os.remove(os.path.join(class_path, img_file))
            
            stats['final'][class_name] = target
            stats['undersampled'].append(class_name)
            
        elif current_count < target:
            # Oversample: augment existing images
            deficit = target - current_count
            
            # Calculate how many augmentations per image
            augmentations_needed = deficit
            augmentation_idx = 0
            
            while augmentations_needed > 0:
                # Pick a random source image
                source_img = random.choice(images)
                source_path = os.path.join(class_path, source_img)
                
                # Create augmented filename
                base_name = os.path.splitext(source_img)[0]
                ext = os.path.splitext(source_img)[1]
                aug_name = f"{base_name}_aug{augmentation_idx}{ext}"
                aug_path = os.path.join(class_path, aug_name)
                
                # Skip if already exists
                if os.path.exists(aug_path):
                    augmentation_idx += 1
                    continue
                
                # Augment and save
                if augment_image_pil(source_path, aug_path, augmentation_idx):
                    augmentations_needed -= 1
                    stats['augmented_images'] += 1
                
                augmentation_idx += 1
            
            stats['final'][class_name] = target
            stats['oversampled'].append(class_name)
        else:
            # Already balanced
            stats['final'][class_name] = current_count
    
    # Print results
    print(f"\n‚úÖ Balancing Complete!")
    print(f"   Undersampled: {len(stats['undersampled'])} classes")
    print(f"   Oversampled: {len(stats['oversampled'])} classes")
    print(f"   Augmented images created: {stats['augmented_images']}")
    
    # Verify balance
    final_counts = stats['final'].values()
    if len(set(final_counts)) == 1:
        print(f"\n‚úÖ Perfect balance achieved!")
        print(f"   All classes now have {target} images")
    else:
        print(f"\nüìä Final distribution:")
        print(f"   Min: {min(final_counts)} images")
        print(f"   Max: {max(final_counts)} images")
        print(f"   Balance ratio: {min(final_counts)/max(final_counts):.2f}")
    
    print("="*70 + "\n")
    
    return stats

# Run balancing
print("üîç Checking if dataset needs balancing...")

# Count current distribution
current_counts = {}
for class_name in sorted(class_names):
    class_path = os.path.join(UNIFIED_DATASET_PATH, class_name)
    num_images = len([f for f in os.listdir(class_path) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    current_counts[class_name] = num_images

min_imgs = min(current_counts.values())
max_imgs = max(current_counts.values())
imbalance_ratio = min_imgs / max_imgs

if imbalance_ratio < 0.7:  # More than 30% imbalance
    print(f"‚ö†Ô∏è Imbalance detected! Ratio: {imbalance_ratio:.2f}")
    print(f"   This can cause bias (e.g., towards maize)")
    print(f"\nüîß Applying automatic balancing...\n")
    
    # Use hybrid strategy: balance to median
    balance_stats = balance_dataset(
        UNIFIED_DATASET_PATH,
        target_images_per_class=None,  # Auto-determine from median
        strategy='hybrid'
    )
    
    # Update class_names list (shouldn't change, but refresh counts)
    print("üîÑ Refreshing dataset information...")
    
else:
    print(f"‚úÖ Dataset is already balanced (ratio: {imbalance_ratio:.2f})")
    print(f"   No balancing needed!\n")

In [None]:
# =============================================================
# üìä DATASET VERIFICATION AND STATISTICS
# =============================================================
# Verify data quality and show class distribution

print("="*70)
print("üìä DATASET ANALYSIS")
print("="*70 + "\n")

# Count images per class
print("üìà Class Distribution:")
print("-" * 70)

class_stats = {}
total_images = 0

for class_name in sorted(class_names):
    class_path = os.path.join(UNIFIED_DATASET_PATH, class_name)
    
    # Count image files
    num_images = len([f for f in os.listdir(class_path) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    
    class_stats[class_name] = num_images
    total_images += num_images
    
    # Show crop and deficiency separately
    crop, deficiency = class_name.split('_', 1)
    print(f"   {class_name:25s} {num_images:4d} images  ({crop:8s} | {deficiency})")

print("-" * 70)
print(f"   TOTAL: {total_images:,} images across {num_classes} classes")
print()

# Calculate statistics
avg_per_class = total_images / num_classes
min_images = min(class_stats.values())
max_images = max(class_stats.values())
min_class = min(class_stats, key=class_stats.get)
max_class = max(class_stats, key=class_stats.get)

print("üìä Statistics:")
print(f"   Average per class: {avg_per_class:.0f} images")
print(f"   Min: {min_images} images ({min_class})")
print(f"   Max: {max_images} images ({max_class})")
print(f"   Balance ratio: {min_images/max_images:.2f} (1.0 = perfectly balanced)")

# Check for severe imbalance
if min_images / max_images < 0.3:
    print(f"\n‚ö†Ô∏è  WARNING: Severe class imbalance detected!")
    print(f"   Consider using class weights during training")
elif min_images / max_images < 0.5:
    print(f"\n‚ö†Ô∏è  Moderate class imbalance - may affect performance")
else:
    print(f"\n‚úÖ Dataset is reasonably balanced")

# Show breakdown by crop
print(f"\nüìä Breakdown by Crop:")
print("-" * 70)
for crop in CROP_DATASETS.keys():
    crop_classes = [c for c in class_names if c.startswith(f"{crop}_")]
    crop_images = sum(class_stats[c] for c in crop_classes)
    print(f"   {crop.upper():10s} {len(crop_classes):2d} classes  {crop_images:5d} images  "
          f"{', '.join([c.replace(f'{crop}_', '') for c in crop_classes])}")

print("="*70 + "\n")

# Verify data integrity (sample check)
print("üîç Data Integrity Check (sampling 5 random images)...")
import random
sample_classes = random.sample(class_names, min(3, len(class_names)))

integrity_ok = True
for class_name in sample_classes:
    class_path = os.path.join(UNIFIED_DATASET_PATH, class_name)
    image_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    
    # Check random images
    for img_file in random.sample(image_files, min(2, len(image_files))):
        img_path = os.path.join(class_path, img_file)
        try:
            from PIL import Image
            img = Image.open(img_path)
            img.verify()  # Verify it's a valid image
        except Exception as e:
            print(f"   ‚ùå Corrupted: {class_name}/{img_file}: {e}")
            integrity_ok = False

if integrity_ok:
    print("   ‚úÖ All sampled images are valid")
else:
    print("   ‚ö†Ô∏è Some images may be corrupted - consider data cleaning")

print(f"\n‚úÖ Dataset verification complete!")
print("="*70 + "\n")

## üîß Stage 3: Adapt Model for Unified Classes


In [None]:
# =============================================================
# üîÑ STAGE 3: Adapt Model for Unified Classes (with Resume)
# =============================================================
if 'num_unified_classes' not in locals() or num_unified_classes == 0:
    raise RuntimeError("‚ö†Ô∏è Run 'Build UNIFIED Dataset' cell first!")

print(f"üîß Setting up Stage 3 for {num_unified_classes} unified classes...")

# Check for existing Stage 3 checkpoint
STAGE3_CHECKPOINT = os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras')
STAGE3_LOCAL = os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras')

model_stage3 = None
resume_stage3 = False
initial_epoch = 0

# Try to load existing checkpoint
for checkpoint_path in [STAGE3_CHECKPOINT, STAGE3_LOCAL]:
    if os.path.exists(checkpoint_path):
        try:
            print(f"üîÑ Found existing Stage 3 checkpoint!")
            print(f"   Loading from: {checkpoint_path}")
            model_stage3 = tf.keras.models.load_model(checkpoint_path)

            # Verify correct output shape
            if model_stage3.output_shape[-1] == num_unified_classes:
                resume_stage3 = True
                print(f"‚úÖ Checkpoint valid ({num_unified_classes} classes)")

                # Evaluate current performance
                print("üìä Evaluating checkpoint...")
                results = model_stage3.evaluate(val_nutrient, verbose=0)
                print(f"   Current - Loss: {results[0]:.4f}, Accuracy: {results[1]:.4f}")

                # Check training history for initial_epoch
                history_path = os.path.join(DRIVE_CHECKPOINT_DIR, 'stage3_history.json')
                if os.path.exists(history_path):
                    with open(history_path, 'r') as f:
                        prev_history = json.load(f)
                        initial_epoch = len(prev_history.get('accuracy', []))
                        print(f"   Resuming from epoch {initial_epoch}")
                break
            else:
                print(f"‚ö†Ô∏è Class mismatch ({model_stage3.output_shape[-1]} vs {num_unified_classes})")
                print("   Creating new model (classes changed)")
                model_stage3 = None
        except Exception as e:
            print(f"‚ö†Ô∏è Could not load: {e}")
            model_stage3 = None

# Create new model if needed
if model_stage3 is None:
    print(f"üèóÔ∏è Creating NEW unified model...")

    # Get base model from Stage 2
    base_model_stage2 = model_stage2.layers[0]
    base_model_stage2.trainable = False  # Keep frozen initially

    # Balanced classification head (prevents overfitting)
    model_stage3 = tf.keras.Sequential([
        base_model_stage2,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(DROPOUT_RATE),
        tf.keras.layers.Dense(384, activation='relu',
                              kernel_regularizer=tf.keras.regularizers.l2(1e-4)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(DROPOUT_RATE * 0.8),
        tf.keras.layers.Dense(num_unified_classes, activation='softmax', dtype='float32')
    ], name='unified_nutrient_model')

# Compile with JIT for speed
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE3)

model_stage3.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')],
    jit_compile=True  # XLA compilation - 10-20% faster
)

trainable_params = sum([tf.keras.backend.count_params(w) for w in model_stage3.trainable_weights])
print(f"\nüìä Trainable params: {trainable_params:,}")
print(f"üéØ Output classes: {num_unified_classes}")
print(f"‚ö° JIT/XLA compilation: Enabled")
print(f"üîÑ Resume training: {'Yes (epoch ' + str(initial_epoch) + ')' if resume_stage3 else 'No (fresh start)'}")

## üéØ Stage 3: Train on UNIFIED Nutrient Dataset


In [None]:
print("üöÄ Starting Stage 3: UNIFIED Nutrient Detection")
print(f"üåæ Training ALL {len(CROP_DATASETS)} crops | Epochs: {UNIFIED_EPOCHS} | LR: {LEARNING_RATE_STAGE3}")
if resume_stage3:
    print(f"üîÑ RESUMING from epoch {initial_epoch}")
print("="*60)

# Callbacks with TQDM progress
callbacks_stage3 = [
    # TQDM Progress with ETA
    TQDMProgressCallback(UNIFIED_EPOCHS, stage_name="Stage 3: Unified Nutrients"),

    # Early stopping - balanced patience
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1,
        min_delta=0.001
    ),

    # Reduce LR on plateau
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=2,
        min_lr=1e-7,
        verbose=1
    ),

    # Save best to local (fast)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    ),

    # Save best to Drive (persistent)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    )
]

# Train with progress tracking
stage3_start = datetime.now()

history_stage3 = model_stage3.fit(
    train_nutrient,
    validation_data=val_nutrient,
    epochs=UNIFIED_EPOCHS,
    initial_epoch=initial_epoch,
    callbacks=callbacks_stage3,
    verbose=0  # Disable default output, use TQDM instead
)

# Save training history for resume
history_dict = {k: [float(v) for v in vals] for k, vals in history_stage3.history.items()}
with open(os.path.join(DRIVE_CHECKPOINT_DIR, 'stage3_history.json'), 'w') as f:
    json.dump(history_dict, f)

# Calculate final stats
best_val_acc = max(history_stage3.history['val_accuracy'])
best_val_loss = min(history_stage3.history['val_loss'])
best_top3_acc = max(history_stage3.history['val_top3_acc'])
final_train_acc = history_stage3.history['accuracy'][-1]
stage3_time = datetime.now() - stage3_start

print(f"\n" + "="*60)
print(f"‚úÖ Stage 3 completed in {stage3_time}")
print(f"üìà Best val accuracy: {best_val_acc:.4f}")
print(f"üéØ Best top-3 accuracy: {best_top3_acc:.4f}")
print(f"üìâ Best val loss: {best_val_loss:.4f}")

# Total training time
if TRAINING_START_TIME:
    total_training_time = datetime.now() - TRAINING_START_TIME
    print(f"\n‚è±Ô∏è TOTAL TRAINING TIME: {total_training_time}")

# Check for overfitting/underfitting
gap = final_train_acc - best_val_acc
if gap > 0.20:
    print(f"\n‚ö†Ô∏è Overfitting detected (gap: {gap:.2%})")
elif best_val_acc < 0.5:
    print(f"\n‚ö†Ô∏è Possible underfitting (val_acc: {best_val_acc:.2%})")
else:
    print(f"\n‚úÖ Good generalization (gap: {gap:.2%})")

print(f"\nüíæ Model & history saved to Drive")

üöÄ Starting Stage 3: UNIFIED Nutrient Detection
üåæ Training ALL 12 crops | Epochs: 10 | LR: 0.0005

üöÄ Stage 3: Unified Nutrients Started
   Target: 10 epochs


üìà Stage 3: Unified Nutrients:   0%|                              | 0/10 [00:00<?, ?epoch/s]

  Epoch 1/10:   0%|                    | 0/1306 [00:00<?, ?batch/s]

KeyboardInterrupt: 

## üìà Stage 3 Results Visualization


In [None]:
# Quick training visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history_stage3.history['accuracy'], 'b-', label='Train')
axes[0].plot(history_stage3.history['val_accuracy'], 'r-', label='Val')
axes[0].set_title('Stage 3: Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_stage3.history['loss'], 'b-', label='Train')
axes[1].plot(history_stage3.history['val_loss'], 'r-', label='Val')
axes[1].set_title('Stage 3: Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'stage3_training_history.png'), dpi=150)
plt.show()

## üîç Model Evaluation & Confusion Matrix


In [None]:
# Quick evaluation (skip heavy confusion matrix for speed)
print("üîç Evaluating UNIFIED model...")
results = model_stage3.evaluate(val_nutrient, verbose=0)

print(f"\nüìä Validation Metrics:")
print(f"   Loss: {results[0]:.4f}")
print(f"   Accuracy: {results[1]:.4f}")
print(f"   Top-3 Accuracy: {results[2]:.4f}")

# Quick per-crop accuracy (sample-based for speed)
print(f"\nüåæ Per-Crop Performance (quick check):")
y_true, y_pred = [], []
for images, labels in val_nutrient.take(20):  # Sample only
    predictions = model_stage3.predict(images, verbose=0)
    y_true.extend(np.argmax(labels.numpy(), axis=1))
    y_pred.extend(np.argmax(predictions, axis=1))

for crop in list(CROP_DATASETS.keys())[:6]:  # First 6 crops
    crop_classes = [cls for cls in class_names if cls.startswith(f"{crop}_")]
    if not crop_classes:
        continue
    crop_indices = [class_names.index(cls) for cls in crop_classes]
    crop_mask = np.isin(y_true, crop_indices)
    if crop_mask.sum() > 0:
        crop_acc = (np.array(y_true)[crop_mask] == np.array(y_pred)[crop_mask]).mean()
        print(f"   {crop.upper():12s}: {crop_acc:.1%}")

# Save classification report
report = classification_report(y_true, y_pred, target_names=[class_names[i] for i in sorted(set(y_true))], output_dict=True, zero_division=0)
with open(os.path.join(OUTPUT_DIR, 'unified_classification_report.json'), 'w') as f:
    json.dump(report, f, indent=2)

print(f"\n‚úÖ Evaluation complete")

## üíæ Export to TensorFlow Lite for Mobile Deployment


In [None]:
print("üì¶ Converting to TensorFlow Lite...")
print(f"‚è±Ô∏è Session time: {get_session_time()}")
check_time_limit()  # Warn if approaching 3-hour limit

# CRITICAL: Disable mixed precision completely
print("üîÑ Disabling mixed precision for FP32 conversion...")
tf.keras.mixed_precision.set_global_policy('float32')
tf.keras.backend.clear_session()

# Load best model (will still have FP16 signatures from training)
best_model_path = os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras')
if not os.path.exists(best_model_path):
    best_model_path = os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras')

print("üì• Loading original model...")
original_model = tf.keras.models.load_model(best_model_path)

# Build the model by calling it with dummy input
print("üîß Building model with dummy input...")
dummy_input = tf.zeros((1, 224, 224, 3), dtype=tf.float32)
_ = original_model(dummy_input, training=False)
print(f"   ‚úÖ Model built successfully")

# Use concrete function with explicit FP32 signature
print("\n‚öôÔ∏è Creating TFLite converter with explicit FP32 signature...")

@tf.function(input_signature=[tf.TensorSpec(shape=[1, 224, 224, 3], dtype=tf.float32)])
def serving_fn(input_image):
    x = tf.cast(input_image, tf.float32)
    output = original_model(x, training=False)
    return tf.cast(output, tf.float32)

# Get concrete function
concrete_func = serving_fn.get_concrete_function()

# Convert using concrete function (bypasses model signature issues)
print("üí° Converting using concrete function (explicit FP32)...")
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # 8-bit weight quantization

# Try standard ops first
try:
    print("   Attempting standard TFLite ops...")
    tflite_model = converter.convert()
    print("   ‚úÖ Standard ops conversion successful!")
    uses_flex = False
except Exception as e:
    print(f"   ‚ö†Ô∏è Standard ops failed: {str(e)[:100]}...")
    print("   üîÑ Falling back to TF Select ops (flex delegates)...")
    
    # Enable TF Select ops as fallback
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS
    ]
    converter._experimental_lower_tensor_list_ops = False
    tflite_model = converter.convert()
    print("   ‚úÖ TF Select ops conversion successful!")
    uses_flex = True

# Save to both local and Drive
tflite_path = os.path.join(OUTPUT_DIR, 'fasalvaidya_unified.tflite')
tflite_drive_path = os.path.join(DRIVE_CHECKPOINT_DIR, 'fasalvaidya_unified.tflite')

with open(tflite_path, 'wb') as f:
    f.write(tflite_model)
with open(tflite_drive_path, 'wb') as f:
    f.write(tflite_model)

keras_size = os.path.getsize(best_model_path) / (1024 * 1024)
tflite_size = os.path.getsize(tflite_path) / (1024 * 1024)

print(f"\n‚úÖ Conversion complete!")
print(f"üìä Keras: {keras_size:.1f}MB ‚Üí TFLite: {tflite_size:.1f}MB ({(1-tflite_size/keras_size)*100:.0f}% smaller)")
print(f"üöÄ Single model for {len(CROP_DATASETS)} crops!")
print(f"‚ö° Optimized with 8-bit weight quantization")
if uses_flex:
    print(f"üì± Uses TF Select ops (requires TFLite with flex delegate)")
    print(f"   Note: App needs tensorflow-lite-select-tf-ops dependency")
else:
    print(f"‚úÖ Uses standard TFLite runtime (no flex ops needed)")
print(f"üîÑ FP32 input/output (mobile-friendly)")
print(f"\nüíæ Saved to:")
print(f"   Local: {tflite_path}")
print(f"   Drive: {tflite_drive_path} (persistent)")

## üß™ Test TFLite Model Inference


In [None]:
# Quick TFLite verification
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("üîç TFLite Model:")
print(f"   Input: {input_details[0]['shape']} ({input_details[0]['dtype']})")
print(f"   Output: {output_details[0]['shape']} ({num_unified_classes} classes)")

# Quick test
for images, labels in val_nutrient.take(1):
    test_image = images[0].numpy()
    input_data = np.expand_dims(test_image, axis=0).astype(np.float32)

    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])

    pred_idx = np.argmax(output[0])
    true_idx = np.argmax(labels[0].numpy())

    print(f"\nüß™ Quick test:")
    print(f"   True: {class_names[true_idx]}")
    print(f"   Pred: {class_names[pred_idx]} ({output[0][pred_idx]:.1%})")
    print(f"   {'‚úÖ CORRECT' if pred_idx == true_idx else '‚ùå INCORRECT'}")
    break

print("\n‚úÖ TFLite model verified!")

## üì§ Save Model Metadata & Class Labels


In [None]:
# Save metadata and labels (to both local and Drive)
print("üìù Saving metadata...")

crop_class_mapping = {crop: [c for c in class_names if c.startswith(f"{crop}_")]
                      for crop in CROP_DATASETS.keys()}

metadata = {
    'model_type': 'unified_multi_crop',
    'model_version': '2.0',
    'training_date': datetime.now().isoformat(),
    'architecture': 'MobileNetV2',
    'supported_crops': list(CROP_DATASETS.keys()),
    'num_crops': len(CROP_DATASETS),
    'input_shape': [IMG_SIZE, IMG_SIZE, 3],
    'num_classes': num_unified_classes,
    'class_names': class_names,
    'crop_class_mapping': crop_class_mapping,
    'metrics': {'accuracy': float(results[1]), 'top3_accuracy': float(results[2])},
    'preprocessing': {'method': 'MobileNetV2', 'normalization': '[-1, 1]'},
    'training_config': {
        'batch_size': BATCH_SIZE,
        'plantvillage_epochs': PLANTVILLAGE_EPOCHS,
        'unified_epochs': UNIFIED_EPOCHS,
        'dropout_rate': DROPOUT_RATE,
        'optimizations': ['mixed_precision_fp16', 'jit_compile', 'autotune_prefetch']
    }
}

# Save to both locations
for save_dir in [OUTPUT_DIR, DRIVE_CHECKPOINT_DIR]:
    with open(os.path.join(save_dir, 'unified_model_metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    with open(os.path.join(save_dir, 'labels.txt'), 'w') as f:
        f.write('\n'.join(class_names))

print(f"‚úÖ Saved: metadata.json, labels.txt")
print(f"üìä {len(CROP_DATASETS)} crops, {num_unified_classes} classes")
print(f"üíæ Saved to both local and Drive (persistent)")

## üì¶ Download Models to Local Machine


In [None]:
# Create and download zip
import shutil

# Session summary
print("=" * 60)
print("üéâ TRAINING SESSION COMPLETE!")
print("=" * 60)
print(f"‚è±Ô∏è Total session time: {get_session_time()}")
print(f"üìä Final validation accuracy: {results[1]:.4f}")
print(f"üéØ Final top-3 accuracy: {results[2]:.4f}")

zip_filename = 'fasalvaidya_unified_model'
shutil.make_archive(f'/content/{zip_filename}', 'zip', OUTPUT_DIR)

print(f"\nüì¶ Created: {zip_filename}.zip")
print(f"\nüìÇ Contents:")
print(f"   üì± fasalvaidya_unified.tflite ({tflite_size:.1f}MB)")
print(f"   üíæ unified_nutrient_best.keras")
print(f"   üìÑ unified_model_metadata.json")
print(f"   üè∑Ô∏è labels.txt ({num_unified_classes} classes)")
print(f"\nüåæ Supports: {', '.join(list(CROP_DATASETS.keys())[:6])}...")

print(f"\nüíæ ALSO SAVED TO DRIVE (persistent):")
print(f"   {DRIVE_CHECKPOINT_DIR}")
print(f"   ‚úÖ Can resume training if disconnected!")

from google.colab import files
files.download(f'/content/{zip_filename}.zip')
print(f"\n‚¨áÔ∏è Download started!")

## üéâ Training Complete!

### üöÄ Performance Optimizations Applied

This notebook is **fully optimized** for fastest training with consistent data architecture:

| Optimization | Benefit |
|--------------|---------|
| **Float32 Precision** | No mixed precision bugs, stable training |
| **Local SSD Data** | 10-50x faster I/O vs Google Drive |
| **XLA/JIT Compilation** | 10-20% faster training |
| **AUTOTUNE Prefetch** | Maximizes GPU utilization |
| **Parallel Data Loading** | Uses all CPU cores |
| **Smart Caching** | Validation set cached in memory |
| **Optimized Augmentation** | Fast TF ops with XLA compilation |
| **Batch Size 32** | Better GPU memory utilization |

### ‚è±Ô∏è Expected Training Time (4 Crops)

| Stage | Time | Description |
|-------|------|-------------|
| Data Copy to SSD | 3-5 min | One-time setup per session |
| Stage 2: PlantVillage | 15-20 min | 8 epochs, transfer learning base |
| Stage 3: Unified Nutrient | 30-40 min | 15 epochs, 4 crops combined |
| Export & Download | 2-3 min | Model conversion and download |
| **TOTAL** | **~50-70 min** | Complete end-to-end |

### üìä Training Configuration

```python
# Dataset
CROPS = ['rice', 'wheat', 'tomato', 'maize']
IMG_SIZE = 224
BATCH_SIZE = 32  # Optimized for GPU

# Training
PLANTVILLAGE_EPOCHS = 8
UNIFIED_EPOCHS = 15
LEARNING_RATE_STAGE2 = 1e-3
LEARNING_RATE_STAGE3 = 5e-4
DROPOUT_RATE = 0.3

# Performance
Precision: float32 (full compatibility)
XLA/JIT: Enabled
AUTOTUNE: Enabled
Workers: All CPU cores
Data Location: Local SSD
```

### üì¶ Output Files

| File | Size | Description |
|------|------|-------------|
| `unified_savedmodel/` | ~20MB | SavedModel format (for backend) |
| `fasalvaidya_unified.tflite` | ~4MB | Mobile TFLite model |
| `unified_nutrient_best.keras` | ~15MB | Full Keras checkpoint |
| `unified_model_metadata.json` | <1KB | Model info & class mappings |
| `labels.txt` | <1KB | All {num_classes} class labels |

### üîÑ Checkpoints Saved To

- **Local:** `/content/fasalvaidya_unified_model/`
- **Drive (Persistent):** `/content/drive/MyDrive/FasalVaidya_Checkpoints/`

If training is interrupted, simply re-run from the checkpoint cells - it will automatically resume!

---

### üì± Next Steps

1. **Download** the `unified_savedmodel` folder
2. **Copy** to `backend/ml/models/unified_savedmodel/`
3. **Test** with: `python backend/test_unified.py`
4. **Verify** predictions show 70-95% confidence (not 0.2%)
5. **Run** the Flask backend and Expo frontend

Expected output:
```
N score: 78.5% (moderate deficiency)
P score: 12.3% (healthy)
K score: 5.2% (healthy)
Detected: rice_Nitrogen(N)
Confidence: 78.5%
```