## üì¶ Setup & Environment


In [None]:
# Install required packages
!pip install -q tensorflow>=2.15.0 kaggle opendatasets scikit-learn matplotlib seaborn tqdm

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
import threading
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from pathlib import Path
from datetime import datetime, timedelta
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tqdm.auto import tqdm
import time

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")
print(f"CPU cores available: {multiprocessing.cpu_count()}")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# =============================================================
# ‚è±Ô∏è SESSION TIME TRACKER (Important for 3-hour limit!)
# =============================================================
SESSION_START_TIME = datetime.now()
TRAINING_START_TIME = None

def get_session_time():
    """Get elapsed session time"""
    elapsed = datetime.now() - SESSION_START_TIME
    hours = elapsed.seconds // 3600
    minutes = (elapsed.seconds % 3600) // 60
    return f"{hours}h {minutes}m"

def get_eta(current_epoch, total_epochs, epoch_time):
    """Calculate ETA for training completion"""
    remaining_epochs = total_epochs - current_epoch
    eta_seconds = remaining_epochs * epoch_time
    eta = timedelta(seconds=int(eta_seconds))
    return str(eta)

def check_time_limit(warn_minutes=150):
    """Warn if approaching 3-hour limit (180 min)"""
    elapsed = (datetime.now() - SESSION_START_TIME).seconds // 60
    remaining = 180 - elapsed
    if elapsed >= warn_minutes:
        print(f"‚ö†Ô∏è WARNING: {remaining} minutes remaining before typical Colab disconnect!")
        print(f"   Consider saving checkpoints and downloading results now.")
        return True
    return False

print(f"\n‚è±Ô∏è Session started at: {SESSION_START_TIME.strftime('%H:%M:%S')}")
print(f"   Target: Complete training within 1-1.5 hours")
print(f"   Checkpoints auto-save to Drive (training resumes from checkpoint)")

## üîë Configuration

### Set your crop type and paths here


In [None]:
# ========== UNIFIED MODEL - ALL CROPS ==========
# Root path to your "Leaf Nutrient Data Sets" folder on Google Drive
NUTRIENT_DATASETS_ROOT = '/content/drive/MyDrive/Leaf Nutrient Data Sets'

# ALL 12 CROPS - Combined into ONE model automatically!
CROP_DATASETS = {
    'rice': 'Rice Nutrients',
    'wheat': 'Wheat Nitrogen',
    'tomato': 'Tomato Nutrients',
    'maize': 'Maize Nutrients',
    'banana': 'Banana leaves Nutrient',
    'coffee': 'Coffee Nutrients',
    'cucumber': 'Cucumber Nutrients',
    'eggplant': 'EggPlant Nutrients',
    'ashgourd': 'Ashgourd Nutrients',
    'bittergourd': 'Bittergourd Nutrients',
    'ridgegourd': 'Ridgegourd',
    'snakegourd': 'Snakegourd Nutrients'
}


# =============================================================
# üöÄ FAST MVP PROTOTYPING MODE (Train in 1-1.5 hours)
# =============================================================
IMG_SIZE = 224
BATCH_SIZE = 16  # User requested batch size 16

# Training epochs (Optimized for 1-1.5 hour completion)
PLANTVILLAGE_EPOCHS = 5    # Stage 2: Quick transfer learning
UNIFIED_EPOCHS = 10        # Stage 3: Nutrient detection

# Learning rates (Optimized for fast convergence)
LEARNING_RATE_STAGE2 = 1e-3  # Aggressive learning rate
LEARNING_RATE_STAGE3 = 5e-4  # Fast fine-tuning

# Regularization
DROPOUT_RATE = 0.25  # Balanced dropout

# Multiprocessing settings
NUM_WORKERS = multiprocessing.cpu_count()  # Use all CPU cores
PREFETCH_BUFFER = 4  # Prefetch batches

# Enable mixed precision training (2x speedup on T4)
tf.keras.mixed_precision.set_global_policy('mixed_float16')
print("üöÄ Mixed precision enabled (FP16)")

# Output paths (persistent on Drive for checkpoint resume)
OUTPUT_DIR = '/content/fasalvaidya_unified_model'
DRIVE_CHECKPOINT_DIR = '/content/drive/MyDrive/FasalVaidya_Checkpoints'
GRADCAM_DIR = '/content/gradcam_outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DRIVE_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(GRADCAM_DIR, exist_ok=True)

print("‚ö° FAST MVP PROTOTYPING MODE (1-1.5 hour target)")
print("="*60)
print(f"Training ONE model for ALL {len(CROP_DATASETS)} crops")
print(f"\nüöÄ RAPID PROTOTYPING SETTINGS:")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"   - Mixed precision: FP16 (2x faster)")
print(f"   - XLA/JIT compilation: Enabled")
print(f"   - Multiprocessing workers: {NUM_WORKERS}")
print(f"   - Prefetch buffer: {PREFETCH_BUFFER}")
print(f"   - Stage 2: {PLANTVILLAGE_EPOCHS} epochs")
print(f"   - Stage 3: {UNIFIED_EPOCHS} epochs")
print(f"   - Grad-CAM: Enabled (visualize during training)")
print(f"\n‚ö° Expected total time: ~1-1.5 hours for full training")

## üî• Grad-CAM Visualization Module

In [None]:
# =============================================================
# üî• GRAD-CAM VISUALIZATION MODULE
# =============================================================
# Generates heatmaps showing what the model focuses on

import cv2

class GradCAM:
    """Grad-CAM implementation for model interpretability"""
    
    def __init__(self, model, layer_name=None):
        """
        Initialize Grad-CAM
        Args:
            model: Keras model
            layer_name: Name of conv layer to visualize (auto-detect if None)
        """
        self.model = model
        self.layer_name = layer_name or self._find_target_layer()
        
    def _find_target_layer(self):
        """Auto-detect the last conv layer in the model"""
        for layer in reversed(self.model.layers):
            if hasattr(layer, 'layers'):  # Sequential or Functional inside
                for sublayer in reversed(layer.layers):
                    if len(sublayer.output_shape) == 4:  # Conv layer
                        return sublayer.name
            if len(layer.output_shape) == 4:
                return layer.name
        # Fallback for MobileNetV2
        return 'Conv_1'  # Last conv layer in MobileNetV2
    
    def get_gradcam_heatmap(self, img_array, pred_index=None):
        """
        Generate Grad-CAM heatmap
        Args:
            img_array: Preprocessed image (1, H, W, 3)
            pred_index: Class index to visualize (None = predicted class)
        Returns:
            heatmap: (H, W) numpy array
        """
        # Find the target layer
        try:
            # For Sequential model with base model as first layer
            if hasattr(self.model.layers[0], 'get_layer'):
                target_layer = self.model.layers[0].get_layer(self.layer_name)
            else:
                target_layer = self.model.get_layer(self.layer_name)
        except:
            # Fallback: use output of first layer (base model)
            target_layer = self.model.layers[0]
            if hasattr(target_layer, 'layers'):
                for layer in reversed(target_layer.layers):
                    if len(layer.output_shape) == 4:
                        target_layer = layer
                        break
        
        # Create gradient model
        grad_model = tf.keras.models.Model(
            inputs=self.model.input,
            outputs=[target_layer.output, self.model.output]
        )
        
        # Compute gradients
        with tf.GradientTape() as tape:
            conv_outputs, predictions = grad_model(img_array)
            if pred_index is None:
                pred_index = tf.argmax(predictions[0])
            class_channel = predictions[:, pred_index]
        
        grads = tape.gradient(class_channel, conv_outputs)
        
        # Global average pooling of gradients
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
        
        # Weight feature maps by importance
        conv_outputs = conv_outputs[0]
        heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
        heatmap = tf.squeeze(heatmap)
        
        # Normalize heatmap
        heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
        return heatmap.numpy()
    
    def overlay_heatmap(self, img, heatmap, alpha=0.4):
        """
        Overlay heatmap on original image
        Args:
            img: Original image (H, W, 3) in [0, 255]
            heatmap: Grad-CAM heatmap (H, W)
            alpha: Overlay transparency
        Returns:
            superimposed: Image with heatmap overlay
        """
        # Resize heatmap to image size
        heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
        
        # Convert to RGB heatmap
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Superimpose
        superimposed = heatmap * alpha + img * (1 - alpha)
        superimposed = np.clip(superimposed, 0, 255).astype(np.uint8)
        
        return superimposed


