In [None]:
# Cell 1: Clean Installation
!pip uninstall -y seaborn scipy -q
!pip install numpy==1.24.3 scipy==1.10.1 ultralytics>=8.3.0 roboflow opencv-python pillow matplotlib pyyaml -q
!pip install seaborn
print("‚úÖ Packages installed successfully!")

In [None]:
# Cell 2: GPU Verification (Multi-GPU Support)
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    gpu_count = torch.cuda.device_count()
    print(f"‚úÖ Found {gpu_count} GPU(s):")
    for i in range(gpu_count):
        print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"   Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
    print(f"CUDA version: {torch.version.cuda}")
    
    # Set device for multi-GPU training
    if gpu_count > 1:
        DEVICE = list(range(gpu_count))  # Use all GPUs: [0, 1, ...]
        print(f"\nüöÄ Multi-GPU training enabled: Using GPUs {DEVICE}")
    else:
        DEVICE = 0  # Single GPU
        print(f"\nüìå Single GPU training: Using GPU {DEVICE}")
else:
    print("‚ùå No GPU detected! Training will be very slow.")
    print("Please check: Runtime > Change runtime type > Hardware accelerator > GPU")
    DEVICE = 'cpu'

In [None]:
# Cell 3: Fixed Imports (NO SEABORN - causing errors)
from ultralytics import YOLO
from roboflow import Roboflow
import numpy as np
import cv2
from PIL import Image
import yaml
from pathlib import Path
import os
import random
import shutil
# Visualization - using matplotlib only
import matplotlib.pyplot as plt

print("‚úÖ All imports successful!")

In [None]:
dataset_path = '/kaggle/input/banana-datasets-early-v2/kaggle/working/combined_yolo_dataset'
print("="*60)
print("üìä ANALYZING: YOLO CLASSIFICATION DATASET")
print("="*60)


print("="*60)

# Create a working copy in writable directory
working_dir = Path('/kaggle/working')
dataset_copy_dir = working_dir / 'yolo_classification_dataset'

# Copy the dataset to working directory if not already copied
if not dataset_copy_dir.exists():
    print(f"üìÅ Copying dataset to {dataset_copy_dir}...")
    shutil.copytree(dataset_path, dataset_copy_dir)
    print("‚úÖ Dataset copied to working directory")
else:
    print("‚úÖ Using existing copy in working directory")

# Now use the copy
data_yaml_path = dataset_copy_dir / 'data.yaml'

# Load data.yaml
with open(data_yaml_path, 'r') as file:
    data_config = yaml.safe_load(file)

print(f"\nNumber of classes: {data_config['nc']}")
print(f"Class names: {data_config['names']}")

# Update paths in the working copy
if 'train' in data_config:
    # Fix path relative to current location
    train_relative = data_config['train'].replace('../', '')
    train_path = dataset_copy_dir / train_relative
    data_config['train'] = str(train_path)
    
if 'val' in data_config:
    val_relative = data_config['val'].replace('../', '')
    val_path = dataset_copy_dir / val_relative
    data_config['val'] = str(val_path)
    
if 'test' in data_config:
    test_relative = data_config['test'].replace('../', '')
    test_path = dataset_copy_dir / test_relative
    data_config['test'] = str(test_path)

# Count images
train_img_path = Path(data_config['train'])
val_img_path = Path(data_config['val'])

train_images = len(list(train_img_path.glob('*.jpg'))) + len(list(train_img_path.glob('*.png')))
val_images = len(list(val_img_path.glob('*.jpg'))) + len(list(val_img_path.glob('*.png')))

if 'test' in data_config:
    test_img_path = Path(data_config['test'])
    test_images = len(list(test_img_path.glob('*.jpg'))) + len(list(test_img_path.glob('*.png')))
else:
    test_images = 0

print(f"\nüìà Dataset Statistics:")
print(f"Training images: {train_images}")
print(f"Validation images: {val_images}")
print(f"Test images: {test_images}")

# Update and save yaml - Now it's writable!
data_config['path'] = str(dataset_copy_dir)
with open(data_yaml_path, 'w') as f:
    yaml.dump(data_config, f, default_flow_style=False)

print(f"\n‚úÖ Dataset configuration saved to: {data_yaml_path}")
print(f"Working directory: {dataset_copy_dir}")

In [None]:
# Cell 6: Visualize sample images with annotations (FIXED - Correct colors and class names)

# Class colors matching data-labeling-classification.ipynb
# Colors in RGB format (for display after BGR2RGB conversion)
CLASS_COLORS_RGB = {
    0: (0, 100, 0),        # Healthy: Dark Green (#006400)
    1: (0, 255, 0),        # Stage1: Green (#00FF00)
    2: (144, 238, 144),    # Stage2: Light Green (#90EE90)
    3: (173, 255, 47),    # Stage3: Yellow Green (#ADFF2F)
    4: (255, 255, 0),      # Stage4: Yellow (#FFFF00)
    5: (255, 165, 0),      # Stage5: Orange (#FFA500)
    6: (255, 0, 0)         # Stage6: Red (#FF0000)
}

