# Alt ML Pipeline 1 - SageMaker Training Notebook

This notebook trains the YOLO model on AWS SageMaker with GPU acceleration.

## Prerequisites

**Before running this notebook:**

1. Upload your data to S3:
   ```bash
   aws s3 sync ml_pipeline/data/processed/yolo_dataset/ s3://your-bucket/alt-pipeline-1/yolo_dataset/
   ```

2. Set your S3 bucket name in the cell below

3. Ensure you're running this in a SageMaker notebook instance or SageMaker Studio

## What This Notebook Does

1. Downloads data from S3
2. Prepares cross-validation splits
3. Trains YOLO model with GPU
4. Evaluates performance
5. Saves model and metrics
6. Uploads results back to S3

## 1. Configuration & Setup

In [None]:
# ============================================================================
# CONFIGURATION - UPDATE THESE VALUES
# ============================================================================

# Your S3 bucket name (REQUIRED - update this!)
S3_BUCKET = "your-bucket-name"  # ← CHANGE THIS

# S3 paths
S3_DATA_PREFIX = "alt-pipeline-1/yolo_dataset"  # Where you uploaded the data
S3_OUTPUT_PREFIX = "alt-pipeline-1/training-output"  # Where to save results

# Training configuration
FOLD_IDX = 0  # Which fold to train (0-4)
NUM_EPOCHS = 100  # Number of training epochs
BATCH_SIZE = 16  # Batch size (can be larger on GPU)
IMAGE_SIZE = 640  # Input image size

# Model selection
MODEL_NAME = "yolov8n-seg.pt"  # Options: yolov8n-seg, yolov8s-seg, yolov8m-seg

# Class names
CLASS_NAMES = [
    "planktonic",
    "single_dispersed",
    "hyphae",
    "clump_dispersed",
    "yeast",
    "biofilm",
    "pseudohyphae"
]

print("✓ Configuration set")
print(f"  S3 Bucket: {S3_BUCKET}")
print(f"  Training Fold: {FOLD_IDX}")
print(f"  Model: {MODEL_NAME}")
print(f"  Epochs: {NUM_EPOCHS}")

In [None]:
# ============================================================================
# INSTALL REQUIRED PACKAGES
# ============================================================================

# Install ultralytics and other dependencies if not already installed
!pip install -q ultralytics==8.0.227 opencv-python-headless PyYAML tqdm

print("✓ Packages installed")

In [None]:
# ============================================================================
# IMPORTS
# ============================================================================

import os
import sys
from pathlib import Path
import yaml
import json
import shutil
from datetime import datetime
from collections import Counter

import boto3
import numpy as np
import cv2
from tqdm import tqdm

# Disable WandB and MLflow integrations (we'll track manually)
os.environ['WANDB_DISABLED'] = 'true'
os.environ['WANDB_MODE'] = 'disabled'

from ultralytics import YOLO, settings
import torch

# Disable Ultralytics integrations
settings['wandb'] = False
settings['mlflow'] = False

print("✓ Imports loaded")
print(f"  PyTorch version: {torch.__version__}")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")

## 2. Download Data from S3

In [None]:
# ============================================================================
# DOWNLOAD DATA FROM S3
# ============================================================================

# Create local data directory
LOCAL_DATA_DIR = Path("/tmp/yolo_dataset")
LOCAL_DATA_DIR.mkdir(parents=True, exist_ok=True)

print("Downloading data from S3...")
print(f"  Source: s3://{S3_BUCKET}/{S3_DATA_PREFIX}/")
print(f"  Destination: {LOCAL_DATA_DIR}")
print()

# Initialize S3 client
s3 = boto3.client('s3')

# List all objects in the S3 prefix
paginator = s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=S3_BUCKET, Prefix=S3_DATA_PREFIX)

files_to_download = []
for page in pages:
    if 'Contents' in page:
        for obj in page['Contents']:
            files_to_download.append(obj['Key'])

print(f"Found {len(files_to_download)} files to download")
print()

