# Dask Geohash Sorted

import logging
import time
from datetime import datetime
from pathlib import Path
from shapely.geometry import Polygon, box
from polygon_geohasher.polygon_geohasher import polygon_to_geohashes, geohashes_to_polygon
import geohash
from functools import reduce

import numpy as np
import pandas as pd
import geopandas as gpd
import dask.dataframe as dd
from distributed import LocalCluster, Client

cluster = LocalCluster(#silence_logs=logging.ERROR,
                       dashboard_address=':8790',
                       n_workers=2,
                       threads_per_worker=2,
                       memory_limit='5 GB')
client = Client(cluster)
client

base_path = Path('../../')
contiguous_us_bounding_box = box(-124.848974, 24.396308, -66.885444, 49.384358)

# load contiguous us data
df = dd.read_parquet(base_path / 'data/contiguous_us_sorted_geohash4.parquet')
df.head(2)

%%time
# Save various size subsets of the zip code data
zips_1 = gpd.read_file(base_path / f'data/zip_codes/zips_1.geojson').loc[:, ['geometry']]
zips_10 = gpd.read_file(base_path / f'data/zip_codes/zips_10.geojson').loc[:, ['geometry']]
zips_100 = gpd.read_file(base_path / f'data/zip_codes/zips_100.geojson').loc[:, ['geometry']]
zips_1000 = gpd.read_file(base_path / f'data/zip_codes/zips_1000.geojson').loc[:, ['geometry']]
zips_10000 = gpd.read_file(base_path / f'data/zip_codes/zips_10000.geojson').loc[:, ['geometry']]

# Point in Polygon Test

# filter function
def spatial_join(large_data_df, zip_codes_gdf):
    if large_data_df.empty:
        print('empty')
    crs = "epsg:4326"
    large_data_gdf = gpd.GeoDataFrame(large_data_df,
                                      geometry=gpd.points_from_xy(large_data_df.longitude,
                                                                  large_data_df.latitude),
                                      crs=crs)
    rdf = gpd.sjoin(large_data_gdf, zip_codes_gdf, how='inner', op='within').drop(['index_right'], axis=1)
    if rdf.empty:
        print(rdf.columns)
        return dd.from_pandas(pd.DataFrame([], columns=['latitude', 'longitude', 'geometry']), npartitions=1)
    return rdf
#     if not rdf.empty():
#         return rdf
#     else:
#         print("None")
#         return 



num_partitions = df.npartitions
geohash_precision = 4
num_polygons = []
time_sec = []
num_result_points = []
num_points = None

# num_points = len(df.partitions[:num_partitions])
t00 = time.time()
for zip_gdf in [zips_1, zips_10, zips_100]:#, zips_1000, zips_10000, zips_all]:
    num_polygons.append(len(zip_gdf))
    t0 = time.time()
    # convert zip_codes to geohashes
    geohashes = list(zip_gdf.geometry.apply(polygon_to_geohashes, 
                                       precision=geohash_precision,
                                       inner=False)\
                                .agg(lambda x: reduce(set.union, x)))
    
    # get points which match geohashes
    dfs = []
    for geohash in geohashes:
        dfs.append(df.loc[geohash])
    geohash_pts = dd.concat(dfs, axis=0)#.compute()
    
    # do point in polygon for exact_match
    rdf = geohash_pts.map_partitions(spatial_join, zip_codes_gdf=zip_gdf).compute()    
    time_sec.append(time.time() - t0)
    
    num_result_points.append(len(rdf))
    print(f'num_polygons[-1]: {num_polygons[-1]}, time_sec[-1]: {time_sec[-1]:.0f} s')

results_df = pd.DataFrame({'num_polygons': num_polygons,
                           'num_points': num_points,
                           'num_result_points': num_result_points,
                           'sort_time_sec': 0,
                           'time_min': np.asarray(time_sec)/60,
                           'total_points': total_points})
results_df['projected_total_time_hr'] = results_df.time_min*total_points/num_points/60                           
results_df.to_csv(f'{datetime.now()}_unsorted_results_df.csv')
results_df