In [1]:
""" Filter based on h3 hexagons mapped to polygons

aws emr add-steps --cluster-id <Your EMR cluster id> --steps Type=spark,Name=TestJob,Args=[--deploy-mode,cluster,--master,yarn,--conf,spark.yarn.submit.waitAppCompletion=true,s3a://your-source-bucket/code/pythonjob.py,s3a://your-source-bucket/data/data.csv,s3a://your-destination-bucket/test-output/],ActionOnFailure=CONTINUE
"""

from collections import namedtuple
import logging
import sys

from sedona.register import SedonaRegistrator  
from sedona.utils import SedonaKryoRegistrator, KryoSerializer
from pyspark.sql.functions import udf
from sedona.utils.adapter import Adapter
#from sedona.core.formatMapper.geojsonReader import GeoJsonReader
from sedona.core.formatMapper.shapefileParser import ShapefileReader
from sedona.core.SpatialRDD import PointRDD, SpatialRDD, CircleRDD
from sedona.sql.types import GeometryType
from sedona.core.enums import GridType
from sedona.core.spatialOperator import JoinQueryRaw
from sedona.core.spatialOperator import JoinQuery
from sedona.core.enums import IndexType
from sedona.core.formatMapper.disc_utils import load_spatial_rdd_from_disc, GeoType
from sedona.core.formatMapper import WktReader, GeoJsonReader



import pyproj

from geopy.distance import great_circle
import pandas as pd
import geopandas as gpd
import os

from datetime import timedelta, date, datetime
from statistics import *

from pyspark import SparkContext

from pyspark.sql import SQLContext, SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    LongType,
    StructField,
    IntegerType,
    StringType,
    DoubleType,
    TimestampType,
    ArrayType
)
from pyspark.sql.functions import (
    from_utc_timestamp,
    to_utc_timestamp,
    dayofyear,
    col,
    unix_timestamp,
    monotonically_increasing_id,
    pandas_udf,
    PandasUDFType,
    col,
    asc,
    lit,
    countDistinct,
)
import pyspark.sql.functions as F
from math import *
import time

from shapely.wkt import loads as wkt_loads
from shapely.geometry import Point, Polygon, shape
from shapely.ops import transform
import shapely

#from timezonefinder import TimezoneFinder

spark = (SparkSession.builder.appName("sedona")
                 .config("spark.serializer", KryoSerializer.getName)          
        .config("spark.kryo.registrator",     
                  SedonaKryoRegistrator.getName)    
         .config("spark.driver.maxResultSize", "3g")
    .getOrCreate() 
        )


# Register Sedona UDTs and UDFs
SedonaRegistrator.registerAll(spark)

spark.sparkContext.addPyFile("s3://ipsos-dvd/scripts/utils.py")
from utils import *
import h3_pyspark as h3s
#import h3pandas 
import h3 as h3



@udf("boolean")
def pip_filter(poly_wkt, point_x, point_y):
    from shapely import wkt
    from shapely import geometry
    polygon = wkt.loads(poly_wkt)
    point = geometry.Point(point_x, point_y)
    return polygon.contains(point)

schema = StructType([
    StructField("dirty", ArrayType(StringType()), False),
    StructField("hexes", ArrayType(StringType()), False)
])


def parse_dates(x):
    if "/" in x:
        start_date = datetime.strptime(x.split('/')[0], "%Y-%m-%d")
        end_date = datetime.strptime(x.split('/')[1], "%Y-%m-%d")
        delta = end_date - start_date
        date_list = []
        for i in range(delta.days + 1):
            date = start_date + timedelta(days = i)
            date_list.append(date.strftime("%Y/%m/%d"))
        return(date_list)
    else: 
        return([x.replace("-", "/")])
    
def create_point(longitude: float, latitude: float):
    return Point(longitude, latitude)

create_point_udf = udf(create_point, GeometryType())


def create_polygon(wkt: str):
    return wkt_loads(wkt)

create_polygon_udf = udf(create_polygon, GeometryType())

def transform_geometry(geom, crs_from = 'EPSG:4326', crs_to = 'EPSG:9311'):
    wgs84 = pyproj.CRS(crs_from)
    utm = pyproj.CRS(crs_to)

    project = pyproj.Transformer.from_crs(wgs84, utm, always_xy=True).transform

    # Ensure that the input geometry is a shapely geometry object
    if not isinstance(geom, (shapely.geometry.base.BaseGeometry, shapely.geometry.base.BaseMultipartGeometry)):
        geom = shape(geom)

    utm_point = transform(project, geom)
    
    return utm_point

transform_geometry_udf = udf(transform_geometry, GeometryType())


def shared_polygon(long, lat):
    return Point(long, lat).buffer(35)
shared_polygon_udf = udf(shared_polygon, GeometryType())

def buffer(geom, meters):
    return geom.buffer(meters)
buffer_udf = udf(buffer, GeometryType())

