# Sweet Potato Root Detection & Segmentation with YOLOv8

**Tuskegee AIFARMS Agricultural AI Research**

This notebook provides a complete training pipeline for sweet potato root detection and instance segmentation using Ultralytics YOLOv8.

**Model choice:** We use the **YOLOv8 m-seg** (medium) checkpoint instead of the smaller n (nano) variant to increase model capacity (more parameters and channels) for finer, more accurate segmentation masks. Training and inference APIs are unchanged.

## ‚ö†Ô∏è IMPORTANT: NumPy 2.x Fix Required

**If you see `AttributeError: _ARRAY_API not found`:**

1. **Run Cell 3 FIRST** (NumPy compatibility fix)
2. **Restart Kernel** (Kernel ‚Üí Restart)
3. Then continue with Cell 4 (Installation)

**Or fix from terminal before opening notebook:**
```bash
pip install 'numpy<2.0' --force-reinstall
```

See `FIX_NUMPY_ERROR.md` for detailed instructions.

## Features
- Root-specific preprocessing for soil backgrounds and lighting variations
- Custom metrics: root count, area coverage, size distribution
- Active learning: low-confidence prediction flagging
- Model comparison: YOLOv8 vs Mask R-CNN baseline
- Transfer learning from agricultural checkpoints
- Production-ready exports (ONNX, TorchScript)

## 1. Setup & Installation

In [32]:
# Local setup - no Google Drive needed
import os

# Check if running in Colab
try:
    from google.colab import drive
    IS_COLAB = True
    drive.mount('/content/drive', force_remount=True)
    print("‚úì Google Drive mounted (Colab mode)")
except ImportError:
    IS_COLAB = False
    print("‚úì Running in local mode (no Google Drive needed)")

‚úì Running in local mode (no Google Drive needed)


In [33]:
# CRITICAL FIX: NumPy 2.x Compatibility Issue
# This MUST run before any other imports to fix NumPy 2.x / opencv-python incompatibility
# Run this cell FIRST if you see _ARRAY_API errors

import subprocess
import sys
import os
import json

print("="*70)
print("FIXING NUMPY 2.X COMPATIBILITY (Must run before other cells)")
print("="*70)

# Check NumPy version using pip (without importing numpy)
print("Checking NumPy version...")
result = subprocess.run(
    [sys.executable, '-m', 'pip', 'show', 'numpy'],
    capture_output=True,
    text=True,
    timeout=30
)

numpy_installed = result.returncode == 0
numpy_version = None

if numpy_installed:
    # Parse version from pip show output
    for line in result.stdout.split('\n'):
        if line.startswith('Version:'):
            numpy_version = line.split(':', 1)[1].strip()
            break
    
    if numpy_version:
        major_version = int(numpy_version.split('.')[0])
        print(f"Current NumPy version: {numpy_version}")
        
        if major_version >= 2:
            print(f"\n‚ö† CRITICAL: NumPy {numpy_version} is incompatible with opencv-python!")
            print("  opencv-python was compiled for NumPy 1.x and will fail with _ARRAY_API error")
            print("\n  Downgrading NumPy to < 2.0...")
            
            # Clear any cached imports first
            modules_to_clear = ['numpy', 'cv2', 'ultralytics', 'opencv']
            for mod in modules_to_clear:
                if mod in sys.modules:
                    del sys.modules[mod]
            
            # Force downgrade NumPy (without dependencies to avoid conflicts)
            print("  Installing NumPy < 2.0...")
            fix_result = subprocess.run(
                [sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall'],
                capture_output=True,
                text=True,
                timeout=180
            )
            
            if fix_result.returncode == 0:
                # Verify the fix
                verify_result = subprocess.run(
                    [sys.executable, '-m', 'pip', 'show', 'numpy'],
                    capture_output=True,
                    text=True,
                    timeout=30
                )
                if verify_result.returncode == 0:
                    for line in verify_result.stdout.split('\n'):
                        if line.startswith('Version:'):
                            new_version = line.split(':', 1)[1].strip()
                            print(f"\n‚úì SUCCESS: NumPy downgraded to {new_version}")
                            print("  IMPORTANT: Restart kernel now, then proceed to Cell 4 (Installation)")
                            print("  In Jupyter: Kernel -> Restart & Run All")
                            break
                else:
                    print("\n‚úì NumPy downgrade completed (verification pending)")
            else:
                print(f"\n‚úó Automatic fix failed. Error:")
                if fix_result.stderr:
                    error_msg = fix_result.stderr[:500]
                    print(f"  {error_msg}")
                print(f"\n  Please run manually in terminal:")
                print(f"  {sys.executable} -m pip install 'numpy<2.0' --force-reinstall")
        else:
            print(f"‚úì NumPy {numpy_version} is compatible (< 2.0)")
            print("  No fix needed - proceed to next cell")
    else:
        print("‚ö† Could not determine NumPy version")
else:
    print("NumPy not installed - will be installed with correct version (< 2.0) in next cell")
    # Pre-install NumPy < 2.0 to avoid issues
    print("Pre-installing NumPy < 2.0...")
    pre_result = subprocess.run(
        [sys.executable, '-m', 'pip', 'install', 'numpy<2.0'],
        capture_output=True,
        text=True,
        timeout=120
    )
    if pre_result.returncode == 0:
        print("‚úì NumPy < 2.0 pre-installed")
    else:
        print("‚ö† Pre-installation failed, will try again in next cell")

print("="*70)
print("IMPORTANT: If NumPy was downgraded, RESTART KERNEL before continuing!")
print("  Kernel -> Restart & Run All")
print("="*70)

FIXING NUMPY 2.X COMPATIBILITY (Must run before other cells)
Checking NumPy version...
Current NumPy version: 1.26.4
‚úì NumPy 1.26.4 is compatible (< 2.0)
  No fix needed - proceed to next cell
IMPORTANT: If NumPy was downgraded, RESTART KERNEL before continuing!
  Kernel -> Restart & Run All


In [34]:
# Install dependencies with exact versions
# In Colab, packages will be installed; locally, check if installed
import subprocess
import sys
import importlib

def install_package(package, use_exact_version=True):
    """Install a package with error handling"""
    try:
        if use_exact_version:
            result = subprocess.run(
                [sys.executable, '-m', 'pip', 'install', package],
                capture_output=True,
                text=True,
                check=False
            )
            if result.returncode == 0:
                return True, None
            else:
                # Try without version pin if exact version fails
                if '==' in package:
                    pkg_name = package.split('==')[0]
                    print(f"    ‚ö† Exact version failed, trying latest: {pkg_name}")
                    return install_package(pkg_name, use_exact_version=False)
                return False, result.stderr
        else:
            result = subprocess.run(
                [sys.executable, '-m', 'pip', 'install', package],
                capture_output=True,
                text=True,
                check=False
            )
            return result.returncode == 0, result.stderr if result.returncode != 0 else None
    except Exception as e:
        return False, str(e)

if IS_COLAB:
    # Colab: install packages using subprocess (works in notebooks)
    packages = [
        'ultralytics==8.1.0', 'roboflow==1.1.1', 'torch==2.1.0', 'torchvision==0.16.0',
        'opencv-python==4.8.1.78', 'matplotlib==3.8.2', 'seaborn==0.13.0',
        'pandas==2.1.4', 'numpy==1.24.3', 'tqdm==4.66.1', 'pyyaml==6.0.1',
        'onnx==1.15.0', 'onnxruntime==1.16.3'
    ]
    for pkg in packages:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg, '-q'])
    print("‚úì All dependencies installed (Colab)")
else:
    # Local: check if packages are installed
    required = {
        'torch': 'torch', 
        'torchvision': 'torchvision', 
        'ultralytics': 'ultralytics', 
        'cv2': 'opencv-python', 
        'yaml': 'pyyaml', 
        'pandas': 'pandas', 
        'numpy': 'numpy', 
        'matplotlib': 'matplotlib', 
        'seaborn': 'seaborn', 
        'tqdm': 'tqdm'
    }
    missing = []
    for module_name, package_name in required.items():
        try:
            if module_name == 'cv2':
                importlib.import_module('cv2')
            else:
                importlib.import_module(module_name)
        except ImportError:
            missing.append(package_name)
    
    if missing:
        print(f"‚ö† Missing packages: {missing}")
        print("\nInstalling missing packages...")
        
        # Priority order: 
        # 1. numpy first (needed by opencv-python and others)
        # 2. torch and torchvision (large dependencies)
        # 3. opencv-python (after numpy to avoid version conflicts)
        # 4. ultralytics (depends on torch)
        # 5. Others
        priority_order = ['numpy', 'torch', 'torchvision', 'opencv-python', 'ultralytics']
        other_packages = [pkg for pkg in missing if pkg not in priority_order]
        ordered_missing = [pkg for pkg in priority_order if pkg in missing] + other_packages
        
        # CRITICAL: Check and fix NumPy version first (NumPy 2.x breaks opencv-python)
        print("  Checking NumPy version compatibility...")
        try:
            import numpy
            numpy_version = numpy.__version__
            major_version = int(numpy_version.split('.')[0])
            if major_version >= 2:
                print(f"  ‚ö† NumPy {numpy_version} detected - downgrading to < 2.0 for opencv-python compatibility")
                subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall', '--quiet'],
                             check=False, timeout=120)
                print("  ‚úì NumPy downgraded")
        except:
            pass  # NumPy not installed yet, will install correct version
        
        # Try installing from requirements.txt first (faster if it works)
        if os.path.exists('requirements.txt'):
            print("  Attempting to install from requirements.txt...")
            result = subprocess.run(
                [sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'],
                capture_output=True,
                text=True,
                check=False
            )
            if result.returncode == 0:
                print("  ‚úì Installed from requirements.txt")
                # Ensure NumPy < 2.0 after installation
                try:
                    import numpy
                    if int(numpy.__version__.split('.')[0]) >= 2:
                        print("  ‚ö† Ensuring NumPy < 2.0...")
                        subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall', '--quiet'],
                                     check=False, timeout=120)
                except:
                    pass
                # Re-check what's still missing
                still_missing = []
                for module_name, package_name in required.items():
                    try:
                        if module_name == 'cv2':
                            importlib.import_module('cv2')
                        else:
                            importlib.import_module(module_name)
                    except ImportError:
                        still_missing.append(package_name)
                if not still_missing:
                    print("‚úì All packages installed successfully!")
                else:
                    print(f"  ‚ö† Some packages still missing: {still_missing}")
                    ordered_missing = [pkg for pkg in priority_order if pkg in still_missing] + [pkg for pkg in still_missing if pkg not in priority_order]
            else:
                print(f"  ‚ö† Installation from requirements.txt failed, installing individually...")
                print(f"  Error: {result.stderr[:200] if result.stderr else 'Unknown error'}")
        
        # Install packages individually with special handling for numpy/opencv conflict
        failed_packages = []
        for pkg in ordered_missing:
            print(f"  Installing {pkg}...", end=' ')
            # Get version from requirements.txt if available
            pkg_with_version = pkg
            if os.path.exists('requirements.txt'):
                with open('requirements.txt', 'r') as f:
                    for line in f:
                        line = line.strip()
                        if line and not line.startswith('#') and '==' in line:
                            req_pkg = line.split('==')[0].strip().lower()
                            if req_pkg == pkg.lower() or req_pkg.replace('-', '_') == pkg.lower().replace('-', '_'):
                                pkg_with_version = line.split('#')[0].strip()
                                break
            
            # Special handling for numpy/opencv-python compatibility
            if pkg == 'opencv-python':
                # CRITICAL: Ensure numpy < 2.0 (opencv-python doesn't work with NumPy 2.x)
                try:
                    import numpy
                    numpy_version = numpy.__version__
                    major_version = int(numpy_version.split('.')[0])
                    if major_version >= 2:
                        print(f"\n    ‚ö† NumPy {numpy_version} detected - downgrading to < 2.0")
                        subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall', '--quiet'],
                                     check=False, timeout=120)
                except:
                    # NumPy not installed, install compatible version
                    subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--quiet'],
                                 check=False, timeout=120)
            
            success, error = install_package(pkg_with_version)
            if success:
                print("‚úì")
                # After installing opencv-python, ensure numpy < 2.0
                if pkg == 'opencv-python':
                    try:
                        import numpy
                        if int(numpy.__version__.split('.')[0]) >= 2:
                            print(f"    ‚ö† Fixing NumPy version...")
                            subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall', '--quiet'],
                                         check=False, timeout=120)
                    except:
                        pass
            else:
                print("‚úó")
                failed_packages.append((pkg, error))
        
        if failed_packages:
            print(f"\n‚ö† Failed to install {len(failed_packages)} package(s):")
            for pkg, error in failed_packages:
                print(f"  - {pkg}")
                if error:
                    error_lines = error.split('\n')[:3]  # Show first 3 lines of error
                    for line in error_lines:
                        if line.strip():
                            print(f"    {line[:100]}")
            print(f"\nüí° Try installing manually:")
            print(f"   pip install {' '.join([pkg for pkg, _ in failed_packages])}")
        else:
            print("\n‚úì All missing packages installed successfully!")
    else:
        print("‚úì All required packages are installed (local mode)")