# Download files with progress bar
for s3_key in tqdm(files_to_download, desc="Downloading"):
    # Remove prefix to get relative path
    rel_path = s3_key.replace(S3_DATA_PREFIX + '/', '')
    if not rel_path:  # Skip if it's the prefix itself
        continue
    
    local_file = LOCAL_DATA_DIR / rel_path
    local_file.parent.mkdir(parents=True, exist_ok=True)
    
    # Download file
    s3.download_file(S3_BUCKET, s3_key, str(local_file))

print("\n✓ Data downloaded successfully")

# Verify directory structure
train_images = list((LOCAL_DATA_DIR / 'images' / 'train').glob('*.tif*'))
train_labels = list((LOCAL_DATA_DIR / 'labels' / 'train').glob('*.txt'))
val_images = list((LOCAL_DATA_DIR / 'images' / 'val').glob('*.tif*'))
val_labels = list((LOCAL_DATA_DIR / 'labels' / 'val').glob('*.txt'))

print(f"\nData verification:")
print(f"  Training images: {len(train_images)}")
print(f"  Training labels: {len(train_labels)}")
print(f"  Validation images: {len(val_images)}")
print(f"  Validation labels: {len(val_labels)}")

## 3. Analyze Dataset

In [None]:
# ============================================================================
# ANALYZE DATASET
# ============================================================================

print("=" * 60)
print("DATASET ANALYSIS")
print("=" * 60)
print()

# Count annotations per class
class_counts = Counter()
total_annotations = 0

for label_file in train_labels + val_labels:
    with open(label_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) > 0:
                class_id = int(parts[0])
                class_counts[class_id] += 1
                total_annotations += 1

print(f"Total images: {len(train_images) + len(val_images)}")
print(f"  Training: {len(train_images)}")
print(f"  Validation: {len(val_images)}")
print()
print(f"Total annotations: {total_annotations}")
print()
print("Class distribution:")
for i, class_name in enumerate(CLASS_NAMES):
    count = class_counts.get(i, 0)
    percentage = (count / total_annotations * 100) if total_annotations > 0 else 0
    print(f"  {i}. {class_name:20s}: {count:5d} ({percentage:5.1f}%)")

print("\n" + "=" * 60)
print()

## 4. Create Cross-Validation Splits

In [None]:
# ============================================================================
# CREATE CROSS-VALIDATION SPLITS (Leave-One-Sequence-Out)
# ============================================================================

print("Creating cross-validation splits...")
print()

# Group images by sequence name
sequences = {}
all_images = train_images + val_images

for img_path in all_images:
    # Extract sequence name (e.g., "MattLines1" from "MattLines1_frame_0000.tif")
    filename = img_path.stem
    parts = filename.split('_frame_')
    if len(parts) == 2:
        seq_name = parts[0]
        if seq_name not in sequences:
            sequences[seq_name] = []
        sequences[seq_name].append(img_path)

print(f"Found {len(sequences)} sequences:")
for seq_name, images in sequences.items():
    print(f"  {seq_name}: {len(images)} frames")
print()

# Create splits
sequence_names = list(sequences.keys())
splits = []

for i, val_seq in enumerate(sequence_names):
    split = {
        'fold': i,
        'val_sequence': val_seq,
        'train_images': [],
        'val_images': sequences[val_seq]
    }
    
    # Add all other sequences to training
    for seq_name in sequence_names:
        if seq_name != val_seq:
            split['train_images'].extend(sequences[seq_name])
    
    splits.append(split)

print(f"Created {len(splits)} folds:")
for split in splits:
    print(f"  Fold {split['fold']}: Val={split['val_sequence']}, "
          f"Train={len(split['train_images'])}, Val={len(split['val_images'])}")

print(f"\n✓ Splits created")

In [None]:
# ============================================================================
# PREPARE SELECTED FOLD
# ============================================================================

print(f"\nPreparing fold {FOLD_IDX}...")
print()

split = splits[FOLD_IDX]
print(f"Validation sequence: {split['val_sequence']}")
print(f"Training images: {len(split['train_images'])}")
print(f"Validation images: {len(split['val_images'])}")
print()

# Create fold directory
FOLD_DIR = Path(f"/tmp/fold_{FOLD_IDX}")
FOLD_DIR.mkdir(parents=True, exist_ok=True)

