# SmartCrop AI - Complete Training Pipeline


**Steps:**
1. Mount Google Drive
2. Setup project directory
3. Install dependencies
4. Verify dataset structure
5. Train models
6. Run predictions
7. Export models


## Step 1: Mount Google Drive


In [None]:
from google.colab import drive
drive.mount('/content/drive')


## Step 2: Setup Project Directory


In [None]:
import os

# Set your Google Drive folder path
PROJECT_DIR = '/content/drive/MyDrive/SmartCrop-AI'
os.chdir(PROJECT_DIR)

# Extract project if needed (uncomment if you uploaded as zip)
# !unzip -q smartcrop-ai-colab.zip -d .

# Navigate to AI directory
os.chdir('smartcrop-ai/ai')
print(f"Current directory: {os.getcwd()}")
!ls -la


## Step 3: Install Dependencies


In [None]:
# Install PyTorch with CUDA support
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install computer vision libraries
!pip install -q opencv-python albumentations ultralytics segment-anything

# Install model export tools
!pip install -q onnx onnxruntime onnxscript tensorflow

# Install data processing libraries
!pip install -q pandas scikit-learn scikit-image

# Install visualization libraries
!pip install -q matplotlib seaborn grad-cam

# Install utilities
!pip install -q pyyaml omegaconf tqdm requests

print("\n‚úì All dependencies installed!")


In [None]:
# Verify installation and GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    !nvidia-smi
else:
    print("‚ö†Ô∏è  No GPU detected. Training will be slow on CPU.")
    print("Go to Runtime ‚Üí Change runtime type ‚Üí GPU")


## Step 4: Verify Dataset Structure


In [None]:
# Verify dataset structure
import os
from pathlib import Path

data_dir = Path('data/raw')
print("Checking dataset structure...")
print(f"Train folder exists: {(data_dir / 'train').exists()}")
print(f"Val folder exists: {(data_dir / 'val').exists()}")
print(f"Test folder exists: {(data_dir / 'test').exists()}")

# Count samples
if (data_dir / 'train').exists():
    train_crops = [d.name for d in (data_dir / 'train').iterdir() if d.is_dir()]
    print(f"\n‚úì Found {len(train_crops)} crops in training set")
    print(f"Sample crops: {train_crops[:5]}")
    
    # Count images in first crop
    if train_crops:
        first_crop = data_dir / 'train' / train_crops[0]
        diseases = [d.name for d in first_crop.iterdir() if d.is_dir()]
        if diseases:
            sample_count = len(list((first_crop / diseases[0]).glob('*.jpg'))) + \
                          len(list((first_crop / diseases[0]).glob('*.JPG')))
            print(f"Sample: {train_crops[0]}/{diseases[0]} has {sample_count} images")

print("\n‚úì Dataset is ready for training!")


## Step 5: (Optional) Reduce Dataset Size


In [None]:
# (Optional) Reduce dataset for faster training
# Skip this cell if you want to use the full dataset
# This keeps small classes intact and reduces large classes

# Uncomment the line below to run reduction:
# !python scripts/reduce_dataset.py

# When prompted, type 'y' to proceed
print("Skipping dataset reduction. Uncomment the line above to reduce dataset size.")


In [None]:
# This cell is not needed since dataset is already organized
# Dataset structure is already in data/raw/train/, data/raw/val/, data/raw/test/
pass


In [None]:
# Dataset is already organized - no need to run this
# If you need to reorganize, uncomment below:
# !python scripts/organize_datasets.py
pass


In [None]:
# This cell is not needed - reduction is in Step 5 above
pass


## Step 6: Train MobileNetV3 Model


In [None]:
# Train MobileNetV3 (on-device model)
# This will take 30-60 minutes depending on dataset size

!python train.py --model mobilenet_v3 --data-dir data/raw --epochs 10 --batch-size 32 --lr 0.001

print("\n‚úì Training completed!")


In [None]:
# Check training results
!ls -lh outputs/models/checkpoints/
!tail -50 outputs/logs/training.log


In [None]:
# Verify model file exists before exporting
import os
from pathlib import Path

model_path = Path('outputs/models/checkpoints/best_model.pth')
if model_path.exists():
    size_mb = model_path.stat().st_size / (1024 * 1024)
    print(f"‚úì Found model: {model_path}")
    print(f"  Size: {size_mb:.2f} MB")
else:
    print(f"‚ö†Ô∏è  Model not found at: {model_path}")
    print("Available files in checkpoints:")
    checkpoint_dir = Path('outputs/models/checkpoints')
    if checkpoint_dir.exists():
        for f in checkpoint_dir.glob('*.pth'):
            print(f"  - {f.name}")
    else:
        print("  Checkpoints directory doesn't exist!")