‚úì All required packages are installed (local mode)


In [35]:
# Import libraries
# Check if packages are installed before importing
import sys

# Track which imports succeed
failed_imports = []

print("Importing required packages...")

# CRITICAL: Import numpy FIRST before cv2 to avoid _ARRAY_API errors
# Also check for NumPy 2.x compatibility issue
try:
    import numpy as np
    numpy_version = np.__version__
    major_version = int(numpy_version.split('.')[0])
    if major_version >= 2:
        print(f"  ‚ö† NumPy {numpy_version} detected (incompatible with opencv-python)")
        print("     Downgrading to NumPy < 2.0...")
        import subprocess
        subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall', '--quiet'],
                     check=False, timeout=120)
        # Reload numpy
        import importlib
        if 'numpy' in sys.modules:
            del sys.modules['numpy']
        import numpy as np
        print(f"  ‚úì numpy {np.__version__} (downgraded)")
    else:
        print(f"  ‚úì numpy {numpy_version}")
except ImportError:
    print("  ‚úó numpy - not installed")
    failed_imports.append(('numpy', 'numpy'))
    # Create a dummy np to avoid NameError later
    np = None

# Now import cv2 (requires numpy to be imported first)
try:
    import cv2
    print("  ‚úì cv2 (opencv-python)")
except (ImportError, AttributeError) as e:
    # Handle both ImportError and AttributeError (_ARRAY_API issue)
    if '_ARRAY_API' in str(e) or 'AttributeError' in str(type(e).__name__):
        print("  ‚ö† cv2 import error (NumPy 2.x compatibility issue)")
        print("     Fixing by downgrading NumPy to < 2.0...")
        try:
            import subprocess
            import importlib
            # Downgrade NumPy to < 2.0
            subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall', '--quiet'], 
                         check=False, timeout=120)
            # Clear cached imports
            if 'numpy' in sys.modules:
                del sys.modules['numpy']
            if 'cv2' in sys.modules:
                del sys.modules['cv2']
            # Re-import
            import numpy as np
            import cv2
            print(f"  ‚úì cv2 (fixed - NumPy {np.__version__})")
        except Exception as fix_error:
            print(f"  ‚úó cv2 - failed to fix: {fix_error}")
            print("     Try manually: pip install 'numpy<2.0' opencv-python --force-reinstall")
            failed_imports.append(('cv2', 'opencv-python'))
    else:
        print("  ‚úó cv2 - opencv-python not installed")
        failed_imports.append(('cv2', 'opencv-python'))

# Import torch and torchvision
try:
    import torch
    print("  ‚úì torch")
except ImportError:
    print("  ‚úó torch - not installed")
    failed_imports.append(('torch', 'torch'))

try:
    import torchvision
    print("  ‚úì torchvision")
except ImportError:
    print("  ‚úó torchvision - not installed")
    failed_imports.append(('torchvision', 'torchvision'))

try:
    import matplotlib.pyplot as plt
    print("  ‚úì matplotlib")
except ImportError:
    print("  ‚úó matplotlib - not installed")
    failed_imports.append(('matplotlib', 'matplotlib'))

try:
    import seaborn as sns
    print("  ‚úì seaborn")
except ImportError:
    print("  ‚úó seaborn - not installed")
    failed_imports.append(('seaborn', 'seaborn'))

try:
    import pandas as pd
    print("  ‚úì pandas")
except ImportError:
    print("  ‚úó pandas - not installed")
    failed_imports.append(('pandas', 'pandas'))

try:
    import yaml
    print("  ‚úì yaml (pyyaml)")
except ImportError:
    print("  ‚úó yaml - pyyaml not installed")
    failed_imports.append(('yaml', 'pyyaml'))

# Import standard library modules (should always work)
import zipfile
import shutil
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

try:
    from tqdm import tqdm
    print("  ‚úì tqdm")
except ImportError:
    print("  ‚úó tqdm - not installed")
    failed_imports.append(('tqdm', 'tqdm'))

# Import ultralytics (depends on torch and cv2)
try:
    from ultralytics import YOLO
    from ultralytics.utils.metrics import ConfusionMatrix
    print("  ‚úì ultralytics")
except (ImportError, AttributeError) as e:
    error_msg = str(e)
    if '_ARRAY_API' in error_msg or 'cv2' in error_msg.lower():
        print("  ‚ö† ultralytics import error (likely cv2/numpy issue)")
        # If cv2 failed, ultralytics will also fail
        cv2_failed = any(mod == 'cv2' for mod, _ in failed_imports)
        if cv2_failed:
            print("     This is expected - cv2 must be fixed first")
        else:
            print("     Attempting to fix...")
            try:
                import subprocess
                import importlib
                # Fix NumPy version and reinstall opencv-python
                subprocess.run([sys.executable, '-m', 'pip', 'install', 'numpy<2.0', 'opencv-python', '--force-reinstall', '--quiet'], 
                             check=False, timeout=120)
                # Clear any cached imports
                if 'numpy' in sys.modules:
                    del sys.modules['numpy']
                if 'cv2' in sys.modules:
                    del sys.modules['cv2']
                if 'ultralytics' in sys.modules:
                    del sys.modules['ultralytics']
                # Re-import
                import numpy as np
                import cv2
                from ultralytics import YOLO
                from ultralytics.utils.metrics import ConfusionMatrix
                print("  ‚úì ultralytics (fixed)")
            except Exception as fix_error:
                print(f"  ‚úó ultralytics - failed to fix: {fix_error}")
                print("     Try manually: pip install 'numpy<2.0' opencv-python ultralytics --force-reinstall")
                failed_imports.append(('ultralytics', 'ultralytics'))
    else:
        print("  ‚úó ultralytics - requires torch to be installed first")
        failed_imports.append(('ultralytics', 'ultralytics'))

# Check for optional packages
try:
    import onnx
    ONNX_AVAILABLE = True
    print("  ‚úì ONNX (optional - for model export)")
except ImportError:
    ONNX_AVAILABLE = False
    print("  ‚ö† ONNX not available (optional - only needed for model export)")

# If critical imports failed, provide helpful error message
if failed_imports:
    print(f"\n‚úó Failed to import {len(failed_imports)} required package(s):")
    for module_name, package_name in failed_imports:
        print(f"   - {module_name} (install: {package_name})")
    print(f"\nüí° To fix this:")
    print(f"   1. Re-run Cell 3 (installation cell) above")
    print(f"   2. Or install manually: pip install {' '.join([pkg for _, pkg in failed_imports])}")
    print(f"   3. Or install all: pip install -r requirements.txt")
    print(f"\n‚ö† Please install the missing packages before continuing!")
    raise ImportError(f"Missing required packages: {', '.join([pkg for _, pkg in failed_imports])}")

# All imports successful
print(f"\n‚úì All required packages imported successfully!")
print(f"‚úì PyTorch version: {torch.__version__}")
print(f"‚úì CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úì CUDA version: {torch.version.cuda}")
else:
    print("‚ö† No GPU detected. Training will be slow on CPU.")

Importing required packages...
  ‚úì numpy 1.26.4
  ‚úì cv2 (opencv-python)
  ‚úì torch
  ‚úì torchvision
  ‚úì matplotlib
  ‚úì seaborn
  ‚úì pandas
  ‚úì yaml (pyyaml)
  ‚úì tqdm
  ‚úì ultralytics
  ‚úì ONNX (optional - for model export)

‚úì All required packages imported successfully!
‚úì PyTorch version: 2.6.0+cu124
‚úì CUDA available: True
‚úì GPU: NVIDIA GeForce RTX 4070 Laptop GPU
‚úì CUDA version: 12.4


In [36]:
# Troubleshooting: Fix NumPy 2.x / opencv-python compatibility issues
# Run this cell if you see _ARRAY_API errors or cv2 import failures
# This fixes the NumPy 2.x incompatibility with opencv-python

import subprocess
import sys
import importlib

print("="*60)
print("FIXING NUMPY/OPENCV-PYTHON COMPATIBILITY")
print("="*60)

