In [1]:
import xarray as xr
import pandas as pd
import numpy as np
import h3


# List of Zarr paths
zarr_paths = ["data/test/hm.zarr", "data/test/elevation.zarr", "data/test/gdp.zarr"]

#TODO:
#1. convert all vars except z to float32
#2. check z to int32


## check if coord encoding works

In [2]:

import torch
from pe.spherical_harmonics import SphericalHarmonics

lonlat = torch.tensor([[51.9,5.6]], dtype=torch.float32)

sh = SphericalHarmonics(legendre_polys=4)

# use lonlat coordinates, as above
embedded_lonlat = sh(lonlat) 

In [3]:
# Define a function to check and harmonize coordinates
def harmonize_coords(datasets, coord_names=["latitude", "longitude"], decimal_places=5):
    """
    Harmonize specified coordinates across a list of xarray datasets.
    
    Parameters:
        datasets (list): List of xarray datasets to harmonize.
        coord_names (list): Names of coordinates to harmonize (default: latitude, longitude).
        decimal_places (int): Decimal places for rounding coordinates during validation (default: 5).
    
    Returns:
        List of harmonized xarray datasets.
    
    Raises:
        ValueError: If coordinates differ significantly after rounding.
    """
    # Use the coordinates from the first dataset as the reference
    ref_ds = datasets[0]
    
    # Round the reference coordinates for comparison
    ref_coords = {
        coord: ref_ds[coord].round(decimal_places).values
        for coord in coord_names if coord in ref_ds.coords
    }
    
    # Harmonize coordinates across all datasets
    for i, ds in enumerate(datasets):
        for coord in coord_names:
            if coord in ds.coords:
                # Round the dataset's coordinates for comparison
                rounded_ds_coords = ds[coord].round(decimal_places).values
                
                # Check if the rounded coordinates match exactly
                if not np.array_equal(ref_coords[coord], rounded_ds_coords):
                    raise ValueError(
                        f"Mismatch in {coord} coordinates for dataset {i}. "
                        f"Reference: {ref_coords[coord][0]} to {ref_coords[coord][-1]}, "
                        f"Dataset: {rounded_ds_coords[0]} to {rounded_ds_coords[-1]}"
                    )
                
                # Replace the coordinates with the reference (ensures exact matching)
                ds = ds.assign_coords({coord: ref_ds[coord].values})
        
        # After handling both coordinates, assign the modified dataset back to the list
        datasets[i] = ds
    
    return datasets


In [4]:
# Open Zarr datasets into xarray objects with specified chunk sizes
chunks = {'latitude': 100, 'longitude': 100, 'time': 1}
datasets = []
for path in zarr_paths:
    # Open dataset with default chunks first to get dimensions
    ds = xr.open_zarr(path)
    # Set chunk size -1 for any dimensions not explicitly specified
    for dim in ds.dims:
        if dim not in chunks:
            chunks[dim] = -1
    # Reopen with specified chunk sizes and drop spatial_ref if present
    ds = xr.open_zarr(path, chunks=chunks).drop_vars('spatial_ref', errors='ignore')
    datasets.append(ds)


In [34]:

# Harmonize coordinates across all datasets
harmonized_datasets = harmonize_coords(datasets)

#merge
ds = xr.merge(harmonized_datasets, join='outer')
# Check that coordinates match first dataset
if (len(ds.latitude) != len(harmonized_datasets[0].latitude) or 
        len(ds.longitude) != len(harmonized_datasets[0].longitude)):
        raise ValueError("Coordinate dimensions do not match reference dataset")

In [35]:
ds = ds.stack(z=('latitude','longitude'))
#create a new coordinate
#we will use this to join back to xr
ds = ds.assign_coords(v=('z', range(len(ds.z))))


In [36]:
#save for joining later
dv = ds.v.reset_index('z')
dv.to_zarr('data/test/hm_coord.zarr')
#dv = dv.set_index(z=['latitude', 'longitude'])
#dv = dv.unstack('z')
#dv

<xarray.backends.zarr.ZarrStore at 0x2d9b38550>

In [37]:
ds = ds.set_index(z='v')

Unnamed: 0,Array,Chunk
Bytes,40.26 MiB,297.42 kiB
Shape,"(9, 9, 130284)","(1, 9, 8460)"
Dask graph,189 chunks in 19 graph layers,189 chunks in 19 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 40.26 MiB 297.42 kiB Shape (9, 9, 130284) (1, 9, 8460) Dask graph 189 chunks in 19 graph layers Data type float32 numpy.ndarray",130284  9  9,

