## üì¶ Setup & Import Libraries

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
import json
import shutil
from pathlib import Path
from datetime import datetime, timedelta
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tqdm.auto import tqdm
import time

print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")

# Check for GPU (will use if available, but notebook optimized for CPU)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"‚úÖ GPU detected: {gpus[0].name}")
    print("   Note: This notebook is optimized for CPU, but will use GPU if available")
    # Enable memory growth to prevent OOM
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
else:
    print("üíª No GPU detected - training on CPU")
    print("   Expected time: 8-13 hours total")

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Session timing
SESSION_START_TIME = datetime.now()
print(f"\n‚è±Ô∏è Session started: {SESSION_START_TIME.strftime('%H:%M:%S')}")

## üîë Configuration - SET YOUR LOCAL PATHS HERE

**üëâ EDIT THESE PATHS TO MATCH YOUR LOCAL SYSTEM:**

In [None]:
# =============================================================
# üîß LOCAL CONFIGURATION - EDIT THESE PATHS!
# =============================================================

# Path to your local "Leaf Nutrient Data Sets" folder
# Example Windows: r'D:\Datasets\Leaf Nutrient Data Sets'
# Example Mac/Linux: '/Users/yourname/Datasets/Leaf Nutrient Data Sets'
NUTRIENT_DATASETS_ROOT = r'B:\FasalVaidya\datasets\Leaf Nutrient Data Sets'

# Path to PlantVillage dataset (optional - only if doing Stage 2)
# Example: r'D:\Datasets\PlantVillage'
PLANTVILLAGE_PATH = r'B:\FasalVaidya\datasets\PlantVillage'

# Output directory for models and checkpoints
OUTPUT_DIR = r'B:\FasalVaidya\models\local_training'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# üåæ Crop datasets to include (4-crop MVP)
CROP_DATASETS = {
    'rice': 'Rice Nutrients',
    'wheat': 'Wheat Nitrogen',
    'tomato': 'Tomato Nutrients',
    'maize': 'Maize Nutrients',
}

# =============================================================
# üéØ CPU-OPTIMIZED TRAINING SETTINGS
# =============================================================
IMG_SIZE = 224
BATCH_SIZE = 8  # Reduced for CPU (GPU uses 32)

# Epochs - reduced for faster CPU training
# Note: You can increase these if you want better accuracy and have time
PLANTVILLAGE_EPOCHS = 5   # Stage 2: Takes ~3-5 hours on CPU
UNIFIED_EPOCHS = 10       # Stage 3: Takes ~5-8 hours on CPU

# Learning rates
LEARNING_RATE_STAGE2 = 1e-3
LEARNING_RATE_STAGE3 = 5e-4

# Regularization
DROPOUT_RATE = 0.3

# CPU-specific optimizations
NUM_WORKERS = min(os.cpu_count(), 4)  # Limit workers to avoid overhead
PREFETCH_BUFFER = 2  # Small prefetch for CPU

# Use float32 (no mixed precision on CPU)
tf.keras.mixed_precision.set_global_policy('float32')
print("‚úÖ Using float32 policy")

# =============================================================
# üìä CONFIGURATION SUMMARY
# =============================================================
print("\n" + "="*70)
print("üíª LOCAL CPU TRAINING CONFIGURATION")
print("="*70)
print(f"üåæ Crops: {len(CROP_DATASETS)} ({', '.join(CROP_DATASETS.keys())})")
print(f"\nüìÅ Data Paths:")
print(f"   Nutrient datasets: {NUTRIENT_DATASETS_ROOT}")
print(f"   PlantVillage: {PLANTVILLAGE_PATH}")
print(f"   Output: {OUTPUT_DIR}")
print(f"\nüéØ Training Settings:")
print(f"   Image size: {IMG_SIZE}√ó{IMG_SIZE}")
print(f"   Batch size: {BATCH_SIZE} (CPU optimized)")
print(f"   CPU workers: {NUM_WORKERS}")
print(f"   Stage 2 epochs: {PLANTVILLAGE_EPOCHS}")
print(f"   Stage 3 epochs: {UNIFIED_EPOCHS}")
print(f"\n‚è±Ô∏è Expected Time:")
print(f"   Stage 2: ~{PLANTVILLAGE_EPOCHS * 40}-{PLANTVILLAGE_EPOCHS * 60} min")
print(f"   Stage 3: ~{UNIFIED_EPOCHS * 40}-{UNIFIED_EPOCHS * 60} min")
print(f"   Total: ~{(PLANTVILLAGE_EPOCHS + UNIFIED_EPOCHS) * 40 // 60}-{(PLANTVILLAGE_EPOCHS + UNIFIED_EPOCHS) * 60 // 60} hours")
print("="*70 + "\n")