def visualize_yolo_annotations(image_path, label_path, class_names_dict):
    """
    Visualization function para sa YOLO format annotations
    Uses correct class colors from data-labeling-classification.ipynb
    """
    try:
        # Read image
        img = cv2.imread(str(image_path))
        if img is None:
            print(f"‚ö†Ô∏è Cannot read image: {image_path}")
            return None
            
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        
        # Check if label file exists
        if not label_path.exists():
            print(f"‚ö†Ô∏è No label file for: {image_path.name}")
            return img
        
        # Read annotations
        with open(label_path, 'r') as f:
            annotations = f.readlines()
        
        # Draw bounding boxes
        for ann in annotations:
            parts = ann.strip().split()
            if len(parts) != 5:
                continue
                
            class_id, x_center, y_center, width, height = map(float, parts)
            class_id = int(class_id)
            
            # Convert YOLO format to pixel coordinates
            x1 = int((x_center - width/2) * w)
            y1 = int((y_center - height/2) * h)
            x2 = int((x_center + width/2) * w)
            y2 = int((y_center + height/2) * h)
            
            # Get color based on class (using correct color scheme)
            color = CLASS_COLORS_RGB.get(class_id, (128, 128, 128))  # Default to gray if unknown
            
            # Draw rectangle and label
            cv2.rectangle(img, (x1, y1), (x2, y2), color, 3)
            
            # Get class name from the provided dictionary
            class_name = class_names_dict.get(class_id, f"Class_{class_id}")
            
            # Add label with background
            label = f"{class_name}"
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 0.7
            thickness = 2
            
            (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
            cv2.rectangle(img, (x1, y1-text_height-10), (x1+text_width+10, y1), color, -1)
            cv2.putText(img, label, (x1+5, y1-5), font, font_scale, (0, 0, 0), thickness)
        
        return img
        
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

# Get the correct paths - FIX: Resolve relative paths properly
print("üîç Looking for images...\n")

# Resolve the train path relative to dataset_path
dataset_path ="/kaggle/working/yolo_classification_dataset"
base_path = Path(dataset_path)

# Get base path from the actual train path in config
train_path = Path(data_config['train'])  # This is already updated to yolo_classification_dataset
base_path = train_path.parent.parent  # Go up from train/images to yolo_classification_dataset
train_img_dir = train_path
train_label_dir = base_path / 'train' / 'labels'

print(f"Base dataset path: {base_path}")
print(f"Image directory: {train_img_dir}")
print(f"Label directory: {train_label_dir}")
print(f"Image directory exists: {train_img_dir.exists()}")
print(f"Label directory exists: {train_label_dir.exists()}")

# Get ALL image files (jpg and png)
all_images = list(train_img_dir.glob('*.jpg')) + list(train_img_dir.glob('*.png'))
print(f"Total images found: {len(all_images)}")

if len(all_images) == 0:
    print("‚ùå No images found! Check the paths.")
    print(f"Trying alternative path resolution...")
    # Try alternative path resolution
    if 'train' in data_config['train']:
        train_img_dir = base_path / 'train' / 'images'
        train_label_dir = base_path / 'train' / 'labels'
        all_images = list(train_img_dir.glob('*.jpg')) + list(train_img_dir.glob('*.png'))
        print(f"Alternative path - Images found: {len(all_images)}")
        print(f"Alternative image dir: {train_img_dir}")
        print(f"Alternative label dir: {train_label_dir}")

if len(all_images) > 0:
    import random
    
    # Use actual class names from data_config (7 classes: Healthy, Stage1-Stage6)
    class_names_map = data_config['names']  # This is already a dict: {0: 'Healthy', 1: 'Stage1', ...}
    
    # Group images by the stages they contain
    # This ensures we get samples from different stages
    print("üìã Grouping images by stage...")
    images_by_stage = {i: [] for i in range(7)}  # One list per stage (0-6)
    
    for img_path in all_images:
        label_path = train_label_dir / f"{img_path.stem}.txt"
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 1:
                        class_id = int(parts[0])
                        if 0 <= class_id <= 6:
                            # Add image to this stage's list (avoid duplicates)
                            if img_path not in images_by_stage[class_id]:
                                images_by_stage[class_id].append(img_path)
    
    # Print stage availability
    print("\nüìä Available images per stage:")
    for stage_id in range(7):
        stage_name = class_names_map.get(stage_id, f"Stage{stage_id}")
        count = len(images_by_stage[stage_id])
        print(f"  {stage_name}: {count} images")
    
    # Select one image from each available stage (prioritize diversity)
    sample_images = []
    selected_stages = []
    selected_image_paths = set()  # Track selected images to avoid duplicates
    
    # Try to get one sample from each stage (0-6)
    for stage_id in range(7):
        if len(images_by_stage[stage_id]) > 0:
            # Get available images for this stage that we haven't selected yet
            available_for_stage = [img for img in images_by_stage[stage_id] if img not in selected_image_paths]
            
            if len(available_for_stage) > 0:
                # Randomly select one image from this stage
                selected_img = random.choice(available_for_stage)
                sample_images.append(selected_img)
                selected_image_paths.add(selected_img)
                selected_stages.append(stage_id)
                stage_name = class_names_map.get(stage_id, f"Stage{stage_id}")
                print(f"‚úì Selected {stage_name} sample: {selected_img.name}")
    
    # If we have less than 4 samples, fill with random samples from any stage
    if len(sample_images) < 4:
        remaining_needed = 4 - len(sample_images)
        # Get images we haven't selected yet
        remaining_images = [img for img in all_images if img not in selected_image_paths]
        if len(remaining_images) > 0:
            additional_samples = random.sample(remaining_images, min(remaining_needed, len(remaining_images)))
            sample_images.extend(additional_samples)
            selected_image_paths.update(additional_samples)
            print(f"‚úì Added {len(additional_samples)} additional random samples")
    
    # Limit to 4 samples for 2x2 grid (should already be unique, but ensure it)
    sample_images = sample_images[:4]
    
    print(f"\n‚úÖ Selected {len(sample_images)} diverse samples from different stages")
    
    # Visualize
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))
    axes = axes.ravel()
    
    images_processed = 0
    
    for idx, img_path in enumerate(sample_images):
        # Corresponding label file
        label_path = train_label_dir / f"{img_path.stem}.txt"
        
        # Get the primary stage for this image (for title)
        primary_stage = "Unknown"
        stage_classes = set()
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 1:
                        class_id = int(parts[0])
                        if 0 <= class_id <= 6:
                            stage_classes.add(class_id)
                            primary_stage = class_names_map.get(class_id, f"Stage{class_id}")
        
        # Visualize with correct class names
        img = visualize_yolo_annotations(img_path, label_path, class_names_map)
        
        if img is not None:
            axes[idx].imshow(img)
            # Show stage in title
            stages_str = ", ".join([class_names_map.get(sid, f"Stage{sid}") for sid in sorted(stage_classes)])
            axes[idx].set_title(f'Sample {idx+1}: {primary_stage}\n{img_path.name}', fontsize=10, fontweight='bold')
            axes[idx].axis('off')
            images_processed += 1
            
            # Print annotation info
            if label_path.exists():
                with open(label_path, 'r') as f:
                    num_objects = len(f.readlines())
                print(f"‚úì {img_path.name}: {num_objects} object(s) - Stages: {stages_str}")
    
    # Hide empty subplots
    for idx in range(images_processed, 4):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig('sample_annotations.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n‚úÖ Processed {images_processed} images")
    print("‚úÖ Visualization saved as 'sample_annotations.png'")
    
    # Show class distribution for ALL 7 classes
    print("\nüìä Quick Class Distribution Check (first 100 labels):")
    class_counts = {i: 0 for i in range(7)}  # Initialize all 7 classes
    
    for label_file in list(train_label_dir.glob('*.txt'))[:100]:
        with open(label_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 1:
                    class_id = int(parts[0])
                    if 0 <= class_id <= 6:
                        class_counts[class_id] = class_counts.get(class_id, 0) + 1
    
    for class_id in range(7):
        count = class_counts.get(class_id, 0)
        class_name = class_names_map.get(class_id, f"Class_{class_id}")
        print(f"  {class_name}: {count} instances")
else:
    print("\n‚ùå Could not find images. Please verify the dataset path structure.")

In [None]:
# Cell 7: Setup YOLO Model

# Load YOLO model (using YOLO11 for compatibility)
model = YOLO('yolo12n.pt')  # You can change to yolo11s.pt, yolo11m.pt, yolo11l.pt, yolo11x.pt for larger models

print("‚úÖ YOLO model loaded successfully!")
print(f"Model type: {type(model)}")

print("‚úÖ Enhanced evaluation tools loaded successfully!")

In [None]:
# Cell 8: Train YOLO Model with Early Detection Optimization
print("üöÄ TRAINING YOLO MODEL FOR EARLY DISEASE DETECTION")
print("=" * 60)
data_yaml_path = '/kaggle/working/yolo_classification_dataset/data.yaml'

# ‚úÖ FIXED: Removed duplicate cos_lr, deprecated parameters, and corrected all settings
training_config = {
    # Basic settings
    'data': data_yaml_path,
    'epochs': 150,
    'imgsz': 736,
    'batch': 40,
    'patience': 15,
    'save': True,
    'device': 0 if torch.cuda.is_available() else 'cpu',
    
    # Optimizer - AdamW is good for small objects (pests/disease spots)
    'optimizer': 'AdamW',
    'lr0': 0.001,              # Lower LR for fine-tuning
    'lrf': 0.01,               # Final learning rate = lr0 * lrf
    'momentum': 0.937,         # SGD momentum
    'weight_decay': 0.0005,
    'warmup_epochs': 5,        # Increased warmup for stability
    'warmup_momentum': 0.8,
    'warmup_bias_lr': 0.1,
    
    # Loss weights - IMPORTANT for pest/disease detection
    'box': 7.5,                # Higher weight for accurate bounding boxes
    'cls': 0.5,                # Balanced classification (important for disease stages)
    'dfl': 1.5,                # Distribution focal loss for precise boxes
    
    # Data augmentation - CRITICAL for pest/disease detection
    'hsv_h': 0.015,            # Hue variation (different lighting conditions)
    'hsv_s': 0.7,              # Saturation (different camera settings)
    'hsv_v': 0.4,              # Brightness (outdoor/indoor variations)
    'degrees': 10.0,           # Rotation (leaves at different angles)
    'translate': 0.1,          # Translation (pest/disease at different positions)
    'fliplr': 0.5,             # Horizontal flip (symmetric disease patterns)
    'scale': 0.7,              # Increased scale variation (0.3-1.7x)
    'shear': 2.0,              # Small shear transformation
    'perspective': 0.0001,     # Slight perspective change
    'flipud': 0.0,             # No vertical flip (leaves don't grow upside down)
    
    'mosaic': 1.0,             # Mosaic augmentation (helps with small objects)
    'mixup': 0.1,              # Mixup (helps distinguish similar symptoms)
    'copy_paste': 0.1,         # Copy-paste augmentation (for rare disease cases)
    'auto_augment': 'randaugment',  # Additional augmentation
    'erasing': 0.4,            # Random erasing augmentation
    
    # Advanced settings for small object detection
    'multi_scale': False,      # Keep single scale for consistency
    'conf': 0.25,              # Lower confidence threshold (catch early symptoms)
    'iou': 0.7,                # IoU threshold for NMS
    'close_mosaic': 15,        # Disable mosaic last 15 epochs
    
    # ‚úÖ FIXED: Removed duplicate cos_lr (was defined twice)
    'cos_lr': True,            # Cosine learning rate scheduler (smoother convergence)
    
    # ‚úÖ FIXED: Removed deprecated parameters (label_smoothing, save_hybrid)
    'nbs': 64,                 # Nominal batch size for scaling
    'overlap_mask': True,      # Better for overlapping objects
    'mask_ratio': 4,           # Mask downsampling ratio
    'dropout': 0.0,            # No dropout (YOLO handles this internally)
    
    'val': True,               # Validate during training
    'plots': True,             # Generate training plots
    'save_json': True,         # Save results in JSON
    'verbose': True,           # Verbose output
    'deterministic': False,    # Faster training (set True for reproducibility)
    
    # Project settings
    'name': 'banana_pest_disease_yolo11',  # Updated name to reflect YOLO11
    'project': 'runs/detect',
    'exist_ok': True,
    'workers': 8,              # Faster data loading
}

print("üìã TRAINING CONFIGURATION:")
for key, value in training_config.items():
    print(f"   {key}: {value}")

# Start training
print(f"\n‚è≥ Starting training with {data_config['nc']} classes:")

# ‚úÖ FIXED: Proper iteration over dictionary
for class_id, class_name in sorted(data_config['names'].items()):
    print(f"   Class {class_id}: {class_name}")

try:
    # Train the model
    results = model.train(**training_config)
    print("‚úÖ Training completed successfully!")
    
    # Display results
    if hasattr(results, 'save_dir'):
        print(f"\nüìÅ Model saved to: {results.save_dir}")
        print(f"üìÅ Best model: {results.save_dir / 'weights' / 'best.pt'}")
        print(f"üìÅ Last model: {results.save_dir / 'weights' / 'last.pt'}")
    
except Exception as e:
    print(f"‚ùå Training error: {e}")
    print("üîÑ Trying with simplified configuration...")
    
    # Fallback configuration
    simple_config = {
        'data': data_yaml_path,
        'epochs': 10,
        'imgsz': 640,
        'batch': 8,
        'patience': 10,
        'save': True,
        'device': 0 if torch.cuda.is_available() else 'cpu',
        'name': 'banana_pest_disease_yolo11_simple',
        'project': 'runs/detect',
        'exist_ok': True,
    }
    
    results = model.train(**simple_config)
    print("‚úÖ Training completed with simplified config!")

# Final summary
if 'results' in locals():
    print("\n" + "="*60)
    print("üéâ TRAINING SUMMARY")
    print("="*60)
    if hasattr(results, 'save_dir'):
        print(f"üìÅ Output directory: {results.save_dir}")
        print(f"üìä View training plots: {results.save_dir / 'results.png'}")
        print(f"üìä View confusion matrix: {results.save_dir / 'confusion_matrix.png'}")
    else:
        print(f"üìÅ Output directory: runs/detect/banana_pest_disease_yolo12")
    print("="*60)

In [None]:
# ============================================
# CELL 9: COMPREHENSIVE PERFORMANCE EVALUATION
# ============================================

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ultralytics import YOLO

print("üìä COMPREHENSIVE PERFORMANCE EVALUATION")
print("=" * 80)

# ============================================
# VERIFY DEPENDENCIES FROM PREVIOUS CELLS
# ============================================
# Ensure data_config and data_yaml_path are available
if 'data_config' not in globals():
    print("‚ö†Ô∏è data_config not found. Loading from data.yaml...")
    data_yaml_path = Path('/kaggle/working/yolo_classification_dataset/data.yaml')
    import yaml
    with open(data_yaml_path, 'r') as f:
        data_config = yaml.safe_load(f)
else:
    # Convert Path to string if needed
    if isinstance(data_yaml_path, Path):
        data_yaml_path = str(data_yaml_path)
    else:
        data_yaml_path = str(data_yaml_path)

# Get class names from data_config
class_names_map = data_config['names']  # {0: 'Healthy', 1: 'Stage1', ...}
num_classes = data_config['nc']

# Class colors matching Cell 4
CLASS_COLORS_RGB = {
    0: (0, 100, 0),        # Healthy: Dark Green
    1: (0, 255, 0),        # Stage1: Green
    2: (144, 238, 144),    # Stage2: Light Green
    3: (173, 255, 47),     # Stage3: Yellow Green
    4: (255, 255, 0),      # Stage4: Yellow
    5: (255, 165, 0),      # Stage5: Orange
    6: (255, 0, 0)         # Stage6: Red
}

# Convert RGB tuples to hex for matplotlib
def rgb_to_hex(rgb):
    return '#%02x%02x%02x' % tuple(rgb)

class_colors = [rgb_to_hex(CLASS_COLORS_RGB[i]) for i in range(num_classes)]

# ============================================
# LOAD TRAINED MODEL
# ============================================
print("\n1Ô∏è‚É£ Loading trained model...")

# Try multiple possible model paths
possible_paths = [
    'runs/detect/banana_pest_disease_yolo11/weights/best.pt',
    'runs/detect/banana_pest_disease_yolo12/weights/best.pt',
    'runs/detect/banana_pest_disease_yolo11/weights/last.pt',
    'runs/detect/banana_pest_disease_yolo12/weights/last.pt',
]

model_path = None
for path in possible_paths:
    if Path(path).exists():
        model_path = path
        break

if model_path is None:
    print("‚ùå No trained model found! Train the model first (Cell 8)")
    raise FileNotFoundError("Model not found. Please run Cell 8 to train the model first.")

print(f"‚úÖ Model loaded from: {model_path}")
trained_model = YOLO(model_path)

# ============================================
# RUN VALIDATION ON TEST SET
# ============================================
print("\n2Ô∏è‚É£ Running validation on test set...")
print("=" * 80)

# ‚úÖ CORRECT METHOD: Use model.val() with data path
# Note: Ultralytics automatically uses the 'test' split if defined in data.yaml
try:
    test_results = trained_model.val(
        data=data_yaml_path,
        batch=16,
        imgsz=640,
        conf=0.25,
        iou=0.7,
        verbose=True
    )
except Exception as e:
    print(f"‚ö†Ô∏è Error during validation: {e}")
    print("üîÑ Trying with default settings...")
    test_results = trained_model.val(data=data_yaml_path, verbose=True)

# ============================================
# EXTRACT OVERALL METRICS
# ============================================
print("\nüìà TEST SET PERFORMANCE METRICS")
print("=" * 80)

print("\n3Ô∏è‚É£ OVERALL METRICS:")

# Safely extract metrics with error handling
try:
    map50 = float(test_results.box.map50)
    map50_95 = float(test_results.box.map)
    precision = float(test_results.box.mp)
    recall = float(test_results.box.mr)
except AttributeError as e:
    print(f"‚ö†Ô∏è Error extracting metrics: {e}")
    print("   Using alternative metric extraction...")
    # Fallback: try to get from results dict
    if hasattr(test_results, 'results_dict'):
        map50 = test_results.results_dict.get('metrics/mAP50(B)', 0.0)
        map50_95 = test_results.results_dict.get('metrics/mAP50-95(B)', 0.0)
        precision = test_results.results_dict.get('metrics/precision(B)', 0.0)
        recall = test_results.results_dict.get('metrics/recall(B)', 0.0)
    else:
        raise ValueError("Cannot extract metrics from validation results")

print(f"   ‚Ä¢ mAP50:     {map50:.3f} ({map50*100:.1f}%)")
print(f"   ‚Ä¢ mAP50-95:  {map50_95:.3f} ({map50_95*100:.1f}%)")
print(f"   ‚Ä¢ Precision: {precision:.3f} ({precision*100:.1f}%)")
print(f"   ‚Ä¢ Recall:    {recall:.3f} ({recall*100:.1f}%)")

# Calculate F1 score with safe division
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
print(f"   ‚Ä¢ F1 Score:  {f1_score:.3f} ({f1_score*100:.1f}%)")

# ============================================
# PER-CLASS METRICS
# ============================================
print("\n4Ô∏è‚É£ PER-CLASS PERFORMANCE:")
print("=" * 80)

# ‚úÖ CORRECT METHOD: Access per-class AP
try:
    ap50_per_class = test_results.box.ap50  # AP at IoU=0.50
    ap_per_class = test_results.box.ap      # AP at IoU=0.50:0.95
    
    # Handle different array shapes
    if hasattr(ap50_per_class, 'shape'):
        if len(ap50_per_class.shape) > 1:
            # If 2D array (class, iou_threshold), take mean
            ap50_per_class = ap50_per_class.mean(axis=-1) if ap50_per_class.ndim > 1 else ap50_per_class
            ap_per_class = ap_per_class.mean(axis=-1) if ap_per_class.ndim > 1 else ap_per_class
except AttributeError:
    print("‚ö†Ô∏è Per-class metrics not available in this format")
    ap50_per_class = np.zeros(num_classes)
    ap_per_class = np.zeros(num_classes)

# Table header
print(f"\n{'Class':<15} | {'mAP50':<10} | {'mAP50-95':<10}")
print("-" * 45)

class_metrics = {}

for class_id in range(num_classes):
    class_name = class_names_map.get(class_id, f"Class_{class_id}")
    
    try:
        # Get AP for this class
        if isinstance(ap50_per_class, (list, np.ndarray)):
            ap50 = float(ap50_per_class[class_id]) if class_id < len(ap50_per_class) else 0.0
        else:
            ap50 = 0.0
            
        if isinstance(ap_per_class, (list, np.ndarray)):
            ap = float(ap_per_class[class_id]) if class_id < len(ap_per_class) else 0.0
        else:
            ap = 0.0
    except (IndexError, TypeError, AttributeError) as e:
        print(f"   ‚ö†Ô∏è Warning: Could not get metrics for {class_name}: {e}")
        ap50 = 0.0
        ap = 0.0
    
    # Store metrics
    class_metrics[class_name] = {
        'ap50': ap50,
        'ap': ap
    }
    
    # Print row
    print(f"{class_name:<15} | {ap50:<10.3f} | {ap:<10.3f}")

# ============================================
# VISUALIZATION: PER-CLASS PERFORMANCE
# ============================================
print("\n5Ô∏è‚É£ Generating per-class performance visualization...")

fig, axes = plt.subplots(1, 2, figsize=(16, 7))
fig.suptitle('üéØ Per-Class Performance Analysis (Test Set)', 
             fontsize=16, fontweight='bold')

classes = [class_names_map[i] for i in range(num_classes)]

# Extract metrics for plotting
ap50s = [class_metrics[c]['ap50'] for c in classes]
aps = [class_metrics[c]['ap'] for c in classes]

# Plot 1: mAP50 per class
ax1 = axes[0]
bars1 = ax1.bar(classes, ap50s, color=class_colors[:len(classes)])
ax1.set_ylabel('mAP50', fontsize=12, fontweight='bold')
ax1.set_title('mAP50 per Disease Class', fontsize=13, fontweight='bold')
ax1.set_ylim([0, 1.05])
ax1.grid(True, alpha=0.3, axis='y', linestyle='--')
ax1.set_xticklabels(classes, rotation=15, ha='right', fontsize=10)

# Add value labels on bars
for bar in bars1:
    height = bar.get_height()
    if height > 0:
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.3f}', ha='center', va='bottom', 
                fontsize=9, fontweight='bold')