Unnamed: 0,Array,Chunk
Bytes,40.26 MiB,297.42 kiB
Shape,"(9, 9, 130284)","(1, 9, 8460)"
Dask graph,189 chunks in 19 graph layers,189 chunks in 19 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.99 MiB,66.09 kiB
Shape,"(130284,)","(8460,)"
Dask graph,21 chunks in 4 graph layers,21 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 0.99 MiB 66.09 kiB Shape (130284,) (8460,) Dask graph 21 chunks in 4 graph layers Data type float64 numpy.ndarray",130284  1,

Unnamed: 0,Array,Chunk
Bytes,0.99 MiB,66.09 kiB
Shape,"(130284,)","(8460,)"
Dask graph,21 chunks in 4 graph layers,21 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8.95 MiB,66.09 kiB
Shape,"(9, 130284)","(1, 8460)"
Dask graph,189 chunks in 4 graph layers,189 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 8.95 MiB 66.09 kiB Shape (9, 130284) (1, 8460) Dask graph 189 chunks in 4 graph layers Data type float64 numpy.ndarray",130284  9,

Unnamed: 0,Array,Chunk
Bytes,8.95 MiB,66.09 kiB
Shape,"(9, 130284)","(1, 8460)"
Dask graph,189 chunks in 4 graph layers,189 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## create static var df

In [45]:
#keep only time=0 and drop hm overall
hm_s = ds.isel(time=0).sel(band=[b for b in ds.band.values if b != 'AA'])
#turn hm pressues into a dataset
hm_band =hm_s['hm'].to_dataset(dim='band')
#drop hm, and temporal vars from hm_o
hm_s = hm_s.drop_vars(['hm','band','gdp'])

#merge hm pressues with static covariables only if exact match
hm_s = hm_s.merge(hm_band,join='exact')
hm_s = hm_s.drop_vars(['time']).to_dataframe().reset_index()
#hm_s.drop_vars(['time']).to_dask_dataframe()

In [46]:
#encode lon lat

# Initialize SphericalHarmonics
sh = SphericalHarmonics(legendre_polys=3)

# Function to apply SphericalHarmonics to a row
def encode_lat_lon(row):
    lonlat = torch.tensor([[row['latitude'], row['longitude']]], dtype=torch.float32)
    embedded = sh(lonlat).numpy().flatten()
    return pd.Series(embedded, index=[f'sh_{i}' for i in range(len(embedded))])

# Apply the function to each row and join the results back to the original DataFrame
hm_s = hm_s.join(hm_s.apply(encode_lat_lon, axis=1))

#add h3 ind
# Convert points to H3 indices
hm_s['h3_index'] = [h3.latlng_to_cell(lat, lng, 7) 
                    for lat, lng in zip(hm_s['latitude'], hm_s['longitude'])]

hm_s

Unnamed: 0,z,elevation,latitude,longitude,AG,BU,HI,PO,EX,FR,...,sh_0,sh_1,sh_2,sh_3,sh_4,sh_5,sh_6,sh_7,sh_8,h3_index
0,0,0.0,-32.2335,17.5815,,,,,,,...,0.282095,0.248433,-0.147588,-0.393994,-0.447948,-0.167799,-0.229061,0.266116,0.213977,87ad148a0ffffff
1,1,0.0,-32.2335,17.5905,,,,,,,...,0.282095,0.248421,-0.147661,-0.393974,-0.447904,-0.167874,-0.228976,0.266234,0.213956,87ad148a0ffffff
2,2,0.0,-32.2335,17.5995,,,,,,,...,0.282095,0.248408,-0.147735,-0.393955,-0.447859,-0.167949,-0.228890,0.266353,0.213935,87ad148a0ffffff
3,3,0.0,-32.2335,17.6085,,,,,,,...,0.282095,0.248396,-0.147808,-0.393935,-0.447815,-0.168024,-0.228804,0.266472,0.213913,87ad148a6ffffff
4,4,0.0,-32.2335,17.6175,,,,,,,...,0.282095,0.248384,-0.147881,-0.393915,-0.447770,-0.168099,-0.228718,0.266590,0.213892,87ad148a6ffffff
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
130279,130279,0.0,-34.9965,21.3435,,,,,,,...,0.282095,0.261007,-0.177831,-0.372805,-0.445311,-0.212417,-0.190056,0.303402,0.162142,87bc6e312ffffff
130280,130280,0.0,-34.9965,21.3525,,,,,,,...,0.282095,0.260991,-0.177903,-0.372782,-0.445257,-0.212489,-0.189955,0.303506,0.162122,87bc6e312ffffff
130281,130281,0.0,-34.9965,21.3615,,,,,,,...,0.282095,0.260975,-0.177974,-0.372760,-0.445202,-0.212562,-0.189854,0.303609,0.162102,87bc6e3a5ffffff
130282,130282,0.0,-34.9965,21.3705,,,,,,,...,0.282095,0.260959,-0.178045,-0.372737,-0.445147,-0.212634,-0.189753,0.303712,0.162082,87bc6e3a5ffffff