# Verify paths exist
print("üîç Verifying paths...")
if not os.path.exists(NUTRIENT_DATASETS_ROOT):
    print(f"‚ùå ERROR: Nutrient dataset root not found!")
    print(f"   Path: {NUTRIENT_DATASETS_ROOT}")
    print(f"   Please update NUTRIENT_DATASETS_ROOT in the cell above")
else:
    print(f"‚úÖ Nutrient datasets found")

if not os.path.exists(PLANTVILLAGE_PATH):
    print(f"‚ö†Ô∏è PlantVillage not found (OK if skipping Stage 2)")
else:
    print(f"‚úÖ PlantVillage dataset found")

print(f"‚úÖ Output directory ready: {OUTPUT_DIR}")

## üöÄ Optimized Data Pipeline for CPU

In [None]:
# =============================================================
# üì¶ CPU-OPTIMIZED DATA PIPELINE
# =============================================================

AUTOTUNE = tf.data.AUTOTUNE

def create_dataset(data_dir, img_size, batch_size, validation_split=0.2, subset=None):
    """Create dataset from directory"""
    print(f"üì¶ Loading {subset} data from {os.path.basename(data_dir)}...")
    
    dataset = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=validation_split,
        subset=subset,
        seed=42,
        image_size=(img_size, img_size),
        batch_size=batch_size,
        label_mode='categorical',
        shuffle=True
    )
    
    return dataset

@tf.function
def augment_image(image, label):
    """Light augmentation for CPU training"""
    # Random flip
    image = tf.image.random_flip_left_right(image)
    
    # Brightness and contrast
    image = tf.image.random_brightness(image, 0.2)
    image = tf.image.random_contrast(image, 0.8, 1.2)
    
    # Clip values
    image = tf.clip_by_value(image, 0.0, 255.0)
    
    return image, label

@tf.function
def normalize_for_mobilenet(image, label):
    """Normalize to MobileNetV2 input range [-1, 1]"""
    image = tf.cast(image, tf.float32)
    image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
    return image, label

def build_pipeline(dataset, is_training=True):
    """Build CPU-optimized pipeline"""
    
    # Light augmentation for training
    if is_training:
        dataset = dataset.map(augment_image, num_parallel_calls=NUM_WORKERS)
    
    # Normalize
    dataset = dataset.map(normalize_for_mobilenet, num_parallel_calls=NUM_WORKERS)
    
    # Small prefetch for CPU
    dataset = dataset.prefetch(buffer_size=PREFETCH_BUFFER)
    
    return dataset

# =============================================================
# üìä PROGRESS CALLBACK
# =============================================================
class CPUProgressCallback(tf.keras.callbacks.Callback):
    """Progress tracking optimized for CPU training"""
    
    def __init__(self, total_epochs, stage_name="Training"):
        super().__init__()
        self.total_epochs = total_epochs
        self.stage_name = stage_name
        self.epoch_times = []
        self.start_time = None
    
    def on_train_begin(self, logs=None):
        self.start_time = time.time()
        print(f"\nüöÄ {self.stage_name} Started")
        print(f"   Epochs: {self.total_epochs}")
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()
        print(f"\nüìà Epoch {epoch+1}/{self.total_epochs}")
    
    def on_epoch_end(self, epoch, logs=None):
        epoch_time = time.time() - self.epoch_start
        self.epoch_times.append(epoch_time)
        
        # Calculate ETA
        avg_time = np.mean(self.epoch_times)
        remaining = self.total_epochs - (epoch + 1)
        eta_seconds = remaining * avg_time
        eta = str(timedelta(seconds=int(eta_seconds)))
        
        # Progress
        val_acc = logs.get('val_accuracy', 0)
        val_loss = logs.get('val_loss', 0)
        train_acc = logs.get('accuracy', 0)
        
        print(f"   ‚úÖ Complete: train_acc={train_acc:.4f}, val_acc={val_acc:.4f}, val_loss={val_loss:.4f}")
        print(f"   ‚è±Ô∏è Time: {epoch_time:.1f}s | Avg: {avg_time:.1f}s | ETA: {eta}")
    
    def on_train_end(self, logs=None):
        total_time = time.time() - self.start_time
        print(f"\n‚úÖ {self.stage_name} Complete!")
        print(f"   Total time: {str(timedelta(seconds=int(total_time)))}")
        print(f"   Avg epoch: {np.mean(self.epoch_times):.1f}s")

