In [1]:
""" Filter based on h3 hexagons mapped to polygons

aws emr add-steps --cluster-id <Your EMR cluster id> --steps Type=spark,Name=TestJob,Args=[--deploy-mode,cluster,--master,yarn,--conf,spark.yarn.submit.waitAppCompletion=true,s3a://your-source-bucket/code/pythonjob.py,s3a://your-source-bucket/data/data.csv,s3a://your-destination-bucket/test-output/],ActionOnFailure=CONTINUE
"""

from collections import namedtuple
import logging
import sys

from geopy.distance import great_circle
import pandas as pd
import geopandas as gpd
import os

from datetime import timedelta, date, datetime
from statistics import *

from pyspark import SparkContext

from pyspark.sql import SQLContext, SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    LongType,
    StructField,
    IntegerType,
    StringType,
    DoubleType,
    TimestampType,
    ArrayType
)
from pyspark.sql.functions import (
    from_utc_timestamp,
    to_utc_timestamp,
    dayofyear,
    col,
    unix_timestamp,
    monotonically_increasing_id,
    pandas_udf,
    PandasUDFType,
    col,
    asc,
    lit,
    countDistinct,
)
import pyspark.sql.functions as F
from math import *
import time

from shapely import wkt

spark = SparkSession.builder.appName(f"demographics_ev").getOrCreate()
spark.sparkContext.addPyFile("s3://ipsos-dvd/scripts/utils.py")
from utils import *
import h3_pyspark as h3s
import h3pandas 
import h3 as h3

data_dir = "s3://external-safegraph/"
data_dyn = "s3://ipsos-dvd/dyn/data/"
data_veraset = "s3://external-veraset-data-us-west-2/us/"


@udf("boolean")
def pip_filter(poly_wkt, point_x, point_y):
    from shapely import wkt
    from shapely import geometry
    polygon = wkt.loads(poly_wkt)
    point = geometry.Point(point_x, point_y)
    return polygon.contains(point)

schema = StructType([
    StructField("dirty", ArrayType(StringType()), False),
    StructField("hexes", ArrayType(StringType()), False)
])


def parse_dates(x):
    if "/" in x:
        start_date = datetime.strptime(x.split('/')[0], "%Y-%m-%d")
        end_date = datetime.strptime(x.split('/')[1], "%Y-%m-%d")
        delta = end_date - start_date
        date_list = []
        for i in range(delta.days + 1):
            date = start_date + timedelta(days = i)
            date_list.append(date.strftime("%Y/%m/%d"))
        return(date_list)
    else: 
        return([x.replace("-", "/")])


VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
3,application_1678151895459_0004,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Join Places and Spend and Filter on Grocery Stores

In [12]:
schema = StructType([
    StructField("placekey", StringType(), True),
    StructField("parent_placekey", StringType(), True),
    StructField("safegraph_brand_ids", StringType(), True),
    StructField("location_name", StringType(), True),
    StructField("brands", StringType(), True),
    StructField("store_id", StringType(), True),
    StructField("top_category", StringType(), True),
    StructField("sub_category", StringType(), True),
    StructField("naics_code", StringType(), True),
    StructField("latitude", StringType(), True),
    StructField("longitude", StringType(), True),
    StructField("street_address", StringType(), True),
    StructField("city", StringType(), True),
    StructField("region", StringType(), True),
    StructField("postal_code", StringType(), True),
    StructField("open_hours", StringType(), True),
    StructField("category_tags", StringType(), True),
    StructField("opened_on", StringType(), True),
    StructField("closed_on", StringType(), True),
    StructField("tracking_closed_since", StringType(), True),
    StructField("websites", StringType(), True),
    StructField("geometry_type", StringType(), True),
    StructField("polygon_wkt", StringType(), True),
    StructField("polygon_class", StringType(), True),
    StructField("enclosed", StringType(), True),
    StructField("phone_number", StringType(), True),
    StructField("is_synthetic", StringType(), True),
    StructField("includes_parking_lot", StringType(), True),
    StructField("iso_country_code", StringType(), True),
    StructField("wkt_area_sq_meters", StringType(), True)
])

places = (sqlContext.read.format('com.databricks.spark.csv')
         .options(header='true', inferschema='true')
         .option('escape','"') # this is necessary because the fields contain "," as well but dont have quotes around them
         .load(data_dir + "places/*.gz"))