def visualize_gradcam_batch(model, images, labels, class_names, num_samples=4, save_path=None):
    """
    Visualize Grad-CAM for a batch of images
    Args:
        model: Trained model
        images: Batch of images (B, H, W, 3) - preprocessed
        labels: One-hot labels (B, num_classes)
        class_names: List of class names
        num_samples: Number of samples to visualize
        save_path: Path to save the figure
    """
    gradcam = GradCAM(model)
    
    num_samples = min(num_samples, len(images))
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        img = images[i]
        label = labels[i]
        
        # Get prediction
        pred = model.predict(np.expand_dims(img, 0), verbose=0)
        pred_idx = np.argmax(pred[0])
        true_idx = np.argmax(label)
        confidence = pred[0][pred_idx]
        
        # Generate heatmap
        try:
            heatmap = gradcam.get_gradcam_heatmap(np.expand_dims(img, 0), pred_idx)
        except Exception as e:
            print(f"Grad-CAM error for sample {i}: {e}")
            heatmap = np.zeros((7, 7))
        
        # Convert image back to displayable format
        display_img = ((img + 1) * 127.5).astype(np.uint8)  # Undo MobileNet preprocessing
        
        # Original image
        axes[i, 0].imshow(display_img)
        axes[i, 0].set_title(f"True: {class_names[true_idx][:20]}", fontsize=10)
        axes[i, 0].axis('off')
        
        # Heatmap
        axes[i, 1].imshow(heatmap, cmap='jet')
        axes[i, 1].set_title("Grad-CAM Heatmap", fontsize=10)
        axes[i, 1].axis('off')
        
        # Overlay
        overlay = gradcam.overlay_heatmap(display_img, heatmap)
        axes[i, 2].imshow(overlay)
        correct = "‚úÖ" if pred_idx == true_idx else "‚ùå"
        axes[i, 2].set_title(f"Pred: {class_names[pred_idx][:15]} ({confidence:.1%}) {correct}", fontsize=10)
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"üíæ Grad-CAM saved to: {save_path}")
    
    plt.show()
    plt.close()


print("‚úÖ Grad-CAM module loaded!")
print("   - Auto-detects last conv layer")
print("   - Generates heatmaps showing model attention")
print("   - Overlays on original images for interpretability")

## üíæ Mount Google Drive


In [None]:
from google.colab import drive
import os
import glob

# Robust check for already-mounted Drive
drive_path = '/content/drive'
is_mounted = False

if os.path.exists(drive_path):
    # Check if directory has content (indicates already mounted)
    try:
        if os.listdir(drive_path):
            is_mounted = True
            print("‚úÖ Google Drive already mounted!")
    except:
        pass

if not is_mounted:
    print("üìÅ Mounting Google Drive...")
    # Clean up if directory exists but is empty
    if os.path.exists(drive_path) and not os.listdir(drive_path):
        os.rmdir(drive_path)
    drive.mount(drive_path)
    print("‚úÖ Google Drive mounted successfully!")

# =============================================================
# üîç SMART PATH DETECTION: Search All Possible Locations
# =============================================================
print("\nüîç Searching for 'Leaf Nutrient Data Sets' folder...")

# List of possible locations to check
search_paths = [
    '/content/drive/MyDrive/Leaf Nutrient Data Sets',
    '/content/drive/Shareddrives/Leaf Nutrient Data Sets',
    '/content/drive/Shared drives/Leaf Nutrient Data Sets',
]

# Also search for shortcuts and nested locations
mydrive_base = '/content/drive/MyDrive'
if os.path.exists(mydrive_base):
    # Search in .shortcut-targets-by-id (where "Shared with me" shortcuts appear)
    shortcut_dir = os.path.join(mydrive_base, '.shortcut-targets-by-id')
    if os.path.exists(shortcut_dir):
        try:
            for folder_id in os.listdir(shortcut_dir):
                target_path = os.path.join(shortcut_dir, folder_id, 'Leaf Nutrient Data Sets')
                if os.path.exists(target_path):
                    search_paths.append(target_path)
        except:
            pass

# Try each location
found_location = None

for search_path in search_paths:
    if os.path.exists(search_path):
        # Verify it has crop folders
        try:
            contents = os.listdir(search_path)
            crop_folders = [f for f in contents if os.path.isdir(os.path.join(search_path, f))]
            if len(crop_folders) >= 5:  # Should have at least 5 crop folders
                print(f"‚úÖ Found at: {search_path}")
                print(f"   Contains {len(crop_folders)} folders")
                found_location = search_path
                break
        except:
            pass

if found_location:
    NUTRIENT_DATASETS_ROOT = found_location
    print(f"\n‚úÖ Using dataset location: {NUTRIENT_DATASETS_ROOT}")
else:
    print(f"\n‚ùå 'Leaf Nutrient Data Sets' folder NOT FOUND!")
    print(f"\nüìÇ What's in your Drive:")
    try:
        mydrive_items = os.listdir(mydrive_base)[:10]  # Show first 10 items
        for item in mydrive_items:
            item_path = os.path.join(mydrive_base, item)
            if os.path.isdir(item_path):
                print(f"   üìÅ {item}")
            else:
                print(f"   üìÑ {item}")
    except:
        print("   (Could not list Drive contents)")
    
    print(f"\n‚ö†Ô∏è FOLDER IS IN 'SHARED WITH ME' - NOT ACCESSIBLE!")
    print(f"\n‚úÖ SOLUTION: Add shortcut to My Drive")
    print(f"   1. Open Google Drive in browser: https://drive.google.com")
    print(f"   2. Click 'Shared with me' in left sidebar")
    print(f"   3. Right-click 'Leaf Nutrient Data Sets' folder")
    print(f"   4. Select 'Add shortcut to Drive' or 'Organize' > 'Add shortcut'")
    print(f"   5. Choose 'My Drive' root (don't put it in a subfolder)")
    print(f"   6. Click 'Add' or 'Add shortcut'")
    print(f"   7. Come back here and re-run this cell")
    print(f"\nüí° After adding shortcut, the folder will appear at:")
    print(f"   /content/drive/MyDrive/Leaf Nutrient Data Sets")

# Verify ALL crop datasets exist (only if folder found)
if found_location:
    print("\nüîç Verifying crop datasets...")
    missing_crops = []
    for crop, folder_name in CROP_DATASETS.items():
        crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
        if os.path.exists(crop_path):
            num_classes = len([d for d in os.listdir(crop_path) if os.path.isdir(os.path.join(crop_path, d))])
            print(f"‚úÖ {crop.upper()}: {num_classes} classes")
        else:
            print(f"‚ùå {crop.upper()}: NOT FOUND")
            missing_crops.append(crop)

    if missing_crops:
        print(f"\n‚ö†Ô∏è WARNING: {len(missing_crops)} crop(s) not found: {', '.join(missing_crops)}")
    else:
        print(f"\n‚úÖ All {len(CROP_DATASETS)} crop datasets verified!")
        

## üöÄ Speed Optimization: Copy Data to Local Disk

**This is the BIGGEST speed boost!** Copying from Drive to local SSD speeds up training 10-50x.

In [None]:
# =============================================================
# üöÄ MASSIVE SPEEDUP: Copy Data from Drive to Local SSD
# =============================================================
# Reading from Google Drive is SLOW (network I/O)
# Copying to /content/ uses Colab's fast local SSD = 10-50x faster!

import shutil
import time

def copy_folder_to_local(src_path, dest_name, force_copy=False):
    """Copy folder from Drive to local SSD for fast I/O"""
    local_path = f"/content/{dest_name}"
    
    # Skip if already exists (for session resume)
    if os.path.exists(local_path) and not force_copy:
        num_items = len(os.listdir(local_path))
        if num_items > 0:
            print(f"‚úÖ {dest_name} already on local disk ({num_items} items)")
            return local_path
    
    if not os.path.exists(src_path):
        print(f"‚ö†Ô∏è Source not found: {src_path}")
        return src_path  # Return original path as fallback
    
    print(f"üöÄ Copying {dest_name} to local SSD...")
    start = time.time()
    
    if os.path.exists(local_path):
        shutil.rmtree(local_path)
    
    shutil.copytree(src_path, local_path)
    
    elapsed = time.time() - start
    size_mb = sum(os.path.getsize(os.path.join(root, f)) 
                  for root, _, files in os.walk(local_path) 
                  for f in files) / (1024 * 1024)
    
    print(f"‚úÖ Copied {dest_name}: {size_mb:.0f}MB in {elapsed:.1f}s")
    return local_path