print("‚úÖ Data pipeline functions ready")

## üå± Stage 2: PlantVillage Training (OPTIONAL)

**‚ö†Ô∏è This stage takes 3-5 hours on CPU!**

You can skip this and go directly to Stage 3 if:
- You want faster training
- You already have a Stage 2 checkpoint
- You're okay with slightly lower accuracy

**To skip:** Just run the cell that creates the Stage 2 model without training.

## üì• Download PlantVillage Dataset (Optional)

If you don't have PlantVillage dataset locally, you can download it from Kaggle.

**Requirements:**
1. Kaggle account
2. Kaggle API key (`kaggle.json`)

**To get kaggle.json:**
1. Go to https://www.kaggle.com/settings
2. Scroll to "API" section
3. Click "Create New Token"
4. Place `kaggle.json` in: `~/.kaggle/` (Linux/Mac) or `C:\Users\YourName\.kaggle\` (Windows)

In [None]:
# =============================================================
# üì• DOWNLOAD PLANTVILLAGE DATASET FROM KAGGLE
# =============================================================

import zipfile
import subprocess

def download_plantvillage():
    """Download PlantVillage dataset from Kaggle with progress tracking"""
    
    # Check if already exists
    if PLANTVILLAGE_PATH.exists():
        # Check if it has data
        try:
            subdirs = [d for d in PLANTVILLAGE_PATH.iterdir() if d.is_dir()]
            if len(subdirs) > 10:  # PlantVillage has ~38 classes
                print("‚úÖ PlantVillage dataset already exists!")
                print(f"   Location: {PLANTVILLAGE_PATH.relative_to(NOTEBOOK_DIR)}")
                print(f"   Classes found: {len(subdirs)}")
                return True
            else:
                print("‚ö†Ô∏è PlantVillage folder exists but seems incomplete")
                print("   Will re-download...")
        except:
            pass
    
    print("üì• Downloading PlantVillage dataset from Kaggle...")
    print("   This may take 5-10 minutes (~1-2 GB)")
    
    # Check for Kaggle credentials
    kaggle_config = Path.home() / '.kaggle' / 'kaggle.json'
    if not kaggle_config.exists():
        print("\n‚ùå Kaggle credentials not found!")
        print(f"   Expected: {kaggle_config}")
        print("\nüìù To set up:")
        print("   1. Go to https://www.kaggle.com/settings")
        print("   2. Scroll to 'API' section")
        print("   3. Click 'Create New Token'")
        print(f"   4. Place kaggle.json in: {kaggle_config.parent}")
        print("\n‚è© Skipping download - you can manually download from:")
        print("   https://www.kaggle.com/datasets/arjuntejaswi/plant-village")
        return False
    
    try:
        # Install kaggle package if not available
        try:
            import kaggle
        except ImportError:
            print("üì¶ Installing kaggle package...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kaggle"])
            import kaggle
        
        # Create parent directory
        PLANTVILLAGE_PATH.parent.mkdir(parents=True, exist_ok=True)
        
        # Download dataset
        print("\n‚¨áÔ∏è Downloading from Kaggle...")
        download_path = PLANTVILLAGE_PATH.parent
        
        # Download using kaggle API
        from kaggle.api.kaggle_api_extended import KaggleApi
        api = KaggleApi()
        api.authenticate()
        
        print("   Downloading arjuntejaswi/plant-village...")
        api.dataset_download_files(
            'arjuntejaswi/plant-village',
            path=str(download_path),
            unzip=False
        )
        
        # Find downloaded zip file
        zip_file = download_path / 'plant-village.zip'
        if not zip_file.exists():
            # Try alternative name
            zip_files = list(download_path.glob('*.zip'))
            if zip_files:
                zip_file = zip_files[0]
            else:
                print("‚ùå Downloaded file not found!")
                return False
        
        print(f"\nüì¶ Extracting {zip_file.name}...")
        with zipfile.ZipFile(zip_file, 'r') as zip_ref:
            # Extract all files
            zip_ref.extractall(download_path)
        
        # Find extracted folder (might be nested)
        possible_dirs = [
            download_path / 'PlantVillage',
            download_path / 'plant-village',
            download_path / 'Plant_Village',
        ]
        
        extracted_dir = None
        for d in possible_dirs:
            if d.exists():
                extracted_dir = d
                break
        
        # If not found, look for any directory with many subdirectories
        if not extracted_dir:
            for d in download_path.iterdir():
                if d.is_dir() and d.name not in ['unified_nutrient_dataset', 'local_training']:
                    subdirs = [x for x in d.iterdir() if x.is_dir()]
                    if len(subdirs) > 10:
                        extracted_dir = d
                        break
        
        if extracted_dir and extracted_dir != PLANTVILLAGE_PATH:
            # Rename to expected path
            if PLANTVILLAGE_PATH.exists():
                shutil.rmtree(PLANTVILLAGE_PATH)
            extracted_dir.rename(PLANTVILLAGE_PATH)
            print(f"‚úÖ Renamed to: {PLANTVILLAGE_PATH.relative_to(NOTEBOOK_DIR)}")
        
        # Clean up zip file
        zip_file.unlink()
        print("üóëÔ∏è Cleaned up zip file")
        
        # Verify extraction
        if PLANTVILLAGE_PATH.exists():
            num_classes = len([d for d in PLANTVILLAGE_PATH.iterdir() if d.is_dir()])
            print(f"\n‚úÖ PlantVillage dataset ready!")
            print(f"   Location: {PLANTVILLAGE_PATH.relative_to(NOTEBOOK_DIR)}")
            print(f"   Classes: {num_classes}")
            return True
        else:
            print("\n‚ùå Extraction failed - path not found")
            return False
            
    except Exception as e:
        print(f"\n‚ùå Download failed: {e}")
        print("\nüí° You can manually download from:")
        print("   https://www.kaggle.com/datasets/arjuntejaswi/plant-village")
        print(f"   Then extract to: {PLANTVILLAGE_PATH.relative_to(NOTEBOOK_DIR)}")
        return False

# Run download
print("="*70)
print("üì• PLANTVILLAGE DATASET CHECK")
print("="*70)

download_success = download_plantvillage()

if download_success:
    print("\n‚úÖ Ready for Stage 2 training!")
else:
    print("\n‚è© You can skip Stage 2 or download manually")

print("="*70)

In [None]:
# Check if PlantVillage path exists
if not os.path.exists(PLANTVILLAGE_PATH):
    print("‚ö†Ô∏è PlantVillage dataset not found")
    print(f"   Path: {PLANTVILLAGE_PATH}")
    print("\nüí° Options:")
    print("   1. Skip Stage 2 (go to Stage 3)")
    print("   2. Download PlantVillage and update path")
    SKIP_STAGE2 = True
else:
    print("‚úÖ PlantVillage dataset found")
    print("\nü§î Do you want to train Stage 2?")
    print("   Set SKIP_STAGE2 = True to skip (faster)")
    print("   Set SKIP_STAGE2 = False to train (better accuracy)")
    SKIP_STAGE2 = False  # Change to True to skip

if SKIP_STAGE2:
    print("\n‚è© Skipping Stage 2 (PlantVillage training)")
    print("   Will create model with ImageNet weights only")
else:
    print("\n‚è±Ô∏è Stage 2 will take approximately 3-5 hours")
    print("   Consider running overnight!")

In [None]:
# =============================================================
# üì¶ CREATE PLANTVILLAGE DATASETS (if not skipping)
# =============================================================

if not SKIP_STAGE2:
    print("üì¶ Creating PlantVillage datasets...")
    
    train_plantvillage_raw = create_dataset(
        PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE,
        validation_split=0.2, subset='training'
    )
    val_plantvillage_raw = create_dataset(
        PLANTVILLAGE_PATH, IMG_SIZE, BATCH_SIZE,
        validation_split=0.2, subset='validation'
    )
    
    # Build pipelines
    train_plantvillage = build_pipeline(train_plantvillage_raw, is_training=True)
    val_plantvillage = build_pipeline(val_plantvillage_raw, is_training=False)
    
    num_plantvillage_classes = len(train_plantvillage_raw.class_names)
    train_batches = tf.data.experimental.cardinality(train_plantvillage_raw).numpy()
    val_batches = tf.data.experimental.cardinality(val_plantvillage_raw).numpy()
    
    print(f"\n‚úÖ PlantVillage Datasets Ready")
    print(f"   Classes: {num_plantvillage_classes}")
    print(f"   Training: {train_batches} batches √ó {BATCH_SIZE}")
    print(f"   Validation: {val_batches} batches √ó {BATCH_SIZE}")
else:
    print("‚è© Skipped PlantVillage dataset creation")

## üèóÔ∏è Build Stage 2 Model

In [None]:
# =============================================================
# üèóÔ∏è CREATE MODEL ARCHITECTURE
# =============================================================

def create_model(num_classes, input_shape=(224, 224, 3), freeze_base=True):
    """Create MobileNetV2-based model"""
    
    # Load MobileNetV2 with ImageNet weights
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=input_shape,
        include_top=False,
        weights='imagenet'
    )
    
    base_model.trainable = not freeze_base
    
    # Build classification head
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(DROPOUT_RATE),
        tf.keras.layers.Dense(
            256,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(1e-4)
        ),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(DROPOUT_RATE * 0.8),
        tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')
    ])
    
    return model, base_model

if not SKIP_STAGE2:
    # Check for existing checkpoint
    stage2_checkpoint = os.path.join(OUTPUT_DIR, 'stage2_plantvillage_best.keras')
    
    if os.path.exists(stage2_checkpoint):
        print("üîÑ Found existing Stage 2 checkpoint")
        try:
            model_stage2 = tf.keras.models.load_model(stage2_checkpoint)
            print("‚úÖ Loaded checkpoint")
        except:
            print("‚ö†Ô∏è Could not load checkpoint, creating new model")
            model_stage2, base_stage2 = create_model(num_plantvillage_classes)
    else:
        print("üèóÔ∏è Creating new Stage 2 model...")
        model_stage2, base_stage2 = create_model(num_plantvillage_classes)
    
    # Compile
    model_stage2.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE2),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    print(f"\nüìä Stage 2 Model Ready")
    print(f"   Classes: {num_plantvillage_classes}")
    print(f"   Trainable params: {sum([tf.keras.backend.count_params(w) for w in model_stage2.trainable_weights]):,}")
else:
    print("‚è© Skipped Stage 2 model creation")

## üéØ Train Stage 2 (PlantVillage)

**‚ö†Ô∏è This will take 3-5 hours!** Consider running overnight.

In [None]:
if not SKIP_STAGE2:
    print("üöÄ Starting Stage 2 Training")
    print(f"‚è±Ô∏è Expected time: ~{PLANTVILLAGE_EPOCHS * 40} - {PLANTVILLAGE_EPOCHS * 60} minutes")
    print("\nüí° You can stop anytime (Ctrl+C) - checkpoint will be saved!\n")
    
    # Callbacks
    callbacks_stage2 = [
        CPUProgressCallback(PLANTVILLAGE_EPOCHS, stage_name="Stage 2: PlantVillage"),
        
        tf.keras.callbacks.ModelCheckpoint(
            os.path.join(OUTPUT_DIR, 'stage2_plantvillage_best.keras'),
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        ),
        
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=3,
            restore_best_weights=True,
            verbose=1
        ),
        
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=2,
            min_lr=1e-6,
            verbose=1
        )
    ]
    
    # Train
    history_stage2 = model_stage2.fit(
        train_plantvillage,
        validation_data=val_plantvillage,
        epochs=PLANTVILLAGE_EPOCHS,
        callbacks=callbacks_stage2,
        verbose=2  # Progress per epoch
    )
    
    print("\n‚úÖ Stage 2 Training Complete!")
    print(f"   Best val accuracy: {max(history_stage2.history['val_accuracy']):.4f}")
else:
    print("‚è© Skipped Stage 2 training")
    print("   Will use ImageNet weights for Stage 3")

## üîÑ Stage 3: Build Unified Nutrient Dataset

This combines all crop datasets into one unified structure.

In [None]:
# =============================================================
# üèóÔ∏è BUILD UNIFIED DATASET
# =============================================================

print("üèóÔ∏è Building unified nutrient dataset...")

UNIFIED_DATASET_PATH = os.path.join(OUTPUT_DIR, 'unified_nutrient_dataset')

def detect_nutrient_classes(crop_path):
    """Detect nutrient classes from folder structure"""
    nutrient_classes = {}
    
    if not os.path.exists(crop_path):
        return nutrient_classes
    
    subfolders = [d for d in os.listdir(crop_path) 
                  if os.path.isdir(os.path.join(crop_path, d))]
    
    # Check for train/test/val splits
    split_keywords = {'train', 'test', 'val', 'validation'}
    has_splits = any(f.lower() in split_keywords for f in subfolders)
    
    if has_splits:
        # Has train/test/val structure
        for split in subfolders:
            if split.lower() in split_keywords:
                split_path = os.path.join(crop_path, split)
                for cls in os.listdir(split_path):
                    cls_path = os.path.join(split_path, cls)
                    if os.path.isdir(cls_path):
                        files = os.listdir(cls_path)[:10]
                        has_images = any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files)
                        if has_images:
                            if cls not in nutrient_classes:
                                nutrient_classes[cls] = []
                            nutrient_classes[cls].append(cls_path)
    else:
        # Flat structure
        for cls in subfolders:
            cls_path = os.path.join(crop_path, cls)
            files = os.listdir(cls_path)[:10]
            has_images = any(f.lower().endswith(('.jpg', '.jpeg', '.png')) for f in files)
            if has_images:
                nutrient_classes[cls] = [cls_path]
    
    return nutrient_classes

# Build unified dataset
if os.path.exists(UNIFIED_DATASET_PATH):
    existing = [d for d in os.listdir(UNIFIED_DATASET_PATH) 
                if os.path.isdir(os.path.join(UNIFIED_DATASET_PATH, d))]
    if len(existing) > 5:
        print(f"‚úÖ Unified dataset exists with {len(existing)} classes")
        unified_classes = existing
        needs_rebuild = False
    else:
        print("‚ö†Ô∏è Incomplete dataset, rebuilding...")
        shutil.rmtree(UNIFIED_DATASET_PATH)
        needs_rebuild = True
else:
    needs_rebuild = True

if needs_rebuild:
    os.makedirs(UNIFIED_DATASET_PATH, exist_ok=True)
    unified_classes = []
    
    print("üìÇ Combining crop datasets...\n")
    
    for crop, folder_name in CROP_DATASETS.items():
        crop_path = os.path.join(NUTRIENT_DATASETS_ROOT, folder_name)
        
        if not os.path.exists(crop_path):
            print(f"   ‚ö†Ô∏è {crop.upper()}: Not found")
            continue
        
        print(f"   üåæ {crop.upper()}: Processing...")
        
        nutrient_classes = detect_nutrient_classes(crop_path)
        
        for cls_name, src_paths in nutrient_classes.items():
            clean_name = cls_name.replace(f"{crop}_", "").replace(f"{crop}__", "")
            unified_class = f"{crop}_{clean_name}"
            
            dst_dir = os.path.join(UNIFIED_DATASET_PATH, unified_class)
            os.makedirs(dst_dir, exist_ok=True)
            
            # Copy images
            total = 0
            for src_dir in src_paths:
                src_name = os.path.basename(os.path.dirname(src_dir))
                for img_file in os.listdir(src_dir):
                    if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        src_file = os.path.join(src_dir, img_file)
                        dst_file = os.path.join(dst_dir, f"{src_name}_{img_file}")
                        if not os.path.exists(dst_file):
                            shutil.copy2(src_file, dst_file)
                            total += 1
            
            if total > 0 and unified_class not in unified_classes:
                unified_classes.append(unified_class)
        
        crop_classes = [c for c in unified_classes if c.startswith(f"{crop}_")]
        print(f"      ‚úÖ {len(crop_classes)} classes")

class_names = sorted(unified_classes)
num_classes = len(class_names)

print(f"\n‚úÖ Unified Dataset Ready")
print(f"   Total classes: {num_classes}")
print(f"   Classes: {', '.join(class_names[:6])}...")
print(f"   Location: {UNIFIED_DATASET_PATH}")

## üì¶ Create Unified Training Datasets

In [None]:
print("üì¶ Creating unified nutrient datasets...")

train_nutrient_raw = create_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='training'
)
val_nutrient_raw = create_dataset(
    UNIFIED_DATASET_PATH, IMG_SIZE, BATCH_SIZE,
    validation_split=0.2, subset='validation'
)

# Build pipelines
train_nutrient = build_pipeline(train_nutrient_raw, is_training=True)
val_nutrient = build_pipeline(val_nutrient_raw, is_training=False)

train_batches = tf.data.experimental.cardinality(train_nutrient_raw).numpy()
val_batches = tf.data.experimental.cardinality(val_nutrient_raw).numpy()

print(f"\n‚úÖ Unified Datasets Ready")
print(f"   Classes: {num_classes}")
print(f"   Training: {train_batches} batches √ó {BATCH_SIZE}")
print(f"   Validation: {val_batches} batches √ó {BATCH_SIZE}")

## üèóÔ∏è Build Stage 3 Model (Unified)

In [None]:
# =============================================================
# üîß CREATE STAGE 3 MODEL
# =============================================================

print("üèóÔ∏è Creating Stage 3 unified model...")

# Check for checkpoint
stage3_checkpoint = os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras')

if os.path.exists(stage3_checkpoint):
    print("üîÑ Found existing Stage 3 checkpoint")
    try:
        model_stage3 = tf.keras.models.load_model(stage3_checkpoint)
        print("‚úÖ Loaded checkpoint")
    except:
        print("‚ö†Ô∏è Could not load, creating new model")
        model_stage3 = None
else:
    model_stage3 = None

if model_stage3 is None:
    if not SKIP_STAGE2 and 'model_stage2' in locals():
        # Use Stage 2 base
        print("   Using Stage 2 trained base")
        base_model = model_stage2.layers[0]
        base_model.trainable = False
    else:
        # Use ImageNet weights
        print("   Using ImageNet pretrained base")
        base_model = tf.keras.applications.MobileNetV2(
            input_shape=(IMG_SIZE, IMG_SIZE, 3),
            include_top=False,
            weights='imagenet'
        )
        base_model.trainable = False
    
    # Build model
    model_stage3 = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(DROPOUT_RATE),
        tf.keras.layers.Dense(
            384,
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(1e-4)
        ),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(DROPOUT_RATE * 0.8),
        tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')
    ], name='unified_nutrient_model')

# Compile
model_stage3.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE_STAGE3),
    loss='categorical_crossentropy',
    metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=3, name='top3_acc')]
)

print(f"\nüìä Stage 3 Model Ready")
print(f"   Classes: {num_classes}")
print(f"   Trainable params: {sum([tf.keras.backend.count_params(w) for w in model_stage3.trainable_weights]):,}")

## üéØ Train Stage 3 (Unified Nutrient Detection)

**‚ö†Ô∏è This will take 5-8 hours on CPU!** Recommended to run overnight.

In [None]:
print("üöÄ Starting Stage 3 Training (UNIFIED)")
print(f"‚è±Ô∏è Expected time: ~{UNIFIED_EPOCHS * 40} - {UNIFIED_EPOCHS * 60} minutes")
print(f"   That's approximately {(UNIFIED_EPOCHS * 40) // 60} - {(UNIFIED_EPOCHS * 60) // 60} hours")
print("\nüí° This is the main training - be patient!")
print("   You can stop anytime - checkpoint will be saved!\n")

# Callbacks
callbacks_stage3 = [
    CPUProgressCallback(UNIFIED_EPOCHS, stage_name="Stage 3: Unified Nutrients"),
    
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(OUTPUT_DIR, 'unified_nutrient_best.keras'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=3,
        min_lr=1e-6,
        verbose=1
    )
]

# Train
history_stage3 = model_stage3.fit(
    train_nutrient,
    validation_data=val_nutrient,
    epochs=UNIFIED_EPOCHS,
    callbacks=callbacks_stage3,
    verbose=2
)

print("\n" + "="*70)
print("üéâ STAGE 3 TRAINING COMPLETE!")
print("="*70)
print(f"‚úÖ Best val accuracy: {max(history_stage3.history['val_accuracy']):.4f}")
print(f"‚úÖ Best top-3 accuracy: {max(history_stage3.history['top3_acc']):.4f}")
print(f"üíæ Model saved to: {OUTPUT_DIR}")

## üìä Training Results Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
axes[0].plot(history_stage3.history['accuracy'], 'b-', label='Train', linewidth=2)
axes[0].plot(history_stage3.history['val_accuracy'], 'r-', label='Validation', linewidth=2)
axes[0].set_title('Training Accuracy', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(history_stage3.history['loss'], 'b-', label='Train', linewidth=2)
axes[1].plot(history_stage3.history['val_loss'], 'r-', label='Validation', linewidth=2)
axes[1].set_title('Training Loss', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'training_history.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"üìä Training curves saved to: {os.path.join(OUTPUT_DIR, 'training_history.png')}")

## üì¶ Export Model (SavedModel Format)

In [None]:
# =============================================================
# üíæ EXPORT SAVEDMODEL FOR BACKEND
# =============================================================

print("üì¶ Exporting model to SavedModel format...")

savedmodel_path = os.path.join(OUTPUT_DIR, 'unified_savedmodel')

# Remove old if exists
if os.path.exists(savedmodel_path):
    shutil.rmtree(savedmodel_path)

# Export
model_stage3.save(savedmodel_path, save_format='tf')

print(f"‚úÖ SavedModel exported to: {savedmodel_path}")

# Save metadata
metadata = {
    'model_name': 'FasalVaidya Unified Nutrient Model',
    'version': '1.0.0-local',
    'created_at': datetime.now().isoformat(),
    'num_classes': num_classes,
    'class_names': class_names,
    'crops': list(CROP_DATASETS.keys()),
    'img_size': IMG_SIZE,
    'trained_on': 'Local CPU',
    'epochs': UNIFIED_EPOCHS,
    'batch_size': BATCH_SIZE,
    'best_val_accuracy': float(max(history_stage3.history['val_accuracy'])),
    'best_top3_accuracy': float(max(history_stage3.history['top3_acc']))
}

with open(os.path.join(savedmodel_path, 'metadata.json'), 'w') as f:
    json.dump(metadata, f, indent=2)

# Save class labels
with open(os.path.join(savedmodel_path, 'labels.txt'), 'w') as f:
    f.write('\n'.join(class_names))

print(f"\nüìÑ Files created:")
print(f"   ‚îú‚îÄ‚îÄ saved_model.pb")
print(f"   ‚îú‚îÄ‚îÄ variables/")
print(f"   ‚îú‚îÄ‚îÄ metadata.json")
print(f"   ‚îî‚îÄ‚îÄ labels.txt")

# Calculate size
total_size = sum(f.stat().st_size for f in Path(savedmodel_path).rglob('*') if f.is_file()) / (1024 * 1024)
print(f"\nüìä Total size: {total_size:.1f} MB")

## üéâ Training Complete!

### üì¶ Next Steps

1. **Copy the SavedModel to your backend:**
   ```bash
   # Copy the entire folder
   cp -r "{OUTPUT_DIR}/unified_savedmodel" "backend/ml/models/"
   ```

2. **Test the model:**
   ```bash
   cd backend
   python test_unified.py
   ```

3. **Expected output:**
   ```
   N score: 75.3% (moderate deficiency)
   P score: 18.2% (healthy)
   K score: 6.5% (healthy)
   Detected: rice_Nitrogen(N)
   Confidence: 75.3%
   ```

### üìä Model Info

Your model is now ready for deployment! It supports:
- üåæ **4 crops:** Rice, Wheat, Tomato, Maize
- üî¨ **Multiple deficiencies:** N, P, K, and healthy classes
- üì± **Real-time inference** on mobile and backend
- üéØ **Expected accuracy:** 70-95% on deficient leaves

### üöÄ Performance Notes

This model was trained on CPU, which is slower but produces the same quality as GPU training. The final model will work just as well in production!

---