# YOLO Cookie Defect Detection

This notebook trains and compiles YOLOv8 models for cookie defect detection using Amazon SageMaker.

## Overview

This notebook implements a complete pipeline for:
1. Environment setup and configuration
2. Dataset acquisition and preparation
3. Format conversion (Lookout for Vision → YOLO)
4. Model training using YOLOv8
5. Model compilation for multiple target platforms
6. Model comparison and validation

**Supported Model Types:**
- YOLO Object Detection (bounding boxes)
- YOLO Instance Segmentation (polygon masks)

**Target Platforms:**
- Jetson Xavier GPU
- x86_64 CPU
- ARM64 CPU

## 1. Environment Setup

Initialize the SageMaker environment and configure project settings.

### 1.1 Library Imports

Import required libraries for SageMaker operations, S3 access, and data processing.

In [None]:
# AWS and SageMaker libraries
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch

# Standard libraries
import json
import datetime
import os
from pathlib import Path
import time

# Display library versions
print(f"SageMaker SDK version: {sagemaker.__version__}")
print(f"Boto3 version: {boto3.__version__}")

### 1.2 SageMaker Session Initialization

Create a SageMaker session and retrieve the default S3 bucket and AWS region.

In [None]:
# Initialize SageMaker session
sagemaker_session = sagemaker.Session()

# Get default S3 bucket
default_bucket = sagemaker_session.default_bucket()

# Get AWS region
region = sagemaker_session.boto_region_name

# Get execution role
role = sagemaker.get_execution_role()

# Display configuration
print(f"✅ SageMaker session initialized")
print(f"   Region: {region}")
print(f"   Default S3 bucket: {default_bucket}")
print(f"   Execution role: {role}")

### 1.3 Project Configuration and S3 Folder Structure

Define project parameters and create the S3 folder structure for outputs.

In [None]:
# Project configuration
project_name = "yolo-cookie-defect-detection"
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# Define S3 folder structure
s3_prefix = f"{project_name}/{timestamp}"
s3_paths = {
    'training_output': f"s3://{default_bucket}/{s3_prefix}/training-output",
    'compilation_output': f"s3://{default_bucket}/{s3_prefix}/compilation-output",
    'dataset': f"s3://{default_bucket}/{s3_prefix}/dataset",
    'models': f"s3://{default_bucket}/{s3_prefix}/models"
}

# Create S3 client
s3_client = boto3.client('s3', region_name=region)

# Create S3 folders by uploading empty marker files
for folder_name, s3_path in s3_paths.items():
    # Extract bucket and key from S3 URI
    s3_uri_parts = s3_path.replace('s3://', '').split('/', 1)
    bucket = s3_uri_parts[0]
    key = s3_uri_parts[1] + '/.folder_marker'
    
    try:
        # Create folder marker
        s3_client.put_object(Bucket=bucket, Key=key, Body=b'')
        print(f"✅ Created S3 folder: {s3_path}")
    except Exception as e:
        print(f"❌ Failed to create S3 folder {folder_name}: {str(e)}")
        raise

# Display S3 paths
print("\n📁 S3 Folder Structure:")
for folder_name, s3_path in s3_paths.items():
    print(f"   {folder_name}: {s3_path}")

## 2. Dataset Acquisition

Download and prepare the cookie dataset from the amazon-lookout-for-vision repository.

### 2.1 Clone Repository

Clone the amazon-lookout-for-vision repository from GitHub.

In [None]:
import subprocess
import shutil

# Repository URL
repo_url = "https://github.com/aws-samples/amazon-lookout-for-vision.git"
repo_dir = "amazon-lookout-for-vision"

# Remove existing repository if present
if os.path.exists(repo_dir):
    print(f"🗑️  Removing existing repository: {repo_dir}")
    shutil.rmtree(repo_dir)

# Clone repository
print(f"📥 Cloning repository: {repo_url}")
try:
    result = subprocess.run(
        ["git", "clone", repo_url, repo_dir],
        capture_output=True,
        text=True,
        check=True
    )
    print(f"✅ Repository cloned successfully to: {repo_dir}")
except subprocess.CalledProcessError as e:
    print(f"❌ Failed to clone repository: {e.stderr}")
    raise

### 2.2 Extract Cookie Dataset

Copy the cookie-dataset folder to the working directory.

In [None]:
# Source and destination paths
source_dataset = os.path.join(repo_dir, "cookie-dataset")
dest_dataset = "cookie-dataset"

# Remove existing dataset if present
if os.path.exists(dest_dataset):
    print(f"🗑️  Removing existing dataset: {dest_dataset}")
    shutil.rmtree(dest_dataset)

# Copy dataset
print(f"📦 Extracting cookie dataset from: {source_dataset}")
try:
    if not os.path.exists(source_dataset):
        raise FileNotFoundError(
            f"Cookie dataset not found at: {source_dataset}\n"
            f"Expected location: {os.path.abspath(source_dataset)}\n"
            f"Available directories in repo: {os.listdir(repo_dir) if os.path.exists(repo_dir) else 'Repository not found'}"
        )
    
    shutil.copytree(source_dataset, dest_dataset)
    print(f"✅ Dataset extracted successfully to: {dest_dataset}")
except Exception as e:
    print(f"❌ Failed to extract dataset: {str(e)}")
    raise

### 2.3 Validate Dataset Structure

Verify the presence of required files and count images.

In [None]:
# Define expected dataset structure
dataset_structure = {
    'training_images': os.path.join(dest_dataset, 'dataset-files', 'training-images'),
    'mask_images': os.path.join(dest_dataset, 'dataset-files', 'mask-images'),
    'manifests': os.path.join(dest_dataset, 'dataset-files', 'manifests')
}

print("🔍 Validating dataset structure...\n")

# Validate each component
validation_passed = True
for component_name, component_path in dataset_structure.items():
    if not os.path.exists(component_path):
        print(f"❌ Missing: {component_name} at {component_path}")
        validation_passed = False
    else:
        print(f"✅ Found: {component_name} at {component_path}")

if not validation_passed:
    raise ValueError(
        f"Invalid dataset structure.\n"
        f"Missing required components.\n"
        f"Expected structure:\n"
        f"  - {dataset_structure['training_images']}\n"
        f"  - {dataset_structure['mask_images']}\n"
        f"  - {dataset_structure['manifests']}"
    )

# Count images
print("\n📊 Dataset Statistics:")

# Count training images
training_images = [f for f in os.listdir(dataset_structure['training_images']) 
                   if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
print(f"   Training images: {len(training_images)}")

# Count mask images
mask_images = [f for f in os.listdir(dataset_structure['mask_images']) 
               if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))]
print(f"   Mask images: {len(mask_images)}")

# List manifest files
manifest_files = [f for f in os.listdir(dataset_structure['manifests']) 
                  if f.endswith('.json') or f.endswith('.jsonl')]
print(f"   Manifest files: {len(manifest_files)}")
if manifest_files:
    for manifest in manifest_files:
        print(f"      - {manifest}")

print("\n✅ Dataset validation complete!")

### 2.4 Cleanup Cloned Repository

Remove the cloned repository to save space (keeping only the extracted dataset).

In [None]:
# Remove cloned repository
if os.path.exists(repo_dir):
    print(f"🗑️  Cleaning up: Removing {repo_dir}")
    try:
        shutil.rmtree(repo_dir)
        print(f"✅ Repository removed successfully")
        print(f"   Dataset retained at: {dest_dataset}")
    except Exception as e:
        print(f"⚠️  Warning: Failed to remove repository: {str(e)}")
        print(f"   You may need to manually delete: {repo_dir}")
else:
    print(f"ℹ️  Repository directory not found: {repo_dir}")

## 3. Format Conversion

Convert Lookout for Vision annotations to YOLO format (detection and segmentation).

*This section will be implemented in subsequent tasks.*

### 3.1 Import Format Converter

Import the YOLO format converter module with helper functions for converting Lookout for Vision annotations to YOLO format.

In [None]:
# Import format converter functions
from yolo_format_converter import (
    extract_bounding_boxes,
    normalize_coordinates,
    convert_to_yolo_format,
    read_manifest,
    write_yolo_annotations,
    create_data_yaml
)

import cv2
import numpy as np

print("✅ Format converter imported successfully")

### 3.2 Process Manifest and Extract Bounding Boxes

Read the Lookout for Vision manifest file and extract bounding boxes from segmentation masks.

In [None]:
# Define paths
manifest_path = 'cookie-dataset/dataset-files/manifests/output.manifest'
mask_images_dir = 'cookie-dataset/dataset-files/mask-images'
training_images_dir = 'cookie-dataset/dataset-files/training-images'

# Read manifest file
print("📖 Reading manifest file...")
manifest_records = read_manifest(manifest_path)
print(f"✅ Loaded {len(manifest_records)} records from manifest")

# Process each record and extract bounding boxes
annotations = {}  # Dictionary: image_filename -> list of YOLO annotation lines

print("\n🔍 Extracting bounding boxes from masks...")
for i, record in enumerate(manifest_records):
    # Get image filename from source-ref
    source_ref = record['source-ref']
    image_filename = os.path.basename(source_ref)
    
    # Get anomaly label (0=normal, 1=anomaly)
    class_id = record['anomaly-label']
    
    # If this is an anomaly, extract bounding boxes from mask
    if class_id == 1:
        # Find corresponding mask file
        mask_filename = image_filename.replace('.jpg', '_mask.png')
        mask_path = os.path.join(mask_images_dir, mask_filename)
        
        if os.path.exists(mask_path):
            # Extract bounding boxes
            bboxes = extract_bounding_boxes(mask_path)
            
            # Get image dimensions
            img = cv2.imread(os.path.join(training_images_dir, image_filename))
            img_height, img_width = img.shape[:2]
            
            # Convert each bbox to YOLO format
            yolo_annotations = []
            for bbox in bboxes:
                # Normalize coordinates
                normalized_bbox = normalize_coordinates(bbox, img_width, img_height)
                
                # Convert to YOLO format string
                yolo_line = convert_to_yolo_format(normalized_bbox, class_id)
                yolo_annotations.append(yolo_line)
            
            annotations[image_filename] = yolo_annotations
            
            if (i + 1) % 10 == 0:
                print(f"  Processed {i + 1}/{len(manifest_records)} images...")
        else:
            print(f"⚠️  Warning: Mask file not found for {image_filename}")
    else:
        # Normal image - no annotations needed (or empty annotation file)
        annotations[image_filename] = []