# Copy Nutrient datasets to local SSD
print("=" * 60)
print("üöÄ COPYING DATASETS TO LOCAL SSD (One-time per session)")
print("=" * 60)
print("‚è≥ This takes 2-5 minutes but saves HOURS of training time!\n")

LOCAL_NUTRIENT_ROOT = '/content/local_nutrient_datasets'
os.makedirs(LOCAL_NUTRIENT_ROOT, exist_ok=True)

copy_success = 0
copy_failed = []

for crop, folder_name in CROP_DATASETS.items():
    src = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
    dst = os.path.join(LOCAL_NUTRIENT_ROOT, folder_name)
    
    if os.path.exists(src):
        if not os.path.exists(dst):
            try:
                shutil.copytree(src, dst)
                print(f"  ‚úÖ {crop}: Copied")
                copy_success += 1
            except Exception as e:
                print(f"  ‚ùå {crop}: Failed ({e})")
                copy_failed.append(crop)
        else:
            print(f"  ‚è© {crop}: Already local")
            copy_success += 1
    else:
        print(f"  ‚ö†Ô∏è {crop}: Not found on Drive")
        copy_failed.append(crop)

# Update root path to local copy
NUTRIENT_DATASETS_ROOT = LOCAL_NUTRIENT_ROOT

print(f"\n‚úÖ {copy_success}/{len(CROP_DATASETS)} crops copied to local SSD")
if copy_failed:
    print(f"‚ö†Ô∏è Failed: {', '.join(copy_failed)}")
print("üöÄ Training will now be 10-50x FASTER!")

## üå± Stage 1: Download PlantVillage Dataset from Kaggle


In [None]:
# Setup Kaggle credentials
# IMPORTANT: You need to manually download kaggle.json FIRST!
# 
# üìù HOW TO GET kaggle.json:
# 1. Go to https://www.kaggle.com/settings
# 2. Scroll down to "API" section
# 3. Click "Create New Token" button
# 4. This will DOWNLOAD a file called "kaggle.json" to your computer
# 5. Find the downloaded file (usually in your Downloads folder)
# 6. Then come back here and upload it when prompted below
#
# ‚ö†Ô∏è NOTE: If you only see the API key on screen but no download happened,
#    click "Create New Token" again - it should download the file

from google.colab import files

print("=" * 70)
print("üì§ UPLOAD YOUR kaggle.json FILE")
print("=" * 70)
print("\nüìù If you haven't downloaded it yet:")
print("   1. Go to: https://www.kaggle.com/settings")
print("   2. Scroll to 'API' section")
print("   3. Click 'Create New Token' (downloads kaggle.json)")
print("   4. Find the file in your Downloads folder")
print("   5. Click 'Choose Files' below and select it")
print("\n‚è≥ Waiting for your kaggle.json file...\n")

uploaded = files.upload()

# Verify the file was uploaded
if 'kaggle.json' not in uploaded:
    print("\n‚ùå ERROR: kaggle.json was not uploaded!")
    print("   Please make sure you selected the correct file.")
    raise FileNotFoundError("kaggle.json not found in uploaded files")

# Move kaggle.json to the correct location
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("\n‚úÖ Kaggle credentials configured successfully!")
print("üìÅ File saved to: ~/.kaggle/kaggle.json")


In [None]:
# Download PlantVillage dataset from Kaggle (SKIP IF ALREADY EXISTS)
import opendatasets as od

PLANTVILLAGE_URL = 'https://www.kaggle.com/datasets/emmarex/plantdisease'
PLANTVILLAGE_PATH = '/content/plantvillage'

# Known possible paths after download
POSSIBLE_PATHS = [
    os.path.join(PLANTVILLAGE_PATH, 'plantdisease', 'PlantVillage'),
    os.path.join(PLANTVILLAGE_PATH, 'PlantVillage'),
    os.path.join(PLANTVILLAGE_PATH, 'plantdisease', 'plantvillage', 'PlantVillage'),
]

def find_plantvillage_dataset():
    """Find PlantVillage dataset if it exists"""
    for path in POSSIBLE_PATHS:
        if os.path.exists(path) and os.path.isdir(path):
            subdirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
            if len(subdirs) >= 15:
                sample_dir = os.path.join(path, subdirs[0])
                sample_files = [f for f in os.listdir(sample_dir)
                              if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                if len(sample_files) > 0:
                    return path
    return None

# Check if dataset already exists
existing_path = find_plantvillage_dataset()

if existing_path:
    print("‚úÖ PlantVillage dataset ALREADY EXISTS! Skipping download...")
    print(f"üìÅ Using cached dataset at: {existing_path}")
    PLANTVILLAGE_PATH = existing_path
else:
    print("üì• Downloading PlantVillage dataset (54,305 images)...")
    print("‚è≥ This will take 3-5 minutes (first time only)...")
    
    od.download(PLANTVILLAGE_URL, data_dir=PLANTVILLAGE_PATH)
    
    print("\nüîç Locating dataset structure...")
    
    # Find the dataset path
    dataset_root = find_plantvillage_dataset()
    
    if not dataset_root:
        # Search recursively as fallback
        for root, dirs, files in os.walk(PLANTVILLAGE_PATH):
            if len(dirs) >= 15:
                has_images = False
                for d in dirs[:3]:
                    dir_path = os.path.join(root, d)
                    if os.path.isdir(dir_path):
                        dir_files = os.listdir(dir_path)
                        if any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in dir_files):
                            has_images = True
                            break
                if has_images:
                    dataset_root = root
                    break
    
    if dataset_root:
        PLANTVILLAGE_PATH = dataset_root
    else:
        raise FileNotFoundError("‚ùå PlantVillage dataset not found after download")

# Verify dataset
class_dirs = [d for d in os.listdir(PLANTVILLAGE_PATH) 
              if os.path.isdir(os.path.join(PLANTVILLAGE_PATH, d))]
num_classes = len(class_dirs)

print(f"\n‚úÖ PlantVillage dataset ready!")
print(f"üìÅ Path: {PLANTVILLAGE_PATH}")
print(f"üåø Classes: {num_classes}")

# Quick image count
total_images = sum(len([f for f in os.listdir(os.path.join(PLANTVILLAGE_PATH, cls))
                        if f.lower().endswith(('.jpg', '.jpeg', '.png'))]) 
                   for cls in class_dirs[:5])
print(f"üìä Sample: First 5 classes have {total_images:,} images")

## üìä Data Exploration & Preparation


In [None]:
# Analyze PlantVillage dataset
plantvillage_classes = sorted(os.listdir(PLANTVILLAGE_PATH))
print(f"üå± PlantVillage Dataset:")
print(f"Total classes: {len(plantvillage_classes)}")
print(f"\nSample classes:")
for cls in plantvillage_classes[:5]:
    class_path = os.path.join(PLANTVILLAGE_PATH, cls)
    if os.path.isdir(class_path):
        num_images = len(os.listdir(class_path))
        print(f"  - {cls}: {num_images} images")

# Build unified dataset info
print(f"\nüåæ UNIFIED Nutrient Dataset (ALL {len(CROP_DATASETS)} crops):")
total_classes = 0
total_images = 0

for crop, folder_name in CROP_DATASETS.items():
    crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
    if os.path.exists(crop_path):
        crop_classes = [d for d in os.listdir(crop_path) if os.path.isdir(os.path.join(crop_path, d))]
        crop_images = sum([len([f for f in os.listdir(os.path.join(crop_path, cls)) 
                                if f.lower().endswith(('.jpg', '.jpeg', '.png'))]) 
                          for cls in crop_classes])
        total_classes += len(crop_classes)
        total_images += crop_images
        print(f"  {crop.upper()}: {len(crop_classes)} classes, {crop_images} images")

print(f"\nüìä UNIFIED TOTALS:")
print(f"  Total classes: {total_classes}")
print(f"  Total images: {total_images}")
print(f"  Class format: {{crop}}_{{deficiency}} (e.g., rice_N, wheat_healthy)")


## üî® Create Data Pipelines


In [None]:
# =============================================================
# üöÄ OPTIMIZED DATA PIPELINE (Multiprocessing + Threading)
# =============================================================
# Parallel data loading with progress tracking