# Check current NumPy version
try:
    import numpy as np
    numpy_version = np.__version__
    major_version = int(numpy_version.split('.')[0])
    print(f"Current NumPy version: {numpy_version}")
    
    if major_version >= 2:
        print(f"\n‚ö† PROBLEM: NumPy {numpy_version} is incompatible with opencv-python")
        print("  opencv-python was compiled for NumPy 1.x and cannot run with NumPy 2.x")
        print("\n  Fixing by downgrading NumPy to < 2.0...")
        
        # Downgrade NumPy
        result = subprocess.run(
            [sys.executable, '-m', 'pip', 'install', 'numpy<2.0', '--force-reinstall'],
            capture_output=True,
            text=True,
            timeout=120
        )
        
        if result.returncode == 0:
            # Clear cached imports
            if 'numpy' in sys.modules:
                del sys.modules['numpy']
            if 'cv2' in sys.modules:
                del sys.modules['cv2']
            if 'ultralytics' in sys.modules:
                del sys.modules['ultralytics']
            
            # Re-import to verify
            import numpy as np
            print(f"\n‚úì SUCCESS: NumPy downgraded to {np.__version__}")
            print("\n  Testing cv2 import...")
            try:
                import cv2
                print("  ‚úì cv2 imports successfully!")
                print("\n‚úÖ FIXED! You can now re-run the import cell (Cell 4)")
            except Exception as e:
                print(f"  ‚úó cv2 still fails: {e}")
                print("\n  Try manually: pip install 'numpy<2.0' opencv-python --force-reinstall")
        else:
            print(f"\n‚úó Installation failed. Error:")
            print(result.stderr[:500] if result.stderr else "Unknown error")
            print("\n  Try manually: pip install 'numpy<2.0' --force-reinstall")
    else:
        print(f"‚úì NumPy {numpy_version} is compatible (< 2.0)")
        try:
            import cv2
            print("‚úì cv2 imports successfully - no fix needed")
        except Exception as e:
            print(f"‚ö† cv2 import failed: {e}")
            print("  Try: pip install opencv-python --force-reinstall")
            
except ImportError:
    print("‚ö† NumPy not installed")
    print("  Installing NumPy < 2.0...")
    result = subprocess.run(
        [sys.executable, '-m', 'pip', 'install', 'numpy<2.0'],
        capture_output=True,
        text=True,
        timeout=120
    )
    if result.returncode == 0:
        print("‚úì NumPy installed")
    else:
        print("‚úó Installation failed")

print("="*60)

FIXING NUMPY/OPENCV-PYTHON COMPATIBILITY
Current NumPy version: 1.26.4
‚úì NumPy 1.26.4 is compatible (< 2.0)
‚úì cv2 imports successfully - no fix needed


In [37]:
# ROBUST PATH HELPER - Works in both notebooks and scripts
from pathlib import Path
import os

# Get the root directory (notebook location or current working directory)
def get_project_root():
    """Get the project root directory, works in notebooks and scripts"""
    try:
        # Try to get notebook file path (works in Jupyter/IPython)
        import __main__
        if hasattr(__main__, '__file__'):
            return Path(__main__.__file__).resolve().parent
    except:
        pass
    
    # Fallback: use current working directory
    return Path.cwd().resolve()

# Set project root
PROJECT_ROOT = get_project_root()

# Define dataset paths
DATASET_ROOT = PROJECT_ROOT / "sweetpotato_project" / "dataset"
DATA_YAML_PATH = DATASET_ROOT / "data.yaml"
WORK_DIR_PATH = PROJECT_ROOT / "sweetpotato_project"

# Convert to absolute paths (strings for compatibility)
DATASET_DIR = str(DATASET_ROOT.resolve())
DATA_YAML = str(DATA_YAML_PATH.resolve())
WORK_DIR = str(WORK_DIR_PATH.resolve())

print("="*60)
print("PATH CONFIGURATION")
print("="*60)
print(f"Project Root: {PROJECT_ROOT}")
print(f"Dataset Dir:  {DATASET_DIR}")
print(f"Data YAML:    {DATA_YAML}")
print(f"Work Dir:     {WORK_DIR}")

# Verify data.yaml exists
if not Path(DATA_YAML).exists():
    print(f"\n‚ö† WARNING: data.yaml not found at {DATA_YAML}")
    print(f"   It will be created in the dataset setup cell.")
else:
    print(f"\n‚úì data.yaml found at: {DATA_YAML}")

print("="*60)

PATH CONFIGURATION
Project Root: C:\Users\kensm\farm-photo-outliner
Dataset Dir:  C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset
Data YAML:    C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\data.yaml
Work Dir:     C:\Users\kensm\farm-photo-outliner\sweetpotato_project

‚úì data.yaml found at: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\data.yaml


## 2. Data Preparation

In [38]:
# SETUP: Find or create data.yaml before verification
# This ensures data_yaml_path and data_config are defined

import os
import yaml
from pathlib import Path

# Find or determine DATASET_DIR
if 'DATASET_DIR' not in globals() or ('DATASET_DIR' in globals() and not os.path.exists(DATASET_DIR)):
    # Try to find dataset directory
    possible_paths = [
        './sweetpotato_project/dataset',
        './dataset',
        '../dataset',
        os.path.join(os.getcwd(), 'dataset')
    ]
    
    # Also check for extracted dataset paths
    if 'DATASET_DIR_EXTRACTED' in globals() and DATASET_DIR_EXTRACTED and os.path.exists(DATASET_DIR_EXTRACTED):
        possible_paths.insert(0, DATASET_DIR_EXTRACTED)
    
    DATASET_DIR = None
    for path in possible_paths:
        if os.path.exists(path) and os.path.isdir(path):
            DATASET_DIR = os.path.abspath(path)
            break
    
    if not DATASET_DIR:
        # Create default dataset directory
        DATASET_DIR = os.path.abspath('./sweetpotato_project/dataset')
        os.makedirs(DATASET_DIR, exist_ok=True)
        print(f"‚ö† DATASET_DIR not found, created: {DATASET_DIR}")
    else:
        print(f"‚úì Using DATASET_DIR: {DATASET_DIR}")
else:
    DATASET_DIR = os.path.abspath(DATASET_DIR)
    print(f"‚úì DATASET_DIR: {DATASET_DIR}")

# Find or create data.yaml
DATA_YAML = "data.yaml"  # Configurable variable
data_yaml_path = os.path.join(DATASET_DIR, DATA_YAML)

if not os.path.exists(data_yaml_path):
    # Search in subdirectories
    print(f"Searching for {DATA_YAML}...")
    found = False
    for root, dirs, files in os.walk(DATASET_DIR):
        if DATA_YAML in files:
            data_yaml_path = os.path.join(root, DATA_YAML)
            found = True
            print(f"‚úì Found {DATA_YAML} at: {data_yaml_path}")
            break
    
    if not found:
        print(f"‚ö† {DATA_YAML} not found. Creating template...")
        
        # Create folder structure if missing
        for folder in ['train/images', 'train/labels', 'valid/images', 'valid/labels', 'test/images', 'test/labels']:
            folder_path = os.path.join(DATASET_DIR, folder)
            os.makedirs(folder_path, exist_ok=True)
        
        # Create default data.yaml template
        default_yaml = {
            'path': DATASET_DIR,
            'train': 'train/images',
            'val': 'valid/images',
            'test': 'test/images',
            'nc': 3,  # Sweet potato classes
            'names': ['Diseased', 'Healthy', 'Non-determined']
        }
        
        with open(data_yaml_path, 'w') as f:
            yaml.dump(default_yaml, f, default_flow_style=False, sort_keys=False)
        
        print(f"‚úì Created {DATA_YAML} at: {data_yaml_path}")
        print(f"‚úì Created dataset folder structure:")
        print(f"   - train/images, train/labels")
        print(f"   - valid/images, valid/labels")
        print(f"   - test/images, test/labels")
else:
    print(f"‚úì Found {DATA_YAML} at: {data_yaml_path}")

# Load data.yaml
try:
    with open(data_yaml_path, 'r') as f:
        data_config = yaml.safe_load(f)
    print(f"‚úì Loaded {DATA_YAML} configuration")
    print(f"  Classes: {data_config.get('nc', 'N/A')}")
    print(f"  Class names: {data_config.get('names', 'N/A')}")
except Exception as e:
    print(f"‚úó Error loading {DATA_YAML}: {e}")
    raise

‚úì DATASET_DIR: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset
‚úì Found data.yaml at: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\data.yaml
‚úì Loaded data.yaml configuration
  Classes: 3
  Class names: ['Diseased', 'Healthy', 'Non-determined']


In [39]:
# Configuration paths - ADAPT FOR YOUR SYSTEM
if IS_COLAB:
    DRIVE_PATH = '/content/drive/MyDrive'
    DATASET_ZIP = 'Sweetpotato_roots.v2i.yolov8 (1).zip'  # Update this path
    WORK_DIR = '/content/sweetpotato_project'
    DATASET_DIR = f'{WORK_DIR}/dataset'
else:
    # LOCAL MODE: Update these paths for your system
    # Option 1: Path to dataset zip file
    DATASET_ZIP = 'Sweetpotato_roots.v2i.yolov8 (1).zip'  # Update this path or set to None
    
    # Option 2: Path to already extracted dataset (if you have it unzipped)
    DATASET_DIR_EXTRACTED = None  # e.g., r'C:\path\to\Sweetpotato_roots.v2i.yolov8'
    
    # Working directory (where outputs will be saved)
    WORK_DIR = './sweetpotato_project'
    DATASET_DIR = f'{WORK_DIR}/dataset'

# Create working directory
os.makedirs(WORK_DIR, exist_ok=True)
if not IS_COLAB and DATASET_DIR_EXTRACTED is None:
    os.makedirs(DATASET_DIR, exist_ok=True)

print(f"‚úì Working directory: {os.path.abspath(WORK_DIR)}")
if not IS_COLAB:
    print(f"‚úì Running in LOCAL mode")
    print(f"  Dataset zip: {DATASET_ZIP if DATASET_ZIP else 'Not specified'}")
    print(f"  Extracted dataset: {DATASET_DIR_EXTRACTED if DATASET_DIR_EXTRACTED else 'Not specified'}")

‚úì Working directory: c:\Users\kensm\farm-photo-outliner\sweetpotato_project
‚úì Running in LOCAL mode
  Dataset zip: Sweetpotato_roots.v2i.yolov8 (1).zip
  Extracted dataset: Not specified


In [40]:
# Locate and unzip dataset
if IS_COLAB:
    # Colab mode: search in Google Drive
    zip_path = None
    for root, dirs, files in os.walk(DRIVE_PATH):
        if DATASET_ZIP in files:
            zip_path = os.path.join(root, DATASET_ZIP)
            break
    
    if zip_path is None:
        raise FileNotFoundError(f"Dataset zip file '{DATASET_ZIP}' not found in Google Drive. Please upload it.")
    
    print(f"‚úì Found dataset: {zip_path}")
    
    # Unzip dataset
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(DATASET_DIR)
        print(f"‚úì Dataset extracted to {DATASET_DIR}")
    
    # Find the actual dataset folder (may be nested)
    dataset_folders = [f for f in os.listdir(DATASET_DIR) if os.path.isdir(os.path.join(DATASET_DIR, f))]
    if len(dataset_folders) == 1:
        actual_dataset = os.path.join(DATASET_DIR, dataset_folders[0])
        # Move contents up one level if nested
        for item in os.listdir(actual_dataset):
            shutil.move(os.path.join(actual_dataset, item), os.path.join(DATASET_DIR, item))
        os.rmdir(actual_dataset)
    
    print(f"‚úì Dataset structure prepared")