## Step 7: Export Model


In [None]:
# Export MobileNetV3 to mobile formats
# Note: Model is saved as 'best_model.pth' (not mobilenet_v3_best.pth)
!python export_model.py --model mobilenet_v3 --checkpoint outputs/models/checkpoints/best_model.pth

# Verify exports
!ls -lh outputs/models/*.tflite 2>/dev/null || echo "No TFLite files"
!ls -lh outputs/models/*.onnx 2>/dev/null || echo "No ONNX files"


## Step 8: Run Predictions


In [None]:
# Upload a test image
from google.colab import files
uploaded = files.upload()

# Get uploaded filename
import os
image_file = list(uploaded.keys())[0]
print(f"Testing on: {image_file}")


In [None]:
# Run prediction with heatmap
# Note: Model is saved as 'best_model.pth' (not mobilenet_v3_best.pth)
!python predict.py --image {image_file} --model outputs/models/checkpoints/best_model.pth --model-type mobilenet_v3 --heatmap --output outputs/result.jpg

# Display result
from IPython.display import Image, display
display(Image('outputs/result.jpg'))


## Step 9: Predict with Severity (Classification + YOLO + SAM)

**This combines all models to calculate disease severity:**
- Classification model ‚Üí Disease type
- YOLOv8 ‚Üí Lesion count
- SAM ‚Üí Affected area percentage
- Combined ‚Üí Severity level (Low/Moderate/High/Critical)

**Note:** Requires SAM checkpoint. See `SAM_SETUP_GUIDE.md` for download instructions.


In [None]:
# Step 1: Download SAM checkpoint (if not already downloaded)
import os
from pathlib import Path

sam_dir = Path('outputs/models/sam')
sam_dir.mkdir(parents=True, exist_ok=True)

sam_checkpoint = sam_dir / 'sam_vit_b.pth'

if not sam_checkpoint.exists():
    print("üì• Downloading SAM checkpoint (375MB)...")
    print("   This may take a few minutes...")
    checkpoint_path = str(sam_checkpoint)
    !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -O {checkpoint_path}
    print(f"‚úì SAM checkpoint downloaded to {sam_checkpoint}")
else:
    print(f"‚úì SAM checkpoint already exists at {sam_checkpoint}")

# Verify checkpoint size
if sam_checkpoint.exists():
    size_mb = sam_checkpoint.stat().st_size / (1024 * 1024)
    print(f"  Checkpoint size: {size_mb:.1f} MB")


In [None]:
# Step 2: Complete Severity Pipeline (Classification + Trained YOLOv8 + SAM)
# Uses the same image from Step 8, or upload a new one

# Use image from previous prediction cell, or upload new one
try:
    image_file  # Use image from Step 8
except NameError:
    print("‚ö†Ô∏è  No image found. Uploading new image...")
    from google.colab import files
    uploaded = files.upload()
    image_file = list(uploaded.keys())[0]

# Run the complete severity pipeline
# Using quotes around image_file to handle spaces/special characters
# Model type is auto-detected (defaults to efficientnet_b3)
# To specify model type: add --model-type mobilenet_v3 or --model-type efficientnet_b3
!python predict_severity_complete.py --image "{image_file}" --output outputs/severity_result.jpg

# Display result
from IPython.display import Image, display
from pathlib import Path

result_path = Path('outputs/severity_result.jpg')
if result_path.exists():
    display(Image('outputs/severity_result.jpg'))
    print("\n‚úÖ Severity prediction completed!")
    print("\nThe output above shows:")
    print("  - Disease type and confidence")
    print("  - Severity level (Low/Moderate/High/Critical)")
    print("  - Affected area percentage")
    print("  - Lesion count and density")
else:
    print("‚ö†Ô∏è  Result image not found. Check the output above for errors.")


## Step 10: Use YOLOv8 for Object Detection




In [None]:
# Use pretrained YOLOv8 for object detection 
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))

from src.models.yolo_detector import YOLODetector
import cv2
import numpy as np
from IPython.display import Image, display
import matplotlib.pyplot as plt

# Load the same image you used for prediction (or upload a new one)
# If you haven't uploaded an image yet, uncomment below:
# from google.colab import files
# uploaded = files.upload()
# image_file = list(uploaded.keys())[0]

# Or use the image from previous prediction
try:
    image_file  # Use image from previous cell
except NameError:
    print("‚ö†Ô∏è  No image found. Please run the prediction cell first or upload an image.")
    print("Uploading new image...")
    from google.colab import files
    uploaded = files.upload()
    image_file = list(uploaded.keys())[0]

print(f"üì∏ Using image: {image_file}")

# Load image
image = cv2.imread(image_file)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Initialize YOLOv8 detector (pretrained - auto-downloads on first use)
print("ü§ñ Loading pretrained YOLOv8-nano...")
detector = YOLODetector(model_size="n", pretrained=True)

# Run detection
print("üîç Running detection...")
results = detector.detect(image_rgb, conf_threshold=0.25)

# Print results
print(f"\n‚úÖ Detection Results:")
print(f"   Found {results['count']} objects")

if results['count'] > 0:
    print(f"\n   Detections:")
    for i, (box, score, cls) in enumerate(zip(results['boxes'], results['scores'], results['classes'])):
        x1, y1, x2, y2 = map(int, box)
        print(f"   {i+1}. Object class {cls}: confidence {score:.2f}, box [{x1}, {y1}, {x2}, {y2}]")
else:
    print("   No objects detected (try lowering confidence threshold)")

# Draw bounding boxes
img_with_boxes = image.copy()
for box, score, cls in zip(results['boxes'], results['scores'], results['classes']):
    x1, y1, x2, y2 = map(int, box)
    cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
    label = f"Class {cls} ({score:.2f})"
    cv2.putText(img_with_boxes, label, (x1, y1 - 10),
               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

# Save and display
output_path = 'outputs/yolo_detection.jpg'
cv2.imwrite(output_path, img_with_boxes)
display(Image(output_path))

print("\nüìù Note: YOLOv8 detects COCO classes (person, car, dog, etc.)")
print("   For plant disease lesion detection, you'd need to train on annotated plant data.")
print("   See YOLO_TRAINING_GUIDE.md for details.")


## Step 12: Save Models to Google Drive


## Step 10: (Optional) Train Additional Models

**Note:** MobileNetV3 is sufficient for most use cases! Only train these if you need specific features.

- **EfficientNet-B3**: For higher accuracy (server/cloud use)
- **YOLOv8**: For lesion detection (requires annotated data)
- **SAM**: No training needed (already pretrained)

See `MODELS_GUIDE.md` for details on when to use each model.


## Step 14: Train YOLOv8 on PlantDoc Dataset

**Train YOLOv8 to detect plant disease lesions:**
- Uses PlantDoc dataset (already in `data/yolo/`)
- Fine-tunes pretrained YOLOv8 on plant disease lesions
- After training, YOLOv8 will detect actual plant disease lesions (not just general objects)


### Manual Cleanup (if fix cell fails)

If you get "File exists" errors, run this cleanup cell first:


In [None]:
# Manual cleanup: Remove any existing valid symlinks/directories
import os
import shutil
from pathlib import Path

yolo_dir = Path('data/yolo')
valid_dir = yolo_dir / 'valid'
test_dir = yolo_dir / 'test'

print("Cleaning up any existing 'valid' directory/symlink...")

if valid_dir.exists():
    try:
        if valid_dir.is_symlink():
            print(f"  Removing symlink: {valid_dir}")
            valid_dir.unlink()
        else:
            print(f"  Removing directory: {valid_dir}")
            shutil.rmtree(valid_dir)
        print(f"  ‚úì Cleaned up")
    except Exception as e:
        print(f"  ‚ö†Ô∏è  Error: {e}")
        # Force remove
        try:
            os.remove(valid_dir) if valid_dir.is_symlink() else shutil.rmtree(valid_dir, ignore_errors=True)
            print(f"  ‚úì Force removed")
        except:
            print(f"  ‚úó Could not remove - you may need to delete manually")
else:
    print(f"  ‚úì No 'valid' directory found - nothing to clean")

# Also check if test is a symlink pointing to valid (wrong direction)
if test_dir.exists() and test_dir.is_symlink():
    try:
        target = test_dir.readlink()
        if 'valid' in str(target):
            print(f"\n  ‚ö†Ô∏è  Warning: 'test' is a symlink pointing to 'valid' (wrong direction)")
            print(f"     This might cause issues. Consider removing it manually if needed.")
    except:
        pass

print(f"\n‚úÖ Cleanup complete. Now run the fix cell.")


### Quick Fix: Create 'valid' directory (if training failed)

If you got an error about missing 'valid/images', run this cell to fix it:


In [None]:
# Quick fix: Create 'valid' directory and fix data.yaml paths
import os
import shutil
import yaml
from pathlib import Path

yolo_dir = Path('data/yolo')
val_dir = yolo_dir / 'val'
valid_dir = yolo_dir / 'valid'
test_dir = yolo_dir / 'test'
yaml_file = yolo_dir / 'data.yaml'

print("Checking dataset structure...")
print(f"  train exists: {(yolo_dir / 'train').exists()}")
print(f"  test exists: {test_dir.exists()}")
print(f"  val exists: {val_dir.exists()}")
print(f"  valid exists: {valid_dir.exists()}")

# Check if valid is a symlink or real directory
valid_is_symlink = valid_dir.is_symlink() if valid_dir.exists() else False
valid_has_images = (valid_dir / 'images').exists() if valid_dir.exists() else False

# Check if test is a symlink (might be pointing to valid - wrong direction!)
test_is_symlink = test_dir.is_symlink() if test_dir.exists() else False
if test_is_symlink:
    try:
        test_target = test_dir.readlink()
        print(f"  ‚ö†Ô∏è  'test' is a symlink pointing to: {test_target}")
        if 'valid' in str(test_target):
            print(f"     This is backwards! Fixing...")
            # Remove the symlink and restore test as a real directory
            test_backup = yolo_dir / 'test_backup'
            if not test_backup.exists():
                # Copy test contents before removing symlink
                if (Path(test_target) / 'images').exists():
                    shutil.copytree(test_target, test_backup)
                    test_dir.unlink()
                    shutil.move(str(test_backup), str(test_dir))
                    print(f"     ‚úì Fixed: restored 'test' as real directory")
    except:
        pass

# Fix valid directory - use COPY instead of symlink (more reliable on Google Drive)
if valid_dir.exists() and valid_has_images:
    print(f"\n‚úì 'valid' directory already exists and has images")
    if valid_is_symlink:
        try:
            link_target = valid_dir.readlink()
            print(f"   (It's a symlink to: {link_target})")
        except:
            pass
else:
    # ALWAYS remove existing valid directory/symlink (even if empty or broken)
    if valid_dir.exists():
        print(f"\nüîß Removing existing 'valid' directory/symlink...")
        try:
            if valid_dir.is_symlink():
                valid_dir.unlink()
                print(f"   ‚úì Removed symlink")
            elif valid_dir.is_dir():
                shutil.rmtree(valid_dir)
                print(f"   ‚úì Removed directory")
            else:
                valid_dir.unlink()
                print(f"   ‚úì Removed file")
        except Exception as e:
            print(f"   ‚ö†Ô∏è  Error removing: {e}, trying force remove...")
            # Force remove - try multiple methods
            import stat
            try:
                # Make writable if needed
                if valid_dir.is_dir():
                    for root, dirs, files in os.walk(valid_dir):
                        for d in dirs:
                            os.chmod(os.path.join(root, d), stat.S_IWRITE)
                        for f in files:
                            os.chmod(os.path.join(root, f), stat.S_IWRITE)
                # Remove
                if valid_dir.is_symlink():
                    os.remove(valid_dir)
                else:
                    shutil.rmtree(valid_dir, ignore_errors=True)
                print(f"   ‚úì Force removed")
            except Exception as e2:
                print(f"   ‚úó Could not remove: {e2}")
                print(f"   Please manually delete: {valid_dir}")
                raise
    
    # Verify valid_dir is gone before copying
    if valid_dir.exists():
        print(f"\n‚ö†Ô∏è  Warning: 'valid' still exists after removal attempt")
        print(f"   Trying one more time...")
        import time
        time.sleep(0.5)  # Brief pause
        try:
            if valid_dir.is_symlink():
                os.remove(valid_dir)
            else:
                shutil.rmtree(valid_dir, ignore_errors=True)
            if not valid_dir.exists():
                print(f"   ‚úì Successfully removed")
            else:
                raise Exception(f"Could not remove {valid_dir}")
        except Exception as e:
            print(f"   ‚úó Failed: {e}")
            print(f"   Please run this command manually:")
            print(f"   !rm -rf '{valid_dir}'")
            raise
    
    # Create valid from val or test - USE COPY (more reliable than symlink on Drive)
    if val_dir.exists() and (val_dir / 'images').exists():
        print(f"\nüîß Creating 'valid' directory from 'val' (copying)...")
        shutil.copytree(val_dir, valid_dir)
        print(f"   ‚úì Copied {val_dir} to {valid_dir}")
    elif test_dir.exists() and (test_dir / 'images').exists():
        print(f"\nüîß No 'val' found. Using 'test' as 'valid' (copying)...")
        print(f"   Note: This will copy test set to valid (YOLO needs validation set)")
        shutil.copytree(test_dir, valid_dir)
        print(f"   ‚úì Copied {test_dir} to {valid_dir}")
    else:
        print(f"\n‚ö†Ô∏è  Could not create 'valid' directory")
        print(f"   Neither 'val' nor 'test' has images/ subdirectory")

# Fix data.yaml paths to be correct (YOLOv8 format: path + relative paths)
if yaml_file.exists():
    print(f"\nüìù Fixing data.yaml paths...")
    with open(yaml_file, 'r') as f:
        yaml_data = yaml.safe_load(f)
    
    # Get absolute path to yolo directory
    yolo_abs = yolo_dir.absolute()
    
    print(f"   Current paths in data.yaml:")
    print(f"     path: {yaml_data.get('path', 'NOT SET')}")
    print(f"     train: {yaml_data.get('train', 'NOT SET')}")
    print(f"     val: {yaml_data.get('val', 'NOT SET')}")
    
    # YOLOv8 format: path is base directory, train/val/test are relative to path
    # Format: path: /absolute/path/to/dataset
    #         train: train/images  (relative to path)
    #         val: valid/images    (relative to path)
    yaml_data['path'] = str(yolo_abs)  # Absolute base path
    yaml_data['train'] = 'train/images'  # Relative to path
    yaml_data['val'] = 'valid/images'    # Relative to path (not val!)
    if test_dir.exists():
        yaml_data['test'] = 'test/images'  # Relative to path
    
    # Write updated yaml
    with open(yaml_file, 'w') as f:
        yaml.dump(yaml_data, f, default_flow_style=False, sort_keys=False)
    
    print(f"\n   ‚úì Updated data.yaml (YOLOv8 format):")
    print(f"     path: {yaml_data['path']}")
    print(f"     train: {yaml_data['train']} (relative to path)")
    print(f"     val: {yaml_data['val']} (relative to path)")
    if 'test' in yaml_data:
        print(f"     test: {yaml_data['test']} (relative to path)")
    
    # Verify paths exist (combine path + relative path, resolve symlinks)
    print(f"\n   Verifying paths exist:")
    base_path = Path(yaml_data['path'])
    for key in ['train', 'val', 'test']:
        if key in yaml_data:
            # Combine base path with relative path
            full_path = base_path / yaml_data[key]
            # Resolve symlinks
            resolved_path = full_path.resolve()
            exists = resolved_path.exists()
            if exists:
                count = len(list(resolved_path.glob('*.jpg'))) + len(list(resolved_path.glob('*.png')))
                print(f"     ‚úì {key}: {count} images")
                if full_path != resolved_path:
                    print(f"       (resolved from: {full_path} -> {resolved_path})")
            else:
                print(f"     ‚úó {key}: Path does not exist!")
                print(f"       Looking for: {full_path}")
                print(f"       Resolved to: {resolved_path}")

print(f"\n‚úÖ Ready for training! Run the training cell now.")


In [None]:
# Step 1: Verify and fix dataset structure
import os
import yaml
from pathlib import Path

yolo_dir = Path('data/yolo')
print("Checking YOLO dataset structure...")

# Check for common structures
possible_paths = [
    'data/yolo/data.yaml',
    'data/yolo/train/data.yaml',
    'data/yolo/plantdoc/data.yaml',
    'data/yolo/PlantDoc-1/data.yaml'
]

data_yaml = None
for path in possible_paths:
    if Path(path).exists():
        data_yaml = path
        print(f"‚úì Found data.yaml at: {path}")
        break

if not data_yaml:
    print("‚ö†Ô∏è  data.yaml not found. Checking directory structure...")
    print(f"Contents of data/yolo/:")
    if yolo_dir.exists():
        for item in sorted(yolo_dir.iterdir()):
            print(f"  - {item.name}")
    else:
        print("  data/yolo/ directory doesn't exist!")
else:
    # Check for images and labels
    yaml_dir = Path(data_yaml).parent
    print(f"\nDataset directory: {yaml_dir}")
    
    # Check actual structure
    actual_splits = {}
    for split_name in ['train', 'val', 'valid', 'test']:
        # Check different possible structures
        possible_dirs = [
            yaml_dir / split_name / 'images',
            yaml_dir / 'images' / split_name,
            yaml_dir / split_name,
        ]
        
        for dir_path in possible_dirs:
            if dir_path.exists():
                img_count = len(list(dir_path.glob('*.jpg'))) + len(list(dir_path.glob('*.png')))
                if img_count > 0:
                    actual_splits[split_name] = {
                        'images': dir_path,
                        'count': img_count
                    }
                    print(f"  ‚úì Found {split_name}: {img_count} images at {dir_path}")
                    break
    
    # Read and fix data.yaml if needed
    if data_yaml:
        print(f"\nüìù Reading data.yaml...")
        with open(data_yaml, 'r') as f:
            yaml_data = yaml.safe_load(f)
        
        print(f"  Current paths in data.yaml:")
        print(f"    train: {yaml_data.get('train', 'NOT SET')}")
        print(f"    val: {yaml_data.get('val', 'NOT SET')}")
        print(f"    test: {yaml_data.get('test', 'NOT SET')}")
        
        # Fix paths if needed - YOLOv8 expects 'valid' but many datasets use 'val'
        needs_fix = False
        yaml_path = Path(data_yaml).parent
        
        # Determine correct paths
        if 'val' in actual_splits and 'valid' not in actual_splits:
            # Dataset has 'val' but YOLO expects 'valid' - create symlink or fix yaml
            val_path = actual_splits['val']['images'].parent
            valid_path = yaml_path / 'valid'
            
            if not valid_path.exists() and val_path.exists():
                print(f"\n  üîß Creating 'valid' symlink (YOLO expects 'valid' not 'val')...")
                import os
                os.symlink(val_path, valid_path)
                print(f"    Created: {valid_path} -> {val_path}")
                needs_fix = True
        
        # Also fix yaml paths to be relative to yaml file location
        if 'path' not in yaml_data or yaml_data['path'] != str(yaml_path):
            yaml_data['path'] = str(yaml_path)
            needs_fix = True
        
        # Update paths to be relative
        for key in ['train', 'val', 'valid', 'test']:
            if key in yaml_data:
                path_val = yaml_data[key]
                # If it's an absolute path or doesn't exist, fix it
                if Path(path_val).is_absolute() or not (yaml_path / path_val).exists():
                    # Try to find the correct relative path
                    for split_name, split_info in actual_splits.items():
                        if key in [split_name, 'valid' if split_name == 'val' else split_name]:
                            rel_path = split_info['images'].relative_to(yaml_path)
                            yaml_data[key] = str(rel_path.parent) if split_name == 'val' and key == 'valid' else str(rel_path)
                            needs_fix = True
                            break
        
        if needs_fix:
            print(f"\n  üíæ Updating data.yaml...")
            with open(data_yaml, 'w') as f:
                yaml.dump(yaml_data, f, default_flow_style=False)
            print(f"    ‚úì Updated data.yaml")
        
        print(f"\n  Final data.yaml paths:")
        for key in ['train', 'val', 'valid', 'test']:
            if key in yaml_data:
                print(f"    {key}: {yaml_data[key]}")

print(f"\n‚úÖ Using data.yaml: {data_yaml}")


In [None]:
# Step 2: Train YOLOv8 on PlantDoc dataset
import sys
import os
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))

from src.models.yolo_detector import YOLODetector

# Find data.yaml (use the one found in previous cell)
# If not found, try common locations
data_yaml = None
for path in ['data/yolo/data.yaml', 'data/yolo/train/data.yaml', 'data/yolo/plantdoc/data.yaml']:
    if Path(path).exists():
        data_yaml = path
        break

if not data_yaml:
    print("‚ùå Error: data.yaml not found!")
    print("Please check that PlantDoc dataset is in data/yolo/")
else:
    # Ensure 'valid' directory exists (YOLOv8 expects 'valid' not 'val')
    yaml_dir = Path(data_yaml).parent
    val_dir = yaml_dir / 'val'
    valid_dir = yaml_dir / 'valid'
    
    if val_dir.exists() and not valid_dir.exists():
        print(f"üîß Creating 'valid' symlink (YOLO expects 'valid' not 'val')...")
        try:
            os.symlink(val_dir, valid_dir)
            print(f"   ‚úì Created: {valid_dir} -> {val_dir}")
        except Exception as e:
            print(f"   ‚ö†Ô∏è  Could not create symlink: {e}")
            print(f"   Trying to copy instead...")
            import shutil
            shutil.copytree(val_dir, valid_dir)
            print(f"   ‚úì Copied {val_dir} to {valid_dir}")
    
    print(f"\nüìä Training YOLOv8 on: {data_yaml}")
    print("   This will take 2-4 hours depending on dataset size...")
    print("   Training uses pretrained YOLOv8 weights (transfer learning)")
    
    # Initialize YOLOv8 detector (pretrained)
    detector = YOLODetector(model_size="n", pretrained=True)
    
    # Train on PlantDoc dataset
    # Adjust epochs/batch based on your dataset size
    detector.train(
        data_yaml=data_yaml,
        epochs=50,      # 50 epochs is usually enough with pretrained weights
        imgsz=416,      # Image size: 416=faster training, 640=better accuracy (slower)
                        # For plant lesions, 416 is usually sufficient and trains 2-3x faster
        batch=16        # Adjust based on GPU memory (16 for T4, 32 for A100)
                        # With imgsz=416, you might be able to use batch=32 for faster training
    )
    
    print("\n‚úì YOLOv8 training completed!")
    print("  Model saved to: runs/detect/train/weights/best.pt")


In [None]:
# Step 3: Test trained YOLOv8 model
import sys
import cv2
import numpy as np
from IPython.display import Image, display
from pathlib import Path

# Add project to path
sys.path.insert(0, str(Path.cwd()))

# Import YOLODetector
from src.models.yolo_detector import YOLODetector

# Use the same image from previous predictions, or upload new one
try:
    image_file  # Use image from Step 8
except NameError:
    print("‚ö†Ô∏è  No image found. Uploading new image...")
    from google.colab import files
    uploaded = files.upload()
    image_file = list(uploaded.keys())[0]

print(f"üì∏ Testing on: {image_file}")

# Load image
image = cv2.imread(image_file)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Load trained YOLOv8 model
# YOLOv8 saves models in runs/detect/train*/weights/best.pt
# Find the most recent training run
runs_dir = Path('runs/detect')
trained_model_path = None