print(f"\n✅ Extracted bounding boxes for {len(annotations)} images")
print(f"   Images with defects: {sum(1 for v in annotations.values() if len(v) > 0)}")
print(f"   Normal images: {sum(1 for v in annotations.values() if len(v) == 0)}")

### 3.3 Write YOLO Detection Annotations

Write the YOLO format annotation files (.txt) for each image.

In [None]:
# Create output directory for annotations
annotations_dir = 'yolo-dataset/labels/train'
os.makedirs(annotations_dir, exist_ok=True)

# Write YOLO annotations
print("📝 Writing YOLO annotation files...")
write_yolo_annotations(annotations, annotations_dir)

# Verify annotations were written
annotation_files = list(Path(annotations_dir).glob('*.txt'))
print(f"✅ Wrote {len(annotation_files)} annotation files to {annotations_dir}")

# Display sample annotation
if annotation_files:
    sample_file = annotation_files[0]
    print(f"\n📄 Sample annotation file: {sample_file.name}")
    with open(sample_file, 'r') as f:
        content = f.read()
        if content:
            print(f"   Content: {content.strip()}")
        else:
            print("   (Empty - normal image with no defects)")

### 3.4 Create data.yaml Configuration

Generate the YOLO dataset configuration file specifying class names and dataset paths.

In [None]:
# Define class names
class_names = ['normal', 'anomaly']

# Create data.yaml
data_yaml_path = 'yolo-dataset/data.yaml'
print("📝 Creating data.yaml configuration...")
create_data_yaml(class_names, data_yaml_path)

# Display the configuration
print(f"✅ Created {data_yaml_path}")
print("\n📄 Configuration content:")
with open(data_yaml_path, 'r') as f:
    print(f.read())

## 3.6 YOLO Segmentation Format Conversion (Optional)

This section converts the Lookout for Vision annotations to YOLO segmentation format.
YOLO segmentation uses polygon coordinates instead of bounding boxes for more precise defect localization.

**Note:** Skip this section if you only want to train detection models. Segmentation models require more computational resources but provide pixel-level accuracy.

### 3.6.1 Process Manifest and Extract Polygons

Extract polygon coordinates from segmentation masks using contour detection and Douglas-Peucker approximation.

In [None]:
# Import segmentation functions
from yolo_format_converter import (
    extract_polygons,
    approximate_polygon,
    convert_to_yolo_segment_format
)

# Define paths for segmentation
manifest_path_seg = 'cookie-dataset/dataset-files/manifests/output.manifest'
mask_dir_seg = 'cookie-dataset/dataset-files/mask-images'
output_dir_seg = 'yolo-dataset-segmentation/labels/train'

# Create output directory
os.makedirs(output_dir_seg, exist_ok=True)

# Read manifest
manifest_records_seg = read_manifest(manifest_path_seg)

print(f"\u2705 Loaded {len(manifest_records_seg)} records from manifest")
print(f"\u2705 Output directory: {output_dir_seg}")

In [None]:
# Process each record and extract polygons
segmentation_annotations = {}
epsilon = 0.01  # Douglas-Peucker approximation parameter

for record in manifest_records_seg:
    # Get image filename from source-ref
    source_ref = record['source-ref']
    image_filename = os.path.basename(source_ref)
    
    # Get class label
    class_id = record['anomaly-label']
    
    # Skip normal images (no defects to segment)
    if class_id == 0:
        segmentation_annotations[image_filename] = []
        continue
    
    # Get corresponding mask file
    mask_filename = image_filename.replace('.jpg', '_mask.png')
    mask_path = os.path.join(mask_dir_seg, mask_filename)
    
    if not os.path.exists(mask_path):
        print(f"\u26a0\ufe0f  Warning: Mask not found for {image_filename}")
        segmentation_annotations[image_filename] = []
        continue
    
    # Extract polygons from mask
    polygons = extract_polygons(mask_path)
    
    # Get image dimensions
    import cv2
    img = cv2.imread(os.path.join('cookie-dataset/dataset-files/training-images', image_filename))
    img_height, img_width = img.shape[:2]
    
    # Convert each polygon to YOLO segmentation format
    annotation_lines = []
    for polygon in polygons:
        # Approximate polygon to reduce number of points
        approx_polygon = approximate_polygon(polygon, epsilon=epsilon)
        
        # Convert to YOLO format
        yolo_line = convert_to_yolo_segment_format(
            approx_polygon, class_id, img_width, img_height
        )
        annotation_lines.append(yolo_line)
    
    segmentation_annotations[image_filename] = annotation_lines

print(f"\u2705 Processed {len(segmentation_annotations)} images")
print(f"\u2705 Found {sum(len(lines) for lines in segmentation_annotations.values())} polygon annotations")

### 3.6.2 Write YOLO Segmentation Annotations

Write polygon annotations to .txt files in YOLO segmentation format.

In [None]:
# Write segmentation annotations to files
write_yolo_annotations(segmentation_annotations, output_dir_seg)

# Count annotation files
annotation_files = [f for f in os.listdir(output_dir_seg) if f.endswith('.txt')]

print(f"\u2705 Created {len(annotation_files)} annotation files")
print(f"\u2705 Annotations saved to: {output_dir_seg}")

# Display sample annotation
if annotation_files:
    sample_file = os.path.join(output_dir_seg, annotation_files[0])
    with open(sample_file, 'r') as f:
        sample_content = f.read()
    print(f"\n\ud83d\udd0d Sample annotation ({annotation_files[0]}):")
    print(sample_content[:200] + '...' if len(sample_content) > 200 else sample_content)

### 3.6.3 Create data.yaml for Segmentation

Generate the YOLO dataset configuration file for segmentation training.

In [None]:
# Create data.yaml for segmentation
data_yaml_path_seg = 'yolo-dataset-segmentation/data.yaml'
class_names_seg = ['normal', 'anomaly']

create_data_yaml(
    class_names=class_names_seg,
    output_path=data_yaml_path_seg,
    train_path='images/train',
    val_path='images/val'
)

print(f"\u2705 Created data.yaml for segmentation: {data_yaml_path_seg}")

# Display contents
with open(data_yaml_path_seg, 'r') as f:
    yaml_content = f.read()

print(f"\n\ud83d\udd0d data.yaml contents:")
print(yaml_content)

### 3.6.4 Copy Training Images for Segmentation

Copy training images to the segmentation dataset directory.

In [None]:
import shutil

# Create images directory for segmentation
images_dir_seg = 'yolo-dataset-segmentation/images/train'
os.makedirs(images_dir_seg, exist_ok=True)