# Create subdirectories
(FOLD_DIR / 'images' / 'train').mkdir(parents=True, exist_ok=True)
(FOLD_DIR / 'images' / 'val').mkdir(parents=True, exist_ok=True)
(FOLD_DIR / 'labels' / 'train').mkdir(parents=True, exist_ok=True)
(FOLD_DIR / 'labels' / 'val').mkdir(parents=True, exist_ok=True)

# Copy training images and labels
print("Copying training data...")
for img_path in tqdm(split['train_images'], desc="Train"):
    # Copy image
    dest_img = FOLD_DIR / 'images' / 'train' / img_path.name
    if not dest_img.exists():
        shutil.copy(img_path, dest_img)
    
    # Copy label
    label_name = img_path.stem + '.txt'
    src_label = LOCAL_DATA_DIR / 'labels' / 'train' / label_name
    if not src_label.exists():
        src_label = LOCAL_DATA_DIR / 'labels' / 'val' / label_name
    
    if src_label.exists():
        dest_label = FOLD_DIR / 'labels' / 'train' / label_name
        if not dest_label.exists():
            shutil.copy(src_label, dest_label)

# Copy validation images and labels
print("Copying validation data...")
for img_path in tqdm(split['val_images'], desc="Val"):
    # Copy image
    dest_img = FOLD_DIR / 'images' / 'val' / img_path.name
    if not dest_img.exists():
        shutil.copy(img_path, dest_img)
    
    # Copy label
    label_name = img_path.stem + '.txt'
    src_label = LOCAL_DATA_DIR / 'labels' / 'train' / label_name
    if not src_label.exists():
        src_label = LOCAL_DATA_DIR / 'labels' / 'val' / label_name
    
    if src_label.exists():
        dest_label = FOLD_DIR / 'labels' / 'val' / label_name
        if not dest_label.exists():
            shutil.copy(src_label, dest_label)

# Create dataset YAML file
dataset_yaml = {
    'path': str(FOLD_DIR.absolute()),
    'train': 'images/train',
    'val': 'images/val',
    'nc': len(CLASS_NAMES),
    'names': CLASS_NAMES
}

yaml_path = FOLD_DIR / 'data.yaml'
with open(yaml_path, 'w') as f:
    yaml.dump(dataset_yaml, f)

print(f"\n✓ Fold {FOLD_IDX} prepared at: {FOLD_DIR}")
print(f"  Dataset YAML: {yaml_path}")

## 5. Train Model

In [None]:
# ============================================================================
# TRAIN YOLO MODEL
# ============================================================================

