# Spatial Joins and Analysis with Apache Sedona

This notebook demonstrates spatial joins, buffering, and intersection operations using Apache Sedona.

In [None]:
from sedona.spark import SedonaContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from typing import List, Tuple
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# Initialize Sedona session
sedona = SedonaContext.create(
    SparkSession.builder
    .appName("SpatialJoinsAnalysis")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.kryo.registrator", "org.apache.sedona.core.serde.SedonaKryoRegistrator")
    .getOrCreate()
)

logger.info("Sedona session initialized for spatial analysis")

In [None]:
# Create sample city points
cities_data = [
    ("NYC", -74.0059, 40.7128, 8500000),
    ("LA", -118.2437, 34.0522, 4000000),
    ("Chicago", -87.6298, 41.8781, 2700000),
    ("Houston", -95.3698, 29.7604, 2300000),
    ("Phoenix", -112.0740, 33.4484, 1700000),
]

cities_schema = StructType([
    StructField("city_name", StringType(), True),
    StructField("longitude", DoubleType(), True),
    StructField("latitude", DoubleType(), True),
    StructField("population", IntegerType(), True)
])

cities_df = sedona.createDataFrame(cities_data, cities_schema)
cities_df.createOrReplaceTempView("cities")

In [None]:
# Create sample polygon regions (simplified state boundaries)
regions_data = [
    ("California", "POLYGON((-124.4 32.5, -114.1 32.5, -114.1 42.0, -124.4 42.0, -124.4 32.5))"),
    ("Texas", "POLYGON((-106.6 25.8, -93.5 25.8, -93.5 36.5, -106.6 36.5, -106.6 25.8))"),
    ("New York", "POLYGON((-79.8 40.5, -71.9 40.5, -71.9 45.0, -79.8 45.0, -79.8 40.5))"),
]

regions_schema = StructType([
    StructField("state_name", StringType(), True),
    StructField("geometry_wkt", StringType(), True)
])

regions_df = sedona.createDataFrame(regions_data, regions_schema)
regions_df.createOrReplaceTempView("regions")

In [None]:
# Perform spatial join: cities within states
spatial_join_df = sedona.sql("""
    SELECT 
        c.city_name,
        c.population,
        r.state_name,
        ST_Point(c.longitude, c.latitude) as city_geometry,
        ST_GeomFromWKT(r.geometry_wkt) as state_geometry,
        ST_Contains(ST_GeomFromWKT(r.geometry_wkt), ST_Point(c.longitude, c.latitude)) as is_within
    FROM cities c
    CROSS JOIN regions r
    WHERE ST_Contains(ST_GeomFromWKT(r.geometry_wkt), ST_Point(c.longitude, c.latitude))
""")

print("Cities within their respective states:")
spatial_join_df.show()

In [None]:
# Create buffers around cities and find intersections
buffer_analysis_df = sedona.sql("""
    WITH city_buffers AS (
        SELECT 
            city_name,
            population,
            ST_Point(longitude, latitude) as city_point,
            ST_Buffer(ST_Point(longitude, latitude), 2.0) as city_buffer_2deg
        FROM cities
    ),
    intersections AS (
        SELECT 
            c1.city_name as city1,
            c2.city_name as city2,
            c1.population as pop1,
            c2.population as pop2,
            ST_Intersects(c1.city_buffer_2deg, c2.city_point) as buffers_intersect,
            ST_Distance(c1.city_point, c2.city_point) as distance_degrees
        FROM city_buffers c1
        CROSS JOIN city_buffers c2
        WHERE c1.city_name != c2.city_name
    )
    SELECT *
    FROM intersections 
    WHERE buffers_intersect = true
    ORDER BY distance_degrees
""")

print("Cities within 2-degree buffer of each other:")
buffer_analysis_df.show()

In [None]:
# Advanced spatial aggregation: convex hull of major cities
convex_hull_df = sedona.sql("""
    SELECT 
        ST_ConvexHull(
            ST_Collect(ST_Point(longitude, latitude))
        ) as cities_convex_hull,
        COUNT(*) as total_cities,
        SUM(population) as total_population
    FROM cities
    WHERE population > 2000000
""")

print("Convex hull and statistics for major cities:")
convex_hull_df.show(truncate=False)

In [None]:
# Performance optimization: cache intermediate results
def cache_spatial_dataframe(df, name: str):
    """Cache spatial dataframe for performance."""
    df.cache()
    count = df.count()
    logger.info(f"Cached {name} with {count} rows")
    return df

# Cache the spatial join result for further analysis
cached_spatial_join = cache_spatial_dataframe(spatial_join_df, "spatial_join_result")

## Spatial Analysis Best Practices

1. **Spatial Joins**: Use `ST_Contains`, `ST_Intersects`, `ST_Within` for relationships
2. **Buffering**: `ST_Buffer()` creates circular zones around geometries  
3. **Aggregations**: `ST_ConvexHull`, `ST_Union` for combining geometries
4. **Performance**: Cache intermediate results for complex workflows
5. **Validation**: Always validate geometries before spatial operations