# Copy all training images
source_images_dir = 'cookie-dataset/dataset-files/training-images'
image_files = [f for f in os.listdir(source_images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

for image_file in image_files:
    src = os.path.join(source_images_dir, image_file)
    dst = os.path.join(images_dir_seg, image_file)
    shutil.copy2(src, dst)

print(f"\u2705 Copied {len(image_files)} training images to {images_dir_seg}")
print(f"\n\ud83c\udf89 YOLO segmentation dataset ready!")
print(f"   Dataset location: yolo-dataset-segmentation/")
print(f"   Images: {len(image_files)}")
print(f"   Annotations: {len(annotation_files)}")

### 3.5 Copy Training Images

Copy training images to the YOLO dataset structure.

In [None]:
import shutil

# Create images directory
images_dir = 'yolo-dataset/images/train'
os.makedirs(images_dir, exist_ok=True)

# Copy training images
print("📁 Copying training images...")
training_images = list(Path(training_images_dir).glob('*.jpg'))

for img_path in training_images:
    dest_path = os.path.join(images_dir, img_path.name)
    shutil.copy2(img_path, dest_path)

print(f"✅ Copied {len(training_images)} images to {images_dir}")

# Verify dataset structure
print("\n📊 YOLO Dataset Structure:")
print(f"   Images: {len(list(Path(images_dir).glob('*.jpg')))} files")
print(f"   Labels: {len(list(Path(annotations_dir).glob('*.txt')))} files")
print(f"   Config: data.yaml")
print("\n✅ YOLO detection format conversion complete!")

## 4. S3 Upload for Training Data

Upload the YOLO-formatted dataset to S3 for SageMaker training jobs.

### 4.1 Upload Training Images to S3

Upload all training images to the S3 training-images folder.

In [None]:
from botocore.exceptions import ClientError

# Define S3 paths for dataset
s3_dataset_prefix = f"{s3_prefix}/dataset"
s3_images_prefix = f"{s3_dataset_prefix}/images/train"
s3_labels_prefix = f"{s3_dataset_prefix}/labels/train"

# Upload training images
print("📤 Uploading training images to S3...")
images_to_upload = list(Path(images_dir).glob('*.jpg'))
uploaded_images = 0
failed_uploads = []

for img_path in images_to_upload:
    s3_key = f"{s3_images_prefix}/{img_path.name}"
    
    try:
        s3_client.upload_file(
            str(img_path),
            default_bucket,
            s3_key
        )
        uploaded_images += 1
        
        # Progress indicator
        if uploaded_images % 10 == 0:
            print(f"  Uploaded {uploaded_images}/{len(images_to_upload)} images...")
            
    except ClientError as e:
        error_code = e.response['Error']['Code']
        error_msg = e.response['Error']['Message']
        failed_uploads.append({
            'file': str(img_path),
            'error': f"{error_code}: {error_msg}",
            's3_destination': f"s3://{default_bucket}/{s3_key}"
        })
        print(f"❌ S3 upload failed: {error_code}")
        print(f"   File: {img_path}")
        print(f"   Destination: s3://{default_bucket}/{s3_key}")
        print(f"   Error: {error_msg}")

# Display results
if failed_uploads:
    print(f"\n⚠️  Upload completed with errors:")
    print(f"   Successfully uploaded: {uploaded_images}/{len(images_to_upload)} images")
    print(f"   Failed uploads: {len(failed_uploads)}")
    for failure in failed_uploads:
        print(f"      - {failure['file']}: {failure['error']}")
else:
    print(f"\n✅ Successfully uploaded {uploaded_images} training images")
    print(f"   S3 location: s3://{default_bucket}/{s3_images_prefix}/")

### 4.2 Upload YOLO Annotations to S3

Upload all YOLO .txt annotation files to the S3 annotations folder.

In [None]:
# Upload annotation files
print("📤 Uploading YOLO annotations to S3...")
annotations_to_upload = list(Path(annotations_dir).glob('*.txt'))
uploaded_annotations = 0
failed_annotation_uploads = []

for ann_path in annotations_to_upload:
    s3_key = f"{s3_labels_prefix}/{ann_path.name}"
    
    try:
        s3_client.upload_file(
            str(ann_path),
            default_bucket,
            s3_key
        )
        uploaded_annotations += 1
        
        # Progress indicator
        if uploaded_annotations % 10 == 0:
            print(f"  Uploaded {uploaded_annotations}/{len(annotations_to_upload)} annotations...")
            
    except ClientError as e:
        error_code = e.response['Error']['Code']
        error_msg = e.response['Error']['Message']
        failed_annotation_uploads.append({
            'file': str(ann_path),
            'error': f"{error_code}: {error_msg}",
            's3_destination': f"s3://{default_bucket}/{s3_key}"
        })
        print(f"❌ S3 upload failed: {error_code}")
        print(f"   File: {ann_path}")
        print(f"   Destination: s3://{default_bucket}/{s3_key}")
        print(f"   Error: {error_msg}")

# Display results
if failed_annotation_uploads:
    print(f"\n⚠️  Upload completed with errors:")
    print(f"   Successfully uploaded: {uploaded_annotations}/{len(annotations_to_upload)} annotations")
    print(f"   Failed uploads: {len(failed_annotation_uploads)}")
    for failure in failed_annotation_uploads:
        print(f"      - {failure['file']}: {failure['error']}")
else:
    print(f"\n✅ Successfully uploaded {uploaded_annotations} annotation files")
    print(f"   S3 location: s3://{default_bucket}/{s3_labels_prefix}/")

### 4.3 Upload data.yaml Configuration to S3

Upload the YOLO dataset configuration file to S3.

In [None]:
# Upload data.yaml
print("📤 Uploading data.yaml to S3...")
data_yaml_s3_key = f"{s3_dataset_prefix}/data.yaml"

try:
    s3_client.upload_file(
        data_yaml_path,
        default_bucket,
        data_yaml_s3_key
    )
    print(f"✅ Successfully uploaded data.yaml")
    print(f"   S3 location: s3://{default_bucket}/{data_yaml_s3_key}")
    
except ClientError as e:
    error_code = e.response['Error']['Code']
    error_msg = e.response['Error']['Message']
    print(f"❌ S3 upload failed: {error_code}")
    print(f"   File: {data_yaml_path}")
    print(f"   Destination: s3://{default_bucket}/{data_yaml_s3_key}")
    print(f"   Error: {error_msg}")
    raise Exception(
        f"Failed to upload data.yaml to S3.\n"
        f"File: {data_yaml_path}\n"
        f"Bucket: {default_bucket}\n"
        f"Key: {data_yaml_s3_key}\n"
        f"Error: {error_code} - {error_msg}"
    )

### 4.4 Verify Uploads and Display S3 URIs

Verify all files were uploaded successfully and display the S3 URIs for training.

In [None]:
# Verify uploads by listing S3 objects
print("🔍 Verifying S3 uploads...\n")

# Check images
try:
    images_response = s3_client.list_objects_v2(
        Bucket=default_bucket,
        Prefix=s3_images_prefix,
        MaxKeys=1000
    )
    s3_image_count = images_response.get('KeyCount', 0)
    print(f"✅ Images in S3: {s3_image_count} files")
    
except ClientError as e:
    print(f"❌ Failed to list images in S3: {e.response['Error']['Code']}")
    print(f"   Bucket: {default_bucket}")
    print(f"   Prefix: {s3_images_prefix}")
    s3_image_count = 0

# Check annotations
try:
    labels_response = s3_client.list_objects_v2(
        Bucket=default_bucket,
        Prefix=s3_labels_prefix,
        MaxKeys=1000
    )
    s3_label_count = labels_response.get('KeyCount', 0)
    print(f"✅ Annotations in S3: {s3_label_count} files")
    
except ClientError as e:
    print(f"❌ Failed to list annotations in S3: {e.response['Error']['Code']}")
    print(f"   Bucket: {default_bucket}")
    print(f"   Prefix: {s3_labels_prefix}")
    s3_label_count = 0

# Check data.yaml
try:
    s3_client.head_object(
        Bucket=default_bucket,
        Key=data_yaml_s3_key
    )
    print(f"✅ data.yaml exists in S3")
    data_yaml_exists = True
    
except ClientError as e:
    if e.response['Error']['Code'] == '404':
        print(f"❌ data.yaml not found in S3")
        print(f"   Expected location: s3://{default_bucket}/{data_yaml_s3_key}")
    else:
        print(f"❌ Failed to check data.yaml: {e.response['Error']['Code']}")
    data_yaml_exists = False

# Display S3 URIs for training
print("\n📍 S3 Dataset URIs for Training:")
print(f"   Dataset root: s3://{default_bucket}/{s3_dataset_prefix}/")
print(f"   Training images: s3://{default_bucket}/{s3_images_prefix}/")
print(f"   Annotations: s3://{default_bucket}/{s3_labels_prefix}/")
print(f"   Configuration: s3://{default_bucket}/{data_yaml_s3_key}")

# Summary
print("\n📊 Upload Summary:")
print(f"   Local images: {len(images_to_upload)}")
print(f"   S3 images: {s3_image_count}")
print(f"   Local annotations: {len(annotations_to_upload)}")
print(f"   S3 annotations: {s3_label_count}")
print(f"   data.yaml: {'Present' if data_yaml_exists else 'Missing'}")

# Check for mismatches
if s3_image_count != len(images_to_upload):
    print(f"\n⚠️  Warning: Image count mismatch!")
    print(f"   Expected {len(images_to_upload)} images, found {s3_image_count} in S3")

if s3_label_count != len(annotations_to_upload):
    print(f"\n⚠️  Warning: Annotation count mismatch!")
    print(f"   Expected {len(annotations_to_upload)} annotations, found {s3_label_count} in S3")

if not data_yaml_exists:
    print(f"\n❌ Error: data.yaml is missing from S3!")
    raise FileNotFoundError(
        f"data.yaml not found in S3.\n"
        f"Expected location: s3://{default_bucket}/{data_yaml_s3_key}\n"
        f"This file is required for YOLO training."
    )

if (s3_image_count == len(images_to_upload) and 
    s3_label_count == len(annotations_to_upload) and 
    data_yaml_exists):
    print("\n✅ All files uploaded successfully! Dataset ready for training.")

## 5. Model Training

Train YOLOv8 models using SageMaker training jobs with custom PyTorch training scripts.

### 5.1 Create PyTorch Estimator

Configure the PyTorch estimator with the custom training script and hyperparameters.

In [None]:
from sagemaker.pytorch import PyTorch

# Training script configuration
training_script = 'yolo_training.py'

# Verify training script exists
if not os.path.exists(training_script):
    raise FileNotFoundError(
        f"Training script not found: {training_script}\n"
        f"Expected location: {os.path.abspath(training_script)}\n"
        f"Please ensure the training script is in the current directory."
    )

# Hyperparameters for YOLO training
hyperparameters = {
    'model-size': 'yolov8n',  # Options: yolov8n, yolov8s, yolov8m, yolov8l, yolov8x
    'task': 'detect',          # Options: detect, segment
    'epochs': 50,              # Number of training epochs
    'batch-size': 16,          # Training batch size
    'img-size': 640,           # Input image size
}

# Create PyTorch estimator
print("🚀 Creating PyTorch estimator...\n")

pytorch_estimator = PyTorch(
    entry_point=training_script,
    role=role,
    framework_version='2.0.0',
    py_version='py310',
    instance_count=1,
    instance_type='ml.g4dn.4xlarge',  # GPU instance for training
    volume_size=20,  # GB
    max_run=7200,  # Maximum runtime in seconds (2 hours)
    hyperparameters=hyperparameters,
    output_path=s3_paths['training_output'],
    sagemaker_session=sagemaker_session,
    base_job_name='yolo-cookie-training'
)

print("✅ PyTorch estimator created successfully")
print("\n📊 Estimator Configuration:")
print(f"   Framework: PyTorch {pytorch_estimator.framework_version}")
print(f"   Python version: {pytorch_estimator.py_version}")
print(f"   Instance type: {pytorch_estimator.instance_type}")
print(f"   Instance count: {pytorch_estimator.instance_count}")
print(f"   Volume size: {pytorch_estimator.volume_size} GB")
print(f"   Max runtime: {pytorch_estimator.max_run} seconds ({pytorch_estimator.max_run // 3600} hours)")
print(f"   Output path: {pytorch_estimator.output_path}")

print("\n🎯 Hyperparameters:")
for key, value in hyperparameters.items():
    print(f"   {key}: {value}")

### 5.2 Configure Input Data Channels

Specify the S3 URI for the training data and configure the data distribution type.

In [None]:
from sagemaker.inputs import TrainingInput

# Define S3 input data path
training_data_s3_uri = f"s3://{default_bucket}/{s3_dataset_prefix}/"

# Create training input configuration
training_input = TrainingInput(
    s3_data=training_data_s3_uri,
    distribution='FullyReplicated',  # Copy all data to each instance
    content_type='application/x-directory',
    s3_data_type='S3Prefix'
)

# Create input channels dictionary
input_channels = {
    'training': training_input
}

print("✅ Input data channels configured")
print("\n📍 Training Data Configuration:")
print(f"   S3 URI: {training_data_s3_uri}")
print(f"   Distribution: FullyReplicated")
print(f"   Content type: application/x-directory")
print(f"   S3 data type: S3Prefix")

# Verify S3 data exists
print("\n🔍 Verifying training data in S3...")
try:
    response = s3_client.list_objects_v2(
        Bucket=default_bucket,
        Prefix=s3_dataset_prefix,
        MaxKeys=10
    )
    
    if response.get('KeyCount', 0) > 0:
        print(f"✅ Training data found in S3")
        print(f"   Files found: {response.get('KeyCount', 0)}+ objects")
    else:
        print(f"❌ No training data found in S3")
        print(f"   Location: {training_data_s3_uri}")
        raise FileNotFoundError(
            f"No training data found in S3.\n"
            f"Expected location: {training_data_s3_uri}\n"
            f"Please ensure the dataset has been uploaded to S3."
        )
        
except ClientError as e:
    print(f"❌ Failed to verify training data: {e.response['Error']['Code']}")
    print(f"   Bucket: {default_bucket}")
    print(f"   Prefix: {s3_dataset_prefix}")
    raise

### 5.3 Launch Training Job

Generate a unique training job name and launch the SageMaker training job.

In [None]:
# Generate unique training job name with timestamp
training_timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
training_job_name = f"yolo-cookie-{hyperparameters['task']}-{training_timestamp}"

print(f"🚀 Launching training job: {training_job_name}\n")

# Launch training job
try:
    pytorch_estimator.fit(
        inputs=input_channels,
        job_name=training_job_name,
        wait=False  # Don't wait for completion in this cell
    )
    
    # Get training job details
    training_job_arn = pytorch_estimator.latest_training_job.describe()['TrainingJobArn']
    training_job_status = pytorch_estimator.latest_training_job.describe()['TrainingJobStatus']
    
    print("✅ Training job launched successfully!\n")
    print("📊 Training Job Details:")
    print(f"   Job name: {training_job_name}")
    print(f"   Job ARN: {training_job_arn}")
    print(f"   Status: {training_job_status}")
    print(f"   Model: {hyperparameters['model-size']}")
    print(f"   Task: {hyperparameters['task']}")
    print(f"   Epochs: {hyperparameters['epochs']}")
    print(f"   Batch size: {hyperparameters['batch-size']}")
    
    print(f"\n🔗 View training job in SageMaker Console:")
    console_url = f"https://console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{training_job_name}"
    print(f"   {console_url}")
    
    print(f"\nℹ️  Training job is running. Use the next cell to monitor progress.")
    
except Exception as e:
    print(f"❌ Failed to launch training job: {str(e)}")
    print(f"\n🔍 Troubleshooting:")
    print(f"   - Verify the training script exists: {training_script}")
    print(f"   - Verify training data is in S3: {training_data_s3_uri}")
    print(f"   - Check IAM role permissions: {role}")
    print(f"   - Verify instance type availability: {pytorch_estimator.instance_type}")
    raise

### 5.4 Monitor Training Job

Poll the training job status and display progress until completion.

In [None]:
import time
from botocore.exceptions import ClientError

# Create SageMaker client for monitoring
sagemaker_client = boto3.client('sagemaker', region_name=region)

print(f"\ud83d\udd0d Monitoring training job: {training_job_name}\n")
print("Status updates every 60 seconds...\n")

# Initialize monitoring variables
previous_status = None
start_time = time.time()
dots_printed = 0

while True:
    try:
        # Get training job status
        response = sagemaker_client.describe_training_job(
            TrainingJobName=training_job_name
        )
        
        current_status = response['TrainingJobStatus']
        
        # Display status change
        if current_status != previous_status:
            if previous_status is not None:
                print()  # New line after dots
            
            elapsed_time = int(time.time() - start_time)
            elapsed_minutes = elapsed_time // 60
            elapsed_seconds = elapsed_time % 60
            
            print(f"[{elapsed_minutes:02d}:{elapsed_seconds:02d}] Status: {current_status}")
            previous_status = current_status
            dots_printed = 0
        
        # Check if training is complete
        if current_status in ['Completed', 'Failed', 'Stopped']:
            print()  # New line after dots
            
            # Display completion details
            if current_status == 'Completed':
                print("\n\u2705 Training job completed successfully!\n")
                
                # Get model artifact location
                model_artifact_s3_uri = response['ModelArtifacts']['S3ModelArtifacts']
                
                # Display training metrics if available
                if 'FinalMetricDataList' in response:
                    print("📊 Final Training Metrics:")
                    for metric in response['FinalMetricDataList']:
                        print(f"   {metric['MetricName']}: {metric['Value']:.4f}")
                    print()
                
                # Display training time
                training_time_seconds = response.get('TrainingTimeInSeconds', 0)
                training_hours = training_time_seconds // 3600
                training_minutes = (training_time_seconds % 3600) // 60
                training_seconds = training_time_seconds % 60
                
                print("⏱️  Training Duration:")
                if training_hours > 0:
                    print(f"   {training_hours}h {training_minutes}m {training_seconds}s")
                else:
                    print(f"   {training_minutes}m {training_seconds}s")
                
                # Display billable time
                billable_time_seconds = response.get('BillableTimeInSeconds', 0)
                billable_minutes = billable_time_seconds // 60
                billable_seconds = billable_time_seconds % 60
                print(f"   Billable time: {billable_minutes}m {billable_seconds}s")
                
                # Display model artifact location
                print(f"\n📦 Trained Model Artifact:")
                print(f"   S3 URI: {model_artifact_s3_uri}")
                
                # Store for later use
                trained_model_s3_uri = model_artifact_s3_uri
                
            elif current_status == 'Failed':
                print("\n\u274c Training job failed!\n")
                
                # Get failure reason
                failure_reason = response.get('FailureReason', 'Unknown failure reason')
                print(f"🚫 Failure Reason:")
                print(f"   {failure_reason}\n")
                
                # Provide troubleshooting guidance
                print("🔍 Common Failure Causes:")
                print("   1. Insufficient instance resources")
                print("      - Try reducing batch size or image size")
                print("      - Use a larger instance type")
                print("   2. Invalid data format")
                print("      - Verify YOLO annotation format is correct")
                print("      - Check data.yaml configuration")
                print("   3. Training timeout exceeded")
                print("      - Reduce number of epochs")
                print("      - Increase max_run parameter")
                print("   4. Missing dependencies")
                print("      - Verify training script has all required imports")
                print("      - Check PyTorch and Ultralytics versions")
                
                # Display CloudWatch logs link
                print(f"\n🔗 View detailed logs in CloudWatch:")
                log_group = f"/aws/sagemaker/TrainingJobs"
                log_stream = training_job_name
                logs_url = (
                    f"https://console.aws.amazon.com/cloudwatch/home?region={region}"
                    f"#logsV2:log-groups/log-group/{log_group.replace('/', '$252F')}"
                    f"/log-events/{log_stream.replace('/', '$252F')}"
                )
                print(f"   {logs_url}")
                
            elif current_status == 'Stopped':
                print("\n⚠️  Training job was stopped.\n")
                print("The training job was manually stopped before completion.")
            
            # Exit monitoring loop
            break
        
        # Display progress indicator (dots) for in-progress status
        if current_status == 'InProgress':
            print('.', end='', flush=True)
            dots_printed += 1
            
            # New line every 60 dots (60 minutes)
            if dots_printed >= 60:
                print()
                dots_printed = 0
        
        # Wait 60 seconds before next poll
        time.sleep(60)
        
    except ClientError as e:
        error_code = e.response['Error']['Code']
        error_msg = e.response['Error']['Message']
        print(f"\n\u274c Error monitoring training job: {error_code}")
        print(f"   Job name: {training_job_name}")
        print(f"   Error: {error_msg}")
        
        if error_code == 'ValidationException':
            print(f"\n🔍 The training job may not exist or the name is incorrect.")
            print(f"   Verify job name: {training_job_name}")
        
        raise
    
    except KeyboardInterrupt:
        print("\n\n⚠️  Monitoring interrupted by user.")
        print(f"   Training job is still running: {training_job_name}")
        print(f"   You can resume monitoring by re-running this cell.")
        break
    
    except Exception as e:
        print(f"\n\u274c Unexpected error: {str(e)}")
        print(f"   Training job: {training_job_name}")
        raise

## 6. Model Preparation for Compilation

Prepare the trained YOLO model for SageMaker Neo compilation by downloading, extracting, and repackaging the model artifacts.

### 6.1 Download Trained Model from S3

Download the model.tar.gz file from S3 to the local directory.

In [None]:
import tarfile
import torch
from pathlib import Path

# Define local directories for model preparation
model_prep_dir = 'model-preparation'
downloaded_model_dir = os.path.join(model_prep_dir, 'downloaded')
extracted_model_dir = os.path.join(model_prep_dir, 'extracted')
compiled_model_dir = os.path.join(model_prep_dir, 'for-compilation')

# Create directories
os.makedirs(downloaded_model_dir, exist_ok=True)
os.makedirs(extracted_model_dir, exist_ok=True)
os.makedirs(compiled_model_dir, exist_ok=True)

print("📥 Downloading trained model from S3...\n")

# Check if trained_model_s3_uri is defined (from training monitoring cell)
if 'trained_model_s3_uri' not in locals():
    print("⚠️  Warning: trained_model_s3_uri not found.")
    print("   This variable should be set by the training monitoring cell.")
    print("   Please provide the S3 URI of your trained model:\n")
    
    # Example format for user to fill in
    trained_model_s3_uri = input("Enter S3 URI (e.g., s3://bucket/path/model.tar.gz): ")
    
    if not trained_model_s3_uri or not trained_model_s3_uri.startswith('s3://'):
        raise ValueError(
            f"Invalid S3 URI provided: {trained_model_s3_uri}\n"
            f"Expected format: s3://bucket-name/path/to/model.tar.gz"
        )

# Parse S3 URI
s3_uri_parts = trained_model_s3_uri.replace('s3://', '').split('/', 1)
model_bucket = s3_uri_parts[0]
model_key = s3_uri_parts[1]

# Local path for downloaded model
local_model_path = os.path.join(downloaded_model_dir, 'model.tar.gz')

print(f"📍 Source:")
print(f"   S3 URI: {trained_model_s3_uri}")
print(f"   Bucket: {model_bucket}")
print(f"   Key: {model_key}")
print(f"\n📍 Destination:")
print(f"   Local path: {local_model_path}\n")

# Download model from S3
try:
    s3_client.download_file(
        Bucket=model_bucket,
        Key=model_key,
        Filename=local_model_path
    )
    
    # Verify download
    if os.path.exists(local_model_path):
        file_size_mb = os.path.getsize(local_model_path) / (1024 * 1024)
        print(f"✅ Model downloaded successfully")
        print(f"   File size: {file_size_mb:.2f} MB")
        print(f"   Location: {local_model_path}")
    else:
        raise FileNotFoundError(
            f"Model file not found after download.\n"
            f"Expected location: {local_model_path}"
        )
        
except ClientError as e:
    error_code = e.response['Error']['Code']
    error_msg = e.response['Error']['Message']
    print(f"❌ S3 download failed: {error_code}")
    print(f"   Bucket: {model_bucket}")
    print(f"   Key: {model_key}")
    print(f"   Error: {error_msg}")
    
    if error_code == 'NoSuchKey':
        print(f"\n🔍 The model file does not exist at the specified S3 location.")
        print(f"   Verify the training job completed successfully.")
        print(f"   Check the S3 URI: {trained_model_s3_uri}")
    elif error_code == 'NoSuchBucket':
        print(f"\n🔍 The S3 bucket does not exist.")
        print(f"   Verify the bucket name: {model_bucket}")
    
    raise Exception(
        f"Failed to download model from S3.\n"
        f"S3 URI: {trained_model_s3_uri}\n"
        f"Error: {error_code} - {error_msg}"
    )

except Exception as e:
    print(f"❌ Unexpected error during download: {str(e)}")
    raise

### 6.2 Extract and Prepare Model for Compilation

Extract the tar.gz file, locate the YOLO model weights, read metadata, and create a compilation-ready package.

In [None]:
print("📦 Extracting model archive...\n")

# Extract tar.gz file
try:
    with tarfile.open(local_model_path, 'r:gz') as tar:
        tar.extractall(path=extracted_model_dir)
    
    print(f"✅ Model archive extracted")
    print(f"   Location: {extracted_model_dir}\n")
    
except tarfile.TarError as e:
    print(f"❌ Failed to extract tar.gz file: {str(e)}")
    print(f"   File: {local_model_path}")
    raise Exception(
        f"Failed to extract model archive.\n"
        f"File: {local_model_path}\n"
        f"Error: {str(e)}\n"
        f"The file may be corrupted or not a valid tar.gz archive."
    )

# List extracted files
print("📁 Extracted files:")
extracted_files = []
for root, dirs, files in os.walk(extracted_model_dir):
    for file in files:
        file_path = os.path.join(root, file)
        rel_path = os.path.relpath(file_path, extracted_model_dir)
        file_size_kb = os.path.getsize(file_path) / 1024
        extracted_files.append(rel_path)
        print(f"   {rel_path} ({file_size_kb:.1f} KB)")

if not extracted_files:
    raise FileNotFoundError(
        f"No files found in extracted archive.\n"
        f"Extraction directory: {extracted_model_dir}\n"
        f"The archive may be empty or extraction failed."
    )

print(f"\n🔍 Locating YOLO model weights...\n")

# Look for YOLO model files (best.pt, yolo.pt, or last.pt)
model_candidates = ['best.pt', 'yolo.pt', 'last.pt']
model_file_path = None

for root, dirs, files in os.walk(extracted_model_dir):
    for candidate in model_candidates:
        if candidate in files:
            model_file_path = os.path.join(root, candidate)
            print(f"✅ Found model weights: {candidate}")
            print(f"   Path: {model_file_path}")
            break
    if model_file_path:
        break

if not model_file_path:
    print(f"❌ Model weights file not found")
    print(f"   Searched for: {', '.join(model_candidates)}")
    print(f"   In directory: {extracted_model_dir}")
    print(f"   Available files: {', '.join(extracted_files)}")
    raise FileNotFoundError(
        f"YOLO model weights not found in extracted archive.\n"
        f"Expected files: {', '.join(model_candidates)}\n"
        f"Found files: {', '.join(extracted_files)}\n"
        f"The training output may not contain the expected model files."
    )

print(f"\n🔍 Reading model metadata...\n")

# Read model metadata to extract input shape
try:
    # Load PyTorch model
    model_data = torch.load(model_file_path, map_location='cpu')
    
    # Extract input shape from model metadata
    # YOLOv8 models typically store this in the model dict
    input_shape = None
    
    # Try to get input shape from various possible locations
    if isinstance(model_data, dict):
        # Check for metadata file
        metadata_path = os.path.join(extracted_model_dir, 'metadata.json')
        if os.path.exists(metadata_path):
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
                input_shape = metadata.get('input_shape')
                print(f"✅ Found input shape in metadata.json: {input_shape}")
        
        # If not in metadata file, try to infer from model
        if not input_shape:
            # Default YOLO input shape (can be overridden by hyperparameters)
            img_size = hyperparameters.get('img-size', 640)
            input_shape = [1, 3, img_size, img_size]
            print(f"ℹ️  Using default YOLO input shape: {input_shape}")
            print(f"   (Based on img-size hyperparameter: {img_size})")
    
    if not input_shape:
        # Fallback to standard YOLO input
        input_shape = [1, 3, 640, 640]
        print(f"⚠️  Could not determine input shape from model")
        print(f"   Using standard YOLO default: {input_shape}")
    
    print(f"\n📊 Model Information:")
    print(f"   Input shape: {input_shape}")
    print(f"   Format: [batch_size, channels, height, width]")
    
except Exception as e:
    print(f"⚠️  Warning: Could not read model metadata: {str(e)}")
    print(f"   Using default input shape: [1, 3, 640, 640]")
    input_shape = [1, 3, 640, 640]

print(f"\n📦 Creating compilation-ready package...\n")

# Create new tar.gz with only the model weights file
compilation_model_path = os.path.join(compiled_model_dir, 'model.tar.gz')

try:
    with tarfile.open(compilation_model_path, 'w:gz') as tar:
        # Add model file with arcname to place it at root of archive
        tar.add(model_file_path, arcname=os.path.basename(model_file_path))
    
    # Verify the new archive
    if os.path.exists(compilation_model_path):
        file_size_mb = os.path.getsize(compilation_model_path) / (1024 * 1024)
        print(f"✅ Compilation-ready package created")
        print(f"   Location: {compilation_model_path}")
        print(f"   Size: {file_size_mb:.2f} MB")
        print(f"   Contents: {os.path.basename(model_file_path)}")
    else:
        raise FileNotFoundError(
            f"Failed to create compilation package.\n"
            f"Expected location: {compilation_model_path}"
        )
        
except Exception as e:
    print(f"❌ Failed to create compilation package: {str(e)}")
    raise

print(f"\n📤 Uploading prepared model to S3...\n")

# Upload prepared model to S3
compilation_model_s3_key = f"{s3_prefix}/models/model-for-compilation.tar.gz"
compilation_model_s3_uri = f"s3://{default_bucket}/{compilation_model_s3_key}"

try:
    s3_client.upload_file(
        compilation_model_path,
        default_bucket,
        compilation_model_s3_key
    )
    
    print(f"✅ Prepared model uploaded to S3")
    print(f"   S3 URI: {compilation_model_s3_uri}")
    
    # Create DataInputConfig for compilation
    data_input_config = json.dumps({"input_shape": input_shape})
    
    print(f"\n📊 Compilation Configuration:")
    print(f"   Model S3 URI: {compilation_model_s3_uri}")
    print(f"   DataInputConfig: {data_input_config}")
    print(f"   Framework: PYTORCH")
    print(f"   Framework Version: 2.0")
    
    print(f"\n✅ Model preparation complete! Ready for compilation.")
    
except ClientError as e:
    error_code = e.response['Error']['Code']
    error_msg = e.response['Error']['Message']
    print(f"❌ S3 upload failed: {error_code}")
    print(f"   File: {compilation_model_path}")
    print(f"   Destination: {compilation_model_s3_uri}")
    print(f"   Error: {error_msg}")
    raise Exception(
        f"Failed to upload prepared model to S3.\n"
        f"File: {compilation_model_path}\n"
        f"S3 URI: {compilation_model_s3_uri}\n"
        f"Error: {error_code} - {error_msg}"
    )

## 7. Model Compilation

Compile trained models for target platforms using SageMaker Neo.

### 7.1 Compile Model for Jetson Xavier GPU

Create and submit a compilation job for Jetson Xavier hardware with GPU acceleration.

In [None]:
# Verify required variables are available
if 'compilation_model_s3_uri' not in locals():
    print("⚠️  Warning: compilation_model_s3_uri not found.")
    print("   This variable should be set by the model preparation cells.")
    print("   Please provide the S3 URI of your prepared model:\n")
    
    compilation_model_s3_uri = input("Enter S3 URI (e.g., s3://bucket/path/model-for-compilation.tar.gz): ")
    
    if not compilation_model_s3_uri or not compilation_model_s3_uri.startswith('s3://'):
        raise ValueError(
            f"Invalid S3 URI provided: {compilation_model_s3_uri}\n"
            f"Expected format: s3://bucket-name/path/to/model-for-compilation.tar.gz"
        )

if 'data_input_config' not in locals():
    print("⚠️  Warning: data_input_config not found.")
    print("   Using default YOLO input shape: [1, 3, 640, 640]\n")
    data_input_config = json.dumps({"input_shape": [1, 3, 640, 640]})

print("🚀 Creating compilation job for Jetson Xavier GPU...\n")

# Generate unique compilation job name
compilation_timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
compilation_job_name_xavier = f"yolo-cookie-xavier-{compilation_timestamp}"

# Define target platform for Jetson Xavier
target_platform_xavier = {
    'Os': 'LINUX',
    'Arch': 'ARM64',
    'Accelerator': 'NVIDIA'
}

# Compiler options for Jetson Xavier
# Xavier uses CUDA 10.2, GPU code sm_72, TensorRT 8.2.1, FP16 precision
compiler_options_xavier = json.dumps({
    'cuda-ver': '10.2',
    'gpu-code': 'sm_72',
    'trt-ver': '8.2.1',
    'precision_mode': 'fp16',
    'jetson-platform': 'xavier'
})

# Output S3 location for compiled model
compilation_output_s3_xavier = f"{s3_paths['compilation_output']}/jetson-xavier/"

print("📊 Compilation Job Configuration:")
print(f"   Job name: {compilation_job_name_xavier}")
print(f"   Input model: {compilation_model_s3_uri}")
print(f"   Framework: PYTORCH 2.0")
print(f"   DataInputConfig: {data_input_config}")
print(f"\n🎯 Target Platform:")
print(f"   OS: {target_platform_xavier['Os']}")
print(f"   Architecture: {target_platform_xavier['Arch']}")
print(f"   Accelerator: {target_platform_xavier['Accelerator']}")
print(f"\n⚙️  Compiler Options:")
compiler_opts = json.loads(compiler_options_xavier)
for key, value in compiler_opts.items():
    print(f"   {key}: {value}")
print(f"\n📍 Output Location:")
print(f"   S3 URI: {compilation_output_s3_xavier}")

# Create compilation job
print(f"\n🚀 Submitting compilation job...\n")

try:
    response_xavier = sagemaker_client.create_compilation_job(
        CompilationJobName=compilation_job_name_xavier,
        RoleArn=role,
        InputConfig={
            'S3Uri': compilation_model_s3_uri,
            'DataInputConfig': data_input_config,
            'Framework': 'PYTORCH',
            'FrameworkVersion': '2.0'
        },
        OutputConfig={
            'S3OutputLocation': compilation_output_s3_xavier,
            'TargetPlatform': target_platform_xavier,
            'CompilerOptions': compiler_options_xavier
        },
        StoppingCondition={
            'MaxRuntimeInSeconds': 900  # 15 minutes
        }
    )
    
    # Get compilation job ARN
    compilation_job_arn_xavier = response_xavier['CompilationJobArn']
    
    print("✅ Compilation job submitted successfully!\n")
    print("📊 Compilation Job Details:")
    print(f"   Job name: {compilation_job_name_xavier}")
    print(f"   Job ARN: {compilation_job_arn_xavier}")
    print(f"   Target: Jetson Xavier GPU")
    print(f"   Status: STARTING")
    
    print(f"\n🔗 View compilation job in SageMaker Console:")
    console_url = f"https://console.aws.amazon.com/sagemaker/home?region={region}#/compilation-jobs/{compilation_job_name_xavier}"
    print(f"   {console_url}")
    
    print(f"\nℹ️  Compilation job is running. Use the monitoring cell to track progress.")
    
except ClientError as e:
    error_code = e.response['Error']['Code']
    error_msg = e.response['Error']['Message']
    print(f"❌ Failed to create compilation job: {error_code}")
    print(f"   Job name: {compilation_job_name_xavier}")
    print(f"   Error: {error_msg}\n")
    
    print("🔍 Troubleshooting:")
    print("   1. Verify the prepared model exists in S3")
    print(f"      Model URI: {compilation_model_s3_uri}")
    print("   2. Check IAM role permissions for SageMaker Neo")
    print(f"      Role: {role}")
    print("   3. Verify DataInputConfig format")
    print(f"      Config: {data_input_config}")
    print("   4. Check compiler options compatibility")
    print(f"      Options: {compiler_options_xavier}")
    
    if error_code == 'ValidationException':
        print("\n⚠️  Validation Error: Check input parameters")
        print("   - Ensure model S3 URI is valid")
        print("   - Verify DataInputConfig JSON format")
        print("   - Check target platform configuration")
    
    raise Exception(
        f"Failed to create compilation job for Jetson Xavier.\n"
        f"Job name: {compilation_job_name_xavier}\n"
        f"Error: {error_code} - {error_msg}"
    )

except Exception as e:
    print(f"❌ Unexpected error: {str(e)}")
    print(f"   Job name: {compilation_job_name_xavier}")
    raise

### 7.2 Compile Model for x86_64 CPU

Create and submit a compilation job for x86_64 CPU hardware (standard server infrastructure).

In [None]:
# Verify required variables are available
if 'compilation_model_s3_uri' not in locals():
    print("⚠️  Warning: compilation_model_s3_uri not found.")
    print("   This variable should be set by the model preparation cells.")
    print("   Please provide the S3 URI of your prepared model:\n")
    
    compilation_model_s3_uri = input("Enter S3 URI (e.g., s3://bucket/path/model-for-compilation.tar.gz): ")
    
    if not compilation_model_s3_uri or not compilation_model_s3_uri.startswith('s3://'):
        raise ValueError(
            f"Invalid S3 URI provided: {compilation_model_s3_uri}\n"
            f"Expected format: s3://bucket-name/path/to/model-for-compilation.tar.gz"
        )

if 'data_input_config' not in locals():
    print("⚠️  Warning: data_input_config not found.")
    print("   Using default YOLO input shape: [1, 3, 640, 640]\n")
    data_input_config = json.dumps({"input_shape": [1, 3, 640, 640]})

print("🚀 Creating compilation job for x86_64 CPU...\n")

# Generate unique compilation job name
compilation_timestamp_x86 = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
compilation_job_name_x86 = f"yolo-cookie-x86-cpu-{compilation_timestamp_x86}"

# Define target platform for x86_64 CPU
target_platform_x86 = {
    'Os': 'LINUX',
    'Arch': 'X86_64'
}

# Output S3 location for compiled model
compilation_output_s3_x86 = f"{s3_paths['compilation_output']}/x86-64-cpu/"

print("📊 Compilation Job Configuration:")
print(f"   Job name: {compilation_job_name_x86}")
print(f"   Input model: {compilation_model_s3_uri}")
print(f"   Framework: PYTORCH 2.0")
print(f"   DataInputConfig: {data_input_config}")
print(f"\n🎯 Target Platform:")
print(f"   OS: {target_platform_x86['Os']}")
print(f"   Architecture: {target_platform_x86['Arch']}")
print(f"   Accelerator: None (CPU only)")
print(f"\n📍 Output Location:")
print(f"   S3 URI: {compilation_output_s3_x86}")

# Create compilation job
print(f"\n🚀 Submitting compilation job...\n")

try:
    response_x86 = sagemaker_client.create_compilation_job(
        CompilationJobName=compilation_job_name_x86,
        RoleArn=role,
        InputConfig={
            'S3Uri': compilation_model_s3_uri,
            'DataInputConfig': data_input_config,
            'Framework': 'PYTORCH',
            'FrameworkVersion': '2.0'
        },
        OutputConfig={
            'S3OutputLocation': compilation_output_s3_x86,
            'TargetPlatform': target_platform_x86
        },
        StoppingCondition={
            'MaxRuntimeInSeconds': 900  # 15 minutes
        }
    )
    
    # Get compilation job ARN
    compilation_job_arn_x86 = response_x86['CompilationJobArn']
    
    print("✅ Compilation job submitted successfully!\n")
    print("📊 Compilation Job Details:")
    print(f"   Job name: {compilation_job_name_x86}")
    print(f"   Job ARN: {compilation_job_arn_x86}")
    print(f"   Target: x86_64 CPU")
    print(f"   Status: STARTING")
    
    print(f"\n🔗 View compilation job in SageMaker Console:")
    console_url = f"https://console.aws.amazon.com/sagemaker/home?region={region}#/compilation-jobs/{compilation_job_name_x86}"
    print(f"   {console_url}")
    
    print(f"\nℹ️  Compilation job is running. Use the monitoring cell to track progress.")
    
except ClientError as e:
    error_code = e.response['Error']['Code']
    error_msg = e.response['Error']['Message']
    print(f"❌ Failed to create compilation job: {error_code}")
    print(f"   Job name: {compilation_job_name_x86}")
    print(f"   Error: {error_msg}\n")
    
    print("🔍 Troubleshooting:")
    print("   1. Verify the prepared model exists in S3")
    print(f"      Model URI: {compilation_model_s3_uri}")
    print("   2. Check IAM role permissions for SageMaker Neo")
    print(f"      Role: {role}")
    print("   3. Verify DataInputConfig format")
    print(f"      Config: {data_input_config}")
    print("   4. Check target platform configuration")
    print(f"      Platform: {target_platform_x86}")
    
    if error_code == 'ValidationException':
        print("\n⚠️  Validation Error: Check input parameters")
        print("   - Ensure model S3 URI is valid")
        print("   - Verify DataInputConfig JSON format")
        print("   - Check target platform configuration")
    
    raise Exception(
        f"Failed to create compilation job for x86_64 CPU.\n"
        f"Job name: {compilation_job_name_x86}\n"
        f"Error: {error_code} - {error_msg}"
    )

except Exception as e:
    print(f"❌ Unexpected error: {str(e)}")
    print(f"   Job name: {compilation_job_name_x86}")
    raise

### 7.3 Compile Model for ARM64 CPU

Create and submit a compilation job for ARM64 CPU hardware (ARM-based edge devices without GPU).

In [None]:
# Verify required variables are available
if 'compilation_model_s3_uri' not in locals():
    print("⚠️  Warning: compilation_model_s3_uri not found.")
    print("   This variable should be set by the model preparation cells.")
    print("   Please provide the S3 URI of your prepared model:\n")
    
    compilation_model_s3_uri = input("Enter S3 URI (e.g., s3://bucket/path/model-for-compilation.tar.gz): ")
    
    if not compilation_model_s3_uri or not compilation_model_s3_uri.startswith('s3://'):
        raise ValueError(
            f"Invalid S3 URI provided: {compilation_model_s3_uri}\n"
            f"Expected format: s3://bucket-name/path/to/model-for-compilation.tar.gz"
        )

if 'data_input_config' not in locals():
    print("⚠️  Warning: data_input_config not found.")
    print("   Using default YOLO input shape: [1, 3, 640, 640]\n")
    data_input_config = json.dumps({"input_shape": [1, 3, 640, 640]})

print("🚀 Creating compilation job for ARM64 CPU...\n")

# Generate unique compilation job name
compilation_timestamp_arm64 = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
compilation_job_name_arm64 = f"yolo-cookie-arm64-cpu-{compilation_timestamp_arm64}"

# Define target platform for ARM64 CPU
target_platform_arm64 = {
    'Os': 'LINUX',
    'Arch': 'ARM64'
}

# Output S3 location for compiled model
compilation_output_s3_arm64 = f"{s3_paths['compilation_output']}/arm64-cpu/"

print("📊 Compilation Job Configuration:")
print(f"   Job name: {compilation_job_name_arm64}")
print(f"   Input model: {compilation_model_s3_uri}")
print(f"   Framework: PYTORCH 2.0")
print(f"   DataInputConfig: {data_input_config}")
print(f"\n🎯 Target Platform:")
print(f"   OS: {target_platform_arm64['Os']}")
print(f"   Architecture: {target_platform_arm64['Arch']}")
print(f"   Accelerator: None (CPU only)")
print(f"\n📍 Output Location:")
print(f"   S3 URI: {compilation_output_s3_arm64}")

# Create compilation job
print(f"\n🚀 Submitting compilation job...\n")

try:
    response_arm64 = sagemaker_client.create_compilation_job(
        CompilationJobName=compilation_job_name_arm64,
        RoleArn=role,
        InputConfig={
            'S3Uri': compilation_model_s3_uri,
            'DataInputConfig': data_input_config,
            'Framework': 'PYTORCH',
            'FrameworkVersion': '2.0'
        },
        OutputConfig={
            'S3OutputLocation': compilation_output_s3_arm64,
            'TargetPlatform': target_platform_arm64
        },
        StoppingCondition={
            'MaxRuntimeInSeconds': 900  # 15 minutes
        }
    )
    
    # Get compilation job ARN
    compilation_job_arn_arm64 = response_arm64['CompilationJobArn']
    
    print("✅ Compilation job submitted successfully!\n")
    print("📊 Compilation Job Details:")
    print(f"   Job name: {compilation_job_name_arm64}")
    print(f"   Job ARN: {compilation_job_arn_arm64}")
    print(f"   Target: ARM64 CPU")
    print(f"   Status: STARTING")
    
    print(f"\n🔗 View compilation job in SageMaker Console:")
    console_url = f"https://console.aws.amazon.com/sagemaker/home?region={region}#/compilation-jobs/{compilation_job_name_arm64}"
    print(f"   {console_url}")
    
    print(f"\nℹ️  Compilation job is running. Use the monitoring cell to track progress.")
    
except ClientError as e:
    error_code = e.response['Error']['Code']
    error_msg = e.response['Error']['Message']
    print(f"❌ Failed to create compilation job: {error_code}")
    print(f"   Job name: {compilation_job_name_arm64}")
    print(f"   Error: {error_msg}\n")
    
    print("🔍 Troubleshooting:")
    print("   1. Verify the prepared model exists in S3")
    print(f"      Model URI: {compilation_model_s3_uri}")
    print("   2. Check IAM role permissions for SageMaker Neo")
    print(f"      Role: {role}")
    print("   3. Verify DataInputConfig format")
    print(f"      Config: {data_input_config}")
    print("   4. Check target platform configuration")
    print(f"      Platform: {target_platform_arm64}")
    
    if error_code == 'ValidationException':
        print("\n⚠️  Validation Error: Check input parameters")
        print("   - Ensure model S3 URI is valid")
        print("   - Verify DataInputConfig JSON format")
        print("   - Check target platform configuration")
    
    raise Exception(
        f"Failed to create compilation job for ARM64 CPU.\n"
        f"Job name: {compilation_job_name_arm64}\n"
        f"Error: {error_code} - {error_msg}"
    )

except Exception as e:
    print(f"❌ Unexpected error: {str(e)}")
    print(f"   Job name: {compilation_job_name_arm64}")
    raise

### 7.4 Monitor Compilation Job

Poll the compilation job status and display progress until completion. This cell can monitor any of the compilation jobs created above.

In [None]:
import time
from botocore.exceptions import ClientError

# Determine which compilation job to monitor
# Priority: ARM64 > x86_64 > Xavier (monitor the most recently created)
compilation_job_to_monitor = None
platform_name = None

if 'compilation_job_name_arm64' in locals():
    compilation_job_to_monitor = compilation_job_name_arm64
    platform_name = "ARM64 CPU"
elif 'compilation_job_name_x86' in locals():
    compilation_job_to_monitor = compilation_job_name_x86
    platform_name = "x86_64 CPU"
elif 'compilation_job_name_xavier' in locals():
    compilation_job_to_monitor = compilation_job_name_xavier
    platform_name = "Jetson Xavier GPU"
else:
    print("⚠️  No compilation job found to monitor.")
    print("   Please run one of the compilation job cells above first.")
    print("   Or manually specify a compilation job name:\n")
    
    compilation_job_to_monitor = input("Enter compilation job name: ")
    platform_name = "Custom"
    
    if not compilation_job_to_monitor:
        raise ValueError("No compilation job name provided.")

print(f"🔍 Monitoring compilation job: {compilation_job_to_monitor}")
print(f"   Target platform: {platform_name}")
print("\nStatus updates every 60 seconds...\n")

# Initialize monitoring variables
previous_status = None
start_time = time.time()
dots_printed = 0
asterisks_printed = 0

while True:
    try:
        # Get compilation job status
        response = sagemaker_client.describe_compilation_job(
            CompilationJobName=compilation_job_to_monitor
        )
        
        current_status = response['CompilationJobStatus']
        
        # Display status change
        if current_status != previous_status:
            if previous_status is not None:
                print()  # New line after progress indicators
            
            elapsed_time = int(time.time() - start_time)
            elapsed_minutes = elapsed_time // 60
            elapsed_seconds = elapsed_time % 60
            
            print(f"[{elapsed_minutes:02d}:{elapsed_seconds:02d}] Status: {current_status}")
            previous_status = current_status
            dots_printed = 0
            asterisks_printed = 0
        
        # Check if compilation is complete
        if current_status in ['COMPLETED', 'FAILED', 'STOPPED']:
            print()  # New line after progress indicators
            
            # Display completion details
            if current_status == 'COMPLETED':
                print("\n✅ Compilation job completed successfully!\n")
                
                # Get compiled model artifact location
                compiled_model_s3_uri = response['ModelArtifacts']['S3ModelArtifacts']
                
                # Display compilation time
                compilation_start = response.get('CompilationStartTime')
                compilation_end = response.get('CompilationEndTime')
                
                if compilation_start and compilation_end:
                    compilation_duration = (compilation_end - compilation_start).total_seconds()
                    compilation_minutes = int(compilation_duration // 60)
                    compilation_seconds = int(compilation_duration % 60)
                    
                    print("⏱️  Compilation Duration:")
                    if compilation_minutes > 0:
                        print(f"   {compilation_minutes}m {compilation_seconds}s")
                    else:
                        print(f"   {compilation_seconds}s")
                    print()
                
                # Display compiled model artifact location
                print(f"📦 Compiled Model Artifact:")
                print(f"   Platform: {platform_name}")
                print(f"   S3 URI: {compiled_model_s3_uri}")
                
                # Display target platform details
                if 'OutputConfig' in response:
                    output_config = response['OutputConfig']
                    if 'TargetPlatform' in output_config:
                        target = output_config['TargetPlatform']
                        print(f"\n🎯 Target Platform Details:")
                        print(f"   OS: {target.get('Os', 'N/A')}")
                        print(f"   Architecture: {target.get('Arch', 'N/A')}")
                        if 'Accelerator' in target:
                            print(f"   Accelerator: {target['Accelerator']}")
                
                # Store for later use
                if platform_name == "Jetson Xavier GPU":
                    compiled_model_xavier_s3_uri = compiled_model_s3_uri
                elif platform_name == "x86_64 CPU":
                    compiled_model_x86_s3_uri = compiled_model_s3_uri
                elif platform_name == "ARM64 CPU":
                    compiled_model_arm64_s3_uri = compiled_model_s3_uri
                
            elif current_status == 'FAILED':
                print("\n❌ Compilation job failed!\n")
                
                # Get failure reason
                failure_reason = response.get('FailureReason', 'Unknown failure reason')
                print(f"🚫 Failure Reason:")
                print(f"   {failure_reason}\n")
                
                # Provide troubleshooting guidance
                print("🔍 Common Failure Causes:")
                print("   1. Model input shape mismatch")
                print("      - Verify DataInputConfig matches model's expected input")
                print("      - Check model metadata for correct input shape")
                print("   2. Framework version compatibility")
                print("      - Ensure PyTorch version is compatible with SageMaker Neo")
                print("      - Verify model was trained with supported PyTorch version")
                print("   3. Compiler options for target platform")
                print("      - Check CUDA version for GPU targets")
                print("      - Verify TensorRT version compatibility")
                print("      - Ensure compiler options match target hardware")
                print("   4. Model architecture not supported")
                print("      - Some YOLO operations may not be supported by Neo")
                print("      - Check Neo documentation for supported operations")
                
                # Display CloudWatch logs link
                print(f"\n🔗 View detailed logs in CloudWatch:")
                log_group = f"/aws/sagemaker/CompilationJobs"
                log_stream = compilation_job_to_monitor
                logs_url = (
                    f"https://console.aws.amazon.com/cloudwatch/home?region={region}"
                    f"#logsV2:log-groups/log-group/{log_group.replace('/', '$252F')}"
                    f"/log-events/{log_stream.replace('/', '$252F')}"
                )
                print(f"   {logs_url}")
                
            elif current_status == 'STOPPED':
                print("\n⚠️  Compilation job was stopped.\n")
                print("The compilation job was manually stopped before completion.")
            
            # Exit monitoring loop
            break
        
        # Display progress indicators based on status
        if current_status == 'STARTING':
            # Display asterisks for starting status
            print('*', end='', flush=True)
            asterisks_printed += 1
            
            # New line every 60 asterisks (60 minutes)
            if asterisks_printed >= 60:
                print()
                asterisks_printed = 0
                
        elif current_status == 'INPROGRESS':
            # Display dots for in-progress status
            print('.', end='', flush=True)
            dots_printed += 1
            
            # New line every 60 dots (60 minutes)
            if dots_printed >= 60:
                print()
                dots_printed = 0
        
        # Wait 60 seconds before next poll
        time.sleep(60)
        
    except ClientError as e:
        error_code = e.response['Error']['Code']
        error_msg = e.response['Error']['Message']
        print(f"\n❌ Error monitoring compilation job: {error_code}")
        print(f"   Job name: {compilation_job_to_monitor}")
        print(f"   Error: {error_msg}")
        
        if error_code == 'ValidationException':
            print(f"\n🔍 The compilation job may not exist or the name is incorrect.")
            print(f"   Verify job name: {compilation_job_to_monitor}")
            print(f"   Check SageMaker console for available compilation jobs.")
        
        raise Exception(
            f"Failed to monitor compilation job.\n"
            f"Job name: {compilation_job_to_monitor}\n"
            f"Error: {error_code} - {error_msg}"
        )
    
    except KeyboardInterrupt:
        print("\n\n⚠️  Monitoring interrupted by user.")
        print(f"   Compilation job is still running: {compilation_job_to_monitor}")
        print(f"   You can resume monitoring by re-running this cell.")
        break
    
    except Exception as e:
        print(f"\n❌ Unexpected error: {str(e)}")
        print(f"   Compilation job: {compilation_job_to_monitor}")
        raise

## 8. Model Comparison

Compare YOLO results with Lookout for Vision baseline to validate the YOLO implementation and assess model performance differences.

### 8.1 Import Comparison Functions

Import the model comparison helper functions.

In [None]:
# Import comparison functions
from yolo_comparison import (
    load_test_images,
    run_yolo_inference,
    calculate_detection_metrics,
    calculate_segmentation_metrics,
    visualize_detections,
    visualize_segmentation,
    create_comparison_table
)

import matplotlib.pyplot as plt
import numpy as np

print("✅ Comparison functions imported successfully")

### 8.2 Load Test Images

Load a set of test images from the cookie dataset for model comparison.

In [None]:
# Define test images directory
test_images_dir = 'cookie-dataset/dataset-files/training-images'

print(f"📂 Loading test images from: {test_images_dir}\n")

try:
    # Load test images
    test_images = load_test_images(test_images_dir)
    
    print(f"✅ Loaded {len(test_images)} test images")
    
    # Display sample images
    print("\n🖼️  Sample images:")
    for i, (filename, img) in enumerate(test_images[:5]):
        print(f"   {i+1}. {filename} - Shape: {img.shape}")
    
    if len(test_images) > 5:
        print(f"   ... and {len(test_images) - 5} more images")
    
except FileNotFoundError as e:
    print(f"❌ Error loading test images: {str(e)}")
    print("\n🔍 Troubleshooting:")
    print(f"   - Verify the test images directory exists: {test_images_dir}")
    print(f"   - Ensure the dataset acquisition step completed successfully")
    raise

except ValueError as e:
    print(f"❌ Error: {str(e)}")
    print("\n🔍 Troubleshooting:")
    print(f"   - Verify the directory contains valid image files")
    print(f"   - Supported formats: .jpg, .jpeg, .png, .bmp")
    raise

### 8.3 Run YOLO Inference

Perform inference using the trained YOLO model (detection or segmentation).

**Note:** You'll need to download the compiled model from S3 first, or use the trained model directly.

In [None]:
# Define model path
# Option 1: Use the trained model from training (before compilation)
# model_path = 'model-preparation/extracted/best.pt'

# Option 2: Specify a local model path
model_path = input("Enter path to YOLO model file (.pt): ")

# Verify model exists
if not os.path.exists(model_path):
    print(f"❌ Model file not found: {model_path}")
    print("\n🔍 Options:")
    print("   1. Download the trained model from S3")
    print("   2. Use a pre-trained YOLO model")
    print("   3. Extract the model from training artifacts")
    raise FileNotFoundError(
        f"Model file not found: {model_path}\n"
        f"Expected location: {os.path.abspath(model_path)}"
    )

# Define task type (detect or segment)
task_type = 'detect'  # Change to 'segment' for segmentation models
conf_threshold = 0.25  # Confidence threshold for predictions

print(f"\n🤖 Running YOLO inference...")
print(f"   Model: {model_path}")
print(f"   Task: {task_type}")
print(f"   Confidence threshold: {conf_threshold}")
print(f"   Test images: {len(test_images)}\n")

try:
    # Run inference
    import time
    start_time = time.time()
    
    yolo_predictions = run_yolo_inference(
        model_path=model_path,
        images=test_images,
        task=task_type,
        conf_threshold=conf_threshold
    )
    
    inference_time = time.time() - start_time
    avg_inference_time_ms = (inference_time / len(test_images)) * 1000
    
    print(f"✅ Inference completed successfully")
    print(f"   Total time: {inference_time:.2f} seconds")
    print(f"   Average time per image: {avg_inference_time_ms:.2f} ms")
    
    # Count detections
    total_detections = sum(len(pred['boxes']) for pred in yolo_predictions.values())
    images_with_detections = sum(1 for pred in yolo_predictions.values() if len(pred['boxes']) > 0)
    
    print(f"\n📊 Detection Statistics:")
    print(f"   Total detections: {total_detections}")
    print(f"   Images with detections: {images_with_detections}/{len(test_images)}")
    print(f"   Images without detections: {len(test_images) - images_with_detections}/{len(test_images)}")
    
    # Display sample predictions
    print("\n🔍 Sample Predictions:")
    for i, (filename, pred) in enumerate(list(yolo_predictions.items())[:3]):
        num_boxes = len(pred['boxes'])
        if num_boxes > 0:
            avg_conf = np.mean(pred['confidences'])
            print(f"   {filename}: {num_boxes} detection(s), avg confidence: {avg_conf:.3f}")
        else:
            print(f"   {filename}: No detections")
    
except FileNotFoundError as e:
    print(f"❌ Error: {str(e)}")
    raise

except ImportError as e:
    print(f"❌ Error: {str(e)}")
    print("\n🔍 Install Ultralytics YOLO:")
    print("   pip install ultralytics")
    raise

except Exception as e:
    print(f"❌ Inference failed: {str(e)}")
    print("\n🔍 Troubleshooting:")
    print("   - Verify the model file is a valid YOLO model (.pt)")
    print("   - Check that the task type matches the model (detect vs segment)")
    print("   - Ensure sufficient memory is available")
    raise

### 8.4 Load Lookout for Vision Reference Results (Optional)

Load reference results from Lookout for Vision if available for comparison.

**Note:** This step is optional. Skip if you don't have LFV reference results.

In [None]:
# Load LFV reference results (if available)
# This is a placeholder - you would need to implement loading LFV results
# based on your specific LFV output format

lfv_results_available = False
lfv_metrics = None

print("ℹ️  Lookout for Vision reference results not available")
print("   Comparison will show YOLO metrics only")
print("\n📝 To add LFV comparison:")
print("   1. Export LFV inference results")
print("   2. Load results in the same format as YOLO predictions")
print("   3. Calculate metrics using calculate_detection_metrics()")

# Example structure for LFV metrics (if you have them):
# lfv_metrics = {
#     'precision': 0.89,
#     'recall': 0.85,
#     'f1_score': 0.87,
#     'avg_inference_time_ms': 120.5
# }

### 8.5 Calculate and Display Metrics

Calculate detection metrics for YOLO predictions and display the results.

In [None]:
# For this example, we'll create ground truth from the manifest
# In a real scenario, you would have separate test set annotations

print("📊 Calculating metrics...\n")

# Load ground truth annotations from manifest
# Note: In production, you should use a separate test set
from yolo_format_converter import read_manifest

manifest_path = 'cookie-dataset/dataset-files/manifests/output.manifest'
manifest_records = read_manifest(manifest_path)

# Create ground truth dictionary
# For simplicity, we'll use the same format as predictions
# In a real scenario, you would have actual bounding box annotations
ground_truth = {}

for record in manifest_records:
    filename = os.path.basename(record['source-ref'])
    class_id = record['anomaly-label']
    
    # For this example, we'll create a simple ground truth
    # In production, you would load actual bounding box annotations
    if class_id == 1:  # Anomaly
        # Placeholder: In reality, load actual bounding boxes
        ground_truth[filename] = {
            'boxes': [[100, 100, 200, 200]],  # Placeholder box
            'classes': [1],
            'confidences': [1.0]
        }
    else:  # Normal
        ground_truth[filename] = {
            'boxes': [],
            'classes': [],
            'confidences': []
        }

# Calculate detection metrics
yolo_metrics = calculate_detection_metrics(
    predictions=yolo_predictions,
    ground_truth=ground_truth,
    iou_threshold=0.5
)

# Add inference time to metrics
yolo_metrics['avg_inference_time_ms'] = avg_inference_time_ms

# Display metrics
print("✅ Metrics calculated successfully\n")
print("📊 YOLO Model Performance:")
print(f"   Precision: {yolo_metrics['precision']:.4f}")
print(f"   Recall: {yolo_metrics['recall']:.4f}")
print(f"   F1 Score: {yolo_metrics['f1_score']:.4f}")
print(f"   True Positives: {yolo_metrics['true_positives']}")
print(f"   False Positives: {yolo_metrics['false_positives']}")
print(f"   False Negatives: {yolo_metrics['false_negatives']}")
print(f"   Avg Inference Time: {yolo_metrics['avg_inference_time_ms']:.2f} ms")

print("\n⚠️  Note: Ground truth is simplified for this example.")
print("   For accurate metrics, use a separate test set with proper annotations.")

### 8.6 Create Side-by-Side Visualizations

Visualize YOLO detections on sample images.

In [None]:
# Visualize detections on sample images
print("🖼️  Creating visualizations...\n")

# Select images with detections for visualization
images_to_visualize = []
for filename, pred in yolo_predictions.items():
    if len(pred['boxes']) > 0:
        # Find the corresponding image
        for img_filename, img in test_images:
            if img_filename == filename:
                images_to_visualize.append((filename, img, pred))
                break
        
        # Limit to 3 images for visualization
        if len(images_to_visualize) >= 3:
            break

if not images_to_visualize:
    print("⚠️  No detections found in test images")
    print("   Try:")
    print("   - Lowering the confidence threshold")
    print("   - Using a different model")
    print("   - Checking if test images contain defects")
else:
    print(f"✅ Visualizing {len(images_to_visualize)} images with detections\n")
    
    # Create visualizations
    for filename, img, pred in images_to_visualize:
        fig = visualize_detections(
            image=img,
            boxes=pred['boxes'],
            classes=pred['classes'],
            confidences=pred['confidences'],
            class_names=['normal', 'anomaly'],
            title=f"YOLO Detections - {filename}"
        )
        plt.show()
        
    print("\n✅ Visualizations complete")

### 8.7 Display Comparison Table

Create and display a comparison table showing YOLO metrics (and LFV metrics if available).

In [None]:
# Create comparison table
print("📊 Model Comparison Table\n")

comparison_table = create_comparison_table(
    yolo_metrics=yolo_metrics,
    lfv_metrics=lfv_metrics  # Will be None if not available
)

print(comparison_table)

# Summary
print("\n📝 Summary:")
print(f"   YOLO Model: {hyperparameters['model-size']}")
print(f"   Task: {hyperparameters['task']}")
print(f"   Test Images: {len(test_images)}")
print(f"   Detections: {sum(len(pred['boxes']) for pred in yolo_predictions.values())}")

if yolo_metrics['f1_score'] >= 0.8:
    print("\n✅ Model performance is good (F1 ≥ 0.8)")
elif yolo_metrics['f1_score'] >= 0.6:
    print("\n⚠️  Model performance is moderate (0.6 ≤ F1 < 0.8)")
    print("   Consider:")
    print("   - Training for more epochs")
    print("   - Using a larger model (yolov8s or yolov8m)")
    print("   - Adjusting confidence threshold")
else:
    print("\n❌ Model performance needs improvement (F1 < 0.6)")
    print("   Recommendations:")
    print("   - Increase training epochs")
    print("   - Use a larger model")
    print("   - Review training data quality")
    print("   - Adjust hyperparameters")

print("\n🎉 Model comparison complete!")