AUTOTUNE = tf.data.AUTOTUNE

def create_dataset(data_dir, img_size, batch_size, validation_split=0.2, subset=None):
    """Create dataset with progress tracking"""
    print(f"üì¶ Loading {subset} dataset from {data_dir}...")
    return tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,
        validation_split=validation_split,
        subset=subset,
        seed=42,
        image_size=(img_size, img_size),
        batch_size=batch_size,
        label_mode='categorical',
        shuffle=True
    )

@tf.function
def augment_training(image, label):
    """Training augmentation - optimized for speed"""
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, 0.15)
    image = tf.image.random_contrast(image, 0.9, 1.1)
    return image, label

@tf.function  
def normalize_mobilenet(image, label):
    """Normalize for MobileNetV2 [-1, 1]"""
    image = tf.cast(image, tf.float32)
    image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
    return image, label

def build_pipeline(dataset, is_training=True):
    """
    Optimized pipeline with multiprocessing
    """
    # Use all CPU cores for parallel processing
    options = tf.data.Options()
    options.threading.private_threadpool_size = NUM_WORKERS
    options.threading.max_intra_op_parallelism = NUM_WORKERS
    dataset = dataset.with_options(options)
    
    if is_training:
        dataset = dataset.map(augment_training, num_parallel_calls=AUTOTUNE)
    
    dataset = dataset.map(normalize_mobilenet, num_parallel_calls=AUTOTUNE)
    dataset = dataset.prefetch(buffer_size=PREFETCH_BUFFER)
    
    return dataset


# =============================================================
# üìä TQDM PROGRESS CALLBACK (Real-time ETA)
# =============================================================
class TQDMProgressCallback(tf.keras.callbacks.Callback):
    """Custom callback with tqdm progress bars and ETA tracking"""
    
    def __init__(self, total_epochs, stage_name="Training"):
        super().__init__()
        self.total_epochs = total_epochs
        self.stage_name = stage_name
        self.epoch_pbar = None
        self.batch_pbar = None
        self.epoch_times = []
        self.stage_start_time = None
        
    def on_train_begin(self, logs=None):
        self.stage_start_time = time.time()
        print(f"\nüöÄ {self.stage_name} Started")
        print(f"   Target: {self.total_epochs} epochs")
        self.epoch_pbar = tqdm(
            total=self.total_epochs,
            desc=f"üìà {self.stage_name}",
            unit="epoch",
            position=0,
            leave=True,
            bar_format='{l_bar}{bar:30}{r_bar}{bar:-10b}'
        )
        
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start_time = time.time()
        # Get total batches
        if hasattr(self.model, '_train_counter'):
            total_batches = self.params.get('steps', 0)
        else:
            total_batches = self.params.get('steps', 0)
        
        self.batch_pbar = tqdm(
            total=total_batches,
            desc=f"  Epoch {epoch+1}/{self.total_epochs}",
            unit="batch",
            position=1,
            leave=False,
            bar_format='{l_bar}{bar:20}{r_bar}'
        )
        
    def on_batch_end(self, batch, logs=None):
        if self.batch_pbar:
            self.batch_pbar.update(1)
            # Update metrics in progress bar
            loss = logs.get('loss', 0)
            acc = logs.get('accuracy', 0)
            self.batch_pbar.set_postfix({
                'loss': f'{loss:.4f}',
                'acc': f'{acc:.3f}'
            })
    
    def on_epoch_end(self, epoch, logs=None):
        if self.batch_pbar:
            self.batch_pbar.close()
        
        epoch_time = time.time() - self.epoch_start_time
        self.epoch_times.append(epoch_time)
        
        # Calculate ETA
        avg_epoch_time = np.mean(self.epoch_times)
        remaining_epochs = self.total_epochs - (epoch + 1)
        eta_seconds = remaining_epochs * avg_epoch_time
        eta = str(timedelta(seconds=int(eta_seconds)))
        
        # Total elapsed time
        total_elapsed = time.time() - self.stage_start_time
        elapsed_str = str(timedelta(seconds=int(total_elapsed)))
        
        # Update epoch progress bar
        self.epoch_pbar.update(1)
        
        val_loss = logs.get('val_loss', 0)
        val_acc = logs.get('val_accuracy', 0)
        train_acc = logs.get('accuracy', 0)
        
        self.epoch_pbar.set_postfix({
            'val_acc': f'{val_acc:.3f}',
            'train_acc': f'{train_acc:.3f}',
            'ETA': eta,
            'elapsed': elapsed_str
        })
        
        # Print detailed epoch summary
        print(f"\n   ‚úÖ Epoch {epoch+1}: val_acc={val_acc:.4f}, val_loss={val_loss:.4f}, "
              f"time={epoch_time:.1f}s, ETA={eta}")
        
    def on_train_end(self, logs=None):
        if self.epoch_pbar:
            self.epoch_pbar.close()
        total_time = time.time() - self.stage_start_time
        print(f"\n‚úÖ {self.stage_name} Complete!")
        print(f"   Total time: {str(timedelta(seconds=int(total_time)))}")
        print(f"   Avg epoch time: {np.mean(self.epoch_times):.1f}s")


# =============================================================
# üî• GRAD-CAM CALLBACK (Visualize during training)
# =============================================================
class GradCAMCallback(tf.keras.callbacks.Callback):
    """Generate Grad-CAM visualizations during training"""
    
    def __init__(self, validation_data, class_names, save_dir, frequency=2):
        """
        Args:
            validation_data: Validation dataset
            class_names: List of class names
            save_dir: Directory to save Grad-CAM images
            frequency: Generate Grad-CAM every N epochs
        """
        super().__init__()
        self.validation_data = validation_data
        self.class_names = class_names
        self.save_dir = save_dir
        self.frequency = frequency
        os.makedirs(save_dir, exist_ok=True)
        
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.frequency == 0:
            print(f"\nüî• Generating Grad-CAM for epoch {epoch + 1}...")
            
            # Get a batch of validation images
            for images, labels in self.validation_data.take(1):
                save_path = os.path.join(self.save_dir, f'gradcam_epoch_{epoch+1}.png')
                
                # Run in thread to not block training
                def generate_gradcam():
                    try:
                        visualize_gradcam_batch(
                            self.model, 
                            images.numpy()[:4], 
                            labels.numpy()[:4],
                            self.class_names,
                            num_samples=4,
                            save_path=save_path
                        )
                    except Exception as e:
                        print(f"‚ö†Ô∏è Grad-CAM error: {e}")
                
                # Use threading to avoid blocking
                thread = threading.Thread(target=generate_gradcam)
                thread.start()
                break


# Create PlantVillage datasets
print("üì¶ Creating PlantVillage datasets...")
print(f"   Using {NUM_WORKERS} CPU workers for parallel loading")

train_plantvillage_raw = create_dataset(
    PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE, 
    validation_split=0.2, subset='training'
)
val_plantvillage_raw = create_dataset(
    PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='validation'
)

# Apply optimized pipeline
train_plantvillage = build_pipeline(train_plantvillage_raw, is_training=True)
val_plantvillage = build_pipeline(val_plantvillage_raw, is_training=False)

train_batches = tf.data.experimental.cardinality(train_plantvillage_raw).numpy()
val_batches = tf.data.experimental.cardinality(val_plantvillage_raw).numpy()

print(f"\n‚úÖ PlantVillage datasets ready")
print(f"   Training: {train_batches} batches √ó {BATCH_SIZE} = ~{train_batches * BATCH_SIZE:,} images")
print(f"   Validation: {val_batches} batches √ó {BATCH_SIZE} = ~{val_batches * BATCH_SIZE:,} images")
print(f"   ‚ö° Multiprocessing: {NUM_WORKERS} workers")
print(f"   üé® Training augmentation: flip, brightness, contrast")

## ‚úÖ Pre-Training Validation

Run this cell to verify everything is set up correctly before training.


In [None]:
# ‚úÖ PRE-TRAINING VALIDATION (T4 Optimized)
print("=" * 60)
print("üîç PRE-TRAINING VALIDATION")
print(f"‚è±Ô∏è Session time: {get_session_time()}")
print("=" * 60)

errors = []