places = places.cache()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [13]:
schema = StructType([
    StructField("placekey", StringType(), True),
    StructField("safegraph_brand_ids", StringType(), True),
    StructField("brands", StringType(), True),
    StructField("spend_date_range_start", StringType(), True),
    StructField("spend_date_range_end", StringType(), True),
    StructField("raw_total_spend", StringType(), True),
    StructField("raw_num_transactions", StringType(), True),
    StructField("raw_num_customers", StringType(), True),
    StructField("median_spend_per_transaction", StringType(), True),
    StructField("median_spend_per_customer", StringType(), True),
    StructField("spend_per_transaction_percentiles", StringType(), True),
    StructField("spend_by_day", StringType(), True),
    StructField("spend_per_transaction_by_day", StringType(), True),
    StructField("spend_by_day_of_week", StringType(), True),
    StructField("day_counts", StringType(), True),
    StructField("spend_pct_change_vs_prev_month", StringType(), True),
    StructField("spend_pct_change_vs_prev_year", StringType(), True),
    StructField("online_transactions", StringType(), True),
    StructField("online_spend", StringType(), True),
    StructField("transaction_intermediary", StringType(), True),
    StructField("spend_by_transaction_intermediary", StringType(), True),
    StructField("bucketed_customer_frequency", StringType(), True),
    StructField("mean_spend_per_customer_by_frequency", StringType(), True),
    StructField("bucketed_customer_incomes", StringType(), True),
    StructField("mean_spend_per_customer_by_income", StringType(), True),
    StructField("customer_home_city", StringType(), True),
    StructField("related_cross_shopping_physical_brands_pct", StringType(), True),
    StructField("related_cross_shopping_online_merchants_pct", StringType(), True),
    StructField("related_cross_shopping_same_category_brands_pct", StringType(), True),
    StructField("related_cross_shopping_local_brands_pct", StringType(), True),
    StructField("related_wireless_carrier_pct", StringType(), True),
    StructField("related_streaming_cable_pct", StringType(), True),
    StructField("related_delivery_service_pct", StringType(), True),
    StructField("related_rideshare_service_pct", StringType(), True),
    StructField("related_buynowpaylater_service_pct", StringType(), True),
    StructField("related_payment_platform_pct", StringType(), True)
])
spend = (sqlContext.read.format('com.databricks.spark.csv')
         .options(header='true', inferschema='true')
         .option('escape','"') # this is necessary because the fields contain "," as well but dont have quotes around them
         .load(data_dir + "spend/*.gz"))
spend = spend.cache()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [14]:
colskeep = [x for x in places.columns if x not in spend.columns or x == "placekey"]
places = places.select(colskeep)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [17]:
spend = spend.join(places, on = "placekey", how = "left")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [23]:
# filter on grocery stores
category_list =  ["Grocery Stores", 'Grocery and Related Product Merchant Wholesalers', 
              "General Merchandise Stores, including Warehouse Clubs and Supercenters", 
              ]

spend = spend.filter(F.col("top_category").isin(category_list))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [28]:
spend.write.mode("overwrite").parquet(os.path.join(data_dyn, "spend_places"))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Find Veraset devices visiting shops

In [2]:
resolution = 11

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
spend = spark.read.parquet(os.path.join(data_dyn, "spend_places"))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
polys = spend.select("placekey", "polygon_wkt").dropDuplicates(["polygon_wkt"])

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
test = h3.geo_to_h3(37.3615593, -122.0553238, 12)
test

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

'8c283470d921dff'

In [8]:
# Add a new salt column with random values between 0 and 100
polys = polys.withColumn("salt", (F.rand() * 1000000).cast("integer"))

# Repartition the dataframe using the salt column and the desired number of partitions
num_partitions = 10000
polys = polys.repartition(num_partitions, "salt")

