## 1Ô∏è‚É£ Setup & Mount Drive


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Check GPU
import tensorflow as tf
print(f"TensorFlow: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

In [None]:
# Install additional dependencies if needed
!pip install -q pillow tqdm scikit-learn

## 2Ô∏è‚É£ Configuration

**IMPORTANT:** Update `DATASET_ROOT` to point to your dataset in Google Drive.

Expected structure:

```
Leaf Nutrient Data Sets/
‚îú‚îÄ‚îÄ Rice Nutrients/
‚îÇ   ‚îú‚îÄ‚îÄ Nitrogen(N)/
‚îÇ   ‚îú‚îÄ‚îÄ Phosphorus(P)/
‚îÇ   ‚îî‚îÄ‚îÄ Potassium(K)/
‚îú‚îÄ‚îÄ Tomato Nutrients/
‚îÇ   ‚îî‚îÄ‚îÄ train/
‚îÇ       ‚îú‚îÄ‚îÄ Tomato - Healthy/
‚îÇ       ‚îú‚îÄ‚îÄ Tomato - Nitrogen Deficiency/
‚îÇ       ‚îî‚îÄ‚îÄ ...
‚îî‚îÄ‚îÄ ... other crops
```


In [None]:
import os
import json
import numpy as np
import random
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
from PIL import Image, ImageEnhance, ImageFilter
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import EfficientNetB0, EfficientNetB2, MobileNetV3Large

# ========================================
# üìÅ UPDATE THIS PATH TO YOUR DRIVE LOCATION
# ========================================
DATASET_ROOT = Path('/content/drive/MyDrive/Leaf Nutrient Data Sets')
MODEL_OUTPUT = Path('/content/drive/MyDrive/FasalVaidya_Models')

# Create output directory
MODEL_OUTPUT.mkdir(parents=True, exist_ok=True)

# Verify dataset exists
if DATASET_ROOT.exists():
    print(f"‚úÖ Dataset found at: {DATASET_ROOT}")
    print(f"üìÇ Contents: {[f.name for f in DATASET_ROOT.iterdir()]}")
else:
    print(f"‚ùå Dataset NOT found at: {DATASET_ROOT}")
    print("Please update DATASET_ROOT to your Google Drive path")

## 3Ô∏è‚É£ Crop Configurations

Each crop has specific folder-to-label mappings for N, P, K, Mg deficiencies.


In [None]:
# Seed for reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
random.seed(SEED)

# Image settings
IMG_SIZE = 224
MAX_SAMPLES_PER_CLASS = 2000

# Crop configurations with folder‚Üílabel mappings
CROP_CONFIGS = {
    'rice': {
        'name': 'Rice',
        'dataset_path': DATASET_ROOT / 'Rice Nutrients',
        'class_mapping': {
            'Nitrogen(N)': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'Phosphorus(P)': {'N': 0, 'P': 1, 'K': 0, 'Mg': 0},
            'Potassium(K)': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
        },
        'has_healthy': False,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'tomato': {
        'name': 'Tomato',
        'dataset_path': DATASET_ROOT / 'Tomato Nutrients',
        'use_train_folder': True,
        'class_mapping': {
            'Tomato - Healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'Tomato - Nitrogen Deficiency': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'Tomato - Potassium Deficiency': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'Tomato - Nitrogen and Potassium Deficiency': {'N': 1, 'P': 0, 'K': 1, 'Mg': 0},
            'Tomato - Jassid and Mite': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'Tomato - Leaf Miner': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'Tomato - Mite': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'wheat': {
        'name': 'Wheat',
        'dataset_path': DATASET_ROOT / 'Wheat Nitrogen',
        'use_train_folder': True,
        'class_mapping': {
            'control': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'deficiency': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'maize': {
        'name': 'Maize',
        'dataset_path': DATASET_ROOT / 'Maize Nutrients',
        'use_train_folder': True,
        'class_mapping': {
            'ALL Present': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'NAB': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'PAB': {'N': 0, 'P': 1, 'K': 0, 'Mg': 0},
            'KAB': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'ALLAB': {'N': 1, 'P': 1, 'K': 1, 'Mg': 0},
            'ZNAB': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'banana': {
        'name': 'Banana',
        'dataset_path': DATASET_ROOT / 'Banana leaves Nutrient',
        'class_mapping': {
            'healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'potassium': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'magnesium': {'N': 0, 'P': 0, 'K': 0, 'Mg': 1},
            'boron': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'calcium': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'iron': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'manganese': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'sulphur': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'zinc': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'coffee': {
        'name': 'Coffee',
        'dataset_path': DATASET_ROOT / 'Coffee Nutrients',
        'class_mapping': {
            'healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'nitrogen-N': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'phosphorus-P': {'N': 0, 'P': 1, 'K': 0, 'Mg': 0},
            'potasium-K': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'cucumber': {
        'name': 'Cucumber',
        'dataset_path': DATASET_ROOT / 'Cucumber Nutrients',
        'class_mapping': {
            'healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'N': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'K': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'N_K': {'N': 1, 'P': 0, 'K': 1, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'eggplant': {
        'name': 'Eggplant',
        'dataset_path': DATASET_ROOT / 'EggPlant Nutrients',
        'class_mapping': {
            'healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'N': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'K': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'N_K': {'N': 1, 'P': 0, 'K': 1, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'ashgourd': {
        'name': 'Ash Gourd',
        'dataset_path': DATASET_ROOT / 'Ashgourd Nutrients',
        'class_mapping': {
            'ash_gourd__healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'N': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'K': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'N_K': {'N': 1, 'P': 0, 'K': 1, 'Mg': 0},
            'K_Mg': {'N': 0, 'P': 0, 'K': 1, 'Mg': 1},
            'N_Mg': {'N': 1, 'P': 0, 'K': 0, 'Mg': 1},
            'PM': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'bittergourd': {
        'name': 'Bitter Gourd',
        'dataset_path': DATASET_ROOT / 'Bittergourd Nutrients',
        'class_mapping': {
            'healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'N': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'K': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'N_K': {'N': 1, 'P': 0, 'K': 1, 'Mg': 0},
            'K_Mg': {'N': 0, 'P': 0, 'K': 1, 'Mg': 1},
            'N_Mg': {'N': 1, 'P': 0, 'K': 0, 'Mg': 1},
            'DM': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'JAS': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'LS': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'ridgegourd': {
        'name': 'Ridge Gourd',
        'dataset_path': DATASET_ROOT / 'Ridgegourd',
        'class_mapping': {
            'healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'N': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'N_Mg': {'N': 1, 'P': 0, 'K': 0, 'Mg': 1},
            'PC': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
    'snakegourd': {
        'name': 'Snake Gourd',
        'dataset_path': DATASET_ROOT / 'Snakegourd Nutrients',
        'class_mapping': {
            'healthy': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
            'N': {'N': 1, 'P': 0, 'K': 0, 'Mg': 0},
            'K': {'N': 0, 'P': 0, 'K': 1, 'Mg': 0},
            'N_K': {'N': 1, 'P': 0, 'K': 1, 'Mg': 0},
            'LS': {'N': 0, 'P': 0, 'K': 0, 'Mg': 0},
        },
        'has_healthy': True,
        'outputs': ['N', 'P', 'K', 'Mg'],
    },
}

print(f"üìã Available crops: {list(CROP_CONFIGS.keys())}")

## 4Ô∏è‚É£ Data Loading & Augmentation


In [None]:
import gc

def load_and_preprocess_image(img_path, target_size=(IMG_SIZE, IMG_SIZE)):
    """Load and preprocess a single image."""
    try:
        img = Image.open(img_path).convert('RGB')
        img = img.resize(target_size, Image.LANCZOS)
        return np.array(img, dtype=np.float32)
    except Exception as e:
        print(f"Error loading {img_path}: {e}")
        return None


def get_image_paths_and_labels(crop_id):
    """Get list of image paths and their labels (doesn't load images into memory)."""
    config = CROP_CONFIGS[crop_id]
    dataset_path = config['dataset_path']
    class_mapping = config['class_mapping']
    use_train_folder = config.get('use_train_folder', False)
    
    if use_train_folder:
        dataset_path = dataset_path / 'train'
    
    if not dataset_path.exists():
        raise FileNotFoundError(f"Dataset not found: {dataset_path}")
    
    image_paths = []
    labels = []
    class_counts = {}
    
    print(f"\nüìÇ Scanning {config['name']} from: {dataset_path}")
    
    for folder_name, label_dict in class_mapping.items():
        folder_path = dataset_path / folder_name
        if not folder_path.exists():
            print(f"  ‚ö†Ô∏è Folder not found: {folder_name}")
            continue
        
        # Get image files
        img_files = list(folder_path.glob('*.jpg')) + list(folder_path.glob('*.jpeg')) + \
                    list(folder_path.glob('*.png')) + list(folder_path.glob('*.JPG')) + \
                    list(folder_path.glob('*.JPEG')) + list(folder_path.glob('*.PNG'))
        
        # Limit samples per class to prevent memory issues
        if len(img_files) > MAX_SAMPLES_PER_CLASS:
            img_files = random.sample(img_files, MAX_SAMPLES_PER_CLASS)
        
        class_counts[folder_name] = len(img_files)
        
        for img_path in img_files:
            image_paths.append(str(img_path))
            label = [label_dict.get('N', 0), label_dict.get('P', 0), 
                     label_dict.get('K', 0), label_dict.get('Mg', 0)]
            labels.append(label)
    
    print(f"  üìä Class distribution: {class_counts}")
    print(f"  ‚úÖ Found {len(image_paths)} images")
    
    return image_paths, np.array(labels, dtype=np.float32)


class DataGenerator(keras.utils.Sequence):
    """Memory-efficient data generator that loads images on-the-fly."""
    
    def __init__(self, image_paths, labels, batch_size=32, img_size=IMG_SIZE, shuffle=True, augment=False):
        self.image_paths = image_paths
        self.labels = labels
        self.batch_size = batch_size
        self.img_size = img_size
        self.shuffle = shuffle
        self.augment = augment
        self.indices = np.arange(len(self.image_paths))
        if self.shuffle:
            np.random.shuffle(self.indices)
    
    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_images = []
        batch_labels = []
        
        for i in batch_indices:
            img = load_and_preprocess_image(self.image_paths[i], (self.img_size, self.img_size))
            if img is not None:
                batch_images.append(img)
                batch_labels.append(self.labels[i])
        
        return np.array(batch_images), np.array(batch_labels)
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)


def clear_memory():
    """Clear GPU and RAM memory."""
    keras.backend.clear_session()
    gc.collect()
    print("üßπ Memory cleared")


# Data augmentation layer (applied in model, not generator)
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.15),
    layers.RandomContrast(0.15),
], name='data_augmentation')

print("‚úÖ Memory-efficient data loading ready")

## 5Ô∏è‚É£ Model Architecture


In [None]:
def create_model(backbone='efficientnetb0', num_outputs=4):
    """Create the nutrient deficiency detection model."""
    
    # Select backbone
    if backbone == 'efficientnetb0':
        base = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))
    elif backbone == 'efficientnetb2':
        base = EfficientNetB2(weights='imagenet', include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))
    elif backbone == 'mobilenetv3large':
        base = MobileNetV3Large(weights='imagenet', include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3))
    else:
        raise ValueError(f"Unknown backbone: {backbone}")
    
    # Freeze base initially
    base.trainable = False
    
    # Build model
    inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    
    # Preprocessing (scale to [0,1] then ImageNet normalization)
    x = layers.Rescaling(1./255)(inputs)
    x = layers.Normalization(mean=[0.485, 0.456, 0.406], variance=[0.229**2, 0.224**2, 0.225**2])(x)
    
    # Augmentation (only during training)
    x = data_augmentation(x)
    
    # Backbone
    x = base(x, training=False)
    
    # Classification head
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    
    # Output: 4 sigmoid outputs for N, P, K, Mg
    outputs = layers.Dense(num_outputs, activation='sigmoid', name='npkmg_output')(x)
    
    model = keras.Model(inputs, outputs, name=f'fasalvaidya_{backbone}')
    return model, base


# Test model creation
test_model, _ = create_model()
print(f"‚úÖ Model created: {test_model.count_params():,} parameters")
test_model.summary()

## 6Ô∏è‚É£ Training Function


In [None]:
def train_crop_model(crop_id, epochs=50, batch_size=16, backbone='efficientnetb0', fine_tune=True):
    """
    Train or improve a model for a specific crop (MEMORY-EFFICIENT VERSION).
    
    - Uses data generators to load images on-the-fly (prevents OOM crashes)
    - If a model already exists: loads it and continues training (fine-tuning)
    - If no model exists: trains from scratch
    
    Args:
        crop_id: Crop identifier (e.g., 'rice', 'tomato')
        epochs: Number of training epochs
        batch_size: Batch size (default 16 to prevent OOM)
        backbone: Model backbone (only used for new models)
        fine_tune: Whether to fine-tune backbone layers
    """
    # Clear memory before starting
    clear_memory()
    
    config = CROP_CONFIGS[crop_id]
    print(f"\n{'='*60}")
    print(f"üå± Training model for: {config['name']}")
    print(f"{'='*60}")
    
    # Check for existing model
    output_dir = MODEL_OUTPUT / crop_id
    output_dir.mkdir(parents=True, exist_ok=True)
    existing_model_path = output_dir / f'{crop_id}_best.keras'
    
    is_continuing = existing_model_path.exists()
    
    # Get image paths and labels (doesn't load images into memory!)
    image_paths, labels = get_image_paths_and_labels(crop_id)
    
    if len(image_paths) == 0:
        print(f"‚ùå No images found for {crop_id}")
        return None, None
    
    # Split into train/val
    from sklearn.model_selection import train_test_split
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=SEED
    )
    print(f"\nüìä Train: {len(train_paths)}, Validation: {len(val_paths)}")
    
    # Create data generators (memory-efficient!)
    train_gen = DataGenerator(train_paths, train_labels, batch_size=batch_size, shuffle=True, augment=True)
    val_gen = DataGenerator(val_paths, val_labels, batch_size=batch_size, shuffle=False, augment=False)
    
    # Either load existing model or create new one
    if is_continuing:
        print(f"\nüîÑ Found existing model - CONTINUING training to improve accuracy")
        print(f"   Loading: {existing_model_path}")
        model = keras.models.load_model(str(existing_model_path))
        base_model = None
        
        # Check previous accuracy on a small sample
        sample_x, sample_y = val_gen[0]
        old_results = model.evaluate(sample_x, sample_y, verbose=0)
        print(f"   Previous (sample) - Accuracy: {old_results[1]:.4f}, AUC: {old_results[2]:.4f}")
        
        initial_lr = 1e-5
    else:
        print(f"\nüÜï No existing model - training from SCRATCH")
        model, base_model = create_model(backbone=backbone)
        initial_lr = 1e-3
        old_results = None
    
    # Compile
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=initial_lr),
        loss='binary_crossentropy',
        metrics=['accuracy', keras.metrics.AUC(name='auc')]
    )
    
    # Callbacks
    callbacks = [
        ModelCheckpoint(
            str(output_dir / f'{crop_id}_best.keras'),
            monitor='val_auc',
            mode='max',
            save_best_only=True,
            verbose=1
        ),
        EarlyStopping(
            monitor='val_auc',
            mode='max',
            patience=8,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=4,
            min_lr=1e-7,
            verbose=1
        )
    ]
    
    if is_continuing:
        # For existing models: single phase of continued training
        print(f"\nüü¢ Continuing training for {epochs} epochs...")
        history = model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )
    else:
        # For new models: two-phase training
        phase1_epochs = max(epochs // 2, 5)
        phase2_epochs = epochs - phase1_epochs
        
        # Phase 1: Train classifier head
        print(f"\nüîµ Phase 1: Training classifier head ({phase1_epochs} epochs)...")
        history1 = model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=phase1_epochs,
            callbacks=callbacks,
            verbose=1
        )
        
        # Phase 2: Fine-tune backbone
        if fine_tune and base_model is not None:
            print(f"\nüü¢ Phase 2: Fine-tuning backbone ({phase2_epochs} epochs)...")
            base_model.trainable = True
            
            # Freeze early layers, train later ones
            for layer in base_model.layers[:-20]:
                layer.trainable = False
            
            model.compile(
                optimizer=keras.optimizers.Adam(learning_rate=1e-5),
                loss='binary_crossentropy',
                metrics=['accuracy', keras.metrics.AUC(name='auc')]
            )
            
            history2 = model.fit(
                train_gen,
                validation_data=val_gen,
                epochs=phase2_epochs,
                callbacks=callbacks,
                verbose=1
            )
    
    # Evaluate on full validation set
    print(f"\nüìà Final Evaluation:")
    results = model.evaluate(val_gen, verbose=0)
    print(f"  Loss: {results[0]:.4f}")
    print(f"  Accuracy: {results[1]:.4f}")
    print(f"  AUC: {results[2]:.4f}")
    
    if is_continuing and old_results:
        acc_change = results[1] - old_results[1]
        auc_change = results[2] - old_results[2]
        print(f"\nüìä Improvement (vs sample):")
        print(f"  Accuracy: {'+' if acc_change >= 0 else ''}{acc_change:.4f}")
        print(f"  AUC: {'+' if auc_change >= 0 else ''}{auc_change:.4f}")
    
    # Save final model
    model.save(str(output_dir / f'{crop_id}_final.keras'))
    print(f"\nüíæ Model saved to: {output_dir}")
    
    # Save/update metadata
    metadata = {
        'crop_id': crop_id,
        'crop_name': config['name'],
        'backbone': backbone if not is_continuing else 'continued',
        'outputs': ['N', 'P', 'K', 'Mg'],
        'val_accuracy': float(results[1]),
        'val_auc': float(results[2]),
        'trained_at': datetime.now().isoformat(),
        'train_samples': len(train_paths),
        'val_samples': len(val_paths),
        'training_mode': 'continued' if is_continuing else 'from_scratch',
    }
    with open(output_dir / 'metadata.json', 'w') as f:
        json.dump(metadata, f, indent=2)
    
    # Clear memory after training
    clear_memory()
    
    return model, results

print("‚úÖ Training function ready (memory-efficient with continue-training support)")

## 7Ô∏è‚É£ Train a Single Crop (Demo)

Run this cell to train a model for a single crop. Change `CROP_TO_TRAIN` to train different crops.


In [None]:
# ========================================
# üéØ SELECT CROP TO TRAIN
# ========================================
CROP_TO_TRAIN = 'rice'  # Change this: rice, tomato, wheat, maize, banana, coffee, etc.
EPOCHS = 30             # Increase for better results (50-100 recommended)
BATCH_SIZE = 16         # Keep at 16 to prevent OOM crashes on T4
BACKBONE = 'efficientnetb0'  # or 'efficientnetb2', 'mobilenetv3large'

# Verify crop exists
if CROP_TO_TRAIN in CROP_CONFIGS:
    print(f"üöÄ Starting training for: {CROP_TO_TRAIN}")
    print(f"üí° Tip: If you run this again, it will IMPROVE the existing model!")
    model, results = train_crop_model(
        CROP_TO_TRAIN, 
        epochs=EPOCHS, 
        batch_size=BATCH_SIZE, 
        backbone=BACKBONE
    )
else:
    print(f"‚ùå Unknown crop: {CROP_TO_TRAIN}")
    print(f"Available: {list(CROP_CONFIGS.keys())}")

## 8Ô∏è‚É£ Train All Crops (Full Run)

‚ö†Ô∏è This will take a while! Only run if you want to train models for all crops.


In [None]:
# ========================================
# üöÄ TRAIN ALL CROPS
# ========================================
# This will train/improve models for ALL crops sequentially
# Memory is cleared between each crop to prevent crashes

EPOCHS = 40
BATCH_SIZE = 16  # Keep at 16 to prevent OOM crashes

results_summary = {}

for crop_id in CROP_CONFIGS.keys():
    try:
        print(f"\n\n{'#'*70}")
        print(f"# Training crop {list(CROP_CONFIGS.keys()).index(crop_id)+1}/{len(CROP_CONFIGS)}: {crop_id}")
        print(f"{'#'*70}")
        
        model, results = train_crop_model(crop_id, epochs=EPOCHS, batch_size=BATCH_SIZE)
        
        if results:
            results_summary[crop_id] = {
                'accuracy': float(results[1]),
                'auc': float(results[2]),
                'status': 'success'
            }
        else:
            results_summary[crop_id] = {'status': 'skipped', 'error': 'No images found'}
            
    except Exception as e:
        print(f"‚ùå Failed to train {crop_id}: {e}")
        results_summary[crop_id] = {'status': 'failed', 'error': str(e)}
        clear_memory()  # Clear memory on failure to recover

# Save summary
with open(MODEL_OUTPUT / 'training_summary.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print("\n" + "="*70)
print("üìä TRAINING SUMMARY")
print("="*70)
for crop, res in results_summary.items():
    if res['status'] == 'success':
        print(f"  ‚úÖ {crop}: Accuracy={res['accuracy']:.4f}, AUC={res['auc']:.4f}")
    elif res['status'] == 'skipped':
        print(f"  ‚è≠Ô∏è {crop}: SKIPPED - {res.get('error', 'Unknown')}")
    else:
        print(f"  ‚ùå {crop}: FAILED - {res.get('error', 'Unknown error')}")

## 9Ô∏è‚É£ Download Trained Models

After training, your models are saved in Google Drive at:
`/content/drive/MyDrive/FasalVaidya_Models/<crop_id>/`

Each folder contains:

- `<crop>_best.keras` - Best model by validation AUC
- `<crop>_final.keras` - Final model after all epochs
- `metadata.json` - Training info and metrics


In [None]:
# List trained models
print("üì¶ Trained models:")
if MODEL_OUTPUT.exists():
    for crop_dir in MODEL_OUTPUT.iterdir():
        if crop_dir.is_dir():
            files = list(crop_dir.glob('*.keras'))
            if files:
                print(f"  {crop_dir.name}/")
                for f in files:
                    size_mb = f.stat().st_size / (1024*1024)
                    print(f"    - {f.name} ({size_mb:.1f} MB)")
else:
    print("  No models found yet. Run training first!")