# Plot 2: mAP50-95 per class
ax2 = axes[1]
bars2 = ax2.bar(classes, aps, color=class_colors[:len(classes)])
ax2.set_ylabel('mAP50-95', fontsize=12, fontweight='bold')
ax2.set_title('mAP50-95 per Disease Class', fontsize=13, fontweight='bold')
ax2.set_ylim([0, 1.05])
ax2.grid(True, alpha=0.3, axis='y', linestyle='--')
ax2.set_xticklabels(classes, rotation=15, ha='right', fontsize=10)

# Add value labels on bars
for bar in bars2:
    height = bar.get_height()
    if height > 0:
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.3f}', ha='center', va='bottom', 
                fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig('per_class_performance.png', dpi=300, bbox_inches='tight')
print("‚úÖ Visualization saved as 'per_class_performance.png'")
plt.show()

# ============================================
# DETAILED COMPARISON TABLE
# ============================================
print("\n6Ô∏è‚É£ DETAILED METRICS COMPARISON:")
print("=" * 80)

# Create comparison with overall metrics
print(f"\n{'Metric':<20} | {'Overall':<12} | {'Best Class':<20} | {'Worst Class':<20}")
print("-" * 75)

# Overall vs per-class comparison
if class_metrics:
    best_ap50_class = max(class_metrics.items(), key=lambda x: x[1]['ap50'])
    worst_ap50_class = min(class_metrics.items(), key=lambda x: x[1]['ap50'])
    
    print(f"{'mAP50':<20} | {map50:<12.3f} | {best_ap50_class[0]:<20} ({best_ap50_class[1]['ap50']:.3f}) | {worst_ap50_class[0]:<20} ({worst_ap50_class[1]['ap50']:.3f})")
    
    best_ap_class = max(class_metrics.items(), key=lambda x: x[1]['ap'])
    worst_ap_class = min(class_metrics.items(), key=lambda x: x[1]['ap'])
    
    print(f"{'mAP50-95':<20} | {map50_95:<12.3f} | {best_ap_class[0]:<20} ({best_ap_class[1]['ap']:.3f}) | {worst_ap_class[0]:<20} ({worst_ap_class[1]['ap']:.3f})")