if runs_dir.exists():
    # Find all train directories
    train_dirs = sorted([d for d in runs_dir.iterdir() if d.is_dir() and d.name.startswith('train')])
    if train_dirs:
        # Try the most recent one
        for train_dir in reversed(train_dirs):  # Start with most recent
            model_path = train_dir / 'weights' / 'best.pt'
            if model_path.exists():
                trained_model_path = model_path
                break

if trained_model_path and Path(trained_model_path).exists():
    print(f"ü§ñ Loading trained YOLOv8 from: {trained_model_path}")
    detector = YOLODetector(model_size="n", weights_path=str(trained_model_path))
else:
    print("‚ö†Ô∏è  Trained model not found. Using pretrained model instead.")
    print(f"   Searched in: {runs_dir}")
    detector = YOLODetector(model_size="n", pretrained=True)

# Run detection
print("üîç Running detection...")
results = detector.detect(image_rgb, conf_threshold=0.25)

# Print results
print(f"\n‚úÖ Detection Results:")
print(f"   Found {results['count']} lesions")

if results['count'] > 0:
    print(f"\n   Detections:")
    for i, (box, score, cls) in enumerate(zip(results['boxes'], results['scores'], results['classes'])):
        x1, y1, x2, y2 = map(int, box)
        print(f"   {i+1}. Lesion class {cls}: confidence {score:.2f}, box [{x1}, {y1}, {x2}, {y2}]")
