# Advanced Apache Sedona Examples

This notebook demonstrates complex spatial analytics scenarios using Apache Sedona with a complete data science stack.

## 🎯 Prerequisites Verified:
- ✅ **Apache Spark 3.4.0** with optimized configuration  
- ✅ **Apache Sedona 1.4.1** with GeoTools wrapper (fixes FactoryException)
- ✅ **Advanced Analytics**: PySpark ML, clustering, and feature engineering
- ✅ **Visualization Stack**: Matplotlib, Seaborn, Folium for interactive maps
- ✅ **Geospatial Libraries**: GeoPandas, Shapely for geometry operations
- ✅ **Data Processing**: Pandas, NumPy for data manipulation

## 🚀 Complex Use Cases Covered:
1. **Spatial ETL Pipeline** - Processing large spatial datasets (NYC Taxi data simulation)
2. **Geofencing & Location Intelligence** - Multi-zone analysis with complex polygons
3. **Spatial Clustering** - Grid-based clustering for hotspot detection
4. **Route Optimization** - Network analysis and efficiency metrics
5. **Heatmap Generation** - Interactive spatial density visualizations
6. **Multi-scale Spatial Joins** - Performance optimization techniques
7. **Spatial Machine Learning** - Predictive spatial modeling with K-means clustering

## 📊 Performance Features:
- Optimized Spark configuration for spatial operations
- Broadcast join strategies for large-scale spatial queries
- Caching strategies for repeated spatial computations
- Advanced indexing techniques for spatial performance

In [1]:
# Advanced imports for complex spatial operations
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

# Import ML components
try:
    from pyspark.ml.clustering import KMeans
    from pyspark.ml.feature import VectorAssembler
    from pyspark.ml.evaluation import ClusteringEvaluator
    print("✅ PySpark ML components imported")
except ImportError as e:
    print(f"⚠️  PySpark ML import warning: {e}")

# Import Sedona components
from sedona.register import SedonaRegistrator
from sedona.utils import SedonaKryoRegistrator, KryoSerializer

# Import data science libraries
import numpy as np
import pandas as pd

# Import visualization libraries
try:
    import matplotlib.pyplot as plt
    import seaborn as sns
    from matplotlib.patches import Polygon as MPLPolygon
    print("✅ Matplotlib and Seaborn imported")
except ImportError as e:
    print(f"⚠️  Matplotlib/Seaborn import warning: {e}")

try:
    import folium
    from folium.plugins import HeatMap
    print("✅ Folium imported")
except ImportError as e:
    print(f"⚠️  Folium import warning: {e}")

try:
    import geopandas as gpd
    from shapely.geometry import Point, Polygon
    print("✅ GeoPandas and Shapely imported")
except ImportError as e:
    print(f"⚠️  GeoPandas/Shapely import warning: {e}")

# Standard library imports
import json
import random
from datetime import datetime, timedelta

# Set random seed for reproducibility
np.random.seed(42)
random.seed(42)

print("🎯 All imports completed successfully!")

✅ PySpark ML components imported
✅ Matplotlib and Seaborn imported
✅ Folium imported
✅ GeoPandas and Shapely imported
🎯 All imports completed successfully!


In [2]:
# Initialize Spark with optimized configuration for spatial operations
spark = SparkSession.builder \
    .appName("AdvancedSedonaExamples") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.kryo.registrator", "org.apache.sedona.core.serde.SedonaKryoRegistrator") \
    .config("spark.sql.extensions", "org.apache.sedona.viz.sql.SedonaVizExtensions,org.apache.sedona.sql.SedonaSqlExtensions") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")
SedonaRegistrator.registerAll(spark)

print("✅ Advanced Sedona environment initialized!")
print(f"Spark Version: {spark.version}")
print(f"Available cores: {spark.sparkContext.defaultParallelism}")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/10/25 12:48:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
  SedonaRegistrator.registerAll(spark)


✅ Advanced Sedona environment initialized!
Spark Version: 3.4.0
Available cores: 14


  cls.register(spark)
25/10/25 12:48:44 WARN UDTRegistration: Cannot register UDT for org.locationtech.jts.geom.Geometry, which is already registered.
25/10/25 12:48:44 WARN UDTRegistration: Cannot register UDT for org.locationtech.jts.index.SpatialIndex, which is already registered.
25/10/25 12:48:44 WARN UDTRegistration: Cannot register UDT for org.geotools.coverage.grid.GridCoverage2D, which is already registered.
25/10/25 12:48:44 WARN SimpleFunctionRegistry: The function st_union_aggr replaced a previously registered function.
25/10/25 12:48:44 WARN SimpleFunctionRegistry: The function st_envelope_aggr replaced a previously registered function.
25/10/25 12:48:44 WARN SimpleFunctionRegistry: The function st_intersection_aggr replaced a previously registered function.


In [3]:
# Quick verification that all packages work together
print("🧪 Testing package integration...")

# Test basic Spark functionality
try:
    test_df = spark.createDataFrame([(1, "test")], ["id", "value"])
    print(f"✅ Spark DataFrame: {test_df.count()} rows")
except Exception as e:
    print(f"❌ Spark test failed: {e}")

# Test Sedona spatial functions
try:
    test_spatial = spark.sql("SELECT ST_Point(1.0, 2.0) as point, ST_Distance(ST_Point(0.0, 0.0), ST_Point(3.0, 4.0)) as distance")
    result = test_spatial.collect()[0]
    print(f"✅ Sedona functions: distance = {result['distance']}")
except Exception as e:
    print(f"❌ Sedona test failed: {e}")

# Test visualization libraries
try:
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(figsize=(2, 2))
    ax.plot([1, 2], [1, 2])
    plt.close(fig)
    print("✅ Matplotlib working")
except Exception as e:
    print(f"❌ Matplotlib test failed: {e}")

print("🎯 All integrations verified!")

🧪 Testing package integration...
✅ Spark DataFrame: 1 rows
✅ Sedona functions: distance = 5.0
✅ Matplotlib working
🎯 All integrations verified!


In [5]:
# 🚀 QUICK START: Run All Prerequisites
# Execute this cell to set up all required data for the advanced examples

print("🚀 Setting up all prerequisites for advanced spatial analytics...")
print("This may take 2-3 minutes for the complete data pipeline...")