In [47]:

#write to parquet
hm_s.to_parquet("data/static", engine="pyarrow")
#read parquet into pandas df
hm_s = pd.read_parquet("data/static")

hm_s


Unnamed: 0,z,elevation,latitude,longitude,AG,BU,HI,PO,EX,FR,...,sh_0,sh_1,sh_2,sh_3,sh_4,sh_5,sh_6,sh_7,sh_8,h3_index
0,0,0.0,-32.2335,17.5815,,,,,,,...,0.282095,0.248433,-0.147588,-0.393994,-0.447948,-0.167799,-0.229061,0.266116,0.213977,87ad148a0ffffff
1,1,0.0,-32.2335,17.5905,,,,,,,...,0.282095,0.248421,-0.147661,-0.393974,-0.447904,-0.167874,-0.228976,0.266234,0.213956,87ad148a0ffffff
2,2,0.0,-32.2335,17.5995,,,,,,,...,0.282095,0.248408,-0.147735,-0.393955,-0.447859,-0.167949,-0.228890,0.266353,0.213935,87ad148a0ffffff
3,3,0.0,-32.2335,17.6085,,,,,,,...,0.282095,0.248396,-0.147808,-0.393935,-0.447815,-0.168024,-0.228804,0.266472,0.213913,87ad148a6ffffff
4,4,0.0,-32.2335,17.6175,,,,,,,...,0.282095,0.248384,-0.147881,-0.393915,-0.447770,-0.168099,-0.228718,0.266590,0.213892,87ad148a6ffffff
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
130279,130279,0.0,-34.9965,21.3435,,,,,,,...,0.282095,0.261007,-0.177831,-0.372805,-0.445311,-0.212417,-0.190056,0.303402,0.162142,87bc6e312ffffff
130280,130280,0.0,-34.9965,21.3525,,,,,,,...,0.282095,0.260991,-0.177903,-0.372782,-0.445257,-0.212489,-0.189955,0.303506,0.162122,87bc6e312ffffff
130281,130281,0.0,-34.9965,21.3615,,,,,,,...,0.282095,0.260975,-0.177974,-0.372760,-0.445202,-0.212562,-0.189854,0.303609,0.162102,87bc6e3a5ffffff
130282,130282,0.0,-34.9965,21.3705,,,,,,,...,0.282095,0.260959,-0.178045,-0.372737,-0.445147,-0.212634,-0.189753,0.303712,0.162082,87bc6e3a5ffffff


## create target df

In [37]:
#select only hm variable and AA band
hm_pd = ds.sel(band='AA')[['hm']].drop_vars(['longitude','latitude','band'])
#convert to df
#hm_pd = hm_pd.to_dataframe().reset_index()

#write to parquet
hm_pd.to_dask_dataframe().to_parquet("data/hm", engine="pyarrow")
#read parquet into pandas df
hm_pd = pd.read_parquet("data/hm")

hm_pd

Unnamed: 0_level_0,time,z,hm
__null_dask_index__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,1990-01-01,0,
1,1990-01-01,1,
2,1990-01-01,2,
3,1990-01-01,3,
4,1990-01-01,4,
...,...,...,...
648454,2010-01-01,127318,
648455,2010-01-01,127319,
648456,2010-01-01,127320,
648457,2010-01-01,127321,


In [31]:
hm_pd

Unnamed: 0,time,z,hm
0,1990-01-01,0,
1,1990-01-01,1,
2,1990-01-01,2,
3,1990-01-01,3,
4,1990-01-01,4,
...,...,...,...
1172551,2030-01-01,130279,
1172552,2030-01-01,130280,
1172553,2030-01-01,130281,
1172554,2030-01-01,130282,


## create dynamiv var df

In [36]:
#select only hm variable and AA band
hm_dyn = ds['gdp'].drop_vars(['longitude','latitude'])
#convert to df
#hm_dyn = hm_dyn.to_dataframe().reset_index()

#write to parquet
hm_dyn.to_dask_dataframe().to_parquet("data/dynamic", engine="pyarrow")
#read parquet into pandas df
hm_dyn = pd.read_parquet("data/dynamic")

hm_dyn


Unnamed: 0,time,z,gdp
0,1990-01-01,0,0.0
1,1990-01-01,1,0.0
2,1990-01-01,2,0.0
3,1990-01-01,3,0.0
4,1990-01-01,4,0.0
...,...,...,...
1172551,2030-01-01,130279,0.0
1172552,2030-01-01,130280,0.0
1172553,2030-01-01,130281,0.0
1172554,2030-01-01,130282,0.0


## next steps