# Drop the salt column from the dataframe
polys = polys.drop("salt")
polys = polys.cache()
polys.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+--------------------+
|           placekey|         polygon_wkt|
+-------------------+--------------------+
|223-223@5vg-3sp-w49|POLYGON ((-122.00...|
|222-222@63v-6px-2x5|POLYGON ((-81.839...|
|223-226@63g-5td-d35|POLYGON ((-80.034...|
|zzw-225@5pr-2zg-hdv|POLYGON ((-94.441...|
|226-222@5xd-qfh-z9f|POLYGON ((-122.99...|
|225-222@5py-swz-bc5|POLYGON ((-87.161...|
|zzw-222@8g8-3xy-5fz|POLYGON ((-84.684...|
|225-222@63s-9vc-k75|POLYGON ((-76.586...|
|222-222@63t-rpf-w8v|POLYGON ((-80.360...|
|zzw-222@62k-f4f-6rk|POLYGON ((-71.705...|
|zzw-222@5st-562-g49|POLYGON ((-103.81...|
|222-223@5q9-bqt-8jv|POLYGON ((-104.99...|
|222-226@63g-46t-435|POLYGON ((-80.087...|
|222-222@5qz-n2g-b8v|POLYGON ((-100.80...|
|222-223@8fz-5j5-f2k|POLYGON ((-80.235...|
|zzy-222@5py-882-8sq|POLYGON ((-85.710...|
|zzw-224@5px-7dy-p35|POLYGON ((-85.686...|
|229-222@5z8-p2f-dsq|POLYGON ((-119.02...|
|223-225@5z5-wkf-vcq|POLYGON ((-117.08...|
|zzy-222@8gg-sy3-xdv|POLYGON ((-82.257...|
+----------

In [8]:
polys.rdd.getNumPartitions()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

1000

In [5]:
polys.count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

122333

In [3]:
# get h3 hexes for store polygons

redo_h3 = False

if redo_h3:
    def getAllHexes(polygon, resolution):
        import h3pandas
        import h3 as h3
#         #import pyspark.sql.functions as F
        from shapely import wkt
        #resolution = 7

        polygon = wkt.loads(polygon)

        centers = gpd.GeoDataFrame(geometry = gpd.GeoSeries(polygon), crs = 'epsg:4326').h3.polyfill_resample(resolution)

        if not centers.empty:
            centers['dirty'] = list(~centers.apply(lambda x: polygon.contains(x['geometry']), axis = 1))
            queue = list(centers[centers.dirty == True].index)
            dirty = list(centers.dirty)
            hexes = list(centers.index)
        else: # not all cbgs have a hexagon's center contained in them, in that case get the hex covering their centroid
            #return [[""], [""]]
            hexes = [h3.geo_to_h3(polygon.centroid.x,polygon.centroid.y, resolution)]
            queue = hexes.copy()
            dirty = [True]
        while queue:
            idx = queue.pop()
            for neighbor in h3.k_ring(idx, 1):
                if not neighbor in hexes:
                    if polygon.intersects(Polygon(h3.h3_to_geo_boundary(neighbor, geo_json = True))):
                        dirty.append(True)
                        hexes.append(neighbor)
                        queue.append(neighbor)
        return [dirty, hexes]

    schema = StructType([
        StructField("dirty", ArrayType(StringType()), False),
        StructField("hexes", ArrayType(StringType()), False)
    ])
    getAllHexes_udf = F.udf(getAllHexes,schema)  

    polys = polys.withColumn("resolution", lit(resolution).cast(IntegerType())).withColumn('new', getAllHexes_udf('polygon_wkt', 'resolution')).cache().drop("resolution")

    polys = (polys.withColumn('new', F.explode(F.arrays_zip(F.col('new.dirty').alias('dirty'), F.col('new.hexes').alias('h3'))))
               .withColumn('dirty', F.col('new.dirty')).withColumn('h3', F.col('new.hexes')).drop('new', 'index', 'geometry')
    )
    polys.write.option("header","true").mode("overwrite").parquet(data_dyn + "poly_h3")

else:
    polys = spark.read.parquet(data_dyn + "poly_h3")


polys = polys.withColumn("dirty", col("dirty") == "true").select("polygon_wkt", "placekey", "dirty", "h3")



VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
# set up date ranges in path form
datelist = []

dates = ["2022-06-02" + "/" + "2022-06-30", "2022-09-01" + "/" + "2022-09-30"]

for arg in dates:
    temp = parse_dates(arg)
    datelist.extend(temp)



VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
polys = polys.cache()

count = 0
for date in datelist:
    
    print(date)

    path = data_veraset + date
    start_date = end_date = "-".join(date.split("/"))

    essential_fields = [
            StructField("utc_timestamp",LongType(),False),
            StructField("caid",StringType(),False),
            StructField("latitude",DoubleType(),False),
            StructField("longitude",DoubleType(),False),
            StructField("altitude",DoubleType(),False),
    ]
    raw_schema = StructType(
        essential_fields + [
            StructField("id_type",StringType(),False),
            StructField("geo_hash",StringType(),False),
            StructField("horizontal_accuracy",DoubleType(),False),
            StructField("ip_address",StringType(),False),
            #StructField("altitude",DoubleType(),False),
            StructField("iso_country_code",StringType(),False)]
    )
    pings = spark.read.schema(raw_schema).parquet(path).select("latitude", "longitude", "caid")
    
    pings = (pings.withColumn('resolution', lit(resolution)).withColumn('h3', h3s.geo_to_h3('latitude', 'longitude', "resolution"))
                                    .drop('resolution')
                .join(polys, 
                      on = "h3", how = "inner") 
                .where((~F.col("dirty")) | (pip_filter("polygon_wkt", "longitude", "latitude")))
                .select("caid").dropDuplicates(["caid"])
                .withColumn("start_date", lit(start_date))
       )

    

    pings.write.mode("overwrite").parquet(os.path.join(data_dyn, "caids_shops/" + start_date))
    
    pings.unpersist()
    
    

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [2]:
test = spark.read.parquet("s3://ipsos-dvd/dyn/data/caids_shops/2022-06-01/")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
test.show()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----+----------+
|caid|start_date|
+----+----------+
+----+----------+