# ============================================
# PERFORMANCE INTERPRETATION
# ============================================
print("\n7Ô∏è‚É£ PERFORMANCE INTERPRETATION:")
print("=" * 80)

if map50 > 0.8:
    status = "üéâ EXCELLENT"
    interpretation = "Model performing very well on test set! Ready for deployment."
elif map50 > 0.7:
    status = "‚úÖ GOOD"
    interpretation = "Model shows strong performance. Consider fine-tuning for production."
elif map50 > 0.6:
    status = "üìä ACCEPTABLE"
    interpretation = "Model is functional but has room for improvement."
else:
    status = "‚ö†Ô∏è NEEDS IMPROVEMENT"
    interpretation = "Consider more training, data augmentation, or larger model."

print(f"\n   Status: {status}")
print(f"   {interpretation}")

# Precision/Recall balance
print(f"\n   Balance Analysis:")
if precision > 0.7 and recall > 0.7:
    print(f"   ‚úÖ Well-balanced precision ({precision:.3f}) and recall ({recall:.3f})")
elif precision < 0.6:
    print(f"   ‚ö†Ô∏è Low precision ({precision:.3f}) - Too many false positives")
    print(f"   üí° Recommendation: Increase confidence threshold (e.g., conf=0.35)")
elif recall < 0.6:
    print(f"   ‚ö†Ô∏è Low recall ({recall:.3f}) - Missing many true positives")
    print(f"   üí° Recommendation: Decrease confidence threshold (e.g., conf=0.20)")

# Class-specific insights
if class_metrics:
    print(f"\n   Class-Specific Insights:")
    print(f"   ‚Ä¢ Best performing: {best_ap50_class[0]} (mAP50: {best_ap50_class[1]['ap50']:.3f})")
    print(f"   ‚Ä¢ Needs attention: {worst_ap50_class[0]} (mAP50: {worst_ap50_class[1]['ap50']:.3f})")
    
    # Additional insight: classes below threshold
    threshold = 0.5
    low_performers = [name for name, metrics in class_metrics.items() 
                      if metrics['ap50'] < threshold]
    if low_performers:
        print(f"   ‚Ä¢ Classes below {threshold} mAP50: {', '.join(low_performers)}")
        print(f"   üí° Consider: More training data or class-specific augmentation for these classes")

# ============================================
# FINAL SUMMARY
# ============================================
print("\n" + "=" * 80)
print("üìã EVALUATION SUMMARY")
print("=" * 80)
print(f"""
‚úÖ TEST SET RESULTS:

   ‚Ä¢ Dataset: {data_yaml_path}
   ‚Ä¢ Model: {model_path}
   ‚Ä¢ Classes: {num_classes} ({', '.join(classes)})

üìä OVERALL PERFORMANCE:

   ‚Ä¢ mAP50:     {map50:.3f} ({map50*100:.1f}%)
   ‚Ä¢ mAP50-95:  {map50_95:.3f} ({map50_95*100:.1f}%)
   ‚Ä¢ Precision: {precision:.3f} ({precision*100:.1f}%)
   ‚Ä¢ Recall:    {recall:.3f} ({recall*100:.1f}%)
   ‚Ä¢ F1 Score:  {f1_score:.3f} ({f1_score*100:.1f}%)

üìÅ OUTPUTS SAVED:

   ‚Ä¢ per_class_performance.png - Per-class visualization
   ‚Ä¢ Confusion matrix: runs/detect/val/confusion_matrix.png
   ‚Ä¢ Other metrics: runs/detect/val/

""")
print("=" * 80)
print("‚úÖ Evaluation completed successfully!")
print("=" * 80)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob

# ============================================
# OVERFITTING & UNDERFITTING DETECTION (OPTIMIZED)
# ============================================
print("üîç DETECTING OVERFITTING & UNDERFITTING")
print("=" * 80)

# ============================================
# SMART PATH FINDING
# ============================================
def find_results_csv():
    """Find results.csv automatically - tries multiple paths"""
    possible_paths = [
        'runs/detect/banana_pest_disease_yolo11/results.csv',
        'runs/detect/banana_pest_disease_yolo12/results.csv',
        'runs/detect/banana_pest_disease_yolo11_simple/results.csv',
    ]
    
    # Also search for any results.csv in runs/detect
    search_patterns = [
        'runs/detect/*/results.csv',
        'runs/detect/*/*/results.csv',
    ]
    
    # Try hardcoded paths first
    for path in possible_paths:
        if Path(path).exists():
            return Path(path)
    
    # Search dynamically
    for pattern in search_patterns:
        matches = glob.glob(pattern)
        if matches:
            # Get the most recent one
            return Path(max(matches, key=lambda p: Path(p).stat().st_mtime))
    
    return None

# Find results file
results_csv = find_results_csv()

if not results_csv or not results_csv.exists():
    print("‚ùå Results CSV not found!")
    print("   Searched in:")
    print("   ‚Ä¢ runs/detect/banana_pest_disease_yolo11/")
    print("   ‚Ä¢ runs/detect/banana_pest_disease_yolo12/")
    print("   ‚Ä¢ runs/detect/*/")
    print("\nüí° Make sure training has completed and results.csv exists.")
    raise FileNotFoundError("results.csv not found. Please run training first.")