# 1. GPU Check
print("\n1Ô∏è‚É£ GPU Check...")
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    gpu_name = gpus[0].name
    print(f"   ‚úÖ GPU: {gpu_name}")
    # Check if T4
    try:
        !nvidia-smi --query-gpu=name --format=csv,noheader
    except:
        pass
else:
    errors.append("No GPU detected!")
    print("   ‚ö†Ô∏è No GPU - training will be MUCH slower")

# 2. Quick data test
print("\n2Ô∏è‚É£ Data Pipeline Check...")
try:
    import time
    start = time.time()
    for batch_images, batch_labels in train_plantvillage.take(1):
        load_time = time.time() - start
        print(f"   ‚úÖ Batch shape: {batch_images.shape}")
        print(f"   ‚úÖ Labels shape: {batch_labels.shape}")
        print(f"   ‚úÖ Image dtype: {batch_images.dtype}")
        print(f"   ‚ö° First batch load: {load_time:.2f}s")
        break
except Exception as e:
    errors.append(f"Data pipeline error: {e}")
    print(f"   ‚ùå Error: {e}")

# 3. Memory check
print("\n3Ô∏è‚É£ Memory Status...")
try:
    import psutil
    ram_total = psutil.virtual_memory().total / (1024**3)
    ram_avail = psutil.virtual_memory().available / (1024**3)
    print(f"   ‚úÖ RAM: {ram_avail:.1f}GB available / {ram_total:.1f}GB total")
    if ram_avail < 3:
        print("   ‚ö†Ô∏è Low RAM - data copied to local SSD helps prevent crashes")
except:
    print("   ‚ÑπÔ∏è Could not check RAM")

# 4. GPU Memory check
print("\n4Ô∏è‚É£ GPU Memory Check...")
try:
    !nvidia-smi --query-gpu=memory.total,memory.free --format=csv,noheader
except:
    print("   ‚ÑπÔ∏è Could not query GPU memory")

# 5. Existing checkpoints
print("\n5Ô∏è‚É£ Checkpoint Status...")
stage2_exists = os.path.exists(os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_plantvillage_best.keras'))
stage3_exists = os.path.exists(os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras'))
print(f"   Stage 2 checkpoint: {'‚úÖ Found (will resume)' if stage2_exists else '‚ùå None (fresh start)'}")
print(f"   Stage 3 checkpoint: {'‚úÖ Found (will resume)' if stage3_exists else '‚ùå None (fresh start)'}")

# Summary
print("\n" + "=" * 60)
if errors:
    print("‚ùå ISSUES FOUND:")
    for e in errors:
        print(f"   ‚Ä¢ {e}")
else:
    print("‚úÖ ALL CHECKS PASSED!")
    print(f"\nüöÄ T4 GPU OPTIMIZED SETTINGS:")
    print(f"   ‚Ä¢ Batch size: {BATCH_SIZE} (fills 16GB VRAM)")
    print(f"   ‚Ä¢ Mixed precision: FP16")
    print(f"   ‚Ä¢ JIT/XLA compilation: Enabled")
    print(f"   ‚Ä¢ AUTOTUNE prefetch: Enabled")
    print(f"   ‚Ä¢ Data location: Local SSD (fast I/O)")
    print(f"\n‚ö° Expected: ~1-2 min/epoch (10x faster than Drive I/O)")
print("=" * 60)

## üèóÔ∏è Stage 2: Build Model with MobileNetV2 Base


In [None]:
def create_model(num_classes, input_shape=(224, 224, 3), freeze_base=True, dropout_rate=0.3):
    """Create MobileNetV2-based model optimized for T4 GPU"""
    
    # Load MobileNetV2 with ImageNet weights
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    
    base_model.trainable = not freeze_base
    
    # Balanced classification head (prevents overfitting)
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(dropout_rate),
        tf.keras.layers.Dense(256, activation='relu', 
                              kernel_regularizer=tf.keras.regularizers.l2(1e-4)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(dropout_rate * 0.8),  # Slightly less on second dropout
        # Float32 output for numerical stability with mixed precision
        tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')
    ])
    
    return model, base_model

# Get number of PlantVillage classes
num_plantvillage_classes = len(plantvillage_classes)

# =============================================================
# üîÑ CHECKPOINT RESUME: Load existing model if available
# =============================================================
STAGE2_CHECKPOINT = os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_plantvillage_best.keras')
STAGE2_LOCAL = os.path.join(OUTPUT_DIR, 'stage2_plantvillage_best.keras')

model_stage2 = None
resume_stage2 = False

# Check for existing checkpoint (Drive first, then local)
for checkpoint_path in [STAGE2_CHECKPOINT, STAGE2_LOCAL]:
    if os.path.exists(checkpoint_path):
        try:
            print(f"üîÑ Found existing Stage 2 checkpoint!")
            print(f"   Loading from: {checkpoint_path}")
            model_stage2 = tf.keras.models.load_model(checkpoint_path)
            
            # Verify it has correct output shape
            if model_stage2.output_shape[-1] == num_plantvillage_classes:
                resume_stage2 = True
                print(f"‚úÖ Resuming from checkpoint ({num_plantvillage_classes} classes)")
                
                # Evaluate current performance
                print("üìä Evaluating checkpoint...")
                val_loss, val_acc = model_stage2.evaluate(val_plantvillage, verbose=0)
                print(f"   Current val_accuracy: {val_acc:.4f}")
                break
            else:
                print(f"‚ö†Ô∏è Checkpoint has different classes ({model_stage2.output_shape[-1]} vs {num_plantvillage_classes})")
                model_stage2 = None
        except Exception as e:
            print(f"‚ö†Ô∏è Could not load checkpoint: {e}")
            model_stage2 = None

# Create new model if no valid checkpoint
if model_stage2 is None:
    print(f"üèóÔ∏è Creating NEW model for PlantVillage ({num_plantvillage_classes} classes)...")
    model_stage2, base_model = create_model(
        num_plantvillage_classes, 
        freeze_base=True,
        dropout_rate=DROPOUT_RATE
    )
else:
    base_model = model_stage2.layers[0]

# Compile with JIT compilation for 10-20% speedup
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE2)

model_stage2.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy'],
    jit_compile=True  # XLA compilation - 10-20% faster on T4
)

# Count trainable params
trainable_params = sum([tf.keras.backend.count_params(w) for w in model_stage2.trainable_weights])
print(f"\nüîí Base model frozen: {not base_model.trainable}")
print(f"üìä Trainable parameters: {trainable_params:,}")
print(f"üíæ Mixed precision: FP16 enabled")
print(f"‚ö° JIT/XLA compilation: Enabled")

## üéØ Stage 2: Train on PlantVillage Dataset

In [None]:
print("üöÄ Starting Stage 2: PlantVillage Fine-tuning")
print(f"‚è±Ô∏è Epochs: {PLANTVILLAGE_EPOCHS} | LR: {LEARNING_RATE_STAGE2} | Batch: {BATCH_SIZE}")
if resume_stage2:
    print("üîÑ RESUMING from previous checkpoint")
print("="*60)

# Callbacks with TQDM progress and Grad-CAM
callbacks_stage2 = [
    # TQDM Progress with ETA
    TQDMProgressCallback(PLANTVILLAGE_EPOCHS, stage_name="Stage 2: PlantVillage"),
    
    # Grad-CAM visualization every 2 epochs
    GradCAMCallback(
        validation_data=val_plantvillage,
        class_names=plantvillage_classes,
        save_dir=os.path.join(GRADCAM_DIR, 'stage2'),
        frequency=2
    ),
    
    # Early stopping
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True,
        verbose=1,
        min_delta=0.005
    ),
    
    # Reduce LR on plateau
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.3,
        patience=2,
        min_lr=1e-6,
        verbose=1
    ),
    
    # Save best to local (fast)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'stage2_plantvillage_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    ),
    
    # Save best to Drive (persistent)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_plantvillage_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    )
]

# Train with progress tracking
TRAINING_START_TIME = datetime.now()

history_stage2 = model_stage2.fit(
    train_plantvillage,
    validation_data=val_plantvillage,
    epochs=PLANTVILLAGE_EPOCHS,
    callbacks=callbacks_stage2,
    verbose=0  # Disable default output, use TQDM instead
)

# Calculate training stats
best_val_acc = max(history_stage2.history['val_accuracy'])
best_val_loss = min(history_stage2.history['val_loss'])
final_train_acc = history_stage2.history['accuracy'][-1]
training_time = datetime.now() - TRAINING_START_TIME