def parse_dates(x):
    if "/" in x:
        start_date = datetime.strptime(x.split('/')[0], "%Y-%m-%d")
        end_date = datetime.strptime(x.split('/')[1], "%Y-%m-%d")
        delta = end_date - start_date
        date_list = []
        for i in range(delta.days + 1):
            date = start_date + timedelta(days = i)
            date_list.append(date.strftime("%Y/%m/%d"))
        return(date_list)
    else: 
        return([x.replace("-", "/")])


def time_zone(long, lat):
    #return tzwhere.tzwhere().tzNameAt(lat, long)
    tzf = TimezoneFinder()
    return tzf.timezone_at(lng=long, lat=lat)

time_zone_udf = udf(time_zone, StringType())

#------
# parameters
#------

data_dir = "s3://external-safegraph/"
data_dyn = "s3://ipsos-dvd/dyn/data/"
data_veraset = "s3://external-veraset-data-us-west-2/us/"




VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
9,application_1679926457970_0010,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
date = "2022/06/01"
path = data_veraset + date
redo_points = False

if redo_points:
    essential_fields = [
            StructField("utc_timestamp",LongType(),False),
            StructField("caid",StringType(),False),
            StructField("latitude",DoubleType(),False),
            StructField("longitude",DoubleType(),False),
            StructField("altitude",DoubleType(),False),
    ]
    raw_schema = StructType(
        essential_fields + [
            StructField("id_type",StringType(),False),
            StructField("geo_hash",StringType(),False),
            StructField("horizontal_accuracy",DoubleType(),False),
            StructField("ip_address",StringType(),False),
            #StructField("altitude",DoubleType(),False),
            StructField("iso_country_code",StringType(),False)]
    )
    pings = spark.read.schema(raw_schema).parquet(path).select("latitude", "longitude", "caid", "utc_timestamp")# .limit(1000)
    pings.createOrReplaceTempView("pings")

    # Read Hive table
    pings = spark.sql(
          """SELECT ST_Point(cast(pings.longitude as Decimal(24,20)), cast(pings.latitude as Decimal(24,20))) AS point, 
          utc_timestamp, caid
          FROM pings;
          """
    )
    pings.write.mode("overwrite").parquet(data_dyn + "pings/" + date)


pings = spark.read.parquet(data_dyn + "pings/" + date) #.limit(1000000)
pings = pings.repartition(10000, "caid")


#                 .withColumn("date", # filter on nighttime hours
#                     from_utc_timestamp(
#                         col("utc_timestamp").cast(dataType=TimestampType()),
#                         col('UTC_OFFSET'),
#                     )
#                 )
#                 .withColumn('hour', F.hour(col('date')))
#                 .filter((col('hour').between(20,24)) | (col('hour').between(0,7))) # between 8pm and 7am
#                 )
    

In [5]:
fn = "s3://ipsos-dvd/data/Time_Zones/"

# timezone shapefile to hive
test = gpd.read_file("s3://ipsos-dvd/data/Time_Zones/")
test = test.to_crs("epsg:4326")[['utc', 'geometry']]
shpf = spark.createDataFrame(test)
shpf = shpf.withColumn("utc", F.concat(lit("UTC"), F.col("utc")))
shpf = shpf.cache()

# cbg shapefile to hive
fn = "s3://ipsos-dvd/ev/data/2020_cbgs/"
polygons_rdd = ShapefileReader.readToGeometryRDD(sc, fn)
polygons_rdd = (Adapter.toDf(polygons_rdd, polygons_rdd.fieldNames, spark))

polygons_rdd = polygons_rdd.repartition(5000)

# filter out zero population CBGs
pop = spark.read.csv(os.path.join(data_dyn, 'census', 'census_wrangle.csv.gz'), header = True).select("census_block_group", "pop")
polygons_rdd = polygons_rdd.join(pop, pop.census_block_group == polygons_rdd.CensusBloc, how = "left")
polygons_rdd = polygons_rdd.filter(F.col("pop") > 0)
polygons_rdd = polygons_rdd.drop("census_block_group", "pop")#.createOrReplaceTempView("polygons_rdd")
polygons_rdd = polygons_rdd.cache()



VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
# use spatial partitioning
grid_type = GridType.KDBTREE


START_DATE = "2022-06-02"
END_DATE = "2022-06-30"
dates = [START_DATE + "/" + END_DATE]

# set up date ranges in path form
datelist = []

for arg in dates:
    temp = parse_dates(arg)
    datelist.extend(temp)
# create list of s3 folders for dates to get
base_bucket = f"s3://external-veraset-data-us-west-2/us/" # f"s3://external-veraset-data-us-west-2/us/" # f"s3://external-veraset-data-us-west-2/movement/"
#base_bucket = f"/home/antonvocalis/ipsos/data/in/veraset/"