else:
    # LOCAL MODE: Handle dataset location
    if DATASET_DIR_EXTRACTED and os.path.exists(DATASET_DIR_EXTRACTED):
        # Use already extracted dataset
        print(f"‚úì Using extracted dataset from: {DATASET_DIR_EXTRACTED}")
        DATASET_DIR = DATASET_DIR_EXTRACTED
    elif DATASET_ZIP and os.path.exists(DATASET_ZIP):
        # Unzip from local path
        print(f"‚úì Found dataset zip: {DATASET_ZIP}")
        with zipfile.ZipFile(DATASET_ZIP, 'r') as zip_ref:
            zip_ref.extractall(DATASET_DIR)
        print(f"‚úì Dataset extracted to {DATASET_DIR}")
        
        # Handle nested folders
        dataset_folders = [f for f in os.listdir(DATASET_DIR) if os.path.isdir(os.path.join(DATASET_DIR, f))]
        if len(dataset_folders) == 1:
            actual_dataset = os.path.join(DATASET_DIR, dataset_folders[0])
            for item in os.listdir(actual_dataset):
                shutil.move(os.path.join(actual_dataset, item), os.path.join(DATASET_DIR, item))
            os.rmdir(actual_dataset)
    else:
        # Search in current directory and subdirectories
        print(f"‚ö† Dataset zip not found at specified path. Searching...")
        zip_found = False
        search_paths = ['.', os.path.dirname(os.path.abspath('.'))]
        for search_root in search_paths:
            for root, dirs, files in os.walk(search_root):
                if DATASET_ZIP and DATASET_ZIP in files:
                    zip_path = os.path.join(root, DATASET_ZIP)
                    print(f"‚úì Found dataset: {zip_path}")
                    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                        zip_ref.extractall(DATASET_DIR)
                    zip_found = True
                    break
            if zip_found:
                break
        
        if not zip_found:
            raise FileNotFoundError(
                f"Dataset not found. Please either:\n"
                f"1. Place '{DATASET_ZIP}' in the current directory, or\n"
                f"2. Update DATASET_ZIP path in the configuration cell, or\n"
                f"3. Set DATASET_DIR_EXTRACTED to point to your extracted dataset"
            )
    
    print(f"‚úì Dataset structure prepared")

‚ö† Dataset zip not found at specified path. Searching...


‚úì Found dataset: c:\Users\kensm\Downloads\Sweetpotato_roots.v2i.yolov8 (1).zip
‚úì Dataset structure prepared


In [41]:
# FIX data.yaml - Ensure it has absolute 'path' field for YOLOv8
# This is critical for YOLOv8 to find the dataset images

from pathlib import Path
import yaml

# Get the data.yaml path (should be set in previous cells)
if 'data_yaml_path' not in globals():
    if 'DATA_YAML' in globals():
        data_yaml_path = DATA_YAML
    elif 'DATASET_DIR' in globals():
        data_yaml_path = os.path.join(DATASET_DIR, 'data.yaml')
    else:
        # Use the robust path helper
        data_yaml_path = str(PROJECT_ROOT / "sweetpotato_project" / "dataset" / "data.yaml")

# Ensure absolute path
data_yaml_path = str(Path(data_yaml_path).resolve())
DATASET_DIR = str(Path(data_yaml_path).parent.resolve())

print("="*60)
print("FIXING data.yaml FOR YOLOv8 COMPATIBILITY")
print("="*60)
print(f"Data YAML: {data_yaml_path}")
print(f"Dataset Dir: {DATASET_DIR}")

# Load existing data.yaml or create new one
if Path(data_yaml_path).exists():
    with open(data_yaml_path, 'r') as f:
        data_config = yaml.safe_load(f) or {}
    print("‚úì Loaded existing data.yaml")
else:
    data_config = {}
    print("‚ö† data.yaml not found, will create new one")

# CRITICAL: Set 'path' to absolute path (YOLOv8 requirement)
data_config['path'] = str(Path(DATASET_DIR).resolve())

# Ensure train/val/test paths are set (relative to 'path')
if 'train' not in data_config:
    data_config['train'] = 'train/images'
if 'val' not in data_config:
    data_config['val'] = 'valid/images'
if 'test' not in data_config:
    data_config['test'] = 'test/images'

# Set classes if not set
if 'nc' not in data_config:
    data_config['nc'] = 3
if 'names' not in data_config:
    data_config['names'] = ['Diseased', 'Healthy', 'Non-determined']

# Write the fixed data.yaml
with open(data_yaml_path, 'w') as f:
    yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)

print(f"\n‚úì Fixed data.yaml with absolute path:")
print(f"  path: {data_config['path']}")
print(f"  train: {data_config['train']}")
print(f"  val: {data_config['val']}")
print(f"  test: {data_config['test']}")

# Verify paths exist
train_path = Path(data_config['path']) / data_config['train']
val_path = Path(data_config['path']) / data_config['val']
test_path = Path(data_config['path']) / data_config['test']

print(f"\n‚úì Verifying paths:")
print(f"  Train: {train_path} - {'‚úì EXISTS' if train_path.exists() else '‚úó MISSING'}")
print(f"  Val:   {val_path} - {'‚úì EXISTS' if val_path.exists() else '‚úó MISSING'}")
print(f"  Test:  {test_path} - {'‚úì EXISTS' if test_path.exists() else '‚úó MISSING'}")

# Store the absolute path for use in training
data_yaml_path_absolute = str(Path(data_yaml_path).resolve())
print(f"\n‚úì data_yaml_path (absolute): {data_yaml_path_absolute}")
print("="*60)

FIXING data.yaml FOR YOLOv8 COMPATIBILITY
Data YAML: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\data.yaml
Dataset Dir: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset
‚úì Loaded existing data.yaml

‚úì Fixed data.yaml with absolute path:
  path: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset
  train: ../train/images
  val: ../valid/images
  test: ../test/images

‚úì Verifying paths:
  Train: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\..\train\images - ‚úó MISSING
  Val:   C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\..\valid\images - ‚úó MISSING
  Test:  C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\..\test\images - ‚úó MISSING

‚úì data_yaml_path (absolute): C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\data.yaml


In [42]:
# VERIFY DATA PATHS - Ensure model will use train, val, and test splits
print("="*60)
print("VERIFYING DATA PATHS FOR YOLOv8 TRAINING")
print("="*60)

# Resolve paths from data.yaml (YOLOv8 resolves paths relative to data.yaml location)
data_yaml_dir = os.path.dirname(os.path.abspath(data_yaml_path))

def resolve_path(path_str, base_dir):
    """Resolve relative or absolute path"""
    if os.path.isabs(path_str):
        return path_str
    # Resolve relative to data.yaml directory
    resolved = os.path.normpath(os.path.join(base_dir, path_str))
    return resolved

# Get paths from config
train_path = data_config.get('train', '')
val_path = data_config.get('val', '')
test_path = data_config.get('test', '')

# Resolve to absolute paths
train_abs = resolve_path(train_path, data_yaml_dir)
val_abs = resolve_path(val_path, data_yaml_dir)
test_abs = resolve_path(test_path, data_yaml_dir)

print(f"\nüìÅ Data paths from data.yaml (resolved):")
print(f"  Train: {train_abs}")
print(f"  Val:   {val_abs}")
print(f"  Test:  {test_abs}")

