# Acoustic Shield - Training Data Pipeline

This notebook orchestrates the complete data pipeline:
1. Extract crash hotspots from GeoJSON
2. Enrich with weather data (Open-Meteo API)
3. Synthesize risk events
4. Build audio generation recipes
5. Run SageMaker Processing job to generate WAV files
6. Validate outputs

## Setup and Configuration

In [None]:
import sys
import json
import logging
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

import boto3
from sagemaker import Session
from sagemaker.processing import ScriptProcessor, ProcessingInput, ProcessingOutput

from data_pipeline import (
    S3Client,
    HotspotExtractor,
    WeatherEnricher,
    RiskEventSynthesizer,
    RecipeBuilder
)

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("✓ Imports complete")

In [None]:
# Configuration - No hard-coded regions!
RAW_BUCKET = 'acousticshield-raw'
ML_BUCKET = 'acousticshield-ml'
CRASH_FILE_KEY = 'crash_hotspots/sanjose_crashes.geojson'
SAGEMAKER_ROLE = 'role-sagemaker-processing'

# Processing parameters - UPDATED TO GENERATE 1000+ FILES
TOP_N_HOTSPOTS = 50  # Increased from 25 to 50
EVENTS_PER_HOTSPOT = 5  # Increased from 4 to 5
RECIPES_PER_EVENT = 4  # NEW: Generate 4 recipe variations per event

# Expected output: 50 hotspots × 5 events × 4 recipes = 1000 audio files

# Get region from bucket
s3_client = S3Client()
REGION = s3_client.get_bucket_region(RAW_BUCKET)

print(f"Configuration:")
print(f"  Raw bucket: {RAW_BUCKET}")
print(f"  ML bucket: {ML_BUCKET}")
print(f"  Region: {REGION}")
print(f"  Crash file: s3://{RAW_BUCKET}/{CRASH_FILE_KEY}")
print(f"  Top hotspots: {TOP_N_HOTSPOTS}")
print(f"  Events per hotspot: {EVENTS_PER_HOTSPOT}")
print(f"  Recipes per event: {RECIPES_PER_EVENT}")
print(f"  Expected total files: {TOP_N_HOTSPOTS * EVENTS_PER_HOTSPOT * RECIPES_PER_EVENT}")

## Step 1: Extract Crash Hotspots

In [None]:
# Load crash data from S3
logger.info(f"Loading crash data from s3://{RAW_BUCKET}/{CRASH_FILE_KEY}")
crash_data = s3_client.read_json(RAW_BUCKET, CRASH_FILE_KEY)

# Extract hotspots
extractor = HotspotExtractor(crash_data)
hotspots = extractor.extract_top_hotspots(top_n=TOP_N_HOTSPOTS)

# Get summary stats
stats = extractor.get_summary_stats()
print(f"\n📊 Crash Data Summary:")
print(f"  Total crashes: {stats['total_crashes']}")
print(f"  Total injuries: {stats['total_injuries']}")
print(f"  Total fatalities: {stats['total_fatalities']}")
print(f"  Speeding involved: {stats['speeding_involved_pct']:.1f}%")

print(f"\n🎯 Top 5 Hotspots:")
for hotspot in hotspots[:5]:
    print(f"  {hotspot['rank']}. {hotspot['location_name']}: {hotspot['crash_count']} crashes")

## Step 2: Enrich with Weather Data

In [None]:
# Enrich hotspots with weather data from Open-Meteo API
logger.info("Fetching weather data for hotspots...")
enricher = WeatherEnricher()
enriched_hotspots = enricher.enrich_hotspots(hotspots, rate_limit_delay=0.5)

# Show sample enriched data
print(f"\n🌤️  Sample Enriched Hotspot:")
sample = enriched_hotspots[0]
print(f"  Location: {sample['location_name']}")
print(f"  Crashes: {sample['crash_count']}")
print(f"  Weather:")
weather = sample['weather']
print(f"    Temperature: {weather['temperature_c']:.1f}°C")
print(f"    Rain: {weather['rain_mm']:.1f}mm")
print(f"    Wind: {weather['wind_speed_kmh']:.1f} km/h")
print(f"    Risk: {enricher.categorize_weather_risk(weather)}")

## Step 3: Synthesize Risk Events

In [None]:
# Generate synthetic risk events
logger.info("Synthesizing risk events...")
synthesizer = RiskEventSynthesizer(seed=42)
risk_events = synthesizer.synthesize_events(enriched_hotspots, events_per_hotspot=EVENTS_PER_HOTSPOT)

# Get distribution
distribution = synthesizer.get_event_distribution(risk_events)
print(f"\n⚠️  Risk Event Distribution:")
print(f"  Total events: {distribution['total_events']}")
print(f"  By risk type:")
for risk_type, count in distribution['risk_type_distribution'].items():
    print(f"    {risk_type}: {count}")
print(f"  By weather risk:")
for weather_risk, count in distribution['weather_risk_distribution'].items():
    print(f"    {weather_risk}: {count}")

# Show sample event
print(f"\n📋 Sample Risk Event:")
sample_event = risk_events[0]
print(json.dumps(sample_event, indent=2))

