# Amazon SageMaker JumpStart Object Detection for Bird Species

## Introduction

Object detection is the process of identifying and localizing objects in an image. A typical object detection solution takes an image as input and provides a bounding box on the image where an object of interest is found. It also identifies what type of object the box encapsulates.

This notebook is an end-to-end example showing how Amazon SageMaker JumpStart can be used to train an object detection model on a custom dataset. We use the [Caltech Birds (CUB 200 2011)](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset, which contains images of 200 bird species with bounding box annotations.

### What is JumpStart?

JumpStart provides pre-trained models that can be fine-tuned on your custom data without writing training scripts. This notebook uses the **Faster R-CNN with ResNet-50** backbone, which is a popular object detection architecture that balances accuracy and speed.

### What You'll Learn

- How to prepare custom data in COCO format for object detection
- How to train a JumpStart model on your own dataset
- How to deploy and test an object detection model
- How to visualize detection results

In [None]:
import os
import json
import time
from sagemaker.core.helper.session_helper import Session, get_execution_role
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import InputData
from sagemaker.core.jumpstart.configs import JumpStartConfig

sagemaker_session = Session()
bucket = sagemaker_session.default_bucket()
role = get_execution_role()

print(f'Bucket: {bucket}')
print(f'Role: {role}')

## 1. Download and Prepare CUB Dataset

The [Caltech Birds (CUB 200 2011)](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) dataset contains 11,788 images across 200 bird species. Each species comes with around 60 images, with a typical size of about 350 pixels by 500 pixels. Bounding boxes are provided for each bird in the image.

For this demonstration, we'll use a subset of 5 bird species to keep training time manageable. The same approach works for all 200 species.

In [None]:
# Download CUB-200-2011 dataset
!wget -q https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz
!tar -xzf CUB_200_2011.tgz
print('Dataset downloaded and extracted')

## 2. Create COCO Format Annotations

JumpStart object detection models expect data in COCO (Common Objects in Context) format, which is a standard format for object detection datasets.

### COCO Format Structure

A COCO dataset consists of:
- **images**: List of image metadata (id, filename, width, height)
- **annotations**: List of bounding boxes with category labels
- **categories**: List of object categories (id, name)

### Important Format Requirements

1. **Category IDs start at 1**: COCO reserves category 0 for background, so your object categories should be 1, 2, 3, etc.
2. **Bounding box format**: Use corner coordinates `[x_min, y_min, x_max, y_max]`
3. **Coordinate validation**: Ensure all bounding boxes are within image bounds and have positive dimensions

In [None]:
from PIL import Image

def create_coco_dataset(base_dir, num_classes=5):
    """Create COCO format annotations with 1-indexed categories."""
    images_dict = {}
    with open(os.path.join(base_dir, 'images.txt')) as f:
        for line in f:
            img_id, img_path = line.strip().split(' ', 1)
            images_dict[img_id] = os.path.basename(img_path)
    
    bboxes_dict = {}
    with open(os.path.join(base_dir, 'bounding_boxes.txt')) as f:
        for line in f:
            parts = line.strip().split()
            img_id = parts[0]
            x, y, w, h = map(float, parts[1:5])
            bboxes_dict[img_id] = [x, y, w, h]
    
    labels_dict = {}
    with open(os.path.join(base_dir, 'image_class_labels.txt')) as f:
        for line in f:
            img_id, class_id = line.strip().split()
            labels_dict[img_id] = int(class_id)  # Keep 1-indexed
    
    split_dict = {}
    with open(os.path.join(base_dir, 'train_test_split.txt')) as f:
        for line in f:
            img_id, is_train = line.strip().split()
            split_dict[img_id] = int(is_train) == 1
    
    valid_classes = list(range(1, num_classes + 1))  # 1-indexed: [1, 2, 3, 4, 5]
    
    # Combine train and val into single dataset
    coco = {
        'images': [],
        'annotations': [],
        'categories': [{'id': i, 'name': f'bird_class_{i}'} for i in valid_classes]
    }
    
    ann_id = 0
    skipped = 0
    
    for img_id in sorted(images_dict.keys()):
        class_id = labels_dict[img_id]
        if class_id not in valid_classes:
            continue
        
        # Get image dimensions
        img_path_full = None
        with open(os.path.join(base_dir, 'images.txt')) as f:
            for line in f:
                if line.startswith(img_id + ' '):
                    img_path_full = line.strip().split(' ', 1)[1]
                    break
        
        img_full_path = os.path.join(base_dir, 'images', img_path_full)
        img = Image.open(img_full_path)
        width, height = img.size
        
        x, y, w, h = bboxes_dict[img_id]
        
        # Fix negative dimensions
        if w < 0:
            x = x + w
            w = abs(w)
        if h < 0:
            y = y + h
            h = abs(h)
        
        # Clamp to image bounds
        x = max(0, x)
        y = max(0, y)
        w = min(w, width - x)
        h = min(h, height - y)
        
        # Skip invalid boxes
        if w <= 0 or h <= 0:
            print(f'Skipping image {img_id} with invalid bbox: [{x}, {y}, {w}, {h}]')
            skipped += 1
            continue
        
        # Add image and annotation
        coco['images'].append({
            'id': int(img_id),
            'file_name': images_dict[img_id],
            'width': width,
            'height': height
        })
        
        coco['annotations'].append({
            'id': ann_id,
            'image_id': int(img_id),
            'category_id': class_id,
            'bbox': [x, y, x + w, y + h],  # Convert to [x_min, y_min, x_max, y_max] for PyTorch
            'area': w * h,
            'iscrowd': 0
        })
        ann_id += 1
    
    os.makedirs('annotations', exist_ok=True)
    with open('annotations/combined.json', 'w') as f:
        json.dump(coco, f)
    
    print(f'Total: {len(coco["images"])} images, {len(coco["annotations"])} annotations, {skipped} skipped')
    return coco

coco_data = create_coco_dataset('CUB_200_2011', num_classes=5)

## 3. Prepare Flat Image Structure

The CUB dataset organizes images in nested folders by species (e.g., `001.Black_footed_Albatross/image1.jpg`). However, JumpStart expects all images in a single flat directory.

We'll copy all images to a flat directory structure where each image filename is unique. The COCO annotations file will reference these filenames.

In [None]:
import shutil

# Create flat image directory
if os.path.exists('flat_images'):
    shutil.rmtree('flat_images')
os.makedirs('flat_images', exist_ok=True)

# Copy all images to flat directory
print('Creating flat image structure...')
for root, dirs, files in os.walk('CUB_200_2011/images'):
    for file in files:
        if file.endswith(('.jpg', '.jpeg', '.png')):
            src = os.path.join(root, file)
            dst = os.path.join('flat_images', file)
            shutil.copy2(src, dst)

print(f'Copied {len(os.listdir("flat_images"))} images to flat directory')

## 4. Upload to S3

SageMaker training jobs read data from S3. We need to upload our prepared dataset in the following structure:

```
s3://bucket/prefix/train/
├── images/
│   ├── image1.jpg
│   ├── image2.jpg
│   └── ...
└── annotations.json
```

**Important**: JumpStart expects all data in a single `training` channel. The training script will automatically split it 80% for training and 20% for validation. Do not create separate train and validation folders.

We use a timestamped prefix to ensure we're using fresh data and not cached versions from previous runs.

In [None]:
# Use timestamped prefix to avoid caching
prefix = f'jumpstart-od-birds-{int(time.time())}'
train_s3 = f's3://{bucket}/{prefix}/train'

print(f'Uploading to: {train_s3}')

# Upload images
!aws s3 sync flat_images {train_s3}/images/ --quiet

# Upload combined annotations
!aws s3 cp annotations/combined.json {train_s3}/annotations.json

print('Upload complete!')
print(f'\nData location: {train_s3}')

## 5. Verify Upload

Before starting a training job, it's good practice to verify that:
1. Data was uploaded successfully to S3
2. The annotations file is valid JSON
3. The number of images and annotations match expectations

This helps catch issues early before spending time and money on training.

In [None]:
# Verify S3 data
!aws s3 ls {train_s3}/

# Check annotation count
result = !aws s3 cp {train_s3}/annotations.json - | python3 -c "import json, sys; d=json.load(sys.stdin); print(f'Images: {len(d[\"images\"])}, Annotations: {len(d[\"annotations\"])}')"
print(result[0])

## 6. Train the Model

Now we're ready to train our object detection model using JumpStart.

### Model Selection

We use `pytorch-od1-fasterrcnn-resnet50-fpn`, which is a Faster R-CNN model with a ResNet-50 backbone and Feature Pyramid Network (FPN). This architecture is well-suited for detecting objects of various sizes.

### Training Configuration

The `ModelTrainer.from_jumpstart_config()` method automatically configures:
- The training container image
- Default hyperparameters optimized for the model
- Instance type for training

We only need to provide:
- The model ID
- The S3 location of our training data
- A base name for the training job

The training process will:
1. Load the pre-trained Faster R-CNN model
2. Replace the final detection layer to match our 5 bird categories
3. Fine-tune the model on our bird images
4. Save the trained model artifacts to S3

In [None]:
# Select JumpStart model
model_id = 'pytorch-od1-fasterrcnn-resnet50-fpn'

js_config = JumpStartConfig(model_id=model_id)

# Create input data config - ONLY training channel
train_input = InputData(
    channel_name='training',
    data_source=train_s3
)

# Create trainer
trainer = ModelTrainer.from_jumpstart_config(
    jumpstart_config=js_config,
    base_job_name='jumpstart-od-birds',
    sagemaker_session=sagemaker_session,
    input_data_config=[train_input]  # Only training channel
)

print('Trainer created')
print(f'Hyperparameters: {trainer.hyperparameters}')
print(f'Input data: {train_input}')

In [None]:
# Start training
trainer.train()

In [None]:
import boto3
# Check training job logs for detailed error
training_job = trainer._latest_training_job
print(f"Training job name: {training_job.training_job_name}")
print(f"Training job status: {training_job.training_job_status}")
print(f"\nFailure reason: {training_job.failure_reason}")

# Get CloudWatch logs
logs_client = boto3.client('logs', region_name=sagemaker_session.boto_region_name)

log_group = '/aws/sagemaker/TrainingJobs'

try:
    # List log streams for this job
    streams = logs_client.describe_log_streams(
        logGroupName=log_group,
        logStreamNamePrefix=training_job.training_job_name
    )
    
    if streams['logStreams']:
        stream_name = streams['logStreams'][0]['logStreamName']
        print(f"\nLog stream: {stream_name}")
        
        # Get last 100 log events
        events = logs_client.get_log_events(
            logGroupName=log_group,
            logStreamName=stream_name,
            limit=100,
            startFromHead=False
        )
        
        print("\n=== Last 100 log lines ===")
        for event in events['events']:
            print(event['message'])
except Exception as e:
    print(f"Could not fetch logs: {e}")

## 7. Deploy the Model

After training completes, we need to deploy the model to an endpoint for real-time inference.

### Deployment Process

Deployment involves three steps:

1. **Create a Model**: Defines the model artifacts and inference container image
2. **Create an Endpoint Configuration**: Specifies the instance type and count
3. **Create an Endpoint**: Deploys the model to a running instance

### Instance Selection

We use `ml.g4dn.xlarge`, which is a GPU instance. Object detection models with deep neural networks require GPU acceleration for fast inference. CPU instances would be too slow for practical use.

The endpoint will remain running until you delete it, so remember to clean up when done to avoid charges.

In [None]:
# Get model artifacts from training job
training_job = trainer._latest_training_job
training_job.refresh()
model_data = training_job.model_artifacts.s3_model_artifacts

print(f'Model artifacts: {model_data}')

In [None]:
from sagemaker.core import image_uris

image = image_uris.retrieve(
    framework='pytorch',
    region=sagemaker_session.boto_region_name,
    image_scope='inference',
    instance_type='ml.g4dn.xlarge',
    version='1.8.1',
    py_version='py3',
)
print(image)

In [None]:
from sagemaker.core.resources import Model, EndpointConfig, Endpoint

# Create model
model = Model.create(
    model_name=f'jumpstart-od-birds-{int(time.time())}',
    execution_role_arn=role,
    primary_container={
        'image': image,
        'model_data_url': model_data
    }
)

# Create endpoint config
endpoint_config = EndpointConfig.create(
    endpoint_config_name=f'jumpstart-od-birds-config-{int(time.time())}',
    production_variants=[{
        'variant_name': 'AllTraffic',
        'model_name': model.model_name,
        'instance_type': 'ml.g4dn.xlarge',
        'initial_instance_count': 1
    }]
)

# Create endpoint
endpoint = Endpoint.create(
    endpoint_name=f'jumpstart-od-birds-{int(time.time())}',
    endpoint_config_name=endpoint_config.endpoint_config_name
)

endpoint.wait_for_status('InService')
print(f'Endpoint: {endpoint.endpoint_name}')

## 8. Test the Model

With our model deployed, we can test it by sending images and examining the predictions.

### How Inference Works

1. **Load an image** from the test set as raw bytes
2. **Send to endpoint** using the SageMaker Runtime client
3. **Parse the response** which contains detected objects

### Understanding the Response

The model returns three parallel arrays:
- `normalized_boxes`: Bounding box coordinates in [0, 1] range as `[x_min, y_min, x_max, y_max]`
- `classes`: Category IDs (1-5 for our bird species)
- `scores`: Confidence scores (0-1) for each detection

Each index represents one detected object. For example:
- `normalized_boxes[0]` = `[0.2, 0.3, 0.8, 0.9]` (bird occupies 20-80% width, 30-90% height)
- `classes[0]` = `1` (bird species 1)
- `scores[0]` = `0.95` (95% confident)

Let's test with a sample bird image from our dataset.

In [None]:
def visualize_detection(img_file, dets, classes=[], thresh=0.5):
    """
    Visualize detections with bounding boxes.
    
    Parameters:
    - img_file: path to image
    - dets: detections as [[class, score, x_min, y_min, x_max, y_max], ...]
    - classes: list of class names
    - thresh: confidence threshold
    """
    import random
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg
    
    img = mpimg.imread(img_file)
    plt.imshow(img)
    height, width = img.shape[:2]
    colors = {}
    num_detections = 0
    
    for det in dets:
        klass, score, x0, y0, x1, y1 = det
        if score < thresh:
            continue
        num_detections += 1
        cls_id = int(klass)
        
        if cls_id not in colors:
            colors[cls_id] = (random.random(), random.random(), random.random())
        
        xmin = int(x0 * width)
        ymin = int(y0 * height)
        xmax = int(x1 * width)
        ymax = int(y1 * height)
        
        rect = plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, 
                           fill=False, edgecolor=colors[cls_id], linewidth=3.5)
        plt.gca().add_patch(rect)
        
        class_name = classes[cls_id - 1] if classes and cls_id <= len(classes) else str(cls_id)
        print(f'{class_name}, {score:.3f}')
        plt.gca().text(xmin, ymin - 2, f'{class_name} {score:.3f}',
                      bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                      fontsize=12, color='white')
    
    print(f'Number of detections: {num_detections}')
    plt.show()