print(f"\n" + "="*60)
print(f"‚úÖ Stage 2 completed in {training_time}")
print(f"üìà Best val accuracy: {best_val_acc:.4f}")
print(f"üìâ Best val loss: {best_val_loss:.4f}")
print(f"üìä Final train accuracy: {final_train_acc:.4f}")

# Check for overfitting/underfitting
gap = final_train_acc - best_val_acc
if gap > 0.15:
    print(f"‚ö†Ô∏è Overfitting detected (train-val gap: {gap:.2%})")
elif best_val_acc < 0.7:
    print(f"‚ö†Ô∏è Possible underfitting (val_acc: {best_val_acc:.2%})")
else:
    print(f"‚úÖ Good generalization (train-val gap: {gap:.2%})")

print(f"\nüíæ Checkpoints saved to:")
print(f"   Local: {OUTPUT_DIR}")
print(f"   Drive: {DRIVE_CHECKPOINT_DIR}")
print(f"   üî• Grad-CAM: {GRADCAM_DIR}/stage2")

## üìà Stage 2 Results Visualization


In [None]:
# Quick training visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history_stage2.history['accuracy'], 'b-', label='Train')
axes[0].plot(history_stage2.history['val_accuracy'], 'r-', label='Val')
axes[0].set_title('Stage 2: Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_stage2.history['loss'], 'b-', label='Train')
axes[1].plot(history_stage2.history['val_loss'], 'r-', label='Val')
axes[1].set_title('Stage 2: Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'stage2_training_history.png'), dpi=150)
plt.savefig(os.path.join(DRIVE_CHECKPOINT_DIR, 'stage2_training_history.png'), dpi=150)
plt.show()

# Memory cleanup (preserve model references)
import gc
gc.collect()
print("üßπ Memory cleaned (models preserved for Stage 3)")

## üîÑ Stage 3: Build UNIFIED Dataset


In [None]:
# Build unified dataset by combining all crops
print("üèóÔ∏è Building UNIFIED dataset...")

UNIFIED_DATASET_PATH = '/content/unified_nutrient_dataset'

# Check if unified dataset already exists
if os.path.exists(UNIFIED_DATASET_PATH):
    existing_classes = [d for d in os.listdir(UNIFIED_DATASET_PATH) 
                        if os.path.isdir(os.path.join(UNIFIED_DATASET_PATH, d))]
    if len(existing_classes) > 10:
        print(f"‚úÖ Already exists with {len(existing_classes)} classes!")
        unified_classes = existing_classes
    else:
        import shutil
        shutil.rmtree(UNIFIED_DATASET_PATH)
        os.makedirs(UNIFIED_DATASET_PATH)
        unified_classes = []
else:
    os.makedirs(UNIFIED_DATASET_PATH)
    unified_classes = []

if len(unified_classes) == 0:
    skipped_crops = []
    
    for crop, folder_name in CROP_DATASETS.items():
        crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
        
        if not os.path.exists(crop_path):
            skipped_crops.append(crop)
            continue
        
        try:
            crop_classes = [d for d in os.listdir(crop_path) 
                            if os.path.isdir(os.path.join(crop_path, d))]
        except:
            skipped_crops.append(crop)
            continue
        
        for cls in crop_classes:
            try:
                clean_cls = cls.replace(f"{crop}_", "").replace(f"{crop}__", "")
                unified_class_name = f"{crop}_{clean_cls}"
                
                src_dir = os.path.join(crop_path, cls)
                dst_dir = os.path.join(UNIFIED_DATASET_PATH, unified_class_name)
                
                if not os.path.exists(dst_dir):
                    os.symlink(src_dir, dst_dir)
                    unified_classes.append(unified_class_name)
            except:
                continue
        
        print(f"  ‚úÖ {crop.upper()}: {len([c for c in unified_classes if c.startswith(crop)])} classes")
    
    if skipped_crops:
        print(f"‚ö†Ô∏è Skipped: {', '.join(skipped_crops)}")

if len(unified_classes) == 0:
    raise RuntimeError("‚ùå No classes! Check Google Drive paths.")

class_names = sorted(unified_classes)
num_unified_classes = len(class_names)

print(f"\n‚úÖ Unified dataset: {num_unified_classes} classes")

# Create MEMORY-SAFE datasets
print("üì¶ Creating datasets (MEMORY-SAFE)...")

train_nutrient_raw = create_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='training'
)
val_nutrient_raw = create_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='validation'
)

train_nutrient = build_pipeline(train_nutrient_raw, is_training=True)
val_nutrient = build_pipeline(val_nutrient_raw, is_training=False)

print(f"‚úÖ Datasets ready (MEMORY-SAFE)")
print(f"   Training: {tf.data.experimental.cardinality(train_nutrient_raw).numpy()} batches")
print(f"   Validation: {tf.data.experimental.cardinality(val_nutrient_raw).numpy()} batches")

## üîß Stage 3: Adapt Model for Unified Classes


In [None]:
# =============================================================
# üîÑ STAGE 3: Adapt Model for Unified Classes (with Resume)
# =============================================================
if 'num_unified_classes' not in locals() or num_unified_classes == 0:
    raise RuntimeError("‚ö†Ô∏è Run 'Build UNIFIED Dataset' cell first!")

print(f"üîß Setting up Stage 3 for {num_unified_classes} unified classes...")

# Check for existing Stage 3 checkpoint
STAGE3_CHECKPOINT = os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras')
STAGE3_LOCAL = os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras')

model_stage3 = None
resume_stage3 = False
initial_epoch = 0

# Try to load existing checkpoint
for checkpoint_path in [STAGE3_CHECKPOINT, STAGE3_LOCAL]:
    if os.path.exists(checkpoint_path):
        try:
            print(f"üîÑ Found existing Stage 3 checkpoint!")
            print(f"   Loading from: {checkpoint_path}")
            model_stage3 = tf.keras.models.load_model(checkpoint_path)
            
            # Verify correct output shape
            if model_stage3.output_shape[-1] == num_unified_classes:
                resume_stage3 = True
                print(f"‚úÖ Checkpoint valid ({num_unified_classes} classes)")
                
                # Evaluate current performance
                print("üìä Evaluating checkpoint...")
                results = model_stage3.evaluate(val_nutrient, verbose=0)
                print(f"   Current - Loss: {results[0]:.4f}, Accuracy: {results[1]:.4f}")
                
                # Check training history for initial_epoch
                history_path = os.path.join(DRIVE_CHECKPOINT_DIR, 'stage3_history.json')
                if os.path.exists(history_path):
                    with open(history_path, 'r') as f:
                        prev_history = json.load(f)
                        initial_epoch = len(prev_history.get('accuracy', []))
                        print(f"   Resuming from epoch {initial_epoch}")
                break
            else:
                print(f"‚ö†Ô∏è Class mismatch ({model_stage3.output_shape[-1]} vs {num_unified_classes})")
                print("   Creating new model (classes changed)")
                model_stage3 = None
        except Exception as e:
            print(f"‚ö†Ô∏è Could not load: {e}")
            model_stage3 = None

# Create new model if needed
if model_stage3 is None:
    print(f"üèóÔ∏è Creating NEW unified model...")
    
    # Get base model from Stage 2
    base_model_stage2 = model_stage2.layers[0]
    base_model_stage2.trainable = False  # Keep frozen initially
    
    # Balanced classification head (prevents overfitting)
    model_stage3 = tf.keras.Sequential([
        base_model_stage2,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(DROPOUT_RATE),
        tf.keras.layers.Dense(384, activation='relu',
                              kernel_regularizer=tf.keras.regularizers.l2(1e-4)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(DROPOUT_RATE * 0.8),
        tf.keras.layers.Dense(num_unified_classes, activation='softmax', dtype='float32')
    ], name='unified_nutrient_model')

# Compile with JIT for speed
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE3)

model_stage3.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')],
    jit_compile=True  # XLA compilation - 10-20% faster
)

trainable_params = sum([tf.keras.backend.count_params(w) for w in model_stage3.trainable_weights])
print(f"\nüìä Trainable params: {trainable_params:,}")
print(f"üéØ Output classes: {num_unified_classes}")
print(f"‚ö° JIT/XLA compilation: Enabled")
print(f"üîÑ Resume training: {'Yes (epoch ' + str(initial_epoch) + ')' if resume_stage3 else 'No (fresh start)'}")

