In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np

from tqdm.notebook import tqdm

In [None]:
if "snakemake" in locals():
    input_passegers_path = snakemake.input["passengers"]
    input_spatial_path = snakemake.input["spatial"]
    output_path = snakemake.output[0]
    zone_attribute = snakemake.params["zone_attribute"]
    seed = snakemake.params["seed"]
else:
    input_passegers_path = "../../../results/belgium/population/discretized_population.parquet"
    input_spatial_path = "../../../results/belgium/census/spatial.parquet"
    output_path = "../../../results/belgium/population/localized_population.parquet"
    zone_attribute = "sector_index"
    seed = 0

In [None]:
# Load passenger data
df_passengers = pd.read_parquet(input_passegers_path)

# Load spatial data
df_sectors = gpd.read_parquet(input_spatial_path)

In [None]:
# Generate locations for sectors
df_locations = df_passengers.groupby(zone_attribute).size().reset_index(name = "count")
df_locations = pd.merge(df_sectors, df_locations, on = zone_attribute, how = "inner")
df_locations["geometry"] = df_locations.sample_points(df_locations["count"] + 1, rng = seed)

In [None]:
for item in df_locations["geometry"].values:
    assert str(type(item)) == "<class 'shapely.geometry.multipoint.MultiPoint'>"

    v = np.array(item.geoms)
    assert 1 == len(v.shape)

    locations = np.array(item.geoms)
    assert len(locations) == len(locations.flatten())

In [None]:
# Assign locations
df_passengers = df_passengers.set_index(zone_attribute)

for sector_id, locations in zip(tqdm(df_locations[zone_attribute].values), df_locations["geometry"].values):
    locations = np.array(locations.geoms)[1:]

    if len(locations) == 1:
        df_passengers.loc[sector_id, "geometry"] = locations[0]
    else:
        df_passengers.loc[sector_id, "geometry"] = list(locations)

df_passengers = gpd.GeoDataFrame(df_passengers, crs = df_sectors.crs, geometry = "geometry")
df_passengers = df_passengers.reset_index()

In [None]:
# Output
df_passengers.to_parquet(output_path)