print("=" * 80)
print(f"TRAINING YOLO MODEL - Fold {FOLD_IDX}")
print("=" * 80)
print()
print(f"Model: {MODEL_NAME}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Image size: {IMAGE_SIZE}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print()

# Initialize model
model = YOLO(MODEL_NAME)

# Training parameters
train_params = {
    'data': str(yaml_path),
    'epochs': NUM_EPOCHS,
    'imgsz': IMAGE_SIZE,
    'batch': BATCH_SIZE,
    'device': 0 if torch.cuda.is_available() else 'cpu',  # Use GPU if available
    'patience': 20,  # Early stopping patience
    'save': True,
    'project': '/tmp/training_output',
    'name': f'fold_{FOLD_IDX}',
    'exist_ok': True,
    'pretrained': True,
    'optimizer': 'AdamW',
    'lr0': 0.001,
    'weight_decay': 0.0005,
    'plots': True,
    'save_period': 10,  # Save checkpoint every 10 epochs
    'val': True,
    'workers': 8
}

# Start training
print("Starting training...")
print("-" * 80)
print()

results = model.train(**train_params)

print()
print("-" * 80)
print("✓ Training complete!")
print()

## 6. Evaluate Model

In [None]:
# ============================================================================
# EVALUATE MODEL ON VALIDATION SET
# ============================================================================

print("Evaluating model on validation set...")
print()

# Run validation
metrics = model.val()

# Extract metrics
results_dict = {
    'fold': FOLD_IDX,
    'validation_sequence': split['val_sequence'],
    'mAP50': float(metrics.box.map50) if hasattr(metrics, 'box') else 0.0,
    'mAP50-95': float(metrics.box.map) if hasattr(metrics, 'box') else 0.0,
    'precision': float(metrics.box.mp) if hasattr(metrics, 'box') else 0.0,
    'recall': float(metrics.box.mr) if hasattr(metrics, 'box') else 0.0,
    'timestamp': datetime.now().isoformat()
}

# Calculate F1 score
if results_dict['precision'] + results_dict['recall'] > 0:
    results_dict['f1'] = 2 * (results_dict['precision'] * results_dict['recall']) / \
                         (results_dict['precision'] + results_dict['recall'])
else:
    results_dict['f1'] = 0.0

# Print results
print("=" * 80)
print("FINAL RESULTS")
print("=" * 80)
print(f"\nFold {FOLD_IDX} - Validation sequence: {split['val_sequence']}")
print()
for key, value in results_dict.items():
    if key not in ['fold', 'validation_sequence', 'timestamp']:
        print(f"  {key:15s}: {value:.4f}")
print()
print("=" * 80)
print()

# Save metrics to file
metrics_file = Path('/tmp/training_output') / f'fold_{FOLD_IDX}' / 'metrics.json'
with open(metrics_file, 'w') as f:
    json.dump(results_dict, f, indent=2)

print(f"✓ Metrics saved to: {metrics_file}")

## 7. Upload Results to S3

In [None]:
# ============================================================================
# UPLOAD RESULTS TO S3
# ============================================================================

print("Uploading results to S3...")
print()

# Paths
training_output_dir = Path('/tmp/training_output') / f'fold_{FOLD_IDX}'
best_model_path = training_output_dir / 'weights' / 'best.pt'
last_model_path = training_output_dir / 'weights' / 'last.pt'

# S3 prefix for this training run
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
s3_output_prefix = f"{S3_OUTPUT_PREFIX}/fold_{FOLD_IDX}/{timestamp}"

# Files to upload
files_to_upload = [
    (metrics_file, f"{s3_output_prefix}/metrics.json"),
    (best_model_path, f"{s3_output_prefix}/best.pt"),
    (last_model_path, f"{s3_output_prefix}/last.pt"),
]

# Upload files
for local_path, s3_key in files_to_upload:
    if local_path.exists():
        print(f"  Uploading {local_path.name} → s3://{S3_BUCKET}/{s3_key}")
        s3.upload_file(str(local_path), S3_BUCKET, s3_key)
    else:
        print(f"  ⚠ Skipping {local_path.name} (not found)")

# Upload training plots if they exist
plots_dir = training_output_dir
for plot_file in plots_dir.glob('*.png'):
    s3_key = f"{s3_output_prefix}/plots/{plot_file.name}"
    print(f"  Uploading {plot_file.name} → s3://{S3_BUCKET}/{s3_key}")
    s3.upload_file(str(plot_file), S3_BUCKET, s3_key)

print()
print("✓ Results uploaded to S3")
print(f"  Location: s3://{S3_BUCKET}/{s3_output_prefix}/")
print()
print("To download results:")
print(f"  aws s3 sync s3://{S3_BUCKET}/{s3_output_prefix}/ ./local_results/")

## 8. Summary & Next Steps

In [None]:
# ============================================================================
# TRAINING SUMMARY
# ============================================================================

print("=" * 80)
print("TRAINING COMPLETE - SUMMARY")
print("=" * 80)
print()
print(f"Fold trained: {FOLD_IDX}")
print(f"Validation sequence: {split['val_sequence']}")
print()
print("Performance metrics:")
print(f"  F1 Score:   {results_dict['f1']:.4f}")
print(f"  mAP50:      {results_dict['mAP50']:.4f}")
print(f"  Precision:  {results_dict['precision']:.4f}")
print(f"  Recall:     {results_dict['recall']:.4f}")
print()
print("Model saved to:")
print(f"  Best:  s3://{S3_BUCKET}/{s3_output_prefix}/best.pt")
print(f"  Last:  s3://{S3_BUCKET}/{s3_output_prefix}/last.pt")
print()
print("Next steps:")
print(f"  1. Review metrics above")
print(f"  2. Train other folds (change FOLD_IDX and rerun)")
print(f"  3. Download models from S3 for inference")
print(f"  4. Compare performance across folds")
print()
print("=" * 80)