## üéØ Stage 3: Train on UNIFIED Nutrient Dataset


In [None]:
print("üöÄ Starting Stage 3: UNIFIED Nutrient Detection")
print(f"üåæ Training ALL {len(CROP_DATASETS)} crops | Epochs: {UNIFIED_EPOCHS} | LR: {LEARNING_RATE_STAGE3}")
if resume_stage3:
    print(f"üîÑ RESUMING from epoch {initial_epoch}")
print("="*60)

# Callbacks with TQDM progress and Grad-CAM
callbacks_stage3 = [
    # TQDM Progress with ETA
    TQDMProgressCallback(UNIFIED_EPOCHS, stage_name="Stage 3: Unified Nutrients"),
    
    # Grad-CAM visualization every 3 epochs
    GradCAMCallback(
        validation_data=val_nutrient,
        class_names=class_names,
        save_dir=os.path.join(GRADCAM_DIR, 'stage3'),
        frequency=3
    ),
    
    # Early stopping - balanced patience
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1,
        min_delta=0.001
    ),
    
    # Reduce LR on plateau
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=2,
        min_lr=1e-7,
        verbose=1
    ),
    
    # Save best to local (fast)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    ),
    
    # Save best to Drive (persistent)
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=0
    )
]

# Train with progress tracking
stage3_start = datetime.now()

history_stage3 = model_stage3.fit(
    train_nutrient,
    validation_data=val_nutrient,
    epochs=UNIFIED_EPOCHS,
    initial_epoch=initial_epoch,
    callbacks=callbacks_stage3,
    verbose=0  # Disable default output, use TQDM instead
)

# Save training history for resume
history_dict = {k: [float(v) for v in vals] for k, vals in history_stage3.history.items()}
with open(os.path.join(DRIVE_CHECKPOINT_DIR, 'stage3_history.json'), 'w') as f:
    json.dump(history_dict, f)

# Calculate final stats
best_val_acc = max(history_stage3.history['val_accuracy'])
best_val_loss = min(history_stage3.history['val_loss'])
best_top3_acc = max(history_stage3.history['val_top3_acc'])
final_train_acc = history_stage3.history['accuracy'][-1]
stage3_time = datetime.now() - stage3_start

print(f"\n" + "="*60)
print(f"‚úÖ Stage 3 completed in {stage3_time}")
print(f"üìà Best val accuracy: {best_val_acc:.4f}")
print(f"üéØ Best top-3 accuracy: {best_top3_acc:.4f}")
print(f"üìâ Best val loss: {best_val_loss:.4f}")

# Total training time
if TRAINING_START_TIME:
    total_training_time = datetime.now() - TRAINING_START_TIME
    print(f"\n‚è±Ô∏è TOTAL TRAINING TIME: {total_training_time}")

# Check for overfitting/underfitting
gap = final_train_acc - best_val_acc
if gap > 0.20:
    print(f"\n‚ö†Ô∏è Overfitting detected (gap: {gap:.2%})")
elif best_val_acc < 0.5:
    print(f"\n‚ö†Ô∏è Possible underfitting (val_acc: {best_val_acc:.2%})")
else:
    print(f"\n‚úÖ Good generalization (gap: {gap:.2%})")

print(f"\nüíæ Model & history saved to Drive")
print(f"üî• Grad-CAM visualizations: {GRADCAM_DIR}/stage3")

## üìà Stage 3 Results Visualization