# Verify each path exists and count images
def verify_split(split_name, path):
    """Verify a split path exists and count images"""
    if not os.path.exists(path):
        return False, 0, f"Path does not exist: {path}"
    
    # Count images
    image_files = []
    if os.path.isdir(path):
        # Direct image directory
        image_files = [f for f in os.listdir(path) 
                      if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
    else:
        # Path might point to a file (shouldn't happen)
        return False, 0, f"Path is not a directory: {path}"
    
    return True, len(image_files), None

# Verify all splits
train_exists, train_count, train_error = verify_split('Train', train_abs)
val_exists, val_count, val_error = verify_split('Val', val_abs)
test_exists, test_count, test_error = verify_split('Test', test_abs)

print(f"\n‚úì Verification Results:")
print(f"  Train: {'‚úì' if train_exists else '‚úó'} {train_count} images" + (f" - {train_error}" if train_error else ""))
print(f"  Val:   {'‚úì' if val_exists else '‚úó'} {val_count} images" + (f" - {val_error}" if val_error else ""))
print(f"  Test:  {'‚úì' if test_exists else '‚úó'} {test_count} images" + (f" - {test_error}" if test_error else ""))

# Check if paths need fixing
needs_fix = False
if not train_exists or not val_exists or not test_exists:
    needs_fix = True
    print(f"\n‚ö† WARNING: Some paths are incorrect!")
    print(f"   Fixing data.yaml paths...")
    
    # Fix paths - use absolute paths or correct relative paths
    # YOLOv8 works best with paths relative to data.yaml location
    data_yaml_dir_abs = os.path.abspath(data_yaml_dir)
    
    # Try to find correct paths
    # Standard YOLOv8 structure: dataset/train/images, dataset/valid/images, dataset/test/images
    dataset_root = os.path.dirname(data_yaml_dir_abs) if 'data.yaml' in os.listdir(data_yaml_dir_abs) else data_yaml_dir_abs
    
    # Check for standard structure
    train_standard = os.path.join(dataset_root, 'train', 'images')
    val_standard = os.path.join(dataset_root, 'valid', 'images')
    test_standard = os.path.join(dataset_root, 'test', 'images')
    
    # Use standard paths if they exist, otherwise try current paths
    if os.path.exists(train_standard):
        train_path_fixed = 'train/images'
    elif os.path.exists(os.path.join(dataset_root, 'train')):
        train_path_fixed = 'train/images' if os.path.exists(train_standard) else 'train'
    else:
        train_path_fixed = train_path  # Keep original
    
    if os.path.exists(val_standard):
        val_path_fixed = 'valid/images'
    elif os.path.exists(os.path.join(dataset_root, 'valid')):
        val_path_fixed = 'valid/images' if os.path.exists(val_standard) else 'valid'
    else:
        val_path_fixed = val_path  # Keep original
    
    if os.path.exists(test_standard):
        test_path_fixed = 'test/images'
    elif os.path.exists(os.path.join(dataset_root, 'test')):
        test_path_fixed = 'test/images' if os.path.exists(test_standard) else 'test'
    else:
        test_path_fixed = test_path  # Keep original
    
    # Update data.yaml
    data_config['train'] = train_path_fixed
    data_config['val'] = val_path_fixed
    data_config['test'] = test_path_fixed
    
    # Also add 'path' field if missing (YOLOv8 uses this as base path)
    if 'path' not in data_config:
        data_config['path'] = dataset_root
    
    # Write updated config
    with open(data_yaml_path, 'w') as f:
        yaml.dump(data_config, f, default_flow_style=False, sort_keys=False)
    
    print(f"   ‚úì Updated data.yaml with corrected paths")
    print(f"   New paths:")
    print(f"     Train: {data_config['train']}")
    print(f"     Val:   {data_config['val']}")
    print(f"     Test:  {data_config['test']}")
    
    # Reload to verify
    with open(data_yaml_path, 'r') as f:
        data_config = yaml.safe_load(f)
    
    train_abs = resolve_path(data_config['train'], data_yaml_dir)
    val_abs = resolve_path(data_config['val'], data_yaml_dir)
    test_abs = resolve_path(data_config['test'], data_yaml_dir)
    
    train_exists, train_count, _ = verify_split('Train', train_abs)
    val_exists, val_count, _ = verify_split('Val', val_abs)
    test_exists, test_count, _ = verify_split('Test', test_abs)

# Final summary
print(f"\n" + "="*60)
print("FINAL VERIFICATION - Model will use:")
print("="*60)
if train_exists and train_count > 0:
    print(f"  ‚úì TRAIN: {train_count} images - {train_abs}")
else:
    print(f"  ‚úó TRAIN: NOT FOUND or EMPTY")
    
if val_exists and val_count > 0:
    print(f"  ‚úì VAL:   {val_count} images - {val_abs}")
else:
    print(f"  ‚úó VAL:   NOT FOUND or EMPTY")
    
if test_exists and test_count > 0:
    print(f"  ‚úì TEST:  {test_count} images - {test_abs}")
else:
    print(f"  ‚úó TEST:  NOT FOUND or EMPTY")

if train_exists and val_exists and test_exists and train_count > 0 and val_count > 0 and test_count > 0:
    print(f"\n‚úÖ SUCCESS: All three splits are configured correctly!")
    print(f"   The model will use:")
    print(f"   - {train_count} training images")
    print(f"   - {val_count} validation images") 
    print(f"   - {test_count} test images")
    print(f"\n   You can proceed with training!")
else:
    print(f"\n‚ö† WARNING: Some splits are missing or empty!")
    print(f"   Please check your dataset structure.")
print("="*60)

VERIFYING DATA PATHS FOR YOLOv8 TRAINING

üìÅ Data paths from data.yaml (resolved):
  Train: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\train\images
  Val:   C:\Users\kensm\farm-photo-outliner\sweetpotato_project\valid\images
  Test:  C:\Users\kensm\farm-photo-outliner\sweetpotato_project\test\images

‚úì Verification Results:
  Train: ‚úó 0 images - Path does not exist: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\train\images
  Val:   ‚úó 0 images - Path does not exist: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\valid\images
  Test:  ‚úó 0 images - Path does not exist: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\test\images

   Fixing data.yaml paths...
   ‚úì Updated data.yaml with corrected paths
   New paths:
     Train: ../train/images
     Val:   ../valid/images
     Test:  ../test/images

FINAL VERIFICATION - Model will use:
  ‚úó TRAIN: NOT FOUND or EMPTY
  ‚úó VAL:   NOT FOUND or EMPTY
  ‚úó TEST:  NOT FOUND or EMPTY

   Please ch

In [43]:
# Verify dataset structure
required_folders = ['train', 'valid', 'test']
missing_folders = []

for folder in required_folders:
    folder_path = os.path.join(DATASET_DIR, folder)
    if not os.path.exists(folder_path):
        missing_folders.append(folder)
    else:
        print(f"‚úì Found {folder}/ folder")

if missing_folders:
    raise FileNotFoundError(f"Missing required folders: {missing_folders}")

# Check for data.yaml
data_yaml_path = os.path.join(DATASET_DIR, 'data.yaml')
if not os.path.exists(data_yaml_path):
    # Search in subdirectories
    for root, dirs, files in os.walk(DATASET_DIR):
        if 'data.yaml' in files:
            data_yaml_path = os.path.join(root, 'data.yaml')
            break
    if not os.path.exists(data_yaml_path):
        raise FileNotFoundError("data.yaml not found. Creating default...")
        # Create default data.yaml
        default_yaml = {
            'path': DATASET_DIR,
            'train': 'train/images',
            'val': 'valid/images',
            'test': 'test/images',
            'nc': 2,
            'names': ['sweetpotato_root', 'background']
        }
        with open(data_yaml_path, 'w') as f:
            yaml.dump(default_yaml, f)

print(f"‚úì Found data.yaml: {data_yaml_path}")

# Load and display data.yaml
with open(data_yaml_path, 'r') as f:
    data_config = yaml.safe_load(f)

print("\nDataset Configuration:")
print(f"  Classes: {data_config.get('nc', 'N/A')}")
print(f"  Class names: {data_config.get('names', 'N/A')}")
print(f"  Train: {data_config.get('train', 'N/A')}")
print(f"  Val: {data_config.get('val', 'N/A')}")
print(f"  Test: {data_config.get('test', 'N/A')}")

‚úì Found train/ folder
‚úì Found valid/ folder
‚úì Found test/ folder
‚úì Found data.yaml: C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\data.yaml

Dataset Configuration:
  Classes: 3
  Class names: ['Diseased', 'Healthy', 'Non-determined']
  Train: ../train/images
  Val: ../valid/images
  Test: ../test/images


In [44]:
# Validate image-annotation pairs
def validate_dataset(split='train'):
    split_path = os.path.join(DATASET_DIR, split)
    images_dir = os.path.join(split_path, 'images')
    labels_dir = os.path.join(split_path, 'labels')
    
    if not os.path.exists(images_dir):
        images_dir = split_path
        labels_dir = split_path
    
    if not os.path.exists(labels_dir):
        print(f"‚ö† Warning: {split}/labels not found, assuming labels are in {split}/")
        labels_dir = split_path
    
    image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    label_files = [f for f in os.listdir(labels_dir) if f.endswith('.txt')]
    
    missing_labels = []
    missing_images = []
    
    for img_file in image_files:
        label_file = os.path.splitext(img_file)[0] + '.txt'
        label_path = os.path.join(labels_dir, label_file)
        if not os.path.exists(label_path):
            missing_labels.append(img_file)
    
    for label_file in label_files:
        img_file = os.path.splitext(label_file)[0] + '.jpg'
        if not img_file in [f.lower() for f in image_files]:
            # Try other extensions
            found = False
            for ext in ['.png', '.jpeg', '.JPG', '.PNG']:
                if os.path.splitext(label_file)[0] + ext in image_files:
                    found = True
                    break
            if not found:
                missing_images.append(label_file)
    
    return len(image_files), len(label_files), missing_labels, missing_images

# Validate all splits
print("Validating dataset splits...")
for split in ['train', 'valid', 'test']:
    img_count, label_count, missing_lbls, missing_imgs = validate_dataset(split)
    print(f"\n{split.upper()}:")
    print(f"  Images: {img_count}")
    print(f"  Labels: {label_count}")
    if missing_lbls:
        print(f"  ‚ö† Missing labels for {len(missing_lbls)} images")
    if missing_imgs:
        print(f"  ‚ö† Missing images for {len(missing_imgs)} labels")
    if not missing_lbls and not missing_imgs:
        print(f"  ‚úì All pairs validated")

print("\n‚úì Dataset validation complete")

Validating dataset splits...

TRAIN:
  Images: 57
  Labels: 57
  ‚ö† Missing images for 57 labels

VALID:
  Images: 5
  Labels: 5
  ‚ö† Missing images for 5 labels

TEST:
  Images: 3
  Labels: 3
  ‚ö† Missing images for 3 labels

‚úì Dataset validation complete


In [45]:
# GPU DIAGNOSTICS - Enhanced for RTX 4070
# Bulletproof GPU detection and verification

import torch
import subprocess
import sys

print("="*70)
print("GPU DIAGNOSTICS")
print("="*70)

# Check PyTorch version
print(f"\nPyTorch: {torch.__version__}")
if '+cpu' in torch.__version__:
    print("üö® PROBLEM: CPU-only PyTorch installed!")
    print("   This is why GPU is not being used.")
else:
    print("‚úì PyTorch has CUDA support")

# Check CUDA availability
cuda_available = torch.cuda.is_available()
print(f"\nCUDA available: {cuda_available}")

if cuda_available:
    print(f"‚úì CUDA version: {torch.version.cuda}")
    print(f"‚úì GPU count: {torch.cuda.device_count()}")
    
    for i in range(torch.cuda.device_count()):
        gpu_name = torch.cuda.get_device_name(i)
        props = torch.cuda.get_device_properties(i)
        gpu_memory_gb = props.total_memory / 1024**3
        
        print(f"\n‚úì GPU {i}: {gpu_name}")
        print(f"  Memory: {gpu_memory_gb:.1f} GB")
        print(f"  Compute Capability: {props.major}.{props.minor}")
        
        # For RTX 4070, verify it's detected
        if '4070' in gpu_name or 'RTX' in gpu_name:
            print(f"  ‚úÖ RTX 4070 detected!")
    
    # Clear cache
    torch.cuda.empty_cache()
    
    # Set device variables
    device = torch.device("cuda")
    TRAINING_DEVICE = '0'  # Use first GPU (RTX 4070)
    
    print(f"\n‚úì Device Selected: {device}")
    print(f"‚úì Training will use: GPU 0 ({torch.cuda.get_device_name(0)})")
else:
    print("\nüö® GPU NOT AVAILABLE - check drivers/CUDA toolkit")
    
    # Check NVIDIA drivers
    print("\nChecking NVIDIA drivers...")
    try:
        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=5)
        if result.returncode == 0:
            print("‚úì NVIDIA drivers installed")
            print("‚ö† But PyTorch can't see GPU - need CUDA-enabled PyTorch")
            print("\nSOLUTION: Install CUDA-enabled PyTorch")
            print("Run these commands in terminal:")
            print("="*60)
            print("pip uninstall -y torch torchvision torchaudio ultralytics")
            print("pip install --index-url https://download.pytorch.org/whl/cu124 torch torchvision torchaudio --upgrade")
            print("pip install ultralytics --upgrade")
            print("="*60)
        else:
            print("‚úó nvidia-smi not found")
    except:
        print("‚ö† Could not check NVIDIA drivers")
    
    # Fallback to CPU
    device = torch.device("cpu")
    TRAINING_DEVICE = 'cpu'

print("="*70)
print(f"Device: {device}")
print(f"TRAINING_DEVICE: {TRAINING_DEVICE}")
print("="*70)

GPU DIAGNOSTICS

PyTorch: 2.6.0+cu124
‚úì PyTorch has CUDA support

CUDA available: True
‚úì CUDA version: 12.4
‚úì GPU count: 1

‚úì GPU 0: NVIDIA GeForce RTX 4070 Laptop GPU
  Memory: 8.0 GB
  Compute Capability: 8.9
  ‚úÖ RTX 4070 detected!

‚úì Device Selected: cuda
‚úì Training will use: GPU 0 (NVIDIA GeForce RTX 4070 Laptop GPU)
Device: cuda
TRAINING_DEVICE: 0


In [46]:
# AUTO-FIX: Install GPU PyTorch if CUDA not available
# This cell will automatically install CUDA-enabled PyTorch if GPU is not detected