try:
    # Step 1: Generate taxi data if not exists
    if 'taxi_df' not in locals() or taxi_df is None:
        print("\n📊 Step 1: Generating NYC taxi trip data...")
        
        def generate_nyc_taxi_data(num_trips=50000):
            # NYC bounding box (approximate)
            nyc_bounds = {
                'min_lat': 40.4774, 'max_lat': 40.9176,
                'min_lon': -74.2591, 'max_lon': -73.7004
            }
            
            trips = []
            base_time = datetime(2024, 1, 1)
            
            for i in range(num_trips):
                # Pickup location (slightly clustered around Manhattan)
                pickup_lat = np.random.normal(40.7589, 0.05)
                pickup_lon = np.random.normal(-73.9851, 0.05)
                
                # Dropoff location (random within NYC)
                dropoff_lat = np.random.uniform(nyc_bounds['min_lat'], nyc_bounds['max_lat'])
                dropoff_lon = np.random.uniform(nyc_bounds['min_lon'], nyc_bounds['max_lon'])
                
                trip_time = base_time + timedelta(minutes=np.random.randint(0, 525600))
                fare = np.random.uniform(5.0, 50.0)
                distance = np.random.uniform(0.1, 20.0)
                
                trips.append({
                    'trip_id': f'trip_{i:06d}',
                    'pickup_datetime': trip_time.isoformat(),
                    'pickup_lat': pickup_lat,
                    'pickup_lon': pickup_lon,
                    'dropoff_lat': dropoff_lat,
                    'dropoff_lon': dropoff_lon,
                    'fare_amount': fare,
                    'trip_distance': distance,
                    'passenger_count': np.random.randint(1, 7)
                })
            
            return trips

        taxi_data = generate_nyc_taxi_data(25000)  # Reduced for faster execution
        
        taxi_schema = StructType([
            StructField("trip_id", StringType(), True),
            StructField("pickup_datetime", StringType(), True),
            StructField("pickup_lat", DoubleType(), True),
            StructField("pickup_lon", DoubleType(), True),
            StructField("dropoff_lat", DoubleType(), True),
            StructField("dropoff_lon", DoubleType(), True),
            StructField("fare_amount", DoubleType(), True),
            StructField("trip_distance", DoubleType(), True),
            StructField("passenger_count", IntegerType(), True)
        ])

        # --- Load Real Taxi Trip Data ---
        # Path to the real taxi trip data CSV
        # Using a sample for faster processing. Update the path to use other files.
        taxi_trips_path = "../taxi-trips/taxi_trips/2019_taxi_trips.csv" 

        print(f"Loading real taxi trips from: {taxi_trips_path}")

        # Load the data, inferring schema and using the header
        raw_taxi_df = spark.read.option("header", "true").option("inferSchema", "true").csv(taxi_trips_path)

        # Rename columns to match the notebook's expected schema and filter bad data
        # The original notebook expects: pickup_datetime, pickup_lon, pickup_lat, etc.
        taxi_df = raw_taxi_df.select(
            col("lpep_pickup_datetime").alias("pickup_datetime"),
            col("pickup_longitude").alias("pickup_lon"),
            col("pickup_latitude").alias("pickup_lat"),
            col("dropoff_longitude").alias("dropoff_lon"),
            col("dropoff_latitude").alias("dropoff_lat"),
            col("fare_amount"),
            col("trip_distance"),
            col("passenger_count")
        ).withColumn("trip_id", monotonically_increasing_id()) \
        .filter("pickup_lon IS NOT NULL AND pickup_lat IS NOT NULL AND dropoff_lon IS NOT NULL AND dropoff_lat IS NOT NULL")

        taxi_df.createOrReplaceTempView("taxi_trips")
        print(f"✅ Loaded and prepared {taxi_df.count()} real taxi trips.")
        taxi_df.show(5)
        print(f"    ✅ Generated {taxi_df.count()} taxi trips")
    else:
        print("    ✅ Taxi data already available")

    # Step 2: Create spatial trips
    if 'spatial_trips' not in locals() or spatial_trips is None:
        print("\n🗺️  Step 2: Processing spatial operations...")
        
        spatial_trips = spark.sql("""
            SELECT 
                trip_id,
                pickup_datetime,
                ST_Point(pickup_lon, pickup_lat) as pickup_point,
                ST_Point(dropoff_lon, dropoff_lat) as dropoff_point,
                ST_Distance(ST_Point(pickup_lon, pickup_lat), ST_Point(dropoff_lon, dropoff_lat)) as euclidean_distance,
                fare_amount,
                trip_distance,
                passenger_count,
                CASE 
                    WHEN HOUR(pickup_datetime) BETWEEN 7 AND 9 THEN 'Morning Rush'
                    WHEN HOUR(pickup_datetime) BETWEEN 17 AND 19 THEN 'Evening Rush'
                    WHEN HOUR(pickup_datetime) BETWEEN 22 AND 5 THEN 'Night'
                    ELSE 'Regular'
                END as time_period
            FROM taxi_trips
            WHERE pickup_lat BETWEEN 40.4 AND 41.0 
              AND pickup_lon BETWEEN -74.5 AND -73.5
              AND dropoff_lat BETWEEN 40.4 AND 41.0 
              AND dropoff_lon BETWEEN -74.5 AND -73.5
        """)
        
        spatial_trips.cache()
        print(f"    ✅ Created {spatial_trips.count()} spatial trips")
    else:
        print("    ✅ Spatial trips already available")

    # Step 3: Create zones
    if 'spatial_zones' not in locals() or spatial_zones is None:
        print("\n🏙️  Step 3: Creating NYC zones...")
        
        zones_data = [
            {
                'zone_id': 'manhattan_south', 
                'zone_name': 'Lower Manhattan',
                'polygon': 'POLYGON((-74.0479 40.6829, -73.9441 40.6829, -73.9441 40.7589, -74.0479 40.7589, -74.0479 40.6829))'
            },
            {
                'zone_id': 'manhattan_central', 
                'zone_name': 'Midtown Manhattan',
                'polygon': 'POLYGON((-74.0479 40.7589, -73.9441 40.7589, -73.9441 40.8176, -74.0479 40.8176, -74.0479 40.7589))'
            },
            {
                'zone_id': 'brooklyn_west', 
                'zone_name': 'West Brooklyn',
                'polygon': 'POLYGON((-74.0479 40.6000, -73.9000 40.6000, -73.9000 40.7000, -74.0479 40.7000, -74.0479 40.6000))'
            },
            {
                'zone_id': 'queens_central', 
                'zone_name': 'Central Queens',
                'polygon': 'POLYGON((-73.9000 40.7000, -73.7500 40.7000, -73.7500 40.8000, -73.9000 40.8000, -73.9000 40.7000))'
            }
        ]

            # --- Load Real Taxi Zone Data ---
        from sedona.read import ShapefileReader

        # Path to the taxi zones shapefile
        taxi_zones_path = "../taxi-trips/taxi_zones_map/taxi_zones_map_shapefiles/"

        print(f"Loading taxi zones from shapefile: {taxi_zones_path}")

        # Use Sedona's ShapefileReader to load the zone geometries
        zones_gdf = ShapefileReader.readToGeometryDF(spark.sparkContext, taxi_zones_path)

        # Convert to a Spark DataFrame and prepare for use
        spatial_zones = zones_gdf.select(
            col("LocationID").alias("zone_id"),
            col("zone").alias("zone_name"),
            col("geometry").alias("zone_geometry")
        ).withColumn("zone_area", ST_Area(col("zone_geometry")))

        spatial_zones.cache()
        spatial_zones.createOrReplaceTempView("spatial_zones")
        spatial_zones.show(5)
        print(f"    ✅ Created {spatial_zones.count()} spatial zones")
    else:
        print("    ✅ Spatial zones already available")

    print("\n🎉 All prerequisites completed successfully!")
    print("📝 You can now run any analysis cell in this notebook.")
    
