# YOLOv8 Model Comparison: Original vs Pruned

This notebook provides a comprehensive side-by-side comparison between the original YOLOv8n model and the pruned version.

**Pruning Results:**
- **49% channels pruned** (1,547 out of 5,296)
- **4.8% model size reduction**
- **Maintained architecture** with optimized weights

**Comparisons Include:**
- Model specifications and sizes
- Inference speed benchmarks
- Detection accuracy on sample image
- Performance metrics visualization

In [1]:
# Import required libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from ultralytics import YOLO
import time
import os
from pathlib import Path
from PIL import Image
import cv2

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

print("Libraries loaded successfully!")

# GPU Information
print("\n=== GPU Information ===")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device.upper()}")

Libraries loaded successfully!

=== GPU Information ===
PyTorch version: 2.8.0+cu126
CUDA available: True
GPU: NVIDIA GeForce MX570
GPU Memory: 2.0 GB
Using device: CUDA


In [2]:
# Load models for comparison
print("Loading models for comparison...")

# Original model
original_model = YOLO('yolov8-prune/yolov8n.pt')
print("✓ Original model loaded")

# Pruned model
pruned_model = YOLO('pruned_yolo_model.pt')
print("✓ Pruned model loaded")

# Fine-tuned model
finetuned_model = YOLO('finetune_results/pruned_finetune4/weights/best.pt')
print("✓ Fine-tuned model loaded")

# Model specifications
def get_model_info(model, name):
    params = sum(p.numel() for p in model.parameters())
    size_mb = os.path.getsize(model.model_path) / (1024 * 1024)
    return {
        'name': name,
        'parameters': params,
        'size_mb': size_mb,
        'path': str(model.model_path)
    }

models_info = [
    get_model_info(original_model, 'Original'),
    get_model_info(pruned_model, 'Pruned'),
    get_model_info(finetuned_model, 'Fine-tuned')
]

print("\n=== Model Specifications ===")
for info in models_info:
    print(f"{info['name']}: {info['parameters']:,} params, {info['size_mb']:.2f} MB")

Loading models for comparison...
✓ Original model loaded
✓ Pruned model loaded
✓ Fine-tuned model loaded


AttributeError: 'DetectionModel' object has no attribute 'model_path'

In [None]:
# Model size comparison visualization
plt.figure(figsize=(12, 6))