results_dir = results_csv.parent
print(f"‚úÖ Found results: {results_csv}")
print(f"üìÅ Results directory: {results_dir}\n")

# Load and clean data
try:
    df = pd.read_csv(results_csv)
    df.columns = df.columns.str.strip()
    
    # Reset index to start from 0
    df = df.reset_index(drop=True)
    
    # Validate dataframe
    if len(df) == 0:
        raise ValueError("Results CSV is empty!")
    
    print(f"‚úÖ Loaded {len(df)} epochs")
    print(f"üìä Columns found: {len(df.columns)}")
    
except Exception as e:
    print(f"‚ùå Error loading results: {e}")
    raise

# ============================================
# SMART COLUMN DETECTION
# ============================================
def find_column(df, possible_names):
    """Find column by trying multiple possible names"""
    for name in possible_names:
        if name in df.columns:
            return df[name]
    return None

# Get columns with fallbacks
train_box = find_column(df, ['train/box_loss', 'train_box_loss', 'box_loss'])
train_cls = find_column(df, ['train/cls_loss', 'train_cls_loss', 'cls_loss'])
val_box = find_column(df, ['val/box_loss', 'val_box_loss'])
val_cls = find_column(df, ['val/cls_loss', 'val_cls_loss'])
map50 = find_column(df, ['metrics/mAP50(B)', 'metrics/mAP50', 'mAP50', 'map50'])
map50_95 = find_column(df, ['metrics/mAP50-95(B)', 'metrics/mAP50-95', 'mAP50-95', 'map50_95'])
precision = find_column(df, ['metrics/precision(B)', 'metrics/precision', 'precision'])
recall = find_column(df, ['metrics/recall(B)', 'metrics/recall', 'recall'])

# Validate required columns
if train_box is None:
    raise ValueError("Required column 'train/box_loss' not found!")

if map50 is None:
    raise ValueError("Required column 'metrics/mAP50(B)' not found!")

print("‚úÖ Required columns found\n")

# ============================================
# FUNCTION 1: TRAIN vs VAL LOSS COMPARISON (OPTIMIZED)
# ============================================
def plot_overfitting_detection(df, train_box, val_box, train_cls, val_cls, 
                                map50, map50_95, save_path=None):
    """Detect overfitting with improved error handling"""
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 10))
    fig.suptitle('üî¨ OVERFITTING & UNDERFITTING ANALYSIS', 
                 fontsize=16, fontweight='bold', y=0.995)
    
    epochs = df.index + 1  # Start from epoch 1
    
    # ==========================================
    # PLOT 1: Box Loss (Train vs Val)
    # ==========================================
    ax1 = axes[0, 0]
    
    ax1.plot(epochs, train_box, label='Train Loss', linewidth=2.5, color='blue', marker='o', markersize=2)
    
    if val_box is not None:
        ax1.plot(epochs, val_box, label='Val Loss', linewidth=2.5, color='red', marker='s', markersize=2)
        
        # Calculate gap
        gap = val_box.iloc[-1] - train_box.iloc[-1]
        
        # Fill between with better visualization
        ax1.fill_between(epochs, train_box, val_box, 
                        where=(val_box >= train_box), 
                        alpha=0.2, color='red', label='Overfitting Gap')
        ax1.fill_between(epochs, train_box, val_box, 
                        where=(val_box < train_box), 
                        alpha=0.2, color='green', label='Good Generalization')
        
        # Add gap annotation
        final_gap = gap
        ax1.annotate(f'Final Gap: {final_gap:.4f}',
                    xy=(len(epochs), (train_box.iloc[-1] + val_box.iloc[-1])/2),
                    xytext=(len(epochs)*0.7, max(train_box.max(), val_box.max())*0.8),
                    bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7),
                    fontsize=10, fontweight='bold',
                    arrowprops=dict(arrowstyle='->', color='black', lw=1.5))
    else:
        ax1.plot(epochs, train_box, linewidth=2.5, color='blue')
        ax1.text(0.5, 0.95, '‚ö†Ô∏è Validation loss not available',
                transform=ax1.transAxes, ha='center', va='top',
                bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7),
                fontsize=10)
    
    ax1.set_xlabel('Epoch', fontsize=11, fontweight='bold')
    ax1.set_ylabel('Box Loss', fontsize=11, fontweight='bold')
    ax1.set_title('üì¶ Box Loss: Train vs Validation', fontsize=12, fontweight='bold')
    ax1.legend(fontsize=10, loc='best')
    ax1.grid(True, alpha=0.3, linestyle='--')
    
    # ==========================================
    # PLOT 2: Classification Loss (Train vs Val)
    # ==========================================
    ax2 = axes[0, 1]
    
    if train_cls is not None:
        ax2.plot(epochs, train_cls, label='Train Loss', linewidth=2.5, color='blue', marker='o', markersize=2)
        
        if val_cls is not None:
            ax2.plot(epochs, val_cls, label='Val Loss', linewidth=2.5, color='red', marker='s', markersize=2)
            gap_cls = val_cls.iloc[-1] - train_cls.iloc[-1]
            ax2.fill_between(epochs, train_cls, val_cls, alpha=0.2, color='orange')
            ax2.text(len(epochs)*0.7, max(train_cls.max(), val_cls.max())*0.8,
                    f'Gap: {gap_cls:.4f}',
                    bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7),
                    fontsize=10, fontweight='bold')
    else:
        ax2.text(0.5, 0.5, 'Classification loss not available',
                transform=ax2.transAxes, ha='center', va='center',
                fontsize=12, style='italic')
    
    ax2.set_xlabel('Epoch', fontsize=11, fontweight='bold')
    ax2.set_ylabel('Classification Loss', fontsize=11, fontweight='bold')
    ax2.set_title('üéØ Classification Loss: Train vs Validation', fontsize=12, fontweight='bold')
    if train_cls is not None:
        ax2.legend(fontsize=10, loc='best')
    ax2.grid(True, alpha=0.3, linestyle='--')
    
    # ==========================================
    # PLOT 3: mAP Performance Over Time
    # ==========================================
    ax3 = axes[1, 0]
    
    ax3.plot(epochs, map50, label='mAP50', linewidth=2.5, color='blue', 
            marker='o', markersize=3, alpha=0.8)
    
    if map50_95 is not None:
        ax3.plot(epochs, map50_95, label='mAP50-95', linewidth=2.5, color='orange', 
                marker='s', markersize=3, alpha=0.8)
    
    # Check if mAP is plateauing
    if len(map50) > 10:
        last_10_change = map50.iloc[-1] - map50.iloc[-10]
        if abs(last_10_change) < 0.01:
            ax3.axhline(y=map50.iloc[-1], color='orange', linestyle='--', 
                       linewidth=2, alpha=0.7, label=f'Plateaued at {map50.iloc[-1]:.3f}')
    
    # Add best epoch marker
    best_epoch = map50.idxmax() + 1
    best_map = map50.max()
    ax3.plot(best_epoch, best_map, 'g*', markersize=15, 
            label=f'Best: Epoch {best_epoch} ({best_map:.3f})')
    
    ax3.set_xlabel('Epoch', fontsize=11, fontweight='bold')
    ax3.set_ylabel('mAP Score', fontsize=11, fontweight='bold')
    ax3.set_title('üìà Validation Accuracy (mAP) Progression', fontsize=12, fontweight='bold')
    ax3.legend(fontsize=10, loc='best')
    ax3.grid(True, alpha=0.3, linestyle='--')
    ax3.set_ylim([0, max(1.0, map50.max() * 1.1)])
    
    # ==========================================
    # PLOT 4: DIAGNOSTIC SUMMARY (IMPROVED)
    # ==========================================
    ax4 = axes[1, 1]
    ax4.axis('off')
    
    # Analyze overfitting/underfitting
    final_train_loss = train_box.iloc[-1]
    final_map = map50.iloc[-1]
    
    # Calculate loss trend
    recent_epochs = min(20, len(train_box))
    train_trend = train_box.iloc[-recent_epochs:].values
    train_decreasing = train_trend[-1] < train_trend[0]
    
    # Determine status
    status = []
    color = 'green'
    
    # Check for underfitting
    if final_map < 0.5 and final_train_loss > 0.5:
        status.append("‚ö†Ô∏è UNDERFITTING DETECTED")
        status.append("")
        status.append("Symptoms:")
        status.append(f"‚Ä¢ Low accuracy: mAP={final_map:.3f} < 0.5")
        status.append(f"‚Ä¢ High train loss: {final_train_loss:.4f}")
        status.append("")
        status.append("Solutions:")
        status.append("1. Train longer (more epochs)")
        status.append("2. Use larger model")
        status.append("3. Reduce augmentation")
        status.append("4. Check data quality")
        color = 'red'
    
    # Check for overfitting
    elif val_box is not None:
        gap = val_box.iloc[-1] - train_box.iloc[-1]
        
        if gap > 0.15:
            status.append("‚ö†Ô∏è OVERFITTING DETECTED")
            status.append("")
            status.append("Symptoms:")
            status.append(f"‚Ä¢ Large gap: {gap:.4f} > 0.15")
            status.append("‚Ä¢ Val loss > Train loss")
            status.append("")
            status.append("Solutions:")
            status.append("1. Increase augmentation")
            status.append("2. Add more data")
            status.append("3. Use regularization")
            status.append("4. Early stopping")
            status.append("5. Reduce model size")
            color = 'orange'
        
        elif gap < 0.05 and final_map > 0.7:
            status.append("‚úÖ GOOD FIT")
            status.append("")
            status.append("Indicators:")
            status.append(f"‚Ä¢ Small gap: {gap:.4f}")
            status.append(f"‚Ä¢ Good mAP: {final_map:.3f}")
            status.append("‚Ä¢ Healthy model")
            status.append("")
            status.append("Recommendation:")
            status.append("‚Ä¢ Continue training")
            status.append("‚Ä¢ Monitor closely")
            color = 'lightgreen'
        
        else:
            status.append("üìä ACCEPTABLE")
            status.append("")
            status.append("Metrics:")
            status.append(f"‚Ä¢ Gap: {gap:.4f}")
            status.append(f"‚Ä¢ mAP: {final_map:.3f}")
            status.append("")
            status.append("Actions:")
            status.append("‚Ä¢ Continue monitoring")
            status.append("‚Ä¢ Check test set")
            color = 'lightyellow'
    
    else:
        status.append("‚ÑπÔ∏è LIMITED ANALYSIS")
        status.append("")
        status.append("No validation loss")
        status.append("available.")
        status.append("")
        if final_map > 0.7:
            status.append(f"‚úÖ Good mAP: {final_map:.3f}")
        else:
            status.append(f"‚ö†Ô∏è Low mAP: {final_map:.3f}")
        color = 'lightblue'
    
    # Add metrics summary
    status.append("")
    status.append("‚îÄ" * 35)
    status.append("METRICS:")
    status.append(f"Epochs: {len(df)}")
    status.append(f"Train Loss: {final_train_loss:.4f}")
    if val_box is not None:
        status.append(f"Val Loss: {val_box.iloc[-1]:.4f}")
    status.append(f"mAP50: {final_map:.3f}")
    if map50_95 is not None:
        status.append(f"mAP50-95: {map50_95.iloc[-1]:.3f}")
    
    status_text = '\n'.join(status)
    
    ax4.text(0.05, 0.5, status_text, 
             fontsize=10, 
             family='monospace',
             verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor=color, alpha=0.6, edgecolor='black', linewidth=1.5))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"üíæ Saved: {save_path}")
    
    plt.show()
    
    return status