for date in datelist:
    path = data_veraset + date
    redo_points = False

    if redo_points:
        essential_fields = [
                StructField("utc_timestamp",LongType(),False),
                StructField("caid",StringType(),False),
                StructField("latitude",DoubleType(),False),
                StructField("longitude",DoubleType(),False),
                StructField("altitude",DoubleType(),False),
        ]
        raw_schema = StructType(
            essential_fields + [
                StructField("id_type",StringType(),False),
                StructField("geo_hash",StringType(),False),
                StructField("horizontal_accuracy",DoubleType(),False),
                StructField("ip_address",StringType(),False),
                #StructField("altitude",DoubleType(),False),
                StructField("iso_country_code",StringType(),False)]
        )
        pings = spark.read.schema(raw_schema).parquet(path).select("latitude", "longitude", "caid", "utc_timestamp")# .limit(1000)
        pings.createOrReplaceTempView("pings")

        # Read Hive table
        pings = spark.sql(
              """SELECT ST_Point(cast(pings.longitude as Decimal(24,20)), cast(pings.latitude as Decimal(24,20))) AS point, 
              utc_timestamp, caid
              FROM pings;
              """
        )
        pings.write.mode("overwrite").parquet(data_dyn + "pings/" + date)


    pings = spark.read.parquet(data_dyn + "pings/" + date).sample(fraction = 0.25)  #.limit(1000000)
    
#     # Define the number of salt keys and the number of output partitions
#     num_salt_keys = 1000
    num_partitions = 10000

#     # Choose the original column to repartition on
#     repartition_column = "caid"

#     # Create a new salted repartition column by concatenating the original column with a random salt key
#     pings = pings.withColumn("salted_repartition_column", F.concat(F.col(repartition_column), (F.rand() * num_salt_keys).cast("int")))

    # Repartition the data using the salted repartition column
    pings = pings.repartition(num_partitions, "caid")

#     # broadcast join
#     pings = pings.alias("pings").join(F.broadcast(shpf).alias("shpf"), F.expr(
#       f"""ST_Within(pings.point, shpf.geometry) """
#     )).select("pings.utc_timestamp", "pings.caid", "pings.point", "shpf.utc")


#     pings = pings.cache()

#     #     pings.createOrReplaceTempView("pings")

#     #     # Read Hive table
#     #     pings = spark.sql(
#     #           """SELECT latitude, longitude, caid, utc_timestamp, point, UTC_OFFSET
#     #             FROM pings, shpf
#     #             WHERE ST_Within(pings.point, shpf.geometry)
#     #           """
#     #     )
#     pings = (pings.withColumn("date", # filter on nighttime hours
#         from_utc_timestamp(
#             col("utc_timestamp").cast(dataType=TimestampType()),
#             col('utc'),
#         )
#     )
#     .withColumn('hour', F.hour(col('date')))
#     .filter((col('hour') >= 20) | (col('hour') < 7))
#     .drop("utc", "date", "hour")
#     )
#     pings = pings.cache()
    
    #pings.createOrReplaceTempView("pings")

#     pings = pings.alias("pings").join(F.broadcast(polygons_rdd).alias("polygons_rdd"), F.expr(
#       f"""ST_Within(pings.point, polygons_rdd.geometry)"""
#     )).select("pings.utc_timestamp", "pings.caid", "polygons_rdd.CensusBloc")
    
    

#     pings = spark.sql(
#         """
#         SELECT caid, utc_timestamp, point, CensusBloc
#             FROM pings, polygons_rdd
#             WHERE ST_Within(pings.point, polygons_rdd.geometry)
#         """
#     )
#     pings.write.mode("overwrite").parquet(os.path.join(data_dyn, 'pings_homes', date))

    
    #pings = pings.cache()
    # convert to spatial rdds
    points_rdd = Adapter.toSpatialRdd(pings, "point")#.transform('epsg:4326', 'epsg:9311')
    

    points_rdd.analyze()
    points_rdd.spatialPartitioning(grid_type)

#     fn = "s3://ipsos-dvd/ev/data/2020_cbgs/"
#     polygons_rdd = ShapefileReader.readToGeometryRDD(sc, fn)
    poly_rdd = Adapter.toSpatialRdd(polygons_rdd, "geometry")#.transform('epsg:4326', 'epsg:9311')


    poly_rdd.analyze()
    poly_rdd.spatialPartitioning(points_rdd.getPartitioner())

    build_on_spatial_partitioned_rdd = True ## Set to TRUE only if run join query
    using_index = True
    points_rdd.buildIndex(IndexType.QUADTREE, build_on_spatial_partitioned_rdd)
    
    # spatial join
    result = JoinQueryRaw.SpatialJoinQueryFlat(points_rdd,poly_rdd, using_index, True)

    # (result.indexedRawRDD.saveAsObjectFile("hdfs://PATH") #saveAsTextFile(data_dyn + "temp/spatial_join_test")
    # )
    # (result.write
    #   .option("header", "true")
    #   .csv(data_dyn + "spatial_join_test")
    # )
    (Adapter.toDf(result, poly_rdd.fieldNames, points_rdd.fieldNames, spark)#.show(5)
        .write.mode("overwrite").parquet(os.path.join(data_dyn, 'pings_homes', date))
    ) # this is currently slow, may need to remove unrealistically large polygons?
    
    pings.unpersist()


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…