else:
    print("   No lesions detected")

# Draw bounding boxes
img_with_boxes = image.copy()
for box, score, cls in zip(results['boxes'], results['scores'], results['classes']):
    x1, y1, x2, y2 = map(int, box)
    cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), (0, 255, 0), 2)
    label = f"Lesion {cls} ({score:.2f})"
    cv2.putText(img_with_boxes, label, (x1, y1 - 10),
               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

# Save and display
output_path = 'outputs/yolo_trained_detection.jpg'
cv2.imwrite(output_path, img_with_boxes)
display(Image(output_path))

print("\nüìù Note: If you see 0 lesions, the model may need more training or different images.")
print("   Try lowering confidence threshold or training for more epochs.")


In [None]:
# (Optional) Train EfficientNet-B3 for higher accuracy
# Only train if MobileNetV3 accuracy < 90% or you need server verification
# This takes 6-8 hours

# Uncomment to train:
# !python train.py --model efficientnet_b3 --data-dir data/raw --epochs 5 --batch-size 32 --image-size 224 --lr 0.001

print("Skipping EfficientNet-B3 training. MobileNetV3 is sufficient for most use cases.")
print("Uncomment the line above if you need higher accuracy.")


## Step 15: Growth Stage Classifier (Optional)

**Classify crop growth stage (Seedling, Vegetative, Flowering, Maturity):**
- Uses ResNet50 pretrained on ImageNet
- Only trains the classification head (fast training)
- Useful for growth recommendations and stage-specific advice

**Note:** Requires growth stage labeled data organized as:
```
data/growth_stage/
  train/
    Seedling/
    Vegetative/
    Flowering/
    Maturity/
  val/
    Seedling/
    Vegetative/
    Flowering/
    Maturity/
```

**üì• Where to Get Growth Stage Datasets:**
- See `GROWTH_STAGE_DATASET_GUIDE.md` for comprehensive dataset sources
- **Quick options:**
  - Create your own: Take photos of crops at different stages
  - Use PlantNet/iNaturalist: Filter by growth stage metadata
  - Search Kaggle: "crop growth stage" or "phenology dataset"
  - Research datasets: Agricultural research institutions


In [None]:
# Step 1: Verify Growth Stage Dataset Structure
from pathlib import Path

growth_data_dir = Path('data/growth_stage')
print("Checking growth stage dataset structure...")

if not growth_data_dir.exists():
    print(f"‚ö†Ô∏è  Growth stage data directory not found: {growth_data_dir}")
    print("\nTo use growth stage classifier, organize your data as:")
    print("  data/growth_stage/")
    print("    train/")
    print("      Seedling/")
    print("      Vegetative/")
    print("      Flowering/")
    print("      Maturity/")
    print("    val/")
    print("      Seedling/")
    print("      Vegetative/")
    print("      Flowering/")
    print("      Maturity/")
else:
    train_dir = growth_data_dir / 'train'
    val_dir = growth_data_dir / 'val'
    
    print(f"‚úì Found data directory: {growth_data_dir}")
    print(f"  Train exists: {train_dir.exists()}")
    print(f"  Val exists: {val_dir.exists()}")
    
    if train_dir.exists():
        stages = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
        print(f"\n  Found {len(stages)} growth stages: {', '.join(stages)}")
        
        # Count images per stage
        for stage in stages:
            stage_dir = train_dir / stage
            count = len(list(stage_dir.glob('*.jpg'))) + len(list(stage_dir.glob('*.JPG'))) + \
                   len(list(stage_dir.glob('*.png'))) + len(list(stage_dir.glob('*.PNG')))
            print(f"    {stage}: {count} images")
    
    print("\n‚úÖ Dataset structure verified!")
    print("   Ready for training (run next cell)")


In [None]:
# Step 2: Train Growth Stage Classifier
# Only train if you have growth stage labeled data

from pathlib import Path

growth_data_dir = Path('data/growth_stage')
if not growth_data_dir.exists():
    print("‚ö†Ô∏è  Growth stage data not found. Skipping training.")
    print("   See previous cell for dataset structure requirements.")
else:
    print("üå± Training Growth Stage Classifier...")
    print("   This will take 30-60 minutes depending on dataset size")
    print("   Model: ResNet50 (pretrained on ImageNet)")
    print("   Only classification head is trained (backbone frozen)\n")
    
    !python train_growth_stage.py \
        --data-dir data/growth_stage \
        --epochs 20 \
        --batch-size 32 \
        --image-size 224 \
        --lr 0.001 \
        --output-dir outputs/models/growth_stage
    
    print("\n‚úÖ Growth stage classifier training completed!")
    print("   Model saved to: outputs/models/growth_stage/best_model.pth")


In [None]:
# Step 3: Predict Growth Stage
# Use the same image from previous predictions, or upload a new one

from pathlib import Path
from IPython.display import Image, display

# Use image from previous prediction cell, or upload new one
try:
    image_file  # Use image from Step 8
except NameError:
    print("‚ö†Ô∏è  No image found. Uploading new image...")
    from google.colab import files
    uploaded = files.upload()
    image_file = list(uploaded.keys())[0]

print(f"üì∏ Predicting growth stage for: {image_file}\n")

# Check if model exists
model_path = Path('outputs/models/growth_stage/best_model.pth')
if not model_path.exists():
    print("‚ö†Ô∏è  Growth stage model not found!")
    print(f"   Expected: {model_path}")
    print("   Please train the model first (Step 15, Cell 2)")
else:
    # Run prediction
    # Using quotes around image_file to handle spaces/special characters
    !python predict_growth_stage.py --image "{image_file}" --model outputs/models/growth_stage/best_model.pth
    
    print("\n‚úÖ Growth stage prediction completed!")
    print("\nThe output above shows:")
    print("  - Predicted growth stage (Seedling/Vegetative/Flowering/Maturity)")
    print("  - Confidence score")
    print("  - Top 3 predictions with confidence scores")


In [None]:
# Create models directory in Drive
!mkdir -p /content/drive/MyDrive/SmartCrop-AI/models

# Copy trained models
!cp -r outputs/models/checkpoints/* /content/drive/MyDrive/SmartCrop-AI/models/

# Copy exported models
!cp outputs/models/*.tflite /content/drive/MyDrive/SmartCrop-AI/models/ 2>/dev/null || echo "No TFLite files"
!cp outputs/models/*.onnx /content/drive/MyDrive/SmartCrop-AI/models/ 2>/dev/null || echo "No ONNX files"

print("‚úì Models saved to Google Drive!")
print("Location: /content/drive/MyDrive/SmartCrop-AI/models/")