except Exception as e:
    print(f"\n❌ Error during setup: {e}")
    print("Please check the error and try running individual cells.")
    import traceback
    traceback.print_exc()

🚀 Setting up all prerequisites for advanced spatial analytics...
This may take 2-3 minutes for the complete data pipeline...

📊 Step 1: Generating NYC taxi trip data...
Loading real taxi trips from: ../taxi-trips/taxi_trips/2019_taxi_trips.csv





❌ Error during setup: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `pickup_longitude` cannot be resolved. Did you mean one of the following? [`lpep_pickup_datetime`, `payment_type`, `tip_amount`, `trip_type`, `trip_distance`].;
'Project [lpep_pickup_datetime#111 AS pickup_datetime#148, 'pickup_longitude AS pickup_lon#149, 'pickup_latitude AS pickup_lat#150, 'dropoff_longitude AS dropoff_lon#151, 'dropoff_latitude AS dropoff_lat#152, fare_amount#119, trip_distance#118, passenger_count#117]
+- Relation [VendorID#110,lpep_pickup_datetime#111,lpep_dropoff_datetime#112,store_and_fwd_flag#113,RatecodeID#114,PULocationID#115,DOLocationID#116,passenger_count#117,trip_distance#118,fare_amount#119,extra#120,mta_tax#121,tip_amount#122,tolls_amount#123,improvement_surcharge#124,total_amount#125,payment_type#126,trip_type#127,congestion_surcharge#128] csv

Please check the error and try running individual cells.


Traceback (most recent call last):                                              
  File "/tmp/ipykernel_24/2985244721.py", line 75, in <module>
    taxi_df = raw_taxi_df.select(
  File "/usr/local/lib/python3.9/dist-packages/pyspark/sql/dataframe.py", line 3036, in select
    jdf = self._jdf.select(self._jcols(*cols))
  File "/usr/local/lib/python3.9/dist-packages/py4j/java_gateway.py", line 1322, in __call__
    return_value = get_return_value(
  File "/usr/local/lib/python3.9/dist-packages/pyspark/errors/exceptions/captured.py", line 175, in deco
    raise converted from None
pyspark.errors.exceptions.captured.AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `pickup_longitude` cannot be resolved. Did you mean one of the following? [`lpep_pickup_datetime`, `payment_type`, `tip_amount`, `trip_type`, `trip_distance`].;
'Project [lpep_pickup_datetime#111 AS pickup_datetime#148, 'pickup_longitude AS pickup_lon#149, 'pickup_latitude AS pickup_

## 1. Spatial ETL Pipeline: Processing NYC Taxi Data

Simulating processing of millions of taxi trips with spatial operations.

In [None]:
# Generate large-scale taxi trip data (simulating 1M+ trips)
def generate_nyc_taxi_data(num_trips=100000):
    # NYC bounding box (approximate)
    nyc_bounds = {
        'min_lat': 40.4774, 'max_lat': 40.9176,
        'min_lon': -74.2591, 'max_lon': -73.7004
    }
    
    # Generate trip data
    trips = []
    base_time = datetime(2024, 1, 1)
    
    for i in range(num_trips):
        # Pickup location (slightly clustered around Manhattan)
        pickup_lat = np.random.normal(40.7589, 0.05)  # Centered on Manhattan
        pickup_lon = np.random.normal(-73.9851, 0.05)
        
        # Dropoff location (random within NYC)
        dropoff_lat = np.random.uniform(nyc_bounds['min_lat'], nyc_bounds['max_lat'])
        dropoff_lon = np.random.uniform(nyc_bounds['min_lon'], nyc_bounds['max_lon'])
        
        # Trip details
        trip_time = base_time + timedelta(minutes=np.random.randint(0, 525600))  # Random time in year
        fare = np.random.uniform(5.0, 50.0)
        distance = np.random.uniform(0.1, 20.0)
        
        trips.append({
            'trip_id': f'trip_{i:06d}',
            'pickup_datetime': trip_time.isoformat(),
            'pickup_lat': pickup_lat,
            'pickup_lon': pickup_lon,
            'dropoff_lat': dropoff_lat,
            'dropoff_lon': dropoff_lon,
            'fare_amount': fare,
            'trip_distance': distance,
            'passenger_count': np.random.randint(1, 7)
        })
    
    return trips

print("Generating NYC taxi trip data...")
taxi_data = generate_nyc_taxi_data(50000)  # 50K trips for demo
print(f"Generated {len(taxi_data)} taxi trips")

# Convert to Spark DataFrame
taxi_schema = StructType([
    StructField("trip_id", StringType(), True),
    StructField("pickup_datetime", StringType(), True),
    StructField("pickup_lat", DoubleType(), True),
    StructField("pickup_lon", DoubleType(), True),
    StructField("dropoff_lat", DoubleType(), True),
    StructField("dropoff_lon", DoubleType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("passenger_count", IntegerType(), True)
])

taxi_df = spark.createDataFrame(taxi_data, schema=taxi_schema)
print(f"Created Spark DataFrame with {taxi_df.count()} records")

In [None]:
# Complex Spatial ETL Operations
taxi_df.createOrReplaceTempView("taxi_trips")

# 1. Create spatial geometries and calculate trip vectors
spatial_trips = spark.sql("""
    SELECT 
        trip_id,
        pickup_datetime,
        ST_Point(pickup_lon, pickup_lat) as pickup_point,
        ST_Point(dropoff_lon, dropoff_lat) as dropoff_point,
        ST_Distance(ST_Point(pickup_lon, pickup_lat), ST_Point(dropoff_lon, dropoff_lat)) as euclidean_distance,
        fare_amount,
        trip_distance,
        passenger_count,
        CASE 
            WHEN HOUR(pickup_datetime) BETWEEN 7 AND 9 THEN 'Morning Rush'
            WHEN HOUR(pickup_datetime) BETWEEN 17 AND 19 THEN 'Evening Rush'
            WHEN HOUR(pickup_datetime) BETWEEN 22 AND 5 THEN 'Night'
            ELSE 'Regular'
        END as time_period
    FROM taxi_trips
    WHERE pickup_lat BETWEEN 40.4 AND 41.0 
      AND pickup_lon BETWEEN -74.5 AND -73.5
      AND dropoff_lat BETWEEN 40.4 AND 41.0 
      AND dropoff_lon BETWEEN -74.5 AND -73.5
""")

spatial_trips.cache()
print(f"Processed {spatial_trips.count()} valid spatial trips")
spatial_trips.show(5)

In [None]:
# Data Pipeline Validation
print("🔍 Validating data pipeline...")

# Check if required DataFrames/tables exist
required_tables = ['taxi_trips', 'spatial_trips', 'spatial_zones']
missing_tables = []

for table in required_tables:
    try:
        # Check if table exists by trying to get its count
        if table == 'taxi_trips':
            count = taxi_df.count() if 'taxi_df' in locals() else 0
        elif table == 'spatial_trips':  
            count = spatial_trips.count() if 'spatial_trips' in locals() else 0
        elif table == 'spatial_zones':
            count = spatial_zones.count() if 'spatial_zones' in locals() else 0
        
        if count > 0:
            print(f"✅ {table}: {count} records")
        else:
            missing_tables.append(table)
            print(f"❌ {table}: Missing or empty")
            
    except Exception as e:
        missing_tables.append(table)
        print(f"❌ {table}: Not found - {str(e)[:50]}")

if missing_tables:
    print(f"\n⚠️  Missing tables: {missing_tables}")
    print("📝 Please run the previous cells in order to create the required data!")
    print("   1. Run taxi data generation cell (creates taxi_df)")
    print("   2. Run spatial processing cell (creates spatial_trips)")  
    print("   3. Run zone creation cell (creates spatial_zones)")
else:
    print("\n🎯 All required data available - ready for analysis!")

## 2. Advanced Geofencing: Multi-Zone Analysis

Creating complex geofences and analyzing spatial patterns.

In [None]:
# Create NYC borough-like zones using polygons
zones_data = [
    {
        'zone_id': 'manhattan_south', 
        'zone_name': 'Lower Manhattan',
        'polygon': 'POLYGON((-74.0479 40.6829, -73.9441 40.6829, -73.9441 40.7589, -74.0479 40.7589, -74.0479 40.6829))'
    },
    {
        'zone_id': 'manhattan_central', 
        'zone_name': 'Midtown Manhattan',
        'polygon': 'POLYGON((-74.0479 40.7589, -73.9441 40.7589, -73.9441 40.8176, -74.0479 40.8176, -74.0479 40.7589))'
    },
    {
        'zone_id': 'brooklyn_west', 
        'zone_name': 'West Brooklyn',
        'polygon': 'POLYGON((-74.0479 40.6000, -73.9000 40.6000, -73.9000 40.7000, -74.0479 40.7000, -74.0479 40.6000))'
    },
    {
        'zone_id': 'queens_central', 
        'zone_name': 'Central Queens',
        'polygon': 'POLYGON((-73.9000 40.7000, -73.7500 40.7000, -73.7500 40.8000, -73.9000 40.8000, -73.9000 40.7000))'
    }
]

zones_df = spark.createDataFrame(zones_data)
zones_df.createOrReplaceTempView("zones")

# Create spatial zones
spatial_zones = spark.sql("""
    SELECT 
        zone_id,
        zone_name,
        ST_GeomFromWKT(polygon) as zone_geometry,
        ST_Area(ST_GeomFromWKT(polygon)) as zone_area
    FROM zones
""")

spatial_zones.show()

In [None]:
# Complex spatial join: Assign pickup and dropoff zones
# SMART EXECUTION: This cell will create required data if missing!

print("🔍 Checking for required data...")

# Check if spatial_trips exists, if not run the Quick Start setup
if 'spatial_trips' not in locals() or spatial_trips is None:
    print("⚠️  spatial_trips not found. Running quick setup...")
    
    # Generate basic taxi data if needed
    if 'taxi_df' not in locals() or taxi_df is None:
        print("📊 Generating taxi data...")
        from datetime import datetime, timedelta
        
        taxi_data = []
        base_time = datetime(2024, 1, 1)
        
        for i in range(10000):  # Smaller dataset for faster execution
            pickup_lat = np.random.normal(40.7589, 0.05)
            pickup_lon = np.random.normal(-73.9851, 0.05)
            dropoff_lat = np.random.uniform(40.6, 40.9)
            dropoff_lon = np.random.uniform(-74.1, -73.7)
            
            taxi_data.append({
                'trip_id': f'trip_{i:06d}',
                'pickup_datetime': (base_time + timedelta(minutes=np.random.randint(0, 1440))).isoformat(),
                'pickup_lat': pickup_lat,
                'pickup_lon': pickup_lon,
                'dropoff_lat': dropoff_lat,
                'dropoff_lon': dropoff_lon,
                'fare_amount': np.random.uniform(5.0, 50.0),
                'trip_distance': np.random.uniform(0.1, 20.0),
                'passenger_count': np.random.randint(1, 7)
            })
        
        taxi_schema = StructType([
            StructField("trip_id", StringType(), True),
            StructField("pickup_datetime", StringType(), True),
            StructField("pickup_lat", DoubleType(), True),
            StructField("pickup_lon", DoubleType(), True),
            StructField("dropoff_lat", DoubleType(), True),
            StructField("dropoff_lon", DoubleType(), True),
            StructField("fare_amount", DoubleType(), True),
            StructField("trip_distance", DoubleType(), True),
            StructField("passenger_count", IntegerType(), True)
        ])
        
        taxi_df = spark.createDataFrame(taxi_data, schema=taxi_schema)
        taxi_df.createOrReplaceTempView("taxi_trips")
        print(f"    ✅ Created {taxi_df.count()} taxi trips")
    
    # Create spatial_trips
    spatial_trips = spark.sql("""
        SELECT 
            trip_id,
            pickup_datetime,
            ST_Point(pickup_lon, pickup_lat) as pickup_point,
            ST_Point(dropoff_lon, dropoff_lat) as dropoff_point,
            ST_Distance(ST_Point(pickup_lon, pickup_lat), ST_Point(dropoff_lon, dropoff_lat)) as euclidean_distance,
            fare_amount,
            trip_distance,
            passenger_count,
            CASE 
                WHEN HOUR(pickup_datetime) BETWEEN 7 AND 9 THEN 'Morning Rush'
                WHEN HOUR(pickup_datetime) BETWEEN 17 AND 19 THEN 'Evening Rush'
                WHEN HOUR(pickup_datetime) BETWEEN 22 AND 5 THEN 'Night'
                ELSE 'Regular'
            END as time_period
        FROM taxi_trips
        WHERE pickup_lat BETWEEN 40.4 AND 41.0 
          AND pickup_lon BETWEEN -74.5 AND -73.5
          AND dropoff_lat BETWEEN 40.4 AND 41.0 
          AND dropoff_lon BETWEEN -74.5 AND -73.5
    """)
    spatial_trips.cache()
    print(f"    ✅ Created {spatial_trips.count()} spatial trips")

# Check if spatial_zones exists, if not create it
if 'spatial_zones' not in locals() or spatial_zones is None:
    print("🏙️  Creating spatial zones...")
    
    zones_data = [
        {'zone_id': 'manhattan_south', 'zone_name': 'Lower Manhattan',
         'polygon': 'POLYGON((-74.0479 40.6829, -73.9441 40.6829, -73.9441 40.7589, -74.0479 40.7589, -74.0479 40.6829))'},
        {'zone_id': 'manhattan_central', 'zone_name': 'Midtown Manhattan',
         'polygon': 'POLYGON((-74.0479 40.7589, -73.9441 40.7589, -73.9441 40.8176, -74.0479 40.8176, -74.0479 40.7589))'},
        {'zone_id': 'brooklyn_west', 'zone_name': 'West Brooklyn',
         'polygon': 'POLYGON((-74.0479 40.6000, -73.9000 40.6000, -73.9000 40.7000, -74.0479 40.7000, -74.0479 40.6000))'},
        {'zone_id': 'queens_central', 'zone_name': 'Central Queens',
         'polygon': 'POLYGON((-73.9000 40.7000, -73.7500 40.7000, -73.7500 40.8000, -73.9000 40.8000, -73.9000 40.7000))'}
    ]

    zones_df = spark.createDataFrame(zones_data)
    zones_df.createOrReplaceTempView("zones")

    spatial_zones = spark.sql("""
        SELECT 
            zone_id,
            zone_name,
            ST_GeomFromWKT(polygon) as zone_geometry,
            ST_Area(ST_GeomFromWKT(polygon)) as zone_area
        FROM zones
    """)
    spatial_zones.cache()
    print(f"    ✅ Created {spatial_zones.count()} spatial zones")

# IMPORTANT: Create temp views BEFORE using them in SQL queries
try:
    spatial_trips.createOrReplaceTempView("spatial_trips")
    spatial_zones.createOrReplaceTempView("spatial_zones")
    print("✅ Created temporary views for spatial_trips and spatial_zones")
    
    # Now perform the spatial join
    trips_with_zones = spark.sql("""
        SELECT 
            t.trip_id,
            t.pickup_datetime,
            t.pickup_point,
            t.dropoff_point,
            t.euclidean_distance,
            t.fare_amount,
            t.trip_distance,
            t.passenger_count,
            t.time_period,
            pz.zone_id as pickup_zone,
            pz.zone_name as pickup_zone_name,
            dz.zone_id as dropoff_zone,
            dz.zone_name as dropoff_zone_name,
            CASE 
                WHEN pz.zone_id = dz.zone_id THEN 'Intra-zone'
                ELSE 'Inter-zone'
            END as trip_type
        FROM spatial_trips t
        LEFT JOIN spatial_zones pz ON ST_Within(t.pickup_point, pz.zone_geometry)
        LEFT JOIN spatial_zones dz ON ST_Within(t.dropoff_point, dz.zone_geometry)
    """)
    
    trips_with_zones.cache()
    # CRITICAL: Create temp view for trips_with_zones so other cells can use it
    trips_with_zones.createOrReplaceTempView("trips_with_zones")
    print(f"✅ trips_with_zones created with {trips_with_zones.count()} records")
    print("✅ Created temporary view 'trips_with_zones' for SQL queries")

    # Now we can run SQL queries on trips_with_zones
    zone_analysis = spark.sql("""
        SELECT 
            pickup_zone_name,
            dropoff_zone_name,
            trip_type,
            time_period,
            COUNT(*) as trip_count,
            AVG(fare_amount) as avg_fare,
            AVG(euclidean_distance) as avg_distance,
            SUM(passenger_count) as total_passengers
        FROM trips_with_zones
        WHERE pickup_zone IS NOT NULL AND dropoff_zone IS NOT NULL
        GROUP BY pickup_zone_name, dropoff_zone_name, trip_type, time_period
        ORDER BY trip_count DESC
    """)
    
    print("🎯 Zone Analysis Results:")
    zone_analysis.show(20)
    
except Exception as e:
    print(f"❌ Error creating spatial join: {e}")
    print("Try running the '🚀 QUICK START' cell at the top of the notebook first!")
    import traceback
    traceback.print_exc()

## ⚠️ **EXECUTION ORDER IMPORTANT**

**If you get `TABLE_OR_VIEW_NOT_FOUND` errors**, it means you need to run cells in order:

1. **First**: Run the **"🚀 QUICK START: Run All Prerequisites"** cell above (Cell 5)
2. **Then**: You can run any analysis cell below

**OR** run these cells in sequence:
- Cell 8: Generate taxi data → Creates `taxi_df` 
- Cell 9: Process spatial trips → Creates `spatial_trips`
- Cell 11: Create zones → Creates `spatial_zones`
- Cell 12: **This cell** → Creates `trips_with_zones`

💡 **Tip**: The Quick Start cell does everything at once!

## 💡 **Important Note**

**Execute cells in order!** This notebook builds a data pipeline where each cell depends on previous ones:

1. **Cell 2-3**: Initialize Spark and verify setup
2. **Cell 4-5**: Generate and process taxi trip data → Creates `taxi_df` and `spatial_trips`
3. **Cell 6**: Create spatial zones → Creates `spatial_zones`  
4. **Cell 7**: **This cell** → Creates `trips_with_zones` (depends on previous cells)
5. **Cell 8+**: Analysis cells that use `trips_with_zones`

If you get `TABLE_OR_VIEW_NOT_FOUND` errors, **run all previous cells first**!

## 3. Spatial Clustering: DBSCAN-like Analysis

Finding hotspots and clusters in pickup locations.

In [None]:
# Spatial clustering using grid-based approach (DBSCAN alternative for big data)

# Check if trips_with_zones exists
try:
    # Test if the table exists by trying to get its schema
    spark.table("trips_with_zones").schema
    print("✅ trips_with_zones table found")
except Exception:
    print("⚠️  trips_with_zones table not found. Using spatial_trips instead.")
    print("💡 Run the spatial join cell above first for complete analysis!")
    
    # Fallback to spatial_trips if available
    if 'spatial_trips' in locals() and spatial_trips is not None:
        spatial_trips.createOrReplaceTempView("trips_with_zones")
        print("✅ Using spatial_trips as fallback")
    else:
        print("❌ No spatial data available. Please run previous cells first!")
        raise Exception("Required data not available - run the Quick Start cell or previous cells in order")

def create_spatial_grid(table_name="trips_with_zones", grid_size=0.001):  # ~100m grid cells
    """
    Create a spatial grid for clustering analysis
    """
    try:
        grid_df = spark.sql(f"""
            SELECT 
                FLOOR(ST_X(pickup_point) / {grid_size}) * {grid_size} as grid_x,
                FLOOR(ST_Y(pickup_point) / {grid_size}) * {grid_size} as grid_y,
                COUNT(*) as point_count,
                AVG(fare_amount) as avg_fare,
                time_period,
                COLLECT_LIST(trip_id) as trip_ids
            FROM {table_name}
            WHERE pickup_point IS NOT NULL
            GROUP BY grid_x, grid_y, time_period
            HAVING point_count >= 3
            ORDER BY point_count DESC
        """)
        
        return grid_df
    except Exception as e:
        print(f"❌ Error creating spatial grid: {e}")
        return None

# Create hotspot analysis
hotspots = create_spatial_grid()

if hotspots is not None:
    hotspots.cache()
    hotspots.createOrReplaceTempView("hotspots")  # Create view for other cells
    print(f"✅ Identified {hotspots.count()} spatial hotspots")
    hotspots.show(10)
else:
    print("❌ Could not create hotspots analysis")

In [None]:
# Advanced hotspot analysis with density metrics
hotspots.createOrReplaceTempView("hotspots")

hotspot_analysis = spark.sql("""
    WITH hotspot_stats AS (
        SELECT 
            grid_x,
            grid_y,
            time_period,
            point_count,
            avg_fare,
            ST_Point(grid_x, grid_y) as grid_center,
            -- Calculate local density (points within 500m)
            (
                SELECT SUM(h2.point_count) 
                FROM hotspots h2 
                WHERE h2.time_period = h1.time_period
                  AND ST_Distance(ST_Point(h1.grid_x, h1.grid_y), ST_Point(h2.grid_x, h2.grid_y)) <= 0.005
            ) as neighborhood_density
        FROM hotspots h1
    ),
    ranked_hotspots AS (
        SELECT *,
            ROW_NUMBER() OVER (PARTITION BY time_period ORDER BY neighborhood_density DESC) as density_rank,
            CASE 
                WHEN neighborhood_density >= 100 THEN 'Super Hotspot'
                WHEN neighborhood_density >= 50 THEN 'Major Hotspot'
                WHEN neighborhood_density >= 20 THEN 'Minor Hotspot'
                ELSE 'Regular Area'
            END as hotspot_category
        FROM hotspot_stats
    )
    SELECT 
        time_period,
        hotspot_category,
        COUNT(*) as num_areas,
        AVG(point_count) as avg_pickups_per_area,
        AVG(avg_fare) as avg_fare_in_category,
        MAX(neighborhood_density) as max_density
    FROM ranked_hotspots
    GROUP BY time_period, hotspot_category
    ORDER BY time_period, max_density DESC
""")

print("Hotspot Analysis by Time Period:")
hotspot_analysis.show()

## 4. Route Optimization & Network Analysis

Analyzing optimal routes and identifying inefficient trips.

In [None]:
# Route efficiency analysis
trips_with_zones.createOrReplaceTempView("trips_analysis")

route_efficiency = spark.sql("""
    SELECT 
        trip_id,
        pickup_zone_name,
        dropoff_zone_name,
        euclidean_distance,
        trip_distance,
        fare_amount,
        time_period,
        -- Calculate efficiency metrics
        CASE 
            WHEN euclidean_distance > 0 THEN trip_distance / euclidean_distance 
            ELSE NULL 
        END as detour_ratio,
        
        CASE 
            WHEN trip_distance > 0 THEN fare_amount / trip_distance 
            ELSE NULL 
        END as fare_per_mile,
        
        -- Classify trip efficiency
        CASE 
            WHEN trip_distance / euclidean_distance <= 1.2 THEN 'Efficient'
            WHEN trip_distance / euclidean_distance <= 1.5 THEN 'Moderate'
            ELSE 'Inefficient'
        END as route_efficiency
    FROM trips_analysis
    WHERE euclidean_distance > 0.001  -- Filter out very short trips
      AND trip_distance > 0
      AND pickup_zone_name IS NOT NULL
      AND dropoff_zone_name IS NOT NULL
""")

route_efficiency.cache()
print(f"Route efficiency analysis for {route_efficiency.count()} trips")
route_efficiency.show(10)

In [None]:
# Advanced route analysis with corridor identification
route_efficiency.createOrReplaceTempView("route_efficiency")

corridor_analysis = spark.sql("""
    WITH route_corridors AS (
        SELECT 
            pickup_zone_name,
            dropoff_zone_name,
            time_period,
            COUNT(*) as trip_volume,
            AVG(detour_ratio) as avg_detour,
            AVG(fare_per_mile) as avg_fare_per_mile,
            AVG(euclidean_distance) as avg_distance,
            PERCENTILE_APPROX(detour_ratio, 0.95) as p95_detour,
            -- Efficiency score (lower is better)
            AVG(detour_ratio) * 100 + (1.0 / AVG(fare_per_mile)) as inefficiency_score
        FROM route_efficiency
        WHERE pickup_zone_name != dropoff_zone_name  -- Inter-zone trips only
        GROUP BY pickup_zone_name, dropoff_zone_name, time_period
        HAVING trip_volume >= 5  -- Focus on popular routes
    ),
    ranked_corridors AS (
        SELECT *,
            ROW_NUMBER() OVER (PARTITION BY time_period ORDER BY trip_volume DESC) as volume_rank,
            ROW_NUMBER() OVER (PARTITION BY time_period ORDER BY inefficiency_score DESC) as inefficiency_rank
        FROM route_corridors
    )
    SELECT 
        time_period,
        pickup_zone_name,
        dropoff_zone_name,
        trip_volume,
        ROUND(avg_detour, 2) as avg_detour_ratio,
        ROUND(avg_fare_per_mile, 2) as avg_fare_per_mile,
        ROUND(inefficiency_score, 2) as inefficiency_score,
        volume_rank,
        inefficiency_rank,
        CASE 
            WHEN inefficiency_rank <= 3 THEN 'Optimization Priority'
            WHEN volume_rank <= 5 THEN 'High Volume Corridor'
            ELSE 'Regular Route'
        END as route_category
    FROM ranked_corridors
    WHERE volume_rank <= 10 OR inefficiency_rank <= 5
    ORDER BY time_period, inefficiency_rank
""")

print("Top Route Corridors by Time Period:")
corridor_analysis.show(30, truncate=False)

## 5. Spatial Machine Learning: Demand Prediction

Using spatial features for predictive modeling.

In [None]:
# Create features for ML model
ml_features = spark.sql("""
    WITH spatial_features AS (
        SELECT 
            grid_x,
            grid_y,
            time_period,
            point_count as demand,
            avg_fare,
            -- Spatial features
            grid_x * 1000000 as x_scaled,  -- Scale coordinates
            grid_y * 1000000 as y_scaled,
            
            -- Time-based features
            CASE time_period 
                WHEN 'Morning Rush' THEN 1 
                WHEN 'Evening Rush' THEN 2
                WHEN 'Night' THEN 3
                ELSE 0 
            END as time_encoded,
            
            -- Distance from city center (approximate)
            SQRT(POWER(grid_x - (-73.9851), 2) + POWER(grid_y - 40.7589, 2)) as distance_from_center
        FROM hotspots
        WHERE point_count >= 3
    )
    SELECT *,
        -- Categorize demand levels for classification
        CASE 
            WHEN demand >= 20 THEN 2  -- High demand
            WHEN demand >= 10 THEN 1  -- Medium demand  
            ELSE 0                    -- Low demand
        END as demand_category
    FROM spatial_features
""")

ml_features.cache()
print(f"Created ML features for {ml_features.count()} spatial-temporal points")
ml_features.show(10)

In [None]:
# Prepare data for ML clustering (spatial demand patterns)
feature_cols = ['x_scaled', 'y_scaled', 'time_encoded', 'distance_from_center', 'avg_fare']

# Assemble features
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
ml_data = assembler.transform(ml_features)

# Apply K-means clustering to identify demand patterns
kmeans = KMeans(k=5, seed=42, featuresCol="features", predictionCol="cluster")
model = kmeans.fit(ml_data)
predictions = model.transform(ml_data)

# Analyze clusters
predictions.createOrReplaceTempView("ml_predictions")

cluster_analysis = spark.sql("""
    SELECT 
        cluster,
        COUNT(*) as cluster_size,
        AVG(demand) as avg_demand,
        AVG(avg_fare) as avg_fare_in_cluster,
        AVG(distance_from_center) as avg_distance_from_center,
        
        -- Most common time period in cluster
        MODE() WITHIN GROUP (ORDER BY time_period) as dominant_time_period,
        
        -- Demand characteristics
        MIN(demand) as min_demand,
        MAX(demand) as max_demand,
        STDDEV(demand) as demand_std
    FROM ml_predictions
    GROUP BY cluster
    ORDER BY cluster
""")

print("Spatial-Temporal Demand Clusters:")
cluster_analysis.show()

## 6. Performance Optimization Techniques

Demonstrating advanced Sedona performance optimization.

In [None]:
# Spatial indexing and partitioning strategies
import time

def benchmark_spatial_join(df1, df2, join_condition, description):
    """
    Benchmark different spatial join strategies
    """
    start_time = time.time()
    result_count = join_condition.count()
    end_time = time.time()
    
    print(f"{description}:")
    print(f"  - Result count: {result_count}")
    print(f"  - Execution time: {end_time - start_time:.2f} seconds")
    print(f"  - Partitions: {join_condition.rdd.getNumPartitions()}")
    return result_count, end_time - start_time

# Create test datasets
large_points = spark.sql("""
    SELECT 
        ST_Point(RAND() * 0.1 - 74.0, RAND() * 0.1 + 40.7) as point,
        CAST(RAND() * 1000 AS INT) as point_id
    FROM range(10000)
""")

test_polygons = spark.sql("""
    SELECT 
        ST_Buffer(ST_Point(RAND() * 0.05 - 73.98, RAND() * 0.05 + 40.75), 0.001) as polygon,
        CAST(RAND() * 100 AS INT) as poly_id
    FROM range(100)
""")

large_points.cache()
test_polygons.cache()

print(f"Created {large_points.count()} test points and {test_polygons.count()} test polygons")

In [None]:
# Benchmark different join strategies
large_points.createOrReplaceTempView("large_points")
test_polygons.createOrReplaceTempView("test_polygons")

# Strategy 1: Basic spatial join
basic_join = spark.sql("""
    SELECT p.point_id, pg.poly_id
    FROM large_points p, test_polygons pg
    WHERE ST_Within(p.point, pg.polygon)
""")

benchmark_spatial_join(large_points, test_polygons, basic_join, "Basic Spatial Join")

# Strategy 2: With spatial indexing hint
indexed_join = spark.sql("""
    SELECT /*+ BROADCAST(pg) */ p.point_id, pg.poly_id
    FROM large_points p, test_polygons pg
    WHERE ST_Within(p.point, pg.polygon)
""")

benchmark_spatial_join(large_points, test_polygons, indexed_join, "Broadcast Join Strategy")

# Performance tips summary
print("\n🚀 Performance Optimization Tips:")
print("1. Use broadcast joins for small polygon datasets (< 200MB)")
print("2. Partition data by spatial regions for large datasets")
print("3. Cache frequently accessed spatial DataFrames")
print("4. Use appropriate spatial predicates (ST_Within vs ST_Intersects)")
print("5. Consider spatial indexing for repeated queries")

## 7. Advanced Visualization: Interactive Heatmaps

Creating sophisticated spatial visualizations.

In [None]:
# Prepare data for visualization
viz_data = spark.sql("""
    SELECT 
        grid_x as longitude,
        grid_y as latitude,
        point_count as intensity,
        avg_fare,
        time_period
    FROM hotspots
    WHERE point_count >= 5
    ORDER BY point_count DESC
    LIMIT 500
""")

viz_pandas = viz_data.toPandas()
print(f"Prepared {len(viz_pandas)} points for visualization")

# Create multi-layer interactive map
def create_advanced_heatmap(df):
    # Center map on NYC
    center_lat = df['latitude'].mean()
    center_lon = df['longitude'].mean()
    
    m = folium.Map(
        location=[center_lat, center_lon],
        zoom_start=12,
        tiles='OpenStreetMap'
    )
    
    # Add different layers for different time periods
    time_periods = df['time_period'].unique()
    colors = ['red', 'blue', 'green', 'orange']
    
    for i, period in enumerate(time_periods):
        period_data = df[df['time_period'] == period]
        
        # Create heatmap data
        heat_data = [[row['latitude'], row['longitude'], row['intensity']] 
                    for idx, row in period_data.iterrows()]
        
        if heat_data:  # Only create layer if data exists
            HeatMap(
                heat_data,
                name=f'Demand - {period}',
                radius=15,
                blur=10,
                gradient={0.2: colors[i % len(colors)], 1.0: colors[i % len(colors)]}
            ).add_to(m)
    
    # Add layer control
    folium.LayerControl().add_to(m)
    
    return m

if len(viz_pandas) > 0:
    heatmap = create_advanced_heatmap(viz_pandas)
    print("✅ Interactive heatmap created! (Display in Jupyter)")
    # heatmap  # Uncomment to display in Jupyter
else:
    print("No data available for visualization")

In [None]:
# Advanced statistical analysis
summary_stats = spark.sql("""
    WITH overall_stats AS (
        SELECT 
            COUNT(DISTINCT trip_id) as total_trips,
            COUNT(DISTINCT pickup_zone_name) as unique_pickup_zones,
            COUNT(DISTINCT dropoff_zone_name) as unique_dropoff_zones,
            AVG(fare_amount) as avg_fare,
            AVG(euclidean_distance) as avg_distance,
            SUM(passenger_count) as total_passengers
        FROM trips_with_zones
        WHERE pickup_zone_name IS NOT NULL
    ),
    efficiency_stats AS (
        SELECT 
            route_efficiency,
            COUNT(*) as trip_count,
            AVG(detour_ratio) as avg_detour,
            AVG(fare_per_mile) as avg_fare_per_mile
        FROM route_efficiency
        GROUP BY route_efficiency
    )
    SELECT 
        'Overall Statistics' as metric_type,
        CAST(total_trips AS STRING) as value,
        'Total processed trips' as description
    FROM overall_stats
    
    UNION ALL
    
    SELECT 
        'Spatial Coverage' as metric_type,
        CAST(unique_pickup_zones AS STRING) as value,
        'Unique pickup zones covered' as description
    FROM overall_stats
    
    UNION ALL
    
    SELECT 
        'Route Efficiency' as metric_type,
        CONCAT(route_efficiency, ': ', CAST(trip_count AS STRING), ' trips') as value,
        CONCAT('Avg detour ratio: ', CAST(ROUND(avg_detour, 2) AS STRING)) as description
    FROM efficiency_stats
    ORDER BY metric_type, value
""")

print("\n📊 Advanced Spatial Analytics Summary:")
summary_stats.show(20, truncate=False)

In [None]:
# Cleanup
print("\n🧹 Cleaning up resources...")
spark.catalog.clearCache()
print("Cache cleared.")

print("\n🎯 Complex Spatial Analytics Completed!")
print("\nKey Capabilities Demonstrated:")
print("• Large-scale spatial ETL processing")
print("• Multi-zone geofencing analysis")
print("• Spatial clustering and hotspot detection")
print("• Route optimization analysis")
print("• Spatial machine learning integration")
print("• Performance optimization techniques")
print("• Advanced interactive visualizations")

# Uncomment to stop Spark (keep running for interactive use)
# spark.stop()