In [None]:
# Quick training visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history_stage3.history['accuracy'], 'b-', label='Train')
axes[0].plot(history_stage3.history['val_accuracy'], 'r-', label='Val')
axes[0].set_title('Stage 3: Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history_stage3.history['loss'], 'b-', label='Train')
axes[1].plot(history_stage3.history['val_loss'], 'r-', label='Val')
axes[1].set_title('Stage 3: Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'stage3_training_history.png'), dpi=150)
plt.show()

## üîç Model Evaluation & Confusion Matrix


In [None]:
# Quick evaluation (skip heavy confusion matrix for speed)
print("üîç Evaluating UNIFIED model...")
results = model_stage3.evaluate(val_nutrient, verbose=0)

print(f"\nüìä Validation Metrics:")
print(f"   Loss: {results[0]:.4f}")
print(f"   Accuracy: {results[1]:.4f}")
print(f"   Top-3 Accuracy: {results[2]:.4f}")

# Quick per-crop accuracy (sample-based for speed)
print(f"\nüåæ Per-Crop Performance (quick check):")
y_true, y_pred = [], []
for images, labels in val_nutrient.take(20):  # Sample only
    predictions = model_stage3.predict(images, verbose=0)
    y_true.extend(np.argmax(labels.numpy(), axis=1))
    y_pred.extend(np.argmax(predictions, axis=1))

for crop in list(CROP_DATASETS.keys())[:6]:  # First 6 crops
    crop_classes = [cls for cls in class_names if cls.startswith(f"{crop}_")]
    if not crop_classes:
        continue
    crop_indices = [class_names.index(cls) for cls in crop_classes]
    crop_mask = np.isin(y_true, crop_indices)
    if crop_mask.sum() > 0:
        crop_acc = (np.array(y_true)[crop_mask] == np.array(y_pred)[crop_mask]).mean()
        print(f"   {crop.upper():12s}: {crop_acc:.1%}")

# Save classification report
report = classification_report(y_true, y_pred, target_names=[class_names[i] for i in sorted(set(y_true))], output_dict=True, zero_division=0)
with open(os.path.join(OUTPUT_DIR, 'unified_classification_report.json'), 'w') as f:
    json.dump(report, f, indent=2)

print(f"\n‚úÖ Evaluation complete")

## üî• Final Grad-CAM Analysis

In [None]:
# =============================================================
# üî• FINAL GRAD-CAM ANALYSIS
# =============================================================
# Generate comprehensive Grad-CAM visualizations for the final model

print("üî• Generating Final Grad-CAM Analysis...")
print("="*60)

# Get multiple batches for comprehensive analysis
all_images = []
all_labels = []

for images, labels in val_nutrient.take(3):
    all_images.extend(images.numpy())
    all_labels.extend(labels.numpy())

all_images = np.array(all_images)
all_labels = np.array(all_labels)

# Generate Grad-CAM for different crops
print("\nüìä Analyzing model attention per crop...")

gradcam = GradCAM(model_stage3)

# Sample images from different crops
crops_analyzed = []
fig, axes = plt.subplots(3, 4, figsize=(16, 12))
axes = axes.flatten()

sample_idx = 0
for i, (img, label) in enumerate(zip(all_images, all_labels)):
    if sample_idx >= 12:
        break
    
    true_idx = np.argmax(label)
    class_name = class_names[true_idx]
    crop = class_name.split('_')[0]
    
    # Skip if we already have this crop
    if crop in crops_analyzed and len(crops_analyzed) < 8:
        continue
    crops_analyzed.append(crop)
    
    # Get prediction and heatmap
    pred = model_stage3.predict(np.expand_dims(img, 0), verbose=0)
    pred_idx = np.argmax(pred[0])
    
    try:
        heatmap = gradcam.get_gradcam_heatmap(np.expand_dims(img, 0), pred_idx)
        display_img = ((img + 1) * 127.5).astype(np.uint8)
        overlay = gradcam.overlay_heatmap(display_img, heatmap)
        
        axes[sample_idx].imshow(overlay)
        correct = "‚úÖ" if pred_idx == true_idx else "‚ùå"
        axes[sample_idx].set_title(f"{class_name[:20]}\n{pred[0][pred_idx]:.1%} {correct}", fontsize=9)
        axes[sample_idx].axis('off')
        sample_idx += 1
    except:
        continue

# Hide unused subplots
for j in range(sample_idx, 12):
    axes[j].axis('off')

plt.suptitle("üî• Grad-CAM: Model Attention Analysis (Final Model)", fontsize=14, fontweight='bold')
plt.tight_layout()

# Save to both locations
final_gradcam_path = os.path.join(OUTPUT_DIR, 'final_gradcam_analysis.png')
plt.savefig(final_gradcam_path, dpi=150, bbox_inches='tight')
plt.savefig(os.path.join(DRIVE_CHECKPOINT_DIR, 'final_gradcam_analysis.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úÖ Final Grad-CAM analysis complete!")
print(f"   Analyzed crops: {', '.join(set(crops_analyzed))}")
print(f"   Saved to: {final_gradcam_path}")

## üíæ Export to TensorFlow Lite for Mobile Deployment


In [None]:
print("üì¶ Converting to TensorFlow Lite...")
print(f"‚è±Ô∏è Session time: {get_session_time()}")
check_time_limit()  # Warn if approaching 3-hour limit

# Load best model
best_model_path = os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras')
if not os.path.exists(best_model_path):
    # Try Drive checkpoint
    best_model_path = os.path.join(DRIVE_CHECKPOINT_DIR, 'unified_nutrient_best.keras')
    
best_model = tf.keras.models.load_model(best_model_path)

# Convert to TFLite with FP16 quantization
converter = tf.lite.TFLiteConverter.from_keras_model(best_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

print("‚öôÔ∏è Converting with FP16 quantization...")
tflite_model = converter.convert()

# Save to both local and Drive
tflite_path = os.path.join(OUTPUT_DIR, 'fasalvaidya_unified.tflite')
tflite_drive_path = os.path.join(DRIVE_CHECKPOINT_DIR, 'fasalvaidya_unified.tflite')

with open(tflite_path, 'wb') as f:
    f.write(tflite_model)
with open(tflite_drive_path, 'wb') as f:
    f.write(tflite_model)

keras_size = os.path.getsize(best_model_path) / (1024 * 1024)
tflite_size = os.path.getsize(tflite_path) / (1024 * 1024)

print(f"\n‚úÖ Conversion complete!")
print(f"üìä Keras: {keras_size:.1f}MB ‚Üí TFLite: {tflite_size:.1f}MB ({(1-tflite_size/keras_size)*100:.0f}% smaller)")
print(f"üöÄ Single model for {len(CROP_DATASETS)} crops!")
print(f"\nüíæ Saved to:")
print(f"   Local: {tflite_path}")
print(f"   Drive: {tflite_drive_path} (persistent)")

## üß™ Test TFLite Model Inference


In [None]:
# Quick TFLite verification
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("üîç TFLite Model:")
print(f"   Input: {input_details[0]['shape']} ({input_details[0]['dtype']})")
print(f"   Output: {output_details[0]['shape']} ({num_unified_classes} classes)")

# Quick test
for images, labels in val_nutrient.take(1):
    test_image = images[0].numpy()
    input_data = np.expand_dims(test_image, axis=0).astype(np.float32)
    
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    
    pred_idx = np.argmax(output[0])
    true_idx = np.argmax(labels[0].numpy())
    
    print(f"\nüß™ Quick test:")
    print(f"   True: {class_names[true_idx]}")
    print(f"   Pred: {class_names[pred_idx]} ({output[0][pred_idx]:.1%})")
    print(f"   {'‚úÖ CORRECT' if pred_idx == true_idx else '‚ùå INCORRECT'}")
    break

print("\n‚úÖ TFLite model verified!")

## üì§ Save Model Metadata & Class Labels


In [None]:
# Save metadata and labels (to both local and Drive)
print("üìù Saving metadata...")

crop_class_mapping = {crop: [c for c in class_names if c.startswith(f"{crop}_")] 
                      for crop in CROP_DATASETS.keys()}

metadata = {
    'model_type': 'unified_multi_crop',
    'model_version': '2.0',
    'training_date': datetime.now().isoformat(),
    'architecture': 'MobileNetV2',
    'supported_crops': list(CROP_DATASETS.keys()),
    'num_crops': len(CROP_DATASETS),
    'input_shape': [IMG_SIZE, IMG_SIZE, 3],
    'num_classes': num_unified_classes,
    'class_names': class_names,
    'crop_class_mapping': crop_class_mapping,
    'metrics': {'accuracy': float(results[1]), 'top3_accuracy': float(results[2])},
    'preprocessing': {'method': 'MobileNetV2', 'normalization': '[-1, 1]'},
    'training_config': {
        'batch_size': BATCH_SIZE,
        'plantvillage_epochs': PLANTVILLAGE_EPOCHS,
        'unified_epochs': UNIFIED_EPOCHS,
        'dropout_rate': DROPOUT_RATE,
        'optimizations': ['mixed_precision_fp16', 'jit_compile', 'autotune_prefetch']
    }
}

# Save to both locations
for save_dir in [OUTPUT_DIR, DRIVE_CHECKPOINT_DIR]:
    with open(os.path.join(save_dir, 'unified_model_metadata.json'), 'w') as f:
        json.dump(metadata, f, indent=2)
    with open(os.path.join(save_dir, 'labels.txt'), 'w') as f:
        f.write('\n'.join(class_names))

print(f"‚úÖ Saved: metadata.json, labels.txt")
print(f"üìä {len(CROP_DATASETS)} crops, {num_unified_classes} classes")
print(f"üíæ Saved to both local and Drive (persistent)")

## üì¶ Download Models to Local Machine


In [None]:
# Create and download zip
import shutil

# Session summary
print("=" * 60)
print("üéâ TRAINING SESSION COMPLETE!")
print("=" * 60)
print(f"‚è±Ô∏è Total session time: {get_session_time()}")
print(f"üìä Final validation accuracy: {results[1]:.4f}")
print(f"üéØ Final top-3 accuracy: {results[2]:.4f}")

zip_filename = 'fasalvaidya_unified_model'
shutil.make_archive(f'/content/{zip_filename}', 'zip', OUTPUT_DIR)

print(f"\nüì¶ Created: {zip_filename}.zip")
print(f"\nüìÇ Contents:")
print(f"   üì± fasalvaidya_unified.tflite ({tflite_size:.1f}MB)")
print(f"   üíæ unified_nutrient_best.keras")
print(f"   üìÑ unified_model_metadata.json")
print(f"   üè∑Ô∏è labels.txt ({num_unified_classes} classes)")
print(f"\nüåæ Supports: {', '.join(list(CROP_DATASETS.keys())[:6])}...")

print(f"\nüíæ ALSO SAVED TO DRIVE (persistent):")
print(f"   {DRIVE_CHECKPOINT_DIR}")
print(f"   ‚úÖ Can resume training if disconnected!")

from google.colab import files
files.download(f'/content/{zip_filename}.zip')
print(f"\n‚¨áÔ∏è Download started!")

## üéâ UNIFIED Model Training Complete!

### üöÄ Performance Optimizations Applied

This notebook is optimized for **1-1.5 hour training** with real-time progress tracking:

| Feature | Description |
|---------|-------------|
| **TQDM Progress** | Real-time ETA, batch progress, metrics display |
| **Grad-CAM** | Visualizes model attention during training |
| **Multiprocessing** | Uses all CPU cores for data loading |
| **Mixed Precision** | FP16 training (2x faster on GPU) |
| **Threading** | Grad-CAM runs in background threads |
| **XLA/JIT** | TensorFlow graph optimization |

### ‚è±Ô∏è Training Time Budget

| Stage | Estimated Time |
|-------|---------------|
| Setup & Data Copy | 5-10 min |
| Stage 2 (PlantVillage, 5 epochs) | 15-25 min |
| Stage 3 (Unified, 10 epochs) | 30-50 min |
| Export & Analysis | 5-10 min |
| **Total** | **~1-1.5 hours** |

### üî• Grad-CAM Outputs

Grad-CAM visualizations are saved at:
- `stage2/gradcam_epoch_*.png` - PlantVillage training
- `stage3/gradcam_epoch_*.png` - Unified training  
- `final_gradcam_analysis.png` - Comprehensive analysis

### üìä Progress Tracking Features

- **Per-batch progress bar** with loss/accuracy
- **Per-epoch progress bar** with ETA
- **Real-time ETA** based on moving average
- **Elapsed time** tracking
- **Session time** monitoring (3-hour Colab limit)

### üì¶ Output Files

| File | Description |
|------|-------------|
| `fasalvaidya_unified.tflite` | Mobile-optimized model |
| `unified_nutrient_best.keras` | Full Keras model |
| `unified_model_metadata.json` | Model info & class mappings |
| `labels.txt` | All class labels |
| `final_gradcam_analysis.png` | Model attention visualization |

### üîÑ Settings Used

```python
BATCH_SIZE = 16
PLANTVILLAGE_EPOCHS = 5
UNIFIED_EPOCHS = 10
NUM_WORKERS = multiprocessing.cpu_count()
Mixed Precision: FP16
JIT/XLA: Enabled
```