In [7]:
#steps:
#encode lat lon
#spatial index and summaries

#wrangle
#convert to df
#save



## assign h3 index

In [19]:
#query neightours in radius
def calculate_population_radius_h3(df, radius_km=10):
    """
    Calculate population within radius using H3 hexagons and append result to input df.
    
    Parameters:
    -----------
    df : pandas DataFrame
        DataFrame with columns 'latitude', 'longitude', and 'population'
    radius_km : float
        Approximate radius in kilometers
    
    Returns:
    --------
    pandas DataFrame
        Input DataFrame with added column 'hm_radius_mean' containing mean population within radius
    """
    # Choose H3 resolution based on radius
    # Resolution 9 ≈ 0.105 km²
    # Resolution 8 ≈ 0.737 km²
    # Resolution 7 ≈ 5.161 km²
    
    # Create mapping of H3 index to population
    h3_pop_map = pd.Series(df['hm'].values, index=df['h3_index'])
    h3_pop_map = h3_pop_map.groupby(level=0).mean()
    
    # Calculate k-rings for each point
    # k=2 at res 8 gives roughly 10km radius
    k=1
    
    total_pop = []
    for h3_idx in df['h3_index']:
        # Get neighboring hexagons
        neighbors = h3.grid_disk(h3_idx, k)
        # Sum population in all neighboring hexagons
        pop_sum = np.mean([h3_pop_map.get(n, 0) for n in neighbors])
        total_pop.append(pop_sum)
    
    # Add results as new column
    df['hm_radius_mean'] = total_pop
    return df

#or
#just assign to hex
def aggregate_to_h3(df,h3col = 'h3_index'):
    """
    Aggregate population data to H3 cells at specified resolution and add result back to input df.
    
    Parameters:
    -----------
    df : pandas DataFrame
        DataFrame with columns 'latitude', 'longitude', and 'population'

    Returns:
    --------
    pandas DataFrame
        Input DataFrame with added H3 aggregated values
    """
 
    # Group and calculate mean
    h3_means = df.groupby(h3col)['hm'].mean()

    # Map means back to original dataframe
    df[f'{h3col}_mean'] = df[h3col].map(h3_means)
    
    return df

In [11]:
%%time
aggregate_to_h3(hm_pd)

CPU times: user 122 ms, sys: 4.83 ms, total: 127 ms
Wall time: 128 ms


Unnamed: 0,band,time,hm,latitude,longitude,v,h3_index,hm_h3_mean
0,AA,1990-01-01,,-32.2335,17.5815,0,87ad148a0ffffff,
1,AA,1990-01-01,,-32.2335,17.5905,1,87ad148a0ffffff,
2,AA,1990-01-01,,-32.2335,17.5995,2,87ad148a0ffffff,
3,AA,1990-01-01,,-32.2335,17.6085,3,87ad148a6ffffff,
4,AA,1990-01-01,,-32.2335,17.6175,4,87ad148a6ffffff,
...,...,...,...,...,...,...,...,...
130279,AA,1990-01-01,,-34.9965,21.3435,130279,87bc6e312ffffff,
130280,AA,1990-01-01,,-34.9965,21.3525,130280,87bc6e312ffffff,
130281,AA,1990-01-01,,-34.9965,21.3615,130281,87bc6e3a5ffffff,
130282,AA,1990-01-01,,-34.9965,21.3705,130282,87bc6e3a5ffffff,


In [12]:
%%time
calculate_population_radius_h3(hm_pd)

CPU times: user 2.36 s, sys: 29.6 ms, total: 2.39 s
Wall time: 2.4 s


Unnamed: 0,band,time,hm,latitude,longitude,v,h3_index,hm_h3_mean,hm_radius_mean
0,AA,1990-01-01,,-32.2335,17.5815,0,87ad148a0ffffff,,0.0
1,AA,1990-01-01,,-32.2335,17.5905,1,87ad148a0ffffff,,0.0
2,AA,1990-01-01,,-32.2335,17.5995,2,87ad148a0ffffff,,0.0
3,AA,1990-01-01,,-32.2335,17.6085,3,87ad148a6ffffff,,0.0
4,AA,1990-01-01,,-32.2335,17.6175,4,87ad148a6ffffff,,0.0
...,...,...,...,...,...,...,...,...,...
130279,AA,1990-01-01,,-34.9965,21.3435,130279,87bc6e312ffffff,,0.0
130280,AA,1990-01-01,,-34.9965,21.3525,130280,87bc6e312ffffff,,0.0
130281,AA,1990-01-01,,-34.9965,21.3615,130281,87bc6e3a5ffffff,,0.0
130282,AA,1990-01-01,,-34.9965,21.3705,130282,87bc6e3a5ffffff,,0.0