## Step 4: Build Audio Recipes

In [None]:
# Build audio generation recipes with multiple variations per event
logger.info("Building audio recipes...")
builder = RecipeBuilder()
recipes = builder.build_recipes(risk_events, recipes_per_event=RECIPES_PER_EVENT)

# Get summary
summary = builder.get_recipe_summary(recipes)
print(f"\n🎵 Audio Recipe Summary:")
print(f"  Total recipes: {summary['total_recipes']}")
print(f"  Total audio duration: {summary['total_audio_duration_minutes']:.2f} minutes")
print(f"  By risk type:")
for risk_type, count in summary['risk_type_distribution'].items():
    print(f"    {risk_type}: {count} recipes")

# Show sample recipe
print(f"\n🎼 Sample Recipe:")
sample_recipe = recipes[0]
print(json.dumps(sample_recipe, indent=2))

## Step 5: Save Intermediate Data to S3

In [None]:
# Save risk events to S3
logger.info("Saving risk events to S3...")
risk_events_key = 'risk_events/risk_events.json'
s3_path = s3_client.write_json(risk_events, RAW_BUCKET, risk_events_key)
print(f"✓ Risk events saved to: {s3_path}")

# Save recipes to S3
logger.info("Saving recipes to S3...")
recipes_key = 'prompts/audio_recipes.json'
s3_path = s3_client.write_json(recipes, RAW_BUCKET, recipes_key)
print(f"✓ Recipes saved to: {s3_path}")

## Step 6: Run SageMaker Processing Job

In [None]:
# Initialize SageMaker session (region-agnostic)
boto_session = boto3.Session(region_name=REGION)
sagemaker_session = Session(boto_session=boto_session)

# Get IAM role ARN
iam_client = boto_session.client('iam')
role_response = iam_client.get_role(RoleName=SAGEMAKER_ROLE)
role_arn = role_response['Role']['Arn']

print(f"SageMaker Session:")
print(f"  Region: {REGION}")
print(f"  Role: {role_arn}")

In [None]:
# Configure processing job
processor = ScriptProcessor(
    role=role_arn,
    image_uri=f'763104351884.dkr.ecr.{REGION}.amazonaws.com/pytorch-training:2.0.1-cpu-py310',
    command=['python3'],
    instance_count=1,
    instance_type='ml.m5.xlarge',
    base_job_name='acousticshield-audio-generation',
    sagemaker_session=sagemaker_session
)

print("✓ Processor configured")

In [None]:
# Run processing job
logger.info("Starting SageMaker processing job...")

processor.run(
    code='../processing/augment.py',
    inputs=[
        ProcessingInput(
            source=f's3://{RAW_BUCKET}/{recipes_key}',
            destination='/opt/ml/processing/input'
        )
    ],
    outputs=[
        ProcessingOutput(
            source='/opt/ml/processing/output',
            destination=f's3://{ML_BUCKET}/train/',
            output_name='training_audio'
        )
    ],
    arguments=[
        '--recipe-dir', '/opt/ml/processing/input',
        '--output-dir', '/opt/ml/processing/output'
    ],
    wait=True,
    logs=True
)

print("\n✅ Processing job completed successfully!")

## Step 7: Validate Outputs

In [None]:
# List generated WAV files
logger.info("Listing generated audio files...")
output_files = s3_client.list_objects(ML_BUCKET, prefix='train/', max_keys=100)

wav_files = [f for f in output_files if f.endswith('.wav')]

print(f"\n🎧 Generated Audio Files:")
print(f"  Total WAV files: {len(wav_files)}")
print(f"  \nFirst 10 files:")
for i, file_key in enumerate(wav_files[:10], start=1):
    print(f"    {i}. s3://{ML_BUCKET}/{file_key}")

if len(wav_files) > 10:
    print(f"    ... and {len(wav_files) - 10} more")

In [None]:
# Count files by risk type
risk_type_counts = {'normal': 0, 'tireskid': 0, 'emergencybraking': 0, 'collisionimminent': 0}

for wav_file in wav_files:
    filename = wav_file.lower()
    for risk_type in risk_type_counts.keys():
        if risk_type in filename:
            risk_type_counts[risk_type] += 1
            break

print(f"\n📊 Audio Files by Risk Type:")
for risk_type, count in risk_type_counts.items():
    print(f"  {risk_type.title()}: {count} files")

## Summary

In [None]:
print("\n" + "="*70)
print("🎉 ACOUSTIC SHIELD DATA PIPELINE COMPLETE")
print("="*70)
print(f"\n📍 Crash Hotspots Analyzed: {len(hotspots)}")
print(f"⚠️  Risk Events Generated: {len(risk_events)}")
print(f"🎵 Audio Recipes Created: {len(recipes)}")
print(f"🎧 WAV Files Generated: {len(wav_files)}")
print(f"\n💾 Data Locations:")
print(f"  Risk Events: s3://{RAW_BUCKET}/{risk_events_key}")
print(f"  Recipes: s3://{RAW_BUCKET}/{recipes_key}")
print(f"  Training Audio: s3://{ML_BUCKET}/train/")
print(f"\n✅ Ready for model training!")
print("="*70)