# YOLO11 Pose Estimation for Waste Throwing Detection

This comprehensive notebook guides you through training a YOLO11 pose estimation model to detect people throwing waste from video data.

## Table of Contents
1. [Install Required Libraries](#install)
2. [Import Libraries and Setup](#setup)
3. [Prepare Video Dataset](#prepare)
4. [Extract Frames from Videos](#extract)
5. [Annotate Data for Pose Estimation](#annotate)
6. [Convert Annotations to YOLO Format](#convert)
7. [Setup YOLO11 Pose Model](#model)
8. [Configure Training Parameters](#config)
9. [Train the Model](#train)
10. [Validate Model Performance](#validate)
11. [Test on New Video Data](#test)
12. [Visualize Pose Detection Results](#visualize)

## Project Overview
- **Goal**: Detect people throwing waste using pose estimation
- **Model**: YOLO11 with pose estimation capabilities
- **Input**: Videos of people throwing waste
- **Output**: Trained model that can detect throwing poses in real-time

## 1. Install Required Libraries and Dependencies {#install}

First, let's install all the necessary libraries for our pose estimation project.

In [None]:
# Install required packages
# Run this cell only once when setting up the environment

import subprocess
import sys

def install_package(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Core packages for YOLO11 and pose estimation
packages = [
    "ultralytics>=8.0.0",  # YOLO11 framework
    "opencv-python>=4.7.0",  # Computer vision
    "torch>=1.13.0",  # PyTorch
    "torchvision>=0.14.0",  # PyTorch vision
    "numpy>=1.21.0",  # Numerical computing
    "matplotlib>=3.5.0",  # Plotting
    "seaborn>=0.11.0",  # Statistical visualization
    "pandas>=1.4.0",  # Data manipulation
    "Pillow>=9.0.0",  # Image processing
    "tqdm>=4.64.0",  # Progress bars
    "PyYAML>=6.0",  # YAML configuration
    "scikit-learn>=1.1.0",  # Machine learning utilities
    "moviepy>=1.0.0",  # Video processing
    "wandb>=0.13.0",  # Experiment tracking (optional)
]

print("Installing required packages...")
for package in packages:
    try:
        install_package(package)
        print(f"✅ Successfully installed {package}")
    except Exception as e:
        print(f"❌ Failed to install {package}: {e}")

print("\n🎉 Installation complete! Please restart the kernel after installation.")

## 2. Import Libraries and Setup Environment {#setup}

Now let's import all necessary libraries and set up our working environment.

In [None]:
# Import essential libraries
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
import yaml
import json
from tqdm.auto import tqdm
import shutil
import random
from datetime import datetime

# Import YOLO11 and PyTorch
from ultralytics import YOLO
import torch
import torchvision

# Import scikit-learn for data splitting
from sklearn.model_selection import train_test_split

# Set up matplotlib for better plots
plt.style.use('default')
sns.set_palette("husl")

# Check if GPU is available
print("🖥️  System Information:")
print(f"PyTorch version: {torch.__version__}")
print(f"OpenCV version: {cv2.__version__}")
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name()}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Create project directory structure
project_root = Path("../")  # Assuming notebook is in notebooks/ folder
data_dir = project_root / "data"
videos_dir = data_dir / "videos"
frames_dir = data_dir / "frames"
annotations_dir = data_dir / "annotations"
models_dir = project_root / "models"
results_dir = project_root / "results"

# Create directories
for directory in [videos_dir, frames_dir, annotations_dir, models_dir, results_dir]:
    directory.mkdir(parents=True, exist_ok=True)

print("\n📁 Project Structure Created:")
print(f"Videos: {videos_dir}")
print(f"Frames: {frames_dir}")
print(f"Annotations: {annotations_dir}")
print(f"Models: {models_dir}")
print(f"Results: {results_dir}")

# Define YOLO11 pose keypoints (COCO format - 17 keypoints)
KEYPOINT_NAMES = [
    'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
    'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]

# Key keypoints for throwing detection
THROWING_KEYPOINTS = [
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip'
]

print(f"\n🎯 Pose keypoints defined: {len(KEYPOINT_NAMES)} total keypoints")
print(f"🎯 Throwing-specific keypoints: {len(THROWING_KEYPOINTS)} keypoints")

## 3. Prepare Video Dataset {#prepare}

In this section, we'll prepare your video dataset for training. Place your videos of people throwing waste in the `data/videos/` directory.

In [None]:
# Function to analyze video dataset
def analyze_video_dataset(videos_directory):
    """Analyze the video dataset and provide statistics"""
    
    video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
    video_files = []
    
    # Find all video files
    for ext in video_extensions:
        video_files.extend(list(videos_directory.glob(f'*{ext}')))
        video_files.extend(list(videos_directory.glob(f'*{ext.upper()}')))
    
    if not video_files:
        print("❌ No video files found!")
        print(f"📁 Please place your videos in: {videos_directory}")
        print(f"📹 Supported formats: {', '.join(video_extensions)}")
        return []
    
    print(f"📹 Found {len(video_files)} video files")
    print("=" * 50)
    
    total_duration = 0
    video_info = []
    
    for video_path in video_files:
        try:
            # Open video and get properties
            cap = cv2.VideoCapture(str(video_path))
            
            if not cap.isOpened():
                print(f"❌ Cannot open: {video_path.name}")
                continue
            
            fps = cap.get(cv2.CAP_PROP_FPS)
            frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            duration = frame_count / fps if fps > 0 else 0
            
            total_duration += duration
            
            video_info.append({
                'filename': video_path.name,
                'duration': duration,
                'fps': fps,
                'frames': frame_count,
                'resolution': f"{width}x{height}",
                'size_mb': video_path.stat().st_size / (1024 * 1024)
            })
            
            print(f"📄 {video_path.name}")
            print(f"   ⏱️  Duration: {duration:.1f} seconds")
            print(f"   🎬 FPS: {fps:.1f}")
            print(f"   🖼️  Resolution: {width}x{height}")
            print(f"   📊 Frames: {frame_count}")
            print(f"   💾 Size: {video_path.stat().st_size / (1024 * 1024):.1f} MB")
            print()
            
            cap.release()
            
        except Exception as e:
            print(f"❌ Error processing {video_path.name}: {e}")
    
    print("📊 Dataset Summary:")
    print(f"   🎬 Total videos: {len(video_info)}")
    print(f"   ⏱️  Total duration: {total_duration:.1f} seconds ({total_duration/60:.1f} minutes)")
    print(f"   📊 Average duration: {total_duration/len(video_info):.1f} seconds")
    print(f"   💾 Total size: {sum(info['size_mb'] for info in video_info):.1f} MB")
    
    return video_info

# Analyze the video dataset
print("🔍 Analyzing your video dataset...")
video_info = analyze_video_dataset(videos_dir)

# Recommendations based on dataset
if video_info:
    total_duration = sum(info['duration'] for info in video_info)
    
    print("\n💡 Recommendations:")
    if total_duration < 300:  # Less than 5 minutes
        print("⚠️  Small dataset detected (< 5 minutes total)")
        print("   - Consider adding more videos for better training")
        print("   - Use data augmentation strategies")
        print("   - Extract frames at higher frequency")
    elif total_duration < 1800:  # Less than 30 minutes
        print("✅ Moderate dataset size (5-30 minutes)")
        print("   - Good for initial training")
        print("   - Consider fine-tuning strategies")
    else:
        print("✅ Large dataset detected (> 30 minutes)")
        print("   - Excellent for robust training")
        print("   - Can use standard training procedures")
else:
    print("\n📝 Next Steps:")
    print("1. Add your video files to the data/videos/ directory")
    print("2. Supported formats: .mp4, .avi, .mov, .mkv, .wmv, .flv")
    print("3. Re-run this cell to analyze your dataset")

## 4. Extract Frames from Videos {#extract}

Now we'll extract frames from your videos to create training data. We'll use intelligent extraction strategies to get the best frames for annotation.

In [None]:
class FrameExtractor:
    """Advanced frame extraction with multiple strategies"""
    
    def __init__(self, output_dir, strategy='uniform', fps_target=2):
        self.output_dir = Path(output_dir)
        self.strategy = strategy
        self.fps_target = fps_target
        self.metadata = []
        
    def extract_uniform(self, video_path, video_name):
        """Extract frames at uniform intervals"""
        cap = cv2.VideoCapture(str(video_path))
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_interval = max(1, int(fps / self.fps_target))
        
        video_output_dir = self.output_dir / video_name
        video_output_dir.mkdir(exist_ok=True)
        
        frame_count = 0
        extracted_count = 0
        
        pbar = tqdm(desc=f"Extracting from {video_name}")
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
                
            if frame_count % frame_interval == 0:
                frame_filename = f"{video_name}_frame_{extracted_count:06d}.jpg"
                frame_path = video_output_dir / frame_filename
                
                cv2.imwrite(str(frame_path), frame)
                
                self.metadata.append({
                    'video_name': video_name,
                    'frame_number': frame_count,
                    'extracted_frame_id': extracted_count,
                    'timestamp': frame_count / fps,
                    'frame_path': str(frame_path)
                })
                
                extracted_count += 1
            
            frame_count += 1
            pbar.update(1)
            
        cap.release()
        pbar.close()
        return extracted_count
    
    def extract_motion_based(self, video_path, video_name):
        """Extract frames based on motion detection"""
        cap = cv2.VideoCapture(str(video_path))
        fps = cap.get(cv2.CAP_PROP_FPS)
        
        ret, prev_frame = cap.read()
        if not ret:
            return 0
            
        prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
        
        video_output_dir = self.output_dir / video_name
        video_output_dir.mkdir(exist_ok=True)
        
        frame_count = 0
        extracted_count = 0
        motion_threshold = 15000  # Adjust based on needs
        
        pbar = tqdm(desc=f"Motion-based extraction from {video_name}")
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
                
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            diff = cv2.absdiff(prev_gray, gray)
            motion_score = np.sum(diff)
            
            if motion_score > motion_threshold:
                frame_filename = f"{video_name}_motion_{extracted_count:06d}.jpg"
                frame_path = video_output_dir / frame_filename
                
                cv2.imwrite(str(frame_path), frame)
                
                self.metadata.append({
                    'video_name': video_name,
                    'frame_number': frame_count,
                    'extracted_frame_id': extracted_count,
                    'timestamp': frame_count / fps,
                    'motion_score': float(motion_score),
                    'frame_path': str(frame_path)
                })
                
                extracted_count += 1
                prev_gray = gray
            
            frame_count += 1
            pbar.update(1)
            
        cap.release()
        pbar.close()
        return extracted_count
    
    def process_videos(self, videos_directory):
        """Process all videos in directory"""
        video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv']
        video_files = []
        
        for ext in video_extensions:
            video_files.extend(list(videos_directory.glob(f'*{ext}')))
            video_files.extend(list(videos_directory.glob(f'*{ext.upper()}')))
        
        if not video_files:
            print("❌ No video files found!")
            return
        
        print(f"🎬 Processing {len(video_files)} videos...")
        print(f"📊 Strategy: {self.strategy}")
        print(f"🎯 Target FPS: {self.fps_target}")
        
        total_extracted = 0
        
        for video_path in video_files:
            video_name = video_path.stem
            print(f"\n📹 Processing: {video_name}")
            
            if self.strategy == 'uniform':
                extracted = self.extract_uniform(video_path, video_name)
            elif self.strategy == 'motion':
                extracted = self.extract_motion_based(video_path, video_name)
            else:
                print(f"Unknown strategy: {self.strategy}")
                continue
            
            print(f"✅ Extracted {extracted} frames")
            total_extracted += extracted
        
        # Save metadata
        metadata_path = self.output_dir / 'extraction_metadata.json'
        with open(metadata_path, 'w') as f:
            json.dump(self.metadata, f, indent=2)
        
        print(f"\n🎉 Extraction Complete!")
        print(f"📊 Total frames extracted: {total_extracted}")
        print(f"📄 Metadata saved: {metadata_path}")
        
        return total_extracted

# Configure frame extraction
EXTRACTION_STRATEGY = 'uniform'  # Options: 'uniform', 'motion'
TARGET_FPS = 2  # Frames per second to extract

print("🎬 Frame Extraction Configuration:")
print(f"Strategy: {EXTRACTION_STRATEGY}")
print(f"Target FPS: {TARGET_FPS}")

# Create frame extractor
extractor = FrameExtractor(
    output_dir=frames_dir,
    strategy=EXTRACTION_STRATEGY,
    fps_target=TARGET_FPS
)

# Extract frames from videos
if video_info:  # Only if videos were found
    total_frames = extractor.process_videos(videos_dir)
    
    if total_frames > 0:
        print(f"\n💡 Next Steps:")
        print(f"1. Review extracted frames in: {frames_dir}")
        print(f"2. Proceed to annotation step")
        print(f"3. For best results, aim for 500-1000+ annotated frames")
else:
    print("⚠️  No videos found. Please add videos to data/videos/ directory first.")

## 5. Annotate Data for Pose Estimation {#annotate}

This is a critical step where you'll annotate the keypoints for human poses. Since this requires manual annotation, we'll provide guidance and tools to make the process efficient.

### Annotation Tools Options:

1. **CVAT (Computer Vision Annotation Tool)** - Recommended
   - Web-based annotation platform
   - Supports pose keypoint annotation
   - Team collaboration features
   
2. **LabelImg** - For bounding boxes (if needed)
   - Simple desktop application
   - Good for object detection

3. **Roboflow** - Online annotation platform
   - User-friendly interface
   - Built-in data management

### YOLO11 Pose Annotation Format:

Each annotation file should contain:
```
class_id x_center y_center width height x1 y1 v1 x2 y2 v2 ... x17 y17 v17
```

Where:
- `class_id`: Object class (0 for person)
- `x_center, y_center, width, height`: Bounding box (normalized 0-1)
- `xi, yi, vi`: Keypoint coordinates (normalized) and visibility
  - `vi = 0`: Not visible
  - `vi = 1`: Visible
  - `vi = 2`: Occluded but labeled

In [None]:
# Annotation Helper Functions

def create_sample_annotation():
    """Create a sample annotation file for reference"""
    sample_annotation = """# Sample YOLO11 Pose Annotation Format
# Each line represents one person in the image
# Format: class_id x_center y_center width height x1 y1 v1 x2 y2 v2 ... x17 y17 v17

# Example annotation for a person:
# 0 0.5 0.6 0.3 0.8 0.52 0.32 2 0.48 0.31 2 0.56 0.31 2 0.45 0.33 1 0.59 0.33 1 0.42 0.45 2 0.58 0.45 2 0.38 0.55 2 0.62 0.55 2 0.35 0.58 2 0.65 0.58 2 0.47 0.78 2 0.53 0.78 2 0.45 0.95 2 0.55 0.95 2 0.43 1.0 1 0.57 1.0 1

# Keypoint order (17 keypoints):
# 0: nose
# 1: left_eye, 2: right_eye
# 3: left_ear, 4: right_ear  
# 5: left_shoulder, 6: right_shoulder
# 7: left_elbow, 8: right_elbow
# 9: left_wrist, 10: right_wrist
# 11: left_hip, 12: right_hip
# 13: left_knee, 14: right_knee
# 15: left_ankle, 16: right_ankle

# For throwing detection, focus on:
# - Shoulders (5, 6): Position indicates body orientation
# - Elbows (7, 8): Arm bending during throw
# - Wrists (9, 10): Hand position and motion direction
# - Hips (11, 12): Lower body stability during throw
"""
    
    sample_path = annotations_dir / 'annotation_format_reference.txt'
    with open(sample_path, 'w') as f:
        f.write(sample_annotation)
    
    print(f"📝 Sample annotation saved to: {sample_path}")

def validate_annotation_file(annotation_path):
    """Validate a single annotation file"""
    try:
        with open(annotation_path, 'r') as f:
            lines = f.readlines()
        
        valid_lines = 0
        errors = []
        
        for i, line in enumerate(lines):
            line = line.strip()
            if not line or line.startswith('#'):
                continue
                
            parts = line.split()
            
            # Check format: 1 class + 4 bbox + 51 keypoints (17 * 3)
            if len(parts) != 56:
                errors.append(f"Line {i+1}: Expected 56 values, got {len(parts)}")
                continue
            
            try:
                # Validate class_id
                class_id = int(parts[0])
                if class_id != 0:
                    errors.append(f"Line {i+1}: Class ID should be 0 for person, got {class_id}")
                
                # Validate bbox (normalized 0-1)
                bbox = list(map(float, parts[1:5]))
                if not all(0 <= val <= 1 for val in bbox):
                    errors.append(f"Line {i+1}: Bbox values should be normalized (0-1)")
                
                # Validate keypoints
                keypoints = list(map(float, parts[5:]))
                for j in range(0, len(keypoints), 3):
                    x, y, v = keypoints[j], keypoints[j+1], keypoints[j+2]
                    if not (0 <= x <= 1 and 0 <= y <= 1 and v in [0, 1, 2]):
                        errors.append(f"Line {i+1}: Invalid keypoint {j//3 + 1}")
                        break
                
                valid_lines += 1
                
            except ValueError as e:
                errors.append(f"Line {i+1}: Value error - {e}")
        
        return valid_lines, errors
        
    except Exception as e:
        return 0, [f"File error: {e}"]

def check_annotation_status():
    """Check the status of annotation files"""
    
    # Look for annotation files
    annotation_files = list(annotations_dir.glob('*.txt'))
    annotation_files = [f for f in annotation_files if not f.name.startswith('annotation_format')]
    
    if not annotation_files:
        print("❌ No annotation files found!")
        print(f"📁 Expected location: {annotations_dir}")
        print("\n📋 To create annotations:")
        print("1. Use CVAT, LabelImg, or similar tool")
        print("2. Export in YOLO format")
        print("3. Place .txt files in annotations/ directory")
        print("4. Each .txt file should match a frame filename")
        return False
    
    print(f"📊 Found {len(annotation_files)} annotation files")
    
    total_valid = 0
    total_errors = 0
    
    print("\n🔍 Validating annotations...")
    for ann_file in tqdm(annotation_files[:10]):  # Check first 10 files
        valid_lines, errors = validate_annotation_file(ann_file)
        total_valid += valid_lines
        total_errors += len(errors)
        
        if errors:
            print(f"⚠️  {ann_file.name}: {len(errors)} errors")
            for error in errors[:3]:  # Show first 3 errors
                print(f"   {error}")
            if len(errors) > 3:
                print(f"   ... and {len(errors)-3} more errors")
    
    print(f"\n📈 Validation Results:")
    print(f"   ✅ Valid annotations: {total_valid}")
    print(f"   ❌ Total errors: {total_errors}")
    
    if total_valid > 0:
        print("✅ Ready for training preparation!")
        return True
    else:
        print("❌ Please fix annotation errors before proceeding")
        return False

# Create reference files
create_sample_annotation()

# Check annotation status
print("🔍 Checking annotation status...")
annotations_ready = check_annotation_status()

if not annotations_ready:
    print("\n📋 Annotation Checklist:")
    print("□ Extract frames from videos (completed above)")
    print("□ Select representative frames for annotation (~200-500 minimum)")
    print("□ Use CVAT or similar tool to annotate keypoints")
    print("□ Export annotations in YOLO pose format")
    print("□ Place annotation files in data/annotations/")
    print("□ Validate annotations (run this cell again)")
    print("□ Proceed to training")
    
    print("\n🔗 Useful Resources:")
    print("- CVAT: https://cvat.org")
    print("- YOLO format guide: https://docs.ultralytics.com/datasets/pose/")
    print("- Pose annotation tutorial: https://blog.roboflow.com/pose-estimation-annotation/")

## 6. Convert Annotations to YOLO Format {#convert}

Now we'll organize our annotated data into the proper YOLO11 training format with train/validation/test splits.

In [None]:
def prepare_yolo_dataset(train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    """Prepare YOLO format dataset with train/val/test splits"""
    
    print("📦 Preparing YOLO dataset...")
    
    # Find all annotated image-annotation pairs
    frame_files = []
    for ext in ['.jpg', '.jpeg', '.png']:
        frame_files.extend(list(frames_dir.glob(f'**/*{ext}')))
        frame_files.extend(list(frames_dir.glob(f'**/*{ext.upper()}')))
    
    # Find corresponding annotations
    pairs = []
    missing_annotations = []
    
    for img_path in frame_files:
        ann_name = img_path.stem + '.txt'
        ann_path = annotations_dir / ann_name
        
        if ann_path.exists():
            pairs.append((img_path, ann_path))
        else:
            missing_annotations.append(img_path)
    
    print(f"📊 Found {len(pairs)} image-annotation pairs")
    if missing_annotations:
        print(f"⚠️  {len(missing_annotations)} images without annotations")
    
    if len(pairs) < 10:
        print("❌ Insufficient annotated data! Need at least 10 pairs for training.")
        return False
    
    # Split dataset
    print(f"📈 Splitting dataset: {train_ratio:.0%} train, {val_ratio:.0%} val, {test_ratio:.0%} test")
    
    # First split: separate test set
    train_val_pairs, test_pairs = train_test_split(pairs, test_size=test_ratio, random_state=42)
    
    # Second split: separate train and validation  
    val_size = val_ratio / (train_ratio + val_ratio)
    train_pairs, val_pairs = train_test_split(train_val_pairs, test_size=val_size, random_state=42)
    
    print(f"✅ Train: {len(train_pairs)} samples")
    print(f"✅ Validation: {len(val_pairs)} samples") 
    print(f"✅ Test: {len(test_pairs)} samples")
    
    # Create directory structure
    for split in ['train', 'val', 'test']:
        (data_dir / split / 'images').mkdir(parents=True, exist_ok=True)
        (data_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
    
    # Copy files to appropriate directories
    def copy_files(pairs_list, split_name):
        print(f"📁 Copying {split_name} files...")
        for img_path, ann_path in tqdm(pairs_list):
            # Copy image
            img_dest = data_dir / split_name / 'images' / img_path.name
            shutil.copy2(img_path, img_dest)
            
            # Copy annotation
            ann_dest = data_dir / split_name / 'labels' / ann_path.name
            shutil.copy2(ann_path, ann_dest)
    
    copy_files(train_pairs, 'train')
    copy_files(val_pairs, 'val')
    if test_pairs:
        copy_files(test_pairs, 'test')
    
    # Create dataset configuration file
    dataset_config = {
        'path': str(data_dir.absolute()),
        'train': 'train/images',
        'val': 'val/images',
        'test': 'test/images',
        'names': {0: 'person'},
        'nc': 1,  # number of classes
        'kpt_shape': [17, 3],  # 17 keypoints, each with x, y, visibility
    }
    
    config_path = data_dir / 'dataset.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(dataset_config, f, default_flow_style=False)
    
    print(f"✅ Dataset configuration saved: {config_path}")
    
    # Create dataset summary
    summary = {
        'total_samples': len(pairs),
        'train_samples': len(train_pairs),
        'val_samples': len(val_pairs),
        'test_samples': len(test_pairs),
        'missing_annotations': len(missing_annotations),
        'dataset_ready': True,
        'config_path': str(config_path)
    }
    
    summary_path = data_dir / 'dataset_summary.json'
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"✅ Dataset summary saved: {summary_path}")
    
    return True

def visualize_dataset_sample():
    """Visualize a sample from the prepared dataset"""
    
    train_images_dir = data_dir / 'train' / 'images'
    train_labels_dir = data_dir / 'train' / 'labels'
    
    if not train_images_dir.exists():
        print("❌ Training dataset not prepared yet!")
        return
    
    # Get a random sample
    image_files = list(train_images_dir.glob('*.jpg'))
    if not image_files:
        image_files = list(train_images_dir.glob('*.png'))
    
    if not image_files:
        print("❌ No training images found!")
        return
    
    sample_img = random.choice(image_files)
    sample_label = train_labels_dir / f"{sample_img.stem}.txt"
    
    if not sample_label.exists():
        print(f"❌ No label file for {sample_img.name}")
        return
    
    # Load image
    img = cv2.imread(str(sample_img))
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img_rgb.shape[:2]
    
    # Load annotations
    with open(sample_label, 'r') as f:
        lines = f.readlines()
    
    plt.figure(figsize=(12, 8))
    plt.imshow(img_rgb)
    plt.title(f'Training Sample: {sample_img.name}')
    plt.axis('off')
    
    # Draw annotations
    for line in lines:
        parts = line.strip().split()
        if len(parts) >= 56:
            # Extract bbox
            x_center, y_center, width, height = map(float, parts[1:5])
            
            # Convert to pixel coordinates
            x_center *= w
            y_center *= h
            width *= w
            height *= h
            
            # Draw bounding box
            x1 = x_center - width/2
            y1 = y_center - height/2
            x2 = x_center + width/2  
            y2 = y_center + height/2
            
            plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r-', linewidth=2)
            
            # Extract and draw keypoints
            keypoints = list(map(float, parts[5:]))
            for i in range(0, len(keypoints), 3):
                if i+2 < len(keypoints):
                    x, y, v = keypoints[i], keypoints[i+1], keypoints[i+2]
                    if v > 0:  # Visible keypoint
                        x_px, y_px = x * w, y * h
                        color = 'red' if i//3 in [5, 6, 7, 8, 9, 10] else 'blue'  # Highlight arm keypoints
                        plt.scatter(x_px, y_px, c=color, s=30, alpha=0.8)
                        
                        # Add keypoint labels for key points
                        if i//3 in [5, 6, 9, 10]:  # Shoulders and wrists
                            kp_name = KEYPOINT_NAMES[i//3] if i//3 < len(KEYPOINT_NAMES) else str(i//3)
                            plt.text(x_px, y_px-10, kp_name, fontsize=8, ha='center', color='white',
                                   bbox=dict(boxstyle='round,pad=0.2', facecolor=color, alpha=0.7))
    
    plt.tight_layout()
    plt.show()
    
    print(f"📷 Sample visualization shown for: {sample_img.name}")
    print(f"🎯 Red points: Arm/throwing keypoints")
    print(f"🔵 Blue points: Other body keypoints")

# Prepare the dataset
if annotations_ready:
    dataset_ready = prepare_yolo_dataset()
    
    if dataset_ready:
        print("\n🎉 Dataset preparation complete!")
        print("\n📊 Dataset Structure:")
        for split in ['train', 'val', 'test']:
            split_dir = data_dir / split
            if split_dir.exists():
                img_count = len(list((split_dir / 'images').glob('*.*')))
                label_count = len(list((split_dir / 'labels').glob('*.txt')))
                print(f"   {split}: {img_count} images, {label_count} labels")
        
        # Show a sample
        print("\n🖼️  Visualizing training sample...")
        visualize_dataset_sample()
        
    else:
        print("❌ Dataset preparation failed!")
else:
    print("⚠️  Please complete annotation step first!")

## 7. Setup YOLO11 Pose Model {#model}

Now we'll set up the YOLO11 pose estimation model. YOLO11 comes in different sizes, each with trade-offs between speed and accuracy.

In [None]:
# YOLO11 Model Configuration

# Available model sizes (choose based on your needs)
MODEL_SIZES = {
    'n': {'name': 'yolo11n-pose.pt', 'description': 'Nano - Fastest, smallest, lower accuracy'},
    's': {'name': 'yolo11s-pose.pt', 'description': 'Small - Good balance of speed and accuracy'},
    'm': {'name': 'yolo11m-pose.pt', 'description': 'Medium - Better accuracy, moderate speed'},
    'l': {'name': 'yolo11l-pose.pt', 'description': 'Large - High accuracy, slower'},
    'x': {'name': 'yolo11x-pose.pt', 'description': 'Extra Large - Highest accuracy, slowest'}
}

def select_model_size():
    """Help user select appropriate model size"""
    print("🤖 YOLO11 Pose Model Sizes:")
    print("=" * 50)
    
    for size, info in MODEL_SIZES.items():
        print(f"{size.upper()}: {info['description']}")
        print(f"    Model: {info['name']}")
    
    print("\n💡 Recommendations:")
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        if gpu_memory >= 8:
            print("✅ GPU with 8GB+ memory detected - Can use Medium (M) or Large (L)")
            recommended = 'm'
        elif gpu_memory >= 6:
            print("✅ GPU with 6GB+ memory detected - Recommended: Small (S) or Medium (M)")
            recommended = 's'
        else:
            print("⚠️  Limited GPU memory - Recommended: Nano (N) or Small (S)")
            recommended = 'n'
    else:
        print("⚠️  No GPU detected - Recommended: Nano (N) for CPU training")
        recommended = 'n'
    
    return recommended

def initialize_model(model_size='s'):
    """Initialize YOLO11 pose model"""
    
    model_name = MODEL_SIZES[model_size]['name']
    print(f"🚀 Initializing YOLO11 pose model: {model_name}")
    
    try:
        # Load pretrained model
        model = YOLO(model_name)
        
        print(f"✅ Model loaded successfully!")
        print(f"📊 Model info:")
        print(f"   - Architecture: YOLO11{model_size.upper()}")
        print(f"   - Task: Pose Estimation")
        print(f"   - Keypoints: 17 (COCO format)")
        print(f"   - Classes: 1 (person)")
        
        # Model summary
        model.info(verbose=False)
        
        return model
        
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        print("💡 Make sure ultralytics is properly installed")
        return None

def verify_dataset_config():
    """Verify dataset configuration is ready"""
    
    config_path = data_dir / 'dataset.yaml'
    
    if not config_path.exists():
        print("❌ Dataset configuration not found!")
        print("Please complete the dataset preparation step first.")
        return False
    
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        
        print("✅ Dataset configuration found:")
        print(f"   - Path: {config.get('path', 'Not specified')}")
        print(f"   - Classes: {config.get('nc', 'Not specified')}")
        print(f"   - Keypoints: {config.get('kpt_shape', 'Not specified')}")
        
        # Check if directories exist
        base_path = Path(config['path'])
        for split in ['train', 'val', 'test']:
            split_path = base_path / config.get(split, f"{split}/images")
            if split_path.exists():
                img_count = len(list(split_path.glob('*.*')))
                print(f"   - {split}: {img_count} images")
            else:
                print(f"   - {split}: Directory not found")
        
        return True
        
    except Exception as e:
        print(f"❌ Error reading dataset config: {e}")
        return False

# Model selection and initialization
print("🎯 Model Selection and Setup")
print("=" * 40)

# Get model size recommendation
recommended_size = select_model_size()
print(f"\n🎯 Auto-selected model size: {recommended_size.upper()}")

# You can change this to a different size if needed
MODEL_SIZE = recommended_size  # Change to 'n', 's', 'm', 'l', or 'x'

# Verify dataset is ready
print(f"\n📋 Verifying dataset configuration...")
dataset_ready = verify_dataset_config()

if dataset_ready:
    # Initialize model
    print(f"\n🤖 Initializing model...")
    model = initialize_model(MODEL_SIZE)
    
    if model is not None:
        print(f"\n✅ Setup Complete!")
        print(f"📦 Model: YOLO11{MODEL_SIZE.upper()}")
        print(f"📊 Ready for training configuration")
    else:
        print(f"❌ Model initialization failed!")
else:
    print(f"❌ Dataset not ready. Please complete previous steps first.")

## 8. Configure Training Parameters {#config}

Let's configure the training parameters optimized for pose estimation and waste throwing detection.

In [None]:
# Training Configuration for Pose Estimation

def create_training_config(dataset_size='small'):
    """Create optimized training configuration based on dataset size"""
    
    # Base configuration
    config = {
        # Training duration
        'epochs': 200,
        'patience': 30,  # Early stopping patience
        
        # Data loading
        'imgsz': 640,  # Input image size
        'batch': 16,   # Batch size (adjust based on GPU memory)
        'workers': 4,  # Number of dataloader workers
        'cache': True, # Cache images for faster training
        
        # Optimization
        'optimizer': 'AdamW',
        'lr0': 0.001,      # Initial learning rate
        'lrf': 0.01,       # Final learning rate (lr0 * lrf)
        'momentum': 0.937,
        'weight_decay': 0.0005,
        'warmup_epochs': 3,
        'warmup_momentum': 0.8,
        'warmup_bias_lr': 0.1,
        
        # Loss weights (important for pose estimation)
        'box': 0.05,    # Box loss weight
        'cls': 0.5,     # Classification loss weight  
        'kobj': 1.0,    # Keypoint objectness loss weight
        'pose': 12.0,   # Pose loss weight (high for pose estimation)
        
        # Data augmentation
        'mosaic': 1.0,      # Mosaic augmentation probability
        'mixup': 0.1,       # Mixup augmentation probability
        'copy_paste': 0.1,  # Copy-paste augmentation
        'degrees': 0.0,     # Rotation range (degrees)
        'translate': 0.1,   # Translation fraction
        'scale': 0.5,       # Scaling range
        'shear': 0.0,       # Shear range
        'perspective': 0.0, # Perspective transformation
        'flipud': 0.0,      # Vertical flip probability
        'fliplr': 0.5,      # Horizontal flip probability
        'hsv_h': 0.015,     # Hue augmentation range
        'hsv_s': 0.7,       # Saturation augmentation range
        'hsv_v': 0.4,       # Value augmentation range
        
        # Training behavior
        'save': True,
        'save_period': 10,  # Save checkpoint every N epochs
        'cos_lr': True,     # Cosine learning rate scheduler
        'close_mosaic': 10, # Disable mosaic in last N epochs
        'amp': True,        # Automatic Mixed Precision
        'single_cls': False,
        'rect': False,      # Rectangular training
        'resume': False,
        'exist_ok': True,
        'pretrained': True,
        'verbose': True,
        'seed': 42,
        'deterministic': True,
        'plots': True,
        'profile': False,
    }
    
    # Adjust based on dataset size
    if dataset_size == 'small':  # < 500 images
        config.update({
            'epochs': 300,      # More epochs for small datasets
            'lr0': 0.002,       # Slightly higher learning rate
            'patience': 50,     # More patience
            'mosaic': 0.8,      # Reduce mosaic for small datasets
            'mixup': 0.15,      # Increase mixup
            'copy_paste': 0.15, # Increase copy-paste
        })
        print("📊 Small dataset configuration applied")
        
    elif dataset_size == 'medium':  # 500-2000 images
        config.update({
            'epochs': 200,
            'lr0': 0.001,
            'patience': 30,
        })
        print("📊 Medium dataset configuration applied")
        
    else:  # large: > 2000 images
        config.update({
            'epochs': 150,      # Fewer epochs for large datasets
            'lr0': 0.0008,      # Lower learning rate
            'patience': 20,     # Less patience
        })
        print("📊 Large dataset configuration applied")
    
    # Adjust batch size based on GPU memory
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        if gpu_memory >= 12:
            config['batch'] = 32
            print("🚀 High-end GPU detected - using batch size 32")
        elif gpu_memory >= 8:
            config['batch'] = 24
            print("🚀 Mid-range GPU detected - using batch size 24")
        elif gpu_memory >= 6:
            config['batch'] = 16
            print("🚀 Entry-level GPU detected - using batch size 16")
        else:
            config['batch'] = 8
            print("⚠️  Limited GPU memory - using batch size 8")
    else:
        config['batch'] = 4
        config['workers'] = 2
        print("⚠️  CPU training - using batch size 4")
    
    return config

def setup_training_environment():
    """Setup training environment and monitoring"""
    
    # Create results directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_name = f"yolo11{MODEL_SIZE}_pose_{timestamp}"
    experiment_dir = results_dir / experiment_name
    experiment_dir.mkdir(exist_ok=True)
    
    print(f"📁 Experiment directory: {experiment_dir}")
    
    # Setup Weights & Biases (optional)
    use_wandb = True
    try:
        import wandb
        wandb.login(anonymous="allow")
        print("✅ Weights & Biases available for experiment tracking")
    except:
        use_wandb = False
        print("⚠️  Weights & Biases not available - using local logging only")
    
    return experiment_name, use_wandb

def estimate_training_time(config, num_samples):
    """Estimate training time based on configuration"""
    
    # Base time estimates (very rough)
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name().lower()
        if 'rtx 40' in gpu_name or 'a100' in gpu_name:
            time_per_epoch = 0.5  # High-end GPU
        elif 'rtx 30' in gpu_name or 'rtx 20' in gpu_name:
            time_per_epoch = 1.0  # Mid-range GPU
        else:
            time_per_epoch = 2.0  # Entry-level GPU
    else:
        time_per_epoch = 10.0  # CPU training
    
    # Adjust for dataset size and batch size
    time_per_epoch *= (num_samples / 1000) * (16 / config['batch'])
    
    total_time = time_per_epoch * config['epochs'] / 60  # Convert to minutes
    
    print(f"⏱️  Estimated training time: {total_time:.1f} minutes ({total_time/60:.1f} hours)")
    
    if total_time > 120:  # More than 2 hours
        print("💡 Consider reducing epochs or using a smaller model for faster training")

# Determine dataset size for configuration
config_path = data_dir / 'dataset.yaml'
if config_path.exists():
    # Count training samples
    train_dir = data_dir / 'train' / 'images'
    if train_dir.exists():
        num_train_samples = len(list(train_dir.glob('*.*')))
        
        if num_train_samples < 500:
            dataset_size = 'small'
        elif num_train_samples < 2000:
            dataset_size = 'medium'
        else:
            dataset_size = 'large'
        
        print(f"📊 Dataset Analysis:")
        print(f"   Training samples: {num_train_samples}")
        print(f"   Dataset size category: {dataset_size}")
        
        # Create training configuration
        train_config = create_training_config(dataset_size)
        
        # Setup training environment
        experiment_name, use_wandb = setup_training_environment()
        
        # Estimate training time
        estimate_training_time(train_config, num_train_samples)
        
        print(f"\n⚙️  Training Configuration Summary:")
        print(f"   Epochs: {train_config['epochs']}")
        print(f"   Batch size: {train_config['batch']}")
        print(f"   Learning rate: {train_config['lr0']}")
        print(f"   Image size: {train_config['imgsz']}")
        print(f"   Pose loss weight: {train_config['pose']}")
        print(f"   Augmentation: Mosaic={train_config['mosaic']}, Mixup={train_config['mixup']}")
        
        # Save configuration
        config_save_path = results_dir / f"{experiment_name}_config.json"
        with open(config_save_path, 'w') as f:
            json.dump(train_config, f, indent=2)
        
        print(f"\n✅ Configuration saved: {config_save_path}")
        print(f"🚀 Ready to start training!")
        
    else:
        print("❌ Training directory not found. Please prepare dataset first.")
else:
    print("❌ Dataset configuration not found. Please prepare dataset first.")

## 9. Train the Model {#train}

Now we'll execute the training process. This is where the magic happens! The model will learn to detect human poses and specifically recognize throwing motions.

In [None]:
# Execute Training Process

def start_training():
    """Start the training process with monitoring"""
    
    if 'model' not in locals() or model is None:
        print("❌ Model not initialized. Please run the model setup cell first.")
        return None
    
    if 'train_config' not in locals():
        print("❌ Training configuration not found. Please run the configuration cell first.")
        return None
    
    config_path = data_dir / 'dataset.yaml'
    if not config_path.exists():
        print("❌ Dataset configuration not found. Please prepare the dataset first.")
        return None
    
    print("🚀 Starting YOLO11 Pose Estimation Training...")
    print("=" * 50)
    print(f"📦 Model: YOLO11{MODEL_SIZE.upper()}")
    print(f"📊 Dataset: {config_path}")
    print(f"⚙️  Configuration: {len(train_config)} parameters")
    print(f"🎯 Experiment: {experiment_name}")
    
    # Start training
    try:
        start_time = datetime.now()
        print(f"⏰ Training started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        # Execute training
        results = model.train(
            data=str(config_path),
            project=str(results_dir),
            name=experiment_name,
            **train_config
        )
        
        end_time = datetime.now()
        training_duration = end_time - start_time
        
        print("🎉 Training completed successfully!")
        print(f"⏰ Training duration: {training_duration}")
        print(f"📁 Results saved to: {results_dir / experiment_name}")
        
        # Save training summary
        training_summary = {
            'start_time': start_time.isoformat(),
            'end_time': end_time.isoformat(),
            'duration': str(training_duration),
            'model_size': MODEL_SIZE,
            'experiment_name': experiment_name,
            'dataset_path': str(config_path),
            'configuration': train_config,
            'results_path': str(results_dir / experiment_name)
        }
        
        summary_path = results_dir / f"{experiment_name}_training_summary.json"
        with open(summary_path, 'w') as f:
            json.dump(training_summary, f, indent=2)
        
        print(f"📄 Training summary saved: {summary_path}")
        
        return results
        
    except KeyboardInterrupt:
        print("\n⚠️  Training interrupted by user")
        print("💡 You can resume training using the checkpoint in the results directory")
        return None
        
    except Exception as e:
        print(f"❌ Training failed with error: {e}")
        print("💡 Check GPU memory, dataset format, and configuration")
        return None

def monitor_training_progress(results_path):
    """Monitor training progress and display key metrics"""
    
    if not results_path.exists():
        print("❌ Training results not found")
        return
    
    # Look for training log files
    log_files = list(results_path.glob("*.csv"))
    if not log_files:
        print("📊 Training logs not available yet")
        return
    
    # Read training metrics
    try:
        log_file = log_files[0]  # Usually results.csv
        df = pd.read_csv(log_file)
        
        print("📈 Training Progress:")
        print(f"   Epochs completed: {len(df)}")
        
        if len(df) > 0:
            latest = df.iloc[-1]
            print(f"   Latest mAP@50: {latest.get('metrics/mAP50(P)', 'N/A'):.4f}")
            print(f"   Latest mAP@50:95: {latest.get('metrics/mAP50-95(P)', 'N/A'):.4f}")
            print(f"   Training loss: {latest.get('train/pose_loss', 'N/A')}")
            print(f"   Validation loss: {latest.get('val/pose_loss', 'N/A')}")
        
        # Plot training curves
        if len(df) > 5:  # Only plot if we have enough data
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            
            # mAP curves
            if 'metrics/mAP50(P)' in df.columns:
                axes[0, 0].plot(df['epoch'], df['metrics/mAP50(P)'], label='mAP@50')
                axes[0, 0].plot(df['epoch'], df['metrics/mAP50-95(P)'], label='mAP@50:95')
                axes[0, 0].set_title('Mean Average Precision')
                axes[0, 0].set_xlabel('Epoch')
                axes[0, 0].set_ylabel('mAP')
                axes[0, 0].legend()
                axes[0, 0].grid(True)
            
            # Loss curves
            loss_cols = [col for col in df.columns if 'loss' in col and 'train' in col]
            if loss_cols:
                for col in loss_cols[:3]:  # Plot first 3 loss types
                    axes[0, 1].plot(df['epoch'], df[col], label=col.replace('train/', ''))
                axes[0, 1].set_title('Training Losses')
                axes[0, 1].set_xlabel('Epoch')
                axes[0, 1].set_ylabel('Loss')
                axes[0, 1].legend()
                axes[0, 1].grid(True)
            
            # Precision and Recall
            if 'metrics/precision(P)' in df.columns:
                axes[1, 0].plot(df['epoch'], df['metrics/precision(P)'], label='Precision')
                axes[1, 0].plot(df['epoch'], df['metrics/recall(P)'], label='Recall')
                axes[1, 0].set_title('Precision and Recall')
                axes[1, 0].set_xlabel('Epoch')
                axes[1, 0].set_ylabel('Score')
                axes[1, 0].legend()
                axes[1, 0].grid(True)
            
            # Learning rate
            if 'lr/pg0' in df.columns:
                axes[1, 1].plot(df['epoch'], df['lr/pg0'])
                axes[1, 1].set_title('Learning Rate')
                axes[1, 1].set_xlabel('Epoch')
                axes[1, 1].set_ylabel('Learning Rate')
                axes[1, 1].grid(True)
            
            plt.tight_layout()
            plt.show()
        
    except Exception as e:
        print(f"❌ Error reading training logs: {e}")

# Check if everything is ready for training
def check_training_readiness():
    """Check if all prerequisites are met for training"""
    
    checks = {
        'Model initialized': 'model' in locals() and model is not None,
        'Dataset prepared': (data_dir / 'dataset.yaml').exists(),
        'Training config ready': 'train_config' in locals(),
        'GPU available': torch.cuda.is_available(),
        'Training images exist': (data_dir / 'train' / 'images').exists(),
        'Validation images exist': (data_dir / 'val' / 'images').exists(),
    }
    
    print("🔍 Training Readiness Check:")
    print("=" * 30)
    
    all_ready = True
    for check, status in checks.items():
        status_icon = "✅" if status else "❌"
        print(f"{status_icon} {check}")
        if not status:
            all_ready = False
    
    return all_ready

# Perform readiness check
ready_to_train = check_training_readiness()

if ready_to_train:
    print(f"\n🎯 All systems ready for training!")
    print(f"\n⚡ To start training, run the next cell")
    print(f"📊 Monitor progress in the output below")
    print(f"⏰ Estimated time: {train_config['epochs']} epochs")
    
    # Optionally start training automatically (comment out if you want manual control)
    print(f"\n🚀 Starting training in 5 seconds...")
    print(f"💡 Press Ctrl+C to cancel")
    
    import time
    try:
        for i in range(5, 0, -1):
            print(f"⏰ Starting in {i}...", end='\r')
            time.sleep(1)
        print(f"🚀 Starting training now!   ")
        
        # Start training
        training_results = start_training()
        
        if training_results:
            print(f"\n📊 Training completed! Monitoring results...")
            monitor_training_progress(results_dir / experiment_name)
            
    except KeyboardInterrupt:
        print(f"\n⚠️  Training start cancelled by user")
        print(f"💡 You can start training manually by calling: start_training()")
        
else:
    print(f"\n❌ Not ready for training. Please complete the missing steps above.")
    print(f"💡 Fix the issues marked with ❌ and run this cell again.")