# ============================================
# FUNCTION 2: LEARNING CURVE ANALYSIS (OPTIMIZED)
# ============================================
def plot_learning_curves(df, train_box, save_path=None):
    """Learning curve analysis with better error handling"""
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle('üìö LEARNING CURVE ANALYSIS', fontsize=16, fontweight='bold')
    
    epochs = df.index + 1
    
    # Plot 1: Loss curves with moving average
    ax1 = axes[0]
    window = min(5, len(train_box) // 4)  # Adaptive window
    if window < 1:
        window = 1
    
    train_box_ma = train_box.rolling(window=window, min_periods=1).mean()
    
    ax1.plot(epochs, train_box, alpha=0.3, color='blue', label='Raw Train Loss', linewidth=1)
    ax1.plot(epochs, train_box_ma, linewidth=2.5, color='blue', label=f'Train Loss (MA-{window})')
    
    ax1.set_xlabel('Epoch', fontsize=11, fontweight='bold')
    ax1.set_ylabel('Box Loss', fontsize=11, fontweight='bold')
    ax1.set_title('Loss Smoothing (Moving Average)', fontsize=12, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3, linestyle='--')
    
    # Plot 2: Loss derivative
    ax2 = axes[1]
    if len(train_box) > 1:
        loss_change = train_box.diff()
        
        ax2.plot(epochs[1:], loss_change[1:], color='purple', linewidth=2)
        ax2.axhline(y=0, color='red', linestyle='--', linewidth=1.5)
        ax2.fill_between(epochs[1:], 0, loss_change[1:], 
                         where=(loss_change[1:] < 0), alpha=0.3, color='green', label='Improving')
        ax2.fill_between(epochs[1:], 0, loss_change[1:], 
                         where=(loss_change[1:] > 0), alpha=0.3, color='red', label='Worsening')
        
        ax2.set_xlabel('Epoch', fontsize=11, fontweight='bold')
        ax2.set_ylabel('Loss Change', fontsize=11, fontweight='bold')
        ax2.set_title('Loss Improvement Rate', fontsize=12, fontweight='bold')
        ax2.legend(fontsize=10)
        ax2.grid(True, alpha=0.3, linestyle='--')
    else:
        ax2.text(0.5, 0.5, 'Not enough data for derivative',
                transform=ax2.transAxes, ha='center', va='center')
    
    # Plot 3: Convergence indicator
    ax3 = axes[2]
    
    window_size = min(10, len(train_box) // 2)
    if window_size >= 2 and len(train_box) > window_size:
        rolling_std = train_box.rolling(window=window_size, min_periods=1).std()
        ax3.plot(epochs, rolling_std, color='orange', linewidth=2.5)
        ax3.set_xlabel('Epoch', fontsize=11, fontweight='bold')
        ax3.set_ylabel(f'Loss Std Dev (window={window_size})', fontsize=11, fontweight='bold')
        ax3.set_title('Training Stability', fontsize=12, fontweight='bold')
        ax3.grid(True, alpha=0.3, linestyle='--')
        
        convergence_threshold = 0.01
        ax3.axhline(y=convergence_threshold, color='green', linestyle='--', 
                   label=f'Converged < {convergence_threshold}', linewidth=2)
        ax3.legend(fontsize=10)
    else:
        ax3.text(0.5, 0.5, 'Not enough data for stability analysis',
                transform=ax3.transAxes, ha='center', va='center')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"üíæ Saved: {save_path}")
    
    plt.show()

# ============================================
# RUN ANALYSIS
# ============================================

# Main overfitting detection
print("üìä Generating overfitting analysis...")
status = plot_overfitting_detection(df, train_box, val_box, train_cls, val_cls,
                                    map50, map50_95, 
                                    save_path=results_dir / 'overfitting_analysis.png')

# Learning curves
print("\nüìä Generating learning curves...")
plot_learning_curves(df, train_box, save_path=results_dir / 'learning_curves.png')

# ============================================
# DETAILED DIAGNOSIS REPORT (IMPROVED)
# ============================================
print("\n" + "=" * 80)
print("üìã DETAILED DIAGNOSIS REPORT")
print("=" * 80)

# Get key metrics safely
final_train_loss = train_box.iloc[-1]
final_map50 = map50.iloc[-1]
final_precision = precision.iloc[-1] if precision is not None else None
final_recall = recall.iloc[-1] if recall is not None else None

print(f"\n1Ô∏è‚É£ TRAINING METRICS:")
print(f"   ‚Ä¢ Final Train Loss: {final_train_loss:.4f}")
print(f"   ‚Ä¢ Final mAP50: {final_map50:.3f}")
if final_precision is not None:
    print(f"   ‚Ä¢ Final Precision: {final_precision:.3f}")
if final_recall is not None:
    print(f"   ‚Ä¢ Final Recall: {final_recall:.3f}")

# Check validation loss if available
if val_box is not None:
    final_val_loss = val_box.iloc[-1]
    gap = final_val_loss - final_train_loss
    print(f"   ‚Ä¢ Final Val Loss: {final_val_loss:.4f}")
    print(f"   ‚Ä¢ Train-Val Gap: {gap:.4f}")

# Trend analysis
print(f"\n2Ô∏è‚É£ TREND ANALYSIS:")
recent_epochs = min(20, len(df))
if len(df) >= recent_epochs:
    map_start = map50.iloc[-recent_epochs]
    map_end = map50.iloc[-1]
    map_improvement = map_end - map_start
    
    if map_improvement > 0.05:
        print(f"   ‚úÖ Still improving (+{map_improvement:.3f} in last {recent_epochs} epochs)")
        print(f"   üí° Continue training for better results")
    elif map_improvement > 0:
        print(f"   üìä Slow improvement (+{map_improvement:.3f} in last {recent_epochs} epochs)")
        print(f"   üí° May be close to optimal")
    else:
        print(f"   ‚ö†Ô∏è Not improving ({map_improvement:.3f} in last {recent_epochs} epochs)")
        print(f"   üí° Consider stopping or adjusting")

# Final recommendations
print(f"\n3Ô∏è‚É£ ACTIONABLE RECOMMENDATIONS:")

recommendations = []

# Based on mAP
if final_map50 < 0.5:
    recommendations.append("‚ùå Low accuracy - Need improvements:")
    recommendations.append("   ‚Ä¢ Add more diverse training data")
    recommendations.append("   ‚Ä¢ Use larger model (yolov8m/l)")
    recommendations.append("   ‚Ä¢ Check data quality and labels")
elif final_map50 < 0.7:
    recommendations.append("‚ö†Ô∏è Moderate accuracy - Room for improvement:")
    recommendations.append("   ‚Ä¢ Increase epochs")
    recommendations.append("   ‚Ä¢ Fine-tune augmentation")
    recommendations.append("   ‚Ä¢ Verify label quality")
else:
    recommendations.append("‚úÖ Good accuracy - Model performing well!")
    recommendations.append("   ‚Ä¢ Ready for testing")
    recommendations.append("   ‚Ä¢ Fine-tune confidence threshold")

# Based on precision-recall balance
if final_precision is not None and final_precision < 0.6:
    recommendations.append("\n‚ö†Ô∏è Low precision - Too many false positives:")
    recommendations.append("   ‚Ä¢ Increase confidence threshold (conf=0.3-0.4)")

if final_recall is not None and final_recall < 0.6:
    recommendations.append("\n‚ö†Ô∏è Low recall - Missing detections:")
    recommendations.append("   ‚Ä¢ Decrease confidence threshold (conf=0.15-0.25)")

for rec in recommendations:
    print(f"   {rec}")

print("\n" + "=" * 80)
print("‚úÖ Analysis complete! Check generated images:")
print(f"   üìÅ {results_dir}")
print("=" * 80)

In [None]:
# ============================================
# Cell 10: Enhanced Prediction Visualization (OPTIMIZED)
# ============================================
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from pathlib import Path
import random
import glob
from ultralytics import YOLO

print("üîç ENHANCED PREDICTION VISUALIZATION")
print("=" * 80)

# ============================================
# SMART MODEL PATH FINDING
# ============================================
def find_trained_model():
    """Auto-detect trained model - tries multiple paths"""
    possible_paths = [
        'runs/detect/banana_pest_disease_yolo11/weights/best.pt',
        'runs/detect/banana_pest_disease_yolo12/weights/best.pt',
        'runs/detect/banana_pest_disease_yolo11_simple/weights/best.pt',
    ]
    
    # Search dynamically
    search_patterns = [
        'runs/detect/*/weights/best.pt',
        'runs/detect/*/*/weights/best.pt',
    ]
    
    # Try hardcoded paths first
    for path in possible_paths:
        if Path(path).exists():
            return Path(path)
    
    # Search dynamically
    for pattern in search_patterns:
        matches = glob.glob(pattern)
        if matches:
            # Get most recent
            return Path(max(matches, key=lambda p: Path(p).stat().st_mtime))
    
    # Try last.pt as fallback
    for path in possible_paths:
        last_path = path.replace('best.pt', 'last.pt')
        if Path(last_path).exists():
            print(f"‚ö†Ô∏è Using last.pt instead of best.pt")
            return Path(last_path)
    
    return None

# ============================================
# SMART TEST IMAGE PATH FINDING
# ============================================
def find_test_images_dir():
    """Auto-detect test images directory"""
    # Try from data_config if available
    if 'data_config' in globals():
        test_path = data_config.get('test', '')
        if test_path:
            test_dir = Path(test_path)
            # Try different structures
            possible_dirs = [
                test_dir,  # Direct path
                test_dir.parent / 'images',  # Parent/images
                test_dir / 'images',  # test/images
            ]
            
            for dir_path in possible_dirs:
                if dir_path.exists() and dir_path.is_dir():
                    # Check if has images
                    images = list(dir_path.glob('*.jpg')) + list(dir_path.glob('*.png'))
                    if len(images) > 0:
                        return dir_path
    
    # Fallback: search common locations
    search_paths = [
        Path('kaggle/working/yolo_classification_dataset/test/images'),
        Path('yolo_classification_dataset/test/images'),
        Path('dataset/test/images'),
        Path('test/images'),
    ]
    
    for path in search_paths:
        if path.exists():
            return path
    
    return None

# ============================================
# ENHANCED VISUALIZATION FUNCTION
# ============================================
def visualize_enhanced_predictions(model, test_images_dir, class_names_map, 
                                   num_samples=50, conf_threshold=0.2):
    """
    Visualize model predictions with confidence-based coloring
    
    Color Coding:
    - GREEN (>70%): High confidence - Clear disease detection
    - YELLOW (50-70%): Medium confidence - Likely disease
    - ORANGE (30-50%): Low confidence - Possible disease
    - RED (<30%): Very low confidence - Uncertain
    """
    
    test_images_dir = Path(test_images_dir)
    
    if not test_images_dir.exists():
        print(f"‚ùå Directory not found: {test_images_dir}")
        print(f"   Absolute path: {test_images_dir.absolute()}")
        return None
    
    # Find all test images
    image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
    test_images = []
    for ext in image_extensions:
        test_images.extend(list(test_images_dir.glob(ext)))
    
    if len(test_images) == 0:
        print(f"‚ùå No images found in {test_images_dir}")
        print("üí° Check if path is correct")
        return None
    
    print(f"üì∏ Found {len(test_images)} test images")
    
    # Sample images
    num_samples = min(num_samples, len(test_images))
    sample_images = random.sample(test_images, num_samples)
    print(f"üé≤ Randomly selected {num_samples} images for visualization\n")
    
    # ============================================
    # AUTO-GRID LAYOUT
    # ============================================
    cols = 4  # Images per row
    rows = int(np.ceil(num_samples / cols))
    
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
    
    # Handle different axis shapes
    if rows == 1 and cols == 1:
        axes = np.array([axes])
    elif rows == 1:
        axes = axes.reshape(-1)
    elif cols == 1:
        axes = axes.reshape(-1)
    else:
        axes = axes.ravel()
    
    # ============================================
    # CONFIDENCE COLOR MAPPING (RGB for matplotlib)
    # ============================================
    confidence_colors = {
        'high': (0, 1, 0),        # Green: >70%
        'medium': (1, 1, 0),      # Yellow: 50-70%
        'low': (1, 0.65, 0),      # Orange: 30-50%
        'very_low': (1, 0, 0)     # Red: <30%
    }
    
    # For OpenCV (BGR)
    confidence_colors_cv = {
        'high': (0, 255, 0),
        'medium': (0, 255, 255),    # Yellow in BGR
        'low': (0, 165, 255),       # Orange in BGR
        'very_low': (0, 0, 255)
    }
    
    # ============================================
    # PROCESS EACH IMAGE
    # ============================================
    all_detections = []
    processed_count = 0
    
    print("üîÑ Processing images...")
    for idx, img_path in enumerate(sample_images):
        try:
            # Get model predictions
            results = model.predict(str(img_path), conf=conf_threshold, verbose=False)
            
            # Read and prepare image
            img = cv2.imread(str(img_path))
            if img is None:
                print(f"   ‚ö†Ô∏è Could not read: {img_path.name}")
                axes[idx].text(0.5, 0.5, 'Image Error', ha='center', va='center',
                             fontsize=10, color='red')
                axes[idx].axis('off')
                all_detections.append({'image': img_path.name, 'count': 0, 'details': []})
                continue
            
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            original_img = img.copy()
            
            detections_count = 0
            detection_details = []
            
            # Process detections
            for result in results:
                boxes = result.boxes
                if boxes is not None and len(boxes) > 0:
                    for box in boxes:
                        confidence = float(box.conf.item())
                        class_id = int(box.cls.item())
                        class_name = class_names_map.get(class_id, f'Class_{class_id}')
                        
                        # Determine color based on confidence
                        if confidence > 0.7:
                            color_cv = confidence_colors_cv['high']
                            conf_label = 'HIGH'
                        elif confidence > 0.5:
                            color_cv = confidence_colors_cv['medium']
                            conf_label = 'MEDIUM'
                        elif confidence > 0.3:
                            color_cv = confidence_colors_cv['low']
                            conf_label = 'LOW'
                        else:
                            color_cv = confidence_colors_cv['very_low']
                            conf_label = 'VERY LOW'
                        
                        # Draw bounding box
                        x1, y1, x2, y2 = map(int, box.xyxy[0])
                        cv2.rectangle(img, (x1, y1), (x2, y2), color_cv, 3)
                        
                        # Prepare label
                        label = f"{class_name}: {confidence:.2f}"
                        
                        # Draw label background
                        font = cv2.FONT_HERSHEY_SIMPLEX
                        font_scale = 0.6
                        thickness = 2
                        (text_width, text_height), baseline = cv2.getTextSize(
                            label, font, font_scale, thickness
                        )
                        
                        # Background rectangle
                        cv2.rectangle(
                            img, 
                            (x1, y1 - text_height - 10), 
                            (x1 + text_width + 10, y1), 
                            color_cv, 
                            -1
                        )
                        
                        # Label text
                        cv2.putText(
                            img, 
                            label, 
                            (x1 + 5, y1 - 5), 
                            font, 
                            font_scale, 
                            (0, 0, 0),  # Black text
                            thickness
                        )
                        
                        detections_count += 1
                        detection_details.append({
                            'class': class_name,
                            'confidence': confidence,
                            'conf_level': conf_label,
                            'bbox': [x1, y1, x2, y2]
                        })
            
            # Display image
            axes[idx].imshow(img)
            
            # Create title with detection info
            title_max_len = 35
            img_name = img_path.name[:title_max_len] + ('...' if len(img_path.name) > title_max_len else '')
            
            if detections_count > 0:
                title = f'{img_name}\n‚úÖ {detections_count} detection(s)'
                axes[idx].set_title(title, fontsize=9, color='green', fontweight='bold')
            else:
                title = f'{img_name}\n‚ùå No detections'
                axes[idx].set_title(title, fontsize=9, color='red')
            
            axes[idx].axis('off')
            
            # Store detection info
            all_detections.append({
                'image': img_path.name,
                'count': detections_count,
                'details': detection_details
            })
            
            processed_count += 1
            if (idx + 1) % 10 == 0:
                print(f"   Processed {idx + 1}/{num_samples} images...")
            
        except Exception as e:
            print(f"   ‚ùå Error processing {img_path.name}: {str(e)[:50]}")
            axes[idx].text(0.5, 0.5, f'Error', 
                          ha='center', va='center', fontsize=8, color='red')
            axes[idx].axis('off')
            all_detections.append({'image': img_path.name, 'count': 0, 'details': []})
    
    print(f"‚úÖ Processed {processed_count}/{num_samples} images\n")
    
    # ============================================
    # HIDE UNUSED SUBPLOTS
    # ============================================
    for idx in range(num_samples, len(axes)):
        axes[idx].axis('off')
    
    # ============================================
    # ADD COLOR LEGEND
    # ============================================
    legend_elements = [
        mpatches.Patch(facecolor=confidence_colors['high'], 
                      edgecolor='black', linewidth=1,
                      label='High Confidence (>70%)'),
        mpatches.Patch(facecolor=confidence_colors['medium'], 
                      edgecolor='black', linewidth=1,
                      label='Medium Confidence (50-70%)'),
        mpatches.Patch(facecolor=confidence_colors['low'], 
                      edgecolor='black', linewidth=1,
                      label='Low Confidence (30-50%)'),
        mpatches.Patch(facecolor=confidence_colors['very_low'], 
                      edgecolor='black', linewidth=1,
                      label='Very Low Confidence (<30%)')
    ]
    
    fig.legend(
        handles=legend_elements,
        loc='upper center',
        bbox_to_anchor=(0.5, 0.98),
        ncol=4,
        fontsize=11,
        frameon=True,
        fancybox=True,
        shadow=True,
        framealpha=0.9
    )
    
    plt.suptitle(
        'üåø Banana Disease Detection - Enhanced Predictions with Confidence Levels',
        fontsize=16,
        fontweight='bold',
        y=0.995
    )
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    # Save figure
    output_path = 'enhanced_predictions.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"üíæ Predictions saved as '{output_path}'")
    
    plt.show()
    
    # ============================================
    # DETECTION SUMMARY (ENHANCED)
    # ============================================
    print("\n" + "=" * 80)
    print("üìä DETECTION SUMMARY")
    print("=" * 80)
    
    total_detections = sum(d['count'] for d in all_detections)
    images_with_detections = sum(1 for d in all_detections if d['count'] > 0)
    images_without_detections = len(all_detections) - images_with_detections
    
    print(f"\nüìà OVERALL STATISTICS:")
    print(f"   ‚Ä¢ Total images analyzed: {len(all_detections)}")
    print(f"   ‚Ä¢ Images with detections: {images_with_detections} ({images_with_detections/len(all_detections)*100:.1f}%)")
    print(f"   ‚Ä¢ Images without detections: {images_without_detections} ({images_without_detections/len(all_detections)*100:.1f}%)")
    print(f"   ‚Ä¢ Total detections: {total_detections}")
    if len(all_detections) > 0:
        print(f"   ‚Ä¢ Average detections per image: {total_detections/len(all_detections):.2f}")
    
    # Count confidence levels
    confidence_counts = {'HIGH': 0, 'MEDIUM': 0, 'LOW': 0, 'VERY LOW': 0}
    class_counts = {}
    confidence_values = []
    
    for detection in all_detections:
        for detail in detection['details']:
            confidence_counts[detail['conf_level']] += 1
            class_name = detail['class']
            class_counts[class_name] = class_counts.get(class_name, 0) + 1
            confidence_values.append(detail['confidence'])
    
    if total_detections > 0:
        print(f"\nüéØ CONFIDENCE DISTRIBUTION:")
        for level, count in confidence_counts.items():
            percentage = (count / total_detections) * 100
            bar = '‚ñà' * int(percentage / 2)  # Visual bar
            print(f"   ‚Ä¢ {level:12s}: {count:3d} ({percentage:5.1f}%) {bar}")
        
        if confidence_values:
            avg_conf = np.mean(confidence_values)
            print(f"\n   ‚Ä¢ Average Confidence: {avg_conf:.3f}")
        
        print(f"\nüè∑Ô∏è CLASS DISTRIBUTION:")
        sorted_classes = sorted(class_counts.items(), key=lambda x: x[1], reverse=True)
        for class_name, count in sorted_classes:
            percentage = (count / total_detections) * 100
            bar = '‚ñà' * int(percentage / 2)
            print(f"   ‚Ä¢ {class_name:20s}: {count:3d} ({percentage:5.1f}%) {bar}")
    
    # Top detections
    if all_detections:
        top_detections = sorted(all_detections, key=lambda x: x['count'], reverse=True)[:5]
        print(f"\nüîù TOP 5 IMAGES WITH MOST DETECTIONS:")
        for i, det in enumerate(top_detections, 1):
            print(f"   {i}. {det['image'][:40]:40s} - {det['count']} detection(s)")
    
    print("\n" + "=" * 80)
    
    return all_detections

# ============================================
# RUN VISUALIZATION
# ============================================
print("\nüì¶ Loading trained model...")

try:
    # Find model
    model_path = find_trained_model()
    
    if model_path is None:
        print("‚ùå No trained model found!")
        print("üí° Train the model first using Cell 8")
        print("\nüí° Searched in:")
        print("   ‚Ä¢ runs/detect/banana_pest_disease_yolo11/weights/")
        print("   ‚Ä¢ runs/detect/banana_pest_disease_yolo12/weights/")
        print("   ‚Ä¢ runs/detect/*/weights/")
    else:
        model = YOLO(str(model_path))
        print(f"‚úÖ Model loaded from: {model_path}")
        
        # Get class names
        if 'data_config' in globals():
            class_names_map = data_config.get('names', {})
        else:
            print("‚ö†Ô∏è data_config not found, using default class names")
            class_names_map = {i: f'Class_{i}' for i in range(7)}
        
        print(f"üìã Classes: {len(class_names_map)}")
        for class_id, class_name in sorted(class_names_map.items()):
            print(f"   ‚Ä¢ {class_id}: {class_name}")
        
        # Find test images directory
        test_images_dir = find_test_images_dir()
        
        if test_images_dir is None:
            print("\n‚ùå Test images directory not found!")
            print("\nüí° Possible fixes:")
            print("   1. Check if test images are in the correct location")
            print("   2. Verify your dataset structure")
            print("   3. Update data_config['test'] path")
            
            # Try manual path
            manual_path = input("\nüí° Enter test images path (or press Enter to skip): ").strip()
            if manual_path:
                test_images_dir = Path(manual_path)
            else:
                raise FileNotFoundError("Test images directory not found")
        else:
            print(f"üìÅ Test images directory: {test_images_dir}")
        
        # Run visualization
        print("\n" + "=" * 80)
        detection_results = visualize_enhanced_predictions(
            model=model,
            test_images_dir=test_images_dir,
            class_names_map=class_names_map,
            num_samples=50,  # Change this to show more/less images
            conf_threshold=0.2  # Minimum confidence threshold
        )
        
        if detection_results:
            print("\n‚úÖ Enhanced visualization completed!")
        else:
            print("\n‚ö†Ô∏è Visualization completed with errors")
            
except FileNotFoundError as e:
    print(f"\n‚ùå Error: {e}")
except Exception as e:
    print(f"\n‚ùå Unexpected error: {e}")
    import traceback
    traceback.print_exc()