# Size comparison
plt.subplot(1, 2, 1)
sizes = [info['size_mb'] for info in models_info]
names = [info['name'] for info in models_info]
bars = plt.bar(names, sizes, color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
plt.title('Model Size Comparison', fontsize=14, fontweight='bold')
plt.ylabel('Size (MB)')
plt.grid(True, alpha=0.3)

# Add value labels
for bar, size in zip(bars, sizes):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{size:.2f} MB', ha='center', va='bottom')

# Parameter comparison
plt.subplot(1, 2, 2)
params = [info['parameters']/1e6 for info in models_info]  # Convert to millions
bars = plt.bar(names, params, color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
plt.title('Parameter Count Comparison', fontsize=14, fontweight='bold')
plt.ylabel('Parameters (Millions)')
plt.grid(True, alpha=0.3)

# Add value labels
for bar, param in zip(bars, params):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{param:.1f}M', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Compression metrics
orig_size = models_info[0]['size_mb']
pruned_size = models_info[1]['size_mb']
ft_size = models_info[2]['size_mb']

print("\n=== Compression Results ===")
print(f"Pruned vs Original: {(pruned_size/orig_size - 1)*100:+.1f}% size change")
print(f"Fine-tuned vs Original: {(ft_size/orig_size - 1)*100:+.1f}% size change")
print(f"Fine-tuned vs Pruned: {(ft_size/pruned_size - 1)*100:+.1f}% size change")

In [None]:
# Inference speed comparison
def benchmark_inference_speed(model, model_name, num_runs=50):
    """Benchmark inference speed for a model"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Create dummy input
    dummy_input = torch.randn(1, 3, 640, 640).to(device)
    model.model.to(device)
    model.model.eval()
    
    # Warm up
    with torch.no_grad():
        for _ in range(5):
            _ = model.model(dummy_input)
    
    # Time inference
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            start = time.time()
            _ = model.model(dummy_input)
            torch.cuda.synchronize() if device == 'cuda' else None
            end = time.time()
            times.append((end - start) * 1000)  # ms
    
    avg_time = np.mean(times)
    std_time = np.std(times)
    fps = 1000 / avg_time
    
    print(f"{model_name} ({device.upper()}):")
    print(f"  Average: {avg_time:.2f} ms")
    print(f"  Std Dev: {std_time:.2f} ms")
    print(f"  FPS: {fps:.1f}")
    
    return times

print("=== Inference Speed Benchmark ===")
print("Testing with 640x640 input on", "GPU" if torch.cuda.is_available() else "CPU")
print()

# Benchmark all models
original_times = benchmark_inference_speed(original_model, "Original YOLOv8n")
print()
pruned_times = benchmark_inference_speed(pruned_model, "Pruned YOLOv8n")
print()
finetuned_times = benchmark_inference_speed(finetuned_model, "Fine-tuned YOLOv8n")

# Calculate speedups
pruned_speedup = np.mean(original_times) / np.mean(pruned_times)
finetuned_speedup = np.mean(original_times) / np.mean(finetuned_times)

print(f"\nSpeedup Results:")
print(f"Pruned vs Original: {pruned_speedup:.2f}x faster")
print(f"Fine-tuned vs Original: {finetuned_speedup:.2f}x faster")

In [None]:
# Visualize inference speed comparison
plt.figure(figsize=(14, 6))

# Box plot comparison
plt.subplot(1, 3, 1)
all_times = [original_times, pruned_times, finetuned_times]
labels = ['Original', 'Pruned', 'Fine-tuned']
plt.boxplot(all_times, labels=labels)
plt.title('Inference Time Distribution', fontsize=12, fontweight='bold')
plt.ylabel('Time (ms)')
plt.grid(True, alpha=0.3)

# Average time comparison
plt.subplot(1, 3, 2)
avg_times = [np.mean(t) for t in all_times]
bars = plt.bar(labels, avg_times, color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
plt.title('Average Inference Time', fontsize=12, fontweight='bold')
plt.ylabel('Time (ms)')
plt.grid(True, alpha=0.3)

# Add value labels
for bar, time in zip(bars, avg_times):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
             f'{time:.1f}ms', ha='center', va='bottom')

# FPS comparison
plt.subplot(1, 3, 3)
fps_values = [1000/np.mean(t) for t in all_times]
bars = plt.bar(labels, fps_values, color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
plt.title('Frames Per Second (FPS)', fontsize=12, fontweight='bold')
plt.ylabel('FPS')
plt.grid(True, alpha=0.3)

# Add value labels
for bar, fps in zip(bars, fps_values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
             f'{fps:.1f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Performance summary
print("\n=== Performance Summary ===")
for i, (label, times) in enumerate(zip(labels, all_times)):
    avg_time = np.mean(times)
    fps = 1000 / avg_time
    print(f"{label}: {avg_time:.2f}ms ({fps:.1f} FPS)")

In [None]:
# Sample image detection comparison
print("=== Sample Image Detection ===")
print("Using TRAFFIC.jpeg for detection comparison")

# Check if image exists
image_path = 'TRAFFIC.jpeg'
if not os.path.exists(image_path):
    print(f"❌ Image {image_path} not found!")
    print("Please ensure TRAFFIC.jpeg is in the current directory")
else:
    print(f"✓ Found image: {image_path}")
    
    # Load and display original image
    img = Image.open(image_path)
    plt.figure(figsize=(8, 6))
    plt.imshow(img)
    plt.title('Original Image: TRAFFIC.jpeg', fontsize=14, fontweight='bold')
    plt.axis('off')
    plt.show()
    
    print(f"Image size: {img.size}")
    print(f"Image mode: {img.mode}")

In [None]:
# Run detection on original model
if os.path.exists('TRAFFIC.jpeg'):
    print("\n--- Original Model Detection ---")
    
    # Time the detection
    start_time = time.time()
    orig_results = original_model.predict(
        'TRAFFIC.jpeg',
        save=True,
        project='detection_comparison',
        name='original',
        exist_ok=True,
        device=device
    )
    orig_detection_time = time.time() - start_time
    
    print(f"Original model detection time: {orig_detection_time:.3f} seconds")
    
    # Extract detection results
    orig_detections = orig_results[0]
    print(f"Detections found: {len(orig_detections.boxes)}")
    
    if len(orig_detections.boxes) > 0:
        print("\nDetected objects:")
        for i, box in enumerate(orig_detections.boxes):
            cls_id = int(box.cls.item())
            conf = box.conf.item()
            class_name = original_model.names[cls_id]
            print(f"  {i+1}. {class_name} (confidence: {conf:.3f})")
    
    # Display result image
    result_img_path = 'detection_comparison/original/TRAFFIC.jpeg'
    if os.path.exists(result_img_path):
        result_img = Image.open(result_img_path)
        plt.figure(figsize=(12, 8))
        plt.imshow(result_img)
        plt.title(f'Original Model Detection\n{len(orig_detections.boxes)} objects detected', 
                 fontsize=14, fontweight='bold')
        plt.axis('off')
        plt.show()
    else:
        print("Result image not found")

In [None]:
# Run detection on pruned model
if os.path.exists('TRAFFIC.jpeg'):
    print("\n--- Pruned Model Detection ---")
    
    # Time the detection
    start_time = time.time()
    pruned_results = pruned_model.predict(
        'TRAFFIC.jpeg',
        save=True,
        project='detection_comparison',
        name='pruned',
        exist_ok=True,
        device=device
    )
    pruned_detection_time = time.time() - start_time
    
    print(f"Pruned model detection time: {pruned_detection_time:.3f} seconds")
    
    # Extract detection results
    pruned_detections = pruned_results[0]
    print(f"Detections found: {len(pruned_detections.boxes)}")
    
    if len(pruned_detections.boxes) > 0:
        print("\nDetected objects:")
        for i, box in enumerate(pruned_detections.boxes):
            cls_id = int(box.cls.item())
            conf = box.conf.item()
            class_name = pruned_model.names[cls_id]
            print(f"  {i+1}. {class_name} (confidence: {conf:.3f})")
    
    # Display result image
    result_img_path = 'detection_comparison/pruned/TRAFFIC.jpeg'
    if os.path.exists(result_img_path):
        result_img = Image.open(result_img_path)
        plt.figure(figsize=(12, 8))
        plt.imshow(result_img)
        plt.title(f'Pruned Model Detection\n{len(pruned_detections.boxes)} objects detected', 
                 fontsize=14, fontweight='bold')
        plt.axis('off')
        plt.show()
    else:
        print("Result image not found")

In [None]:
# Run detection on fine-tuned model
if os.path.exists('TRAFFIC.jpeg'):
    print("\n--- Fine-tuned Model Detection ---")
    
    # Time the detection
    start_time = time.time()
    ft_results = finetuned_model.predict(
        'TRAFFIC.jpeg',
        save=True,
        project='detection_comparison',
        name='finetuned',
        exist_ok=True,
        device=device
    )
    ft_detection_time = time.time() - start_time
    
    print(f"Fine-tuned model detection time: {ft_detection_time:.3f} seconds")
    
    # Extract detection results
    ft_detections = ft_results[0]
    print(f"Detections found: {len(ft_detections.boxes)}")
    
    if len(ft_detections.boxes) > 0:
        print("\nDetected objects:")
        for i, box in enumerate(ft_detections.boxes):
            cls_id = int(box.cls.item())
            conf = box.conf.item()
            class_name = finetuned_model.names[cls_id]
            print(f"  {i+1}. {class_name} (confidence: {conf:.3f})")
    
    # Display result image
    result_img_path = 'detection_comparison/finetuned/TRAFFIC.jpeg'
    if os.path.exists(result_img_path):
        result_img = Image.open(result_img_path)
        plt.figure(figsize=(12, 8))
        plt.imshow(result_img)
        plt.title(f'Fine-tuned Model Detection\n{len(ft_detections.boxes)} objects detected', 
                 fontsize=14, fontweight='bold')
        plt.axis('off')
        plt.show()
    else:
        print("Result image not found")

In [None]:
# Detection comparison summary
if os.path.exists('TRAFFIC.jpeg'):
    print("\n=== Detection Comparison Summary ===")
    
    # Detection counts
    orig_count = len(orig_detections.boxes) if 'orig_detections' in locals() else 0
    pruned_count = len(pruned_detections.boxes) if 'pruned_detections' in locals() else 0
    ft_count = len(ft_detections.boxes) if 'ft_detections' in locals() else 0
    
    # Timing
    orig_time = orig_detection_time if 'orig_detection_time' in locals() else 0
    pruned_time = pruned_detection_time if 'pruned_detection_time' in locals() else 0
    ft_time = ft_detection_time if 'ft_detection_time' in locals() else 0
    
    print("Model\t\tDetections\tTime (s)\t\tFPS")
    print("-" * 50)
    print(f"Original\t{orig_count}\t\t{orig_time:.3f}\t\t{1/orig_time:.2f}")
    print(f"Pruned\t\t{pruned_count}\t\t{pruned_time:.3f}\t\t{1/pruned_time:.2f}")
    print(f"Fine-tuned\t{ft_count}\t\t{ft_time:.3f}\t\t{1/ft_time:.2f}")
    
    # Speed comparison
    if orig_time > 0 and pruned_time > 0:
        speedup = orig_time / pruned_time
        print(f"\nPruned model is {speedup:.2f}x faster than original!")
    
    if orig_time > 0 and ft_time > 0:
        speedup_ft = orig_time / ft_time
        print(f"Fine-tuned model is {speedup_ft:.2f}x faster than original!")
    
    # Detection accuracy comparison
    if orig_count > 0:
        pruned_acc = (pruned_count / orig_count) * 100
        ft_acc = (ft_count / orig_count) * 100
        print(f"\nDetection count vs Original:")
        print(f"Pruned: {pruned_acc:.1f}% of original detections")
        print(f"Fine-tuned: {ft_acc:.1f}% of original detections")

In [None]:
# Final comprehensive comparison visualization
plt.figure(figsize=(16, 10))
fig.suptitle('YOLOv8 Model Comparison: Original vs Pruned vs Fine-tuned', 
             fontsize=16, fontweight='bold')

# Model sizes
plt.subplot(2, 3, 1)
sizes = [info['size_mb'] for info in models_info]
bars = plt.bar([info['name'] for info in models_info], sizes, 
               color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
plt.title('Model Size (MB)', fontsize=12, fontweight='bold')
plt.grid(True, alpha=0.3)
for bar, size in zip(bars, sizes):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{size:.2f}', ha='center', va='bottom')

# Parameters
plt.subplot(2, 3, 2)
params = [info['parameters']/1e6 for info in models_info]
bars = plt.bar([info['name'] for info in models_info], params, 
               color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
plt.title('Parameters (Millions)', fontsize=12, fontweight='bold')
plt.grid(True, alpha=0.3)
for bar, param in zip(bars, params):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{param:.1f}', ha='center', va='bottom')

# Inference speed
plt.subplot(2, 3, 3)
avg_times = [np.mean(original_times), np.mean(pruned_times), np.mean(finetuned_times)]
bars = plt.bar(['Original', 'Pruned', 'Fine-tuned'], avg_times, 
               color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
plt.title('Inference Time (ms)', fontsize=12, fontweight='bold')
plt.grid(True, alpha=0.3)
for bar, time in zip(bars, avg_times):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
             f'{time:.1f}', ha='center', va='bottom')

# Detection counts (if available)
plt.subplot(2, 3, 4)
if 'orig_detections' in locals() and 'pruned_detections' in locals() and 'ft_detections' in locals():
    counts = [len(orig_detections.boxes), len(pruned_detections.boxes), len(ft_detections.boxes)]
    bars = plt.bar(['Original', 'Pruned', 'Fine-tuned'], counts, 
                   color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
    plt.title('Objects Detected', fontsize=12, fontweight='bold')
    plt.grid(True, alpha=0.3)
    for bar, count in zip(bars, counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
                 f'{int(count)}', ha='center', va='bottom')
else:
    plt.text(0.5, 0.5, 'Detection data\nnot available', 
             ha='center', va='center', transform=plt.gca().transAxes)
    plt.title('Objects Detected', fontsize=12, fontweight='bold')

# Detection timing (if available)
plt.subplot(2, 3, 5)
if 'orig_detection_time' in locals() and 'pruned_detection_time' in locals() and 'ft_detection_time' in locals():
    times = [orig_detection_time, pruned_detection_time, ft_detection_time]
    bars = plt.bar(['Original', 'Pruned', 'Fine-tuned'], times, 
                   color=['skyblue', 'lightcoral', 'lightgreen'], alpha=0.8)
    plt.title('Detection Time (s)', fontsize=12, fontweight='bold')
    plt.grid(True, alpha=0.3)
    for bar, t in zip(bars, times):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, 
                 f'{t:.3f}', ha='center', va='bottom')
else:
    plt.text(0.5, 0.5, 'Timing data\nnot available', 
             ha='center', va='center', transform=plt.gca().transAxes)
    plt.title('Detection Time (s)', fontsize=12, fontweight='bold')

# Performance summary
plt.subplot(2, 3, 6)
plt.axis('off')
summary_text = f"""YOLOv8 Pruning Results:

• 49% channels pruned
• 4.8% size reduction
• {pruned_speedup:.2f}x inference speedup
• Architecture preserved
• Ready for edge deployment

Methodology: Network Slimming"""
plt.text(0.1, 0.9, summary_text, fontsize=10, verticalalignment='top',
         bbox=dict(boxstyle="round,pad=0.3", facecolor='lightblue', alpha=0.5))

plt.tight_layout()
plt.show()

print("\n🎉 Comparison Complete!")
print("The pruned model achieves significant computational savings")
print("while maintaining detection capabilities for edge deployment.")