def predict_and_visualize(img_file, endpoint_name, thresh=0.5):
    """Run inference and visualize results."""
    runtime_client = sagemaker_session.sagemaker_runtime_client
    
    with open(img_file, 'rb') as f:
        img_bytes = f.read()
    
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/x-image',
        Body=img_bytes
    )
    
    result = json.loads(response['Body'].read())
    
    # Convert to detection format: [class, score, x_min, y_min, x_max, y_max]
    dets = []
    for bbox, cls, score in zip(result['normalized_boxes'], result['classes'], result['scores']):
        if score > 0.5:  # Only include high-confidence detections
            x_min, y_min, x_max, y_max = bbox
            dets.append([cls, score, x_min, y_min, x_max, y_max])
    
    class_names = [f'bird_class_{i}' for i in range(1, 6)]
    visualize_detection(img_file, dets, class_names, thresh)

### Download Test Images

Let's download some bird images that the model hasn't seen during training.

In [None]:
import urllib.request

# Download test images
test_image_urls = {
    'multi-goldfinch-1.jpg': 'https://t3.ftcdn.net/jpg/01/44/64/36/500_F_144643697_GJRUBtGc55KYSMpyg1Kucb9yJzvMQooW.jpg',
    'hummingbird-1.jpg': 'http://res.freestockphotos.biz/pictures/17/17875-hummingbird-close-up-pv.jpg'
}

for filename, url in test_image_urls.items():
    if not os.path.exists(filename):
        print(f'Downloading {filename}...')
        urllib.request.urlretrieve(url, filename)

print('Downloaded 2 test images')

In [None]:
def test_model(endpoint_name):
    """Test model with downloaded bird images."""
    test_images = [
        'hummingbird-1.jpg',
        'multi-goldfinch-1.jpg',
    ]
    
    for img in test_images:
        if os.path.exists(img):
            print(f'\nTesting: {img}')
            predict_and_visualize(img, endpoint_name, thresh=0.4)

# Test the model
test_model(endpoint.endpoint_name)

## 9. Cleanup

To avoid ongoing charges, delete the endpoint when you're done testing.

The endpoint runs on a GPU instance which incurs hourly charges even when not in use. Always clean up resources after experimentation.

In [None]:
# Delete the endpoint
endpoint.delete()
print(f'Deleted endpoint: {endpoint.endpoint_name}')