import torch
import subprocess
import sys
import importlib

print("="*70)
print("AUTO-FIX: GPU PyTorch Installation")
print("="*70)

if not torch.cuda.is_available():
    print("\nüö® CUDA not available - installing GPU PyTorch...")
    print("This will uninstall CPU PyTorch and install CUDA 12.4 version")
    print("(Compatible with RTX 4070 and RTX 40-series GPUs)\n")
    
    try:
        # Step 1: Uninstall CPU PyTorch
        print("Step 1: Uninstalling CPU PyTorch...")
        subprocess.check_call([
            sys.executable, "-m", "pip", "uninstall", "-y", 
            "torch", "torchvision", "torchaudio", "ultralytics"
        ], timeout=120)
        print("‚úì Uninstalled CPU PyTorch")
        
        # Step 2: Install CUDA 12.4 PyTorch (for RTX 4070)
        print("\nStep 2: Installing CUDA 12.4 PyTorch...")
        subprocess.check_call([
            sys.executable, "-m", "pip", "install", 
            "--index-url", "https://download.pytorch.org/whl/cu124",
            "torch", "torchvision", "torchaudio", "--upgrade"
        ], timeout=300)
        print("‚úì Installed CUDA PyTorch")
        
        # Step 3: Reinstall ultralytics
        print("\nStep 3: Reinstalling ultralytics...")
        subprocess.check_call([
            sys.executable, "-m", "pip", "install", "ultralytics", "--upgrade"
        ], timeout=120)
        print("‚úì Reinstalled ultralytics")
        
        print("\n" + "="*70)
        print("‚úÖ INSTALLATION COMPLETE!")
        print("="*70)
        print("\n‚ö† IMPORTANT: Restart kernel now!")
        print("   Kernel ‚Üí Restart (or Ctrl+Shift+P ‚Üí 'Restart')")
        print("\n   Then re-run Cell 16 (GPU Diagnostics) to verify GPU is detected")
        print("="*70)
        
    except subprocess.TimeoutExpired:
        print("‚ö† Installation timed out - try running commands manually in terminal")
    except Exception as e:
        print(f"‚ö† Installation error: {e}")
        print("\nPlease run these commands manually in terminal:")
        print("="*60)
        print("pip uninstall -y torch torchvision torchaudio ultralytics")
        print("pip install --index-url https://download.pytorch.org/whl/cu124 torch torchvision torchaudio --upgrade")
        print("pip install ultralytics --upgrade")
        print("="*60)
else:
    print("‚úì CUDA is already available!")
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
    print("‚úì No installation needed")

AUTO-FIX: GPU PyTorch Installation
‚úì CUDA is already available!
‚úì GPU: NVIDIA GeForce RTX 4070 Laptop GPU
‚úì No installation needed


## 3. Model Training

In [47]:
# Load training configuration
if IS_COLAB:
    config_path = '/content/config.yaml'  # Will be uploaded or created
else:
    config_path = './config.yaml'  # Local config file

# Default configuration (yolov8m-seg = medium; more capacity than n for finer masks)
default_config = {
    'model': 'yolov8m-seg.pt',  # Options: yolov8n-seg.pt, yolov8s-seg.pt, yolov8m-seg.pt, yolov8l-seg.pt, yolov8x-seg.pt
    'epochs': 100,
    'imgsz': 640,
    'batch': 16,
    'optimizer': 'AdamW',
    'lr0': 0.01,
    'lrf': 0.01,
    'momentum': 0.937,
    'weight_decay': 0.0005,
    'warmup_epochs': 3.0,
    'warmup_momentum': 0.8,
    'warmup_bias_lr': 0.1,
    'box': 7.5,
    'cls': 0.5,
    'dfl': 1.5,
    'pose': 12.0,
    'kobj': 2.0,
    'label_smoothing': 0.0,
    'nbs': 64,
    'hsv_h': 0.015,
    'hsv_s': 0.7,
    'hsv_v': 0.4,
    'degrees': 0.0,
    'translate': 0.1,
    'scale': 0.5,
    'shear': 0.0,
    'perspective': 0.0,
    'flipud': 0.0,
    'fliplr': 0.5,
    'mosaic': 1.0,
    'mixup': 0.15,
    'copy_paste': 0.0,
    'auto_augment': 'randaugment',
    'erasing': 0.4,
    'crop_fraction': 1.0
}

if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        user_config = yaml.safe_load(f)
    default_config.update(user_config)
    print(f"‚úì Loaded config from {config_path}")
else:
    print(f"‚ö† Config file not found, using defaults")

print("\nTraining Configuration:")
for key, value in default_config.items():
    print(f"  {key}: {value}")

‚úì Loaded config from ./config.yaml

Training Configuration:
  model: yolov8m-seg.pt
  epochs: 100
  imgsz: 640
  batch: 16
  optimizer: AdamW
  lr0: 0.01
  lrf: 0.01
  momentum: 0.937
  weight_decay: 0.0005
  warmup_epochs: 3.0
  warmup_momentum: 0.8
  warmup_bias_lr: 0.1
  box: 7.5
  cls: 0.5
  dfl: 1.5
  pose: 12.0
  kobj: 2.0
  label_smoothing: 0.0
  nbs: 64
  hsv_h: 0.015
  hsv_s: 0.7
  hsv_v: 0.4
  degrees: 0.0
  translate: 0.1
  scale: 0.5
  shear: 0.0
  perspective: 0.0
  flipud: 0.0
  fliplr: 0.5
  mosaic: 1.0
  mixup: 0.15
  copy_paste: 0.0
  auto_augment: randaugment
  erasing: 0.4
  crop_fraction: 1.0
  save_period: 10
  patience: 50
  seed: 42
  deterministic: True
  amp: True
  overlap_mask: True
  mask_ratio: 4
  pretrained: True
  freeze: None
  val: True
  plots: True


In [48]:
# Fix for PyTorch 2.6+ weights_only issue (recursion-safe)
# Use torch.serialization.load so we never re-enter torch.load and avoid RecursionError
# when Ultralytics also patches torch.load.
import torch
from pathlib import Path
from ultralytics import YOLO

_real_torch_load = getattr(torch.serialization, "load", torch.load)

def _patched_torch_load(*args, **kwargs):
    if "weights_only" not in kwargs:
        kwargs["weights_only"] = False
    return _real_torch_load(*args, **kwargs)

torch.load = _patched_torch_load
print("‚úì Applied PyTorch 2.6+ compatibility fix for YOLOv8 model loading")

# Initialize YOLOv8 segmentation model exactly once (m-seg = medium; more capacity for finer masks)
MODEL_NAME = default_config.get("model", "yolov8m-seg.pt")
if hasattr(MODEL_NAME, "strip"):
    MODEL_NAME = str(MODEL_NAME).strip()
else:
    MODEL_NAME = str(MODEL_NAME)

suffix = Path(MODEL_NAME).suffix.lower()
if suffix in (".pt", ".onnx", ".engine", ".yaml") or (suffix == "" and MODEL_NAME.strip()):
    model = YOLO(MODEL_NAME)
else:
    raise ValueError(
        f"Invalid model spec: {MODEL_NAME!r} ‚Äì expected a weights/config path or YOLO name (e.g. yolov8m-seg.pt)."
    )

print(f"‚úì Loaded model: {MODEL_NAME}")
print(f"‚úì Model parameters: {sum(p.numel() for p in model.model.parameters()):,}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úì Using device: {device}")
if device == "cpu":
    print("‚ö† Warning: Training on CPU will be very slow. Consider enabling GPU in Colab.")

‚úì Applied PyTorch 2.6+ compatibility fix for YOLOv8 model loading


RecursionError: maximum recursion depth exceeded

In [None]:
# Train the model with error handling
# Ensure GPU is used if available

# Get device from GPU diagnostics cell or detect automatically
if 'TRAINING_DEVICE' in globals():
    train_device = TRAINING_DEVICE
    # Convert 'cuda' to '0' for YOLO
    if train_device == 'cuda' and torch.cuda.is_available():
        train_device = '0'
elif 'device' in globals() and hasattr(device, 'type') and device.type == 'cuda':
    train_device = '0'  # Use first GPU
elif torch.cuda.is_available():
    train_device = '0'  # Use first GPU
else:
    train_device = 'cpu'  # Fallback to CPU

# Display device info
if train_device == '0' and torch.cuda.is_available():
    print(f"‚úì Training on GPU {train_device}: {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
elif train_device == 'cpu':
    print(f"‚ö† Training on CPU (GPU not available)")
    print("  Install CUDA-enabled PyTorch to use GPU")
    print("  Run: pip uninstall -y torch torchvision torchaudio")
    print("       pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio")

try:
    results = model.train(
        data=data_yaml_path,
        device=train_device,  # Explicitly set device: '0' for GPU, 'cpu' for CPU
        epochs=default_config['epochs'],
        imgsz=default_config['imgsz'],
        batch=default_config['batch'],
        optimizer=default_config['optimizer'],
        lr0=default_config['lr0'],
        lrf=default_config['lrf'],
        momentum=default_config['momentum'],
        weight_decay=default_config['weight_decay'],
        warmup_epochs=default_config['warmup_epochs'],
        warmup_momentum=default_config['warmup_momentum'],
        warmup_bias_lr=default_config['warmup_bias_lr'],
        box=default_config['box'],
        cls=default_config['cls'],
        dfl=default_config['dfl'],
        label_smoothing=default_config['label_smoothing'],
        nbs=default_config['nbs'],
        hsv_h=default_config['hsv_h'],
        hsv_s=default_config['hsv_s'],
        hsv_v=default_config['hsv_v'],
        degrees=default_config['degrees'],
        translate=default_config['translate'],
        scale=default_config['scale'],
        shear=default_config['shear'],
        perspective=default_config['perspective'],
        flipud=default_config['flipud'],
        fliplr=default_config['fliplr'],
        mosaic=default_config['mosaic'],
        mixup=default_config['mixup'],
        copy_paste=default_config['copy_paste'],
        auto_augment=default_config['auto_augment'],
        erasing=default_config['erasing'],
        crop_fraction=default_config['crop_fraction'],
        save=True,
        save_period=10,  # Save checkpoint every 10 epochs
        project=f'{WORK_DIR}/runs/segment',
        name='sweetpotato_exp',
        exist_ok=True,
        pretrained=True,
        verbose=True,
        seed=42,
        deterministic=True,
        single_cls=False,
        rect=False,
        cos_lr=False,
        close_mosaic=10,
        resume=False,
        amp=True,  # Automatic Mixed Precision for faster training
        fraction=1.0,
        profile=False,
        freeze=None,
        # Multi-scale training
        multi_scale=False,
        overlap_mask=True,
        mask_ratio=4,
        dropout=0.0
    )
    
    print("\n‚úì Training completed successfully!")
    print(f"‚úì Best model saved to: {results.save_dir}/weights/best.pt")
    
except RuntimeError as e:
    if "out of memory" in str(e).lower() or "oom" in str(e).lower():
        print("\n‚ö† GPU Out of Memory! Trying with smaller batch size...")
        # Retry with smaller batch size
        torch.cuda.empty_cache()
        default_config['batch'] = max(4, default_config['batch'] // 2)
        print(f"Retrying with batch size: {default_config['batch']}")
        # Re-run training with smaller batch
        # (User should re-run the cell)
    else:
        raise e
except Exception as e:
    print(f"\n‚úó Training failed with error: {e}")
    raise e

‚úì Training on GPU 0: NVIDIA GeForce RTX 4070 Laptop GPU
  GPU Memory: 8.59 GB
Ultralytics 8.4.12  Python-3.13.5 torch-2.6.0+cu124 CUDA:0 (NVIDIA GeForce RTX 4070 Laptop GPU, 8188MiB)
[34m[1mengine\trainer: [0magnostic_nms=False, amp=True, angle=1.0, augment=False, auto_augment=randaugment, batch=16, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, compile=False, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\data.yaml, degrees=0.0, deterministic=True, device=0, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, end2end=None, epochs=100, erasing=0.4, exist_ok=True, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=640, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.15, mode=train, model=yolov8m-seg.pt, mome

## 4. Evaluation & Metrics

In [None]:
# Load the best model from training
best_model_path = f'{WORK_DIR}/runs/segment/sweetpotato_exp/weights/best.pt'
if not os.path.exists(best_model_path):
    # Try to find the latest run
    runs_dir = f'{WORK_DIR}/runs/segment'
    if os.path.exists(runs_dir):
        runs = sorted([d for d in os.listdir(runs_dir) if os.path.isdir(os.path.join(runs_dir, d))])
        if runs:
            latest_run = runs[-1]
            best_model_path = f'{runs_dir}/{latest_run}/weights/best.pt'

if os.path.exists(best_model_path):
    model = YOLO(best_model_path)
    print(f"‚úì Loaded best model: {best_model_path}")
else:
    print("‚ö† Best model not found, using last checkpoint")
    model = YOLO(model_name)  # Fallback to pretrained

‚úì Loaded best model: ./sweetpotato_project/runs/segment/sweetpotato_exp/weights/best.pt


In [None]:
# Evaluate on validation and test sets
print("Evaluating on validation set...")
val_metrics = model.val(data=data_yaml_path, split='val')

print("\nEvaluating on test set...")
test_metrics = model.val(data=data_yaml_path, split='test')

print("\n" + "="*50)
print("VALIDATION METRICS")
print("="*50)
print(f"mAP50 (bbox): {val_metrics.box.map50:.4f}")
print(f"mAP50-95 (bbox): {val_metrics.box.map:.4f}")
print(f"mAP50 (mask): {val_metrics.seg.map50:.4f}")
print(f"mAP50-95 (mask): {val_metrics.seg.map:.4f}")
print(f"Precision: {val_metrics.box.mp:.4f}")
print(f"Recall: {val_metrics.box.mr:.4f}")

print("\n" + "="*50)
print("TEST METRICS")
print("="*50)
print(f"mAP50 (bbox): {test_metrics.box.map50:.4f}")
print(f"mAP50-95 (bbox): {test_metrics.box.map:.4f}")
print(f"mAP50 (mask): {test_metrics.seg.map50:.4f}")
print(f"mAP50-95 (mask): {test_metrics.seg.map:.4f}")
print(f"Precision: {test_metrics.box.mp:.4f}")
print(f"Recall: {test_metrics.box.mr:.4f}")

Evaluating on validation set...
Ultralytics 8.4.12  Python-3.13.5 torch-2.6.0+cu124 CUDA:0 (NVIDIA GeForce RTX 4070 Laptop GPU, 8188MiB)
YOLOv8m-seg summary (fused): 105 layers, 27,224,121 parameters, 0 gradients, 104.3 GFLOPs
[34m[1mval: [0mFast image access  (ping: 0.10.0 ms, read: 174.374.4 MB/s, size: 44.3 KB)
[K[34m[1mval: [0mScanning C:\Users\kensm\farm-photo-outliner\sweetpotato_project\dataset\valid\labels.cache... 5 images, 0 backgrounds, 0 corrupt: 100% ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 5/5 1.2Mit/s 0.0s
[K                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95)     Mask(P          R      mAP50  mAP50-95): 100% ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 1/1 4.2s/it 4.2s
                   all          5          7     0.0342     0.0833     0.0356     0.0108   0.000682     0.0833   0.000918   0.000375
              Diseased          1          1          0          0     0.0711     0.0182          0          0          0          0
        

In [None]:
# Generate confusion matrix
from ultralytics.utils.plotting import Annotator

confusion_matrix_path = f'{WORK_DIR}/runs/segment/sweetpotato_exp/confusion_matrix.png'
if os.path.exists(confusion_matrix_path):
    img = cv2.imread(confusion_matrix_path)
    plt.figure(figsize=(10, 8))
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.title('Confusion Matrix', fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    print("‚úì Confusion matrix displayed")

In [None]:
# Custom metrics: Root count, area coverage, size distribution
def calculate_custom_metrics(results_list, class_names):
    """Calculate root-specific metrics"""
    metrics = {
        'root_counts': [],
        'total_areas': [],
        'avg_areas': [],
        'size_distribution': {'small': 0, 'medium': 0, 'large': 0}
    }
    
    for result in results_list:
        if result.masks is not None:
            root_count = len(result.boxes)
            total_area = 0
            areas = []
            
            for mask in result.masks.data:
                area = mask.sum().item()
                total_area += area
                areas.append(area)
            
            metrics['root_counts'].append(root_count)
            metrics['total_areas'].append(total_area)
            if areas:
                metrics['avg_areas'].append(np.mean(areas))
                
                # Size distribution (based on area percentiles)
                for area in areas:
                    if area < np.percentile(areas, 33):
                        metrics['size_distribution']['small'] += 1
                    elif area < np.percentile(areas, 67):
                        metrics['size_distribution']['medium'] += 1
                    else:
                        metrics['size_distribution']['large'] += 1
    
    return metrics

# Run inference on test set for custom metrics
test_images_dir = os.path.join(DATASET_DIR, 'test', 'images')
if not os.path.exists(test_images_dir):
    test_images_dir = os.path.join(DATASET_DIR, 'test')

test_images = [f for f in os.listdir(test_images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
test_results = []

print(f"Running inference on {len(test_images)} test images...")
for img_file in tqdm(test_images[:20]):  # Limit to 20 for demo
    img_path = os.path.join(test_images_dir, img_file)
    results = model.predict(img_path, conf=0.25, iou=0.45, verbose=False)
    test_results.extend(results)

# Calculate custom metrics
custom_metrics = calculate_custom_metrics(test_results, data_config.get('names', []))

print("\n" + "="*50)
print("CUSTOM ROOT METRICS")
print("="*50)
print(f"Average roots per image: {np.mean(custom_metrics['root_counts']):.2f}")
print(f"Total area coverage (avg): {np.mean(custom_metrics['total_areas']):.2f} pixels")
print(f"Average root area: {np.mean(custom_metrics['avg_areas']):.2f} pixels")
print(f"\nSize Distribution:")
print(f"  Small: {custom_metrics['size_distribution']['small']}")
print(f"  Medium: {custom_metrics['size_distribution']['medium']}")
print(f"  Large: {custom_metrics['size_distribution']['large']}")

Running inference on 3 test images...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:00<00:00,  8.86it/s]


CUSTOM ROOT METRICS
Average roots per image: 1.50
Total area coverage (avg): 164341.00 pixels
Average root area: 153383.25 pixels

Size Distribution:
  Small: 1
  Medium: 0
  Large: 2





In [None]:
# Active Learning: Flag low-confidence predictions
def flag_low_confidence(results_list, confidence_threshold=0.5, output_dir=None):
    """Flag images with low-confidence predictions for re-annotation"""
    flagged_images = []
    
    for result in results_list:
        if result.boxes is not None and len(result.boxes) > 0:
            confidences = result.boxes.conf.cpu().numpy()
            avg_confidence = np.mean(confidences)
            min_confidence = np.min(confidences)
            
            if avg_confidence < confidence_threshold or min_confidence < confidence_threshold * 0.7:
                flagged_images.append({
                    'image': result.path,
                    'avg_confidence': avg_confidence,
                    'min_confidence': min_confidence,
                    'num_detections': len(result.boxes)
                })
    
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        df = pd.DataFrame(flagged_images)
        csv_path = os.path.join(output_dir, 'low_confidence_predictions.csv')
        df.to_csv(csv_path, index=False)
        print(f"‚úì Saved flagged images to {csv_path}")
    
    return flagged_images

# Flag low-confidence predictions
flagged = flag_low_confidence(test_results, confidence_threshold=0.5, 
                              output_dir=f'{WORK_DIR}/runs/segment/sweetpotato_exp')

print(f"\nFlagged {len(flagged)} images for re-annotation")
if flagged:
    print("\nTop 5 lowest confidence predictions:")
    flagged_sorted = sorted(flagged, key=lambda x: x['avg_confidence'])
    for i, item in enumerate(flagged_sorted[:5], 1):
        print(f"{i}. {os.path.basename(item['image'])}: avg_conf={item['avg_confidence']:.3f}")

‚úì Saved flagged images to ./sweetpotato_project/runs/segment/sweetpotato_exp\low_confidence_predictions.csv

Flagged 0 images for re-annotation


## 5. Inference & Visualization

In [None]:
# Run inference on test images
output_dir = f'{WORK_DIR}/outputs/predictions'
os.makedirs(output_dir, exist_ok=True)

test_images = [f for f in os.listdir(test_images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
test_images = test_images[:20]  # Process 20 images

print(f"Running inference on {len(test_images)} test images...")

all_predictions = []

for img_file in tqdm(test_images):
    img_path = os.path.join(test_images_dir, img_file)
    
    # Run prediction
    results = model.predict(
        img_path,
        conf=0.25,
        iou=0.45,
        save=True,
        save_txt=True,
        save_conf=True,
        project=output_dir,
        name='predictions',
        exist_ok=True
    )
    
    # Extract prediction data
    result = results[0]
    if result.boxes is not None:
        for i, box in enumerate(result.boxes):
            bbox = box.xyxy[0].cpu().numpy()  # [x1, y1, x2, y2]
            conf = box.conf[0].cpu().item()
            cls = int(box.cls[0].cpu().item())
            
            # Get mask polygon if available
            mask_polygon = None
            if result.masks is not None and i < len(result.masks.data):
                mask = result.masks.data[i].cpu().numpy()
                # Convert mask to polygon
                contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                if contours:
                    # Get largest contour
                    largest_contour = max(contours, key=cv2.contourArea)
                    mask_polygon = largest_contour.reshape(-1, 2).tolist()
            
            all_predictions.append({
                'image_name': img_file,
                'bbox_x1': bbox[0],
                'bbox_y1': bbox[1],
                'bbox_x2': bbox[2],
                'bbox_y2': bbox[3],
                'confidence': conf,
                'class': cls,
                'class_name': data_config.get('names', ['unknown'])[cls] if cls < len(data_config.get('names', [])) else 'unknown',
                'mask_polygon': str(mask_polygon) if mask_polygon else None
            })

print(f"\n‚úì Inference complete. Results saved to {output_dir}")
print(f"‚úì Total predictions: {len(all_predictions)}")

Running inference on 3 test images...


  0%|          | 0/3 [00:00<?, ?it/s]

Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
1 label saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels
Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
1 label saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels


 67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 2/3 [00:00<00:00, 16.60it/s]

Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:00<00:00, 17.54it/s]


‚úì Inference complete. Results saved to ./sweetpotato_project/outputs/predictions
‚úì Total predictions: 3





In [None]:
# Save predictions to CSV
df_predictions = pd.DataFrame(all_predictions)
csv_path = f'{WORK_DIR}/runs/segment/sweetpotato_exp/predictions.csv'
df_predictions.to_csv(csv_path, index=False)
print(f"‚úì Predictions saved to: {csv_path}")
print(f"\nFirst few predictions:")
print(df_predictions.head())

‚úì Predictions saved to: ./sweetpotato_project/runs/segment/sweetpotato_exp/predictions.csv

First few predictions:
                                          image_name     bbox_x1   bbox_y1  \
0  Hr-23_jpg.rf.b6b8a13ec4bf745687738e0663c99023.jpg  243.222412  9.757080   
1  Hr-23_jpg.rf.b6b8a13ec4bf745687738e0663c99023.jpg  272.502380  8.479111   
2  Inf-44_jpg.rf.460ec20282d0e0c9daece65791d1e0ba...    0.000000  2.480560   

      bbox_x2     bbox_y2  confidence  class      class_name  \
0  516.458374  327.337585    0.624851      2  Non-determined   
1  633.675110  515.169922    0.509105      2  Non-determined   
2  638.666870  491.609955    0.578858      0        Diseased   

                                        mask_polygon  
0  [[414, 106], [414, 115], [416, 117], [416, 118...  
1  [[507, 263], [506, 264], [502, 264], [501, 265...  
2  [[0, 2], [0, 133], [1, 133], [2, 134], [2, 136...  


In [None]:
# Visualize predictions on sample images
def visualize_predictions(image_path, results, save_path=None):
    """Visualize predictions with masks and bounding boxes"""
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    result = results[0]
    
    # Draw masks
    if result.masks is not None:
        for i, mask in enumerate(result.masks.data):
            mask_np = mask.cpu().numpy().astype(np.uint8)
            # Create colored overlay
            color = np.random.randint(0, 255, 3).tolist()
            colored_mask = np.zeros_like(img_rgb)
            colored_mask[mask_np > 0] = color
            img_rgb = cv2.addWeighted(img_rgb, 0.7, colored_mask, 0.3, 0)
    
    # Draw bounding boxes
    if result.boxes is not None:
        for box in result.boxes:
            bbox = box.xyxy[0].cpu().numpy().astype(int)
            conf = box.conf[0].cpu().item()
            cls = int(box.cls[0].cpu().item())
            
            cv2.rectangle(img_rgb, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
            label = f"{data_config.get('names', ['class'])[cls]}: {conf:.2f}"
            cv2.putText(img_rgb, label, (bbox[0], bbox[1]-10), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    plt.figure(figsize=(12, 8))
    plt.imshow(img_rgb)
    plt.title(f"Predictions: {os.path.basename(image_path)}", fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()

# Visualize a few sample predictions
sample_images = test_images[:5]
for img_file in sample_images:
    img_path = os.path.join(test_images_dir, img_file)
    results = model.predict(img_path, conf=0.25, iou=0.45, verbose=False)
    save_path = f'{output_dir}/visualizations/{os.path.splitext(img_file)[0]}_pred.png'
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    visualize_predictions(img_path, results, save_path)

Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels


<Figure size 1200x800 with 1 Axes>

Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels


<Figure size 1200x800 with 1 Axes>

Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels


<Figure size 1200x800 with 1 Axes>

## 6. Model Export

In [None]:
# Export to ONNX for edge deployment
export_dir = f'{WORK_DIR}/exports'
os.makedirs(export_dir, exist_ok=True)

print("Exporting to ONNX format...")
onnx_path = model.export(format='onnx', imgsz=640, simplify=True, opset=12)
print(f"‚úì ONNX model exported to: {onnx_path}")

# Also export to TorchScript
print("\nExporting to TorchScript format...")
torchscript_path = model.export(format='torchscript', imgsz=640)
print(f"‚úì TorchScript model exported to: {torchscript_path}")

# Export with FP16 quantization for faster inference
print("\nExporting ONNX with FP16 quantization...")
onnx_fp16_path = model.export(format='onnx', imgsz=640, simplify=True, opset=12, half=True)
print(f"‚úì FP16 ONNX model exported to: {onnx_fp16_path}")

Exporting to ONNX format...
Ultralytics 8.4.12  Python-3.13.5 torch-2.6.0+cu124 CPU (13th Gen Intel Core i9-13900H)
 ProTip: Export to OpenVINO format for best performance on Intel hardware. Learn more at https://docs.ultralytics.com/integrations/openvino/

[34m[1mPyTorch:[0m starting from 'sweetpotato_project\runs\segment\sweetpotato_exp\weights\best.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) ((1, 39, 8400), (1, 32, 160, 160)) (312.5 MB)
[31m[1mrequirements:[0m Ultralytics requirements ['onnx>=1.12.0,<2.0.0', 'onnxslim>=0.1.71', 'onnxruntime-gpu'] not found, attempting AutoUpdate...
Collecting onnx<2.0.0,>=1.12.0
  Downloading onnx-1.20.1-cp312-abi3-win_amd64.whl.metadata (8.6 kB)
Collecting onnxslim>=0.1.71
  Downloading onnxslim-0.1.84-py3-none-any.whl.metadata (10 kB)
Collecting onnxruntime-gpu
  Downloading onnxruntime_gpu-1.24.1-cp313-cp313-win_amd64.whl.metadata (5.5 kB)
Collecting ml_dtypes>=0.5.0 (from onnx<2.0.0,>=1.12.0)
  Downloading ml_dtypes-0.5.

In [None]:
# Test ONNX model inference
try:
    import onnxruntime as ort
    
    # Load ONNX model
    session = ort.InferenceSession(onnx_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
    
    # Test with a sample image
    test_img_path = os.path.join(test_images_dir, test_images[0])
    test_img = cv2.imread(test_img_path)
    test_img = cv2.resize(test_img, (640, 640))
    test_img = test_img.transpose(2, 0, 1)  # HWC to CHW
    test_img = test_img.astype(np.float32) / 255.0
    test_img = np.expand_dims(test_img, axis=0)
    
    # Run inference
    input_name = session.get_inputs()[0].name
    outputs = session.run(None, {input_name: test_img})
    
    print(f"‚úì ONNX model inference successful!")
    print(f"  Output shape: {outputs[0].shape}")
    
except Exception as e:
    print(f"‚ö† ONNX inference test failed: {e}")
    print("  Model exported but may need verification")

‚úì ONNX model inference successful!
  Output shape: (1, 39, 8400)


## 7. Download Results

In [None]:
# Package all results for download
results_zip = f'{WORK_DIR}/sweetpotato_training_results.zip'

with zipfile.ZipFile(results_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add best model
    if os.path.exists(best_model_path):
        zipf.write(best_model_path, 'best.pt')
    
    # Add ONNX models
    if os.path.exists(onnx_path):
        zipf.write(onnx_path, 'model.onnx')
    if os.path.exists(onnx_fp16_path):
        zipf.write(onnx_fp16_path, 'model_fp16.onnx')
    
    # Add predictions CSV
    if os.path.exists(csv_path):
        zipf.write(csv_path, 'predictions.csv')
    
    # Add training metrics
    run_dir = f'{WORK_DIR}/runs/segment/sweetpotato_exp'
    if os.path.exists(run_dir):
        for root, dirs, files in os.walk(run_dir):
            for file in files:
                if file.endswith(('.png', '.jpg', '.csv', '.txt', '.json')):
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, run_dir)
                    zipf.write(file_path, f'results/{arcname}')

print(f"‚úì Results packaged to: {results_zip}")
print(f"\nTo download, run:")
print(f"from google.colab import files")
print(f"files.download('{results_zip}')")

# Download results
if IS_COLAB:
    from google.colab import files
    files.download(results_zip)
else:
    print(f"\n‚úì Results saved locally to: {os.path.abspath(results_zip)}")
    print(f"  You can find all outputs in: {os.path.abspath(WORK_DIR)}")

‚úì Results packaged to: ./sweetpotato_project/sweetpotato_training_results.zip

To download, run:
from google.colab import files
files.download('./sweetpotato_project/sweetpotato_training_results.zip')

‚úì Results saved locally to: c:\Users\kensm\farm-photo-outliner\sweetpotato_project\sweetpotato_training_results.zip
  You can find all outputs in: c:\Users\kensm\farm-photo-outliner\sweetpotato_project


## 8. Training Summary

### Model Performance
- **Best mAP@0.5**: Check validation metrics above
- **Model Size**: Check model parameters above
- **Inference Speed**: Run speed benchmark below

### Files Generated
- `best.pt`: Best trained model weights
- `model.onnx`: ONNX export for deployment
- `predictions.csv`: All predictions with bboxes and masks
- `low_confidence_predictions.csv`: Images flagged for re-annotation
- Training logs and visualizations in `runs/segment/sweetpotato_exp/`

### Next Steps
1. Review low-confidence predictions for active learning
2. Fine-tune hyperparameters if mAP < 0.85
3. Test ONNX model on edge devices
4. Compare with Mask R-CNN baseline (see comparison notebook)

In [None]:
# Inference speed benchmark
import time

test_img_path = os.path.join(test_images_dir, test_images[0])
num_runs = 100

print(f"Benchmarking inference speed ({num_runs} runs)...")

# Warmup
for _ in range(10):
    _ = model.predict(test_img_path, verbose=False)

# Benchmark
start_time = time.time()
for _ in range(num_runs):
    _ = model.predict(test_img_path, verbose=False)
end_time = time.time()

avg_time = (end_time - start_time) / num_runs
fps = 1.0 / avg_time

print(f"\n‚úì Average inference time: {avg_time*1000:.2f} ms")
print(f"‚úì FPS: {fps:.2f}")
print(f"‚úì {'Meets' if fps >= 30 else 'Below'} real-time requirement (30+ FPS)")

Benchmarking inference speed (100 runs)...
Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels
Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels
Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions\labels
Results saved to [1mC:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_project\outputs\predictions\predictions[0m
2 labels saved to C:\Users\kensm\farm-photo-outliner\runs\segment\sweetpotato_proj