In [None]:
import pandas as pd
import geopandas as gpd
import contextily as ctx
import matplotlib.pyplot as plt
import h3
from libpysal.weights import W

In [None]:
def plot_df(df, column=None, ax=None, add_basemap=True):
    "Plot based on the `geometry` column of a GeoPandas dataframe"
    df = df.copy()
    df = df.to_crs(epsg=3857)  # web mercator

    if ax is None:
        _, ax = plt.subplots(figsize=(8,8))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    df.plot(
        ax=ax,
        alpha=0.25, edgecolor='k',
        column=column, categorical=True,
        legend=True, legend_kwds={'loc': 'upper left'},
    )
    if add_basemap:
        ctx.add_basemap(ax, crs=df.crs, source=ctx.providers.CartoDB.Positron)
def plot_shape(shape, ax=None, add_basemap=True):
    df = gpd.GeoDataFrame({'geometry': [shape]}, crs='EPSG:4326')
    plot_df(df, ax=ax, add_basemap=add_basemap)
def plot_cell(cell, ax=None):
    shape = h3.cells_to_h3shape([cell])
    plot_shape(shape, ax=ax)
    
def plot_cells(cells, ax=None):
    fig, ax = plt.subplots(figsize=(8,8))
    shape = h3.cells_to_h3shape(cells)
    plot_shape(shape, ax=ax, add_basemap=True)
    
    for single_cell in cells:
        single_shape = h3.cells_to_h3shape([single_cell])
        # gdf = gpd.GeoDataFrame({'geometry': [single_shape]}, crs='EPSG:4326')
        # gdf = gdf.to_crs(epsg=3857)
        # gdf.plot(ax=ax, alpha=0.5, edgecolor='k')
        plot_shape(single_shape, ax=ax, add_basemap=False)

def plot_cell_area(cells):
    fig, ax = plt.subplots(figsize=(8,8))
    shape = h3.cells_to_h3shape(cells)
    plot_shape(shape, ax=ax, add_basemap=True)


In [None]:
filename_DD = '../../data/nextbike/trips_DD_with_small_hexids_res10_2025-04-21_11-55-31.csv'
filename_FB = '../../data/nextbike/trips_FB_with_small_hexids_res10_2025-04-21_11-55-31.csv'
df_DD = pd.read_csv(filename_DD, index_col=0)
df_FB = pd.read_csv(filename_FB, index_col=0)

In [None]:
df_DD

In [None]:
def transform_df(df_input):
    df = df_input.copy()
    df['hour'] = pd.to_datetime(df['datetime_rent']).dt.hour
    df['weekday'] = pd.to_datetime(df['datetime_rent']).dt.weekday
    df = df[df.weekday<=4]
    df["hour_interval"] = df['hour'] # pd.cut(df["hour"], bins=[0, 7, 12, 15, 20, 24], labels=["0-6", "7-11", "12-14", "15-19", "20-23"], right=False)
    df_grouped = df.groupby(["small_hex_id_rent", 'hour_interval']).size()
    df_grouped = df_grouped.reset_index(name='count_rent')
    df_grouped = df_grouped.pivot(index='small_hex_id_rent', columns='hour_interval', values='count_rent').fillna(0).astype(int)
    return df_grouped
    


In [None]:
def min_max_scale(x):
    denom = (x.max() - x.min())
    if denom == 0:
        return x
    else:
        return (x - x.min()) / denom


In [None]:
df_DD_grouped = transform_df(df_DD)

df_DD_grouped_scaled = df_DD_grouped.apply(lambda x: min_max_scale(x), axis=1)
df_DD_grouped_scaled["total_count"] = df_DD_grouped.apply(sum, axis=1)

df_DD_grouped_scaled.total_count = min_max_scale(df_DD_grouped_scaled.total_count)

In [None]:
df_FB_grouped = transform_df(df_FB)

df_FB_grouped_scaled = df_FB_grouped.apply(lambda x: min_max_scale(x), axis=1)
df_FB_grouped_scaled["total_count"] = df_FB_grouped.apply(sum, axis=1)

df_FB_grouped_scaled.total_count = min_max_scale(df_FB_grouped_scaled.total_count)

In [None]:
df_FB_grouped_scaled

In [None]:
len(df_DD_grouped_scaled)

# add missing hex IDs and fill them with 0

In [None]:
def add_missing_hex_ids(df_grouped_scaled_input, df_input):
    df_tmp = df_grouped_scaled_input.copy()
    print(len(df_tmp))

    existing_hex_ids = df_tmp.index.tolist()
    to_add= set(df_input.loc[~df_input.small_hex_id_rent.isin(existing_hex_ids)].small_hex_id_rent.dropna().unique().tolist())
    len(to_add)
    to_add =  to_add | set(df_input.loc[~df_input.small_hex_id_return.isin(existing_hex_ids)].small_hex_id_return.dropna().unique().tolist())
    len(to_add)
    rows_to_add = pd.DataFrame(columns=df_tmp.columns, index=list(to_add))
    rows_to_add.fillna(0, inplace=True)
    df_tmp= pd.concat([df_tmp, rows_to_add], axis=0)

    print(len(df_tmp))

    neighbors_to_add = []
    existing_hex_ids = df_tmp.index.tolist()
    for cell in existing_hex_ids:
        neighbors = h3.grid_ring(cell, 1)
        for neighbor in neighbors:
            if neighbor not in existing_hex_ids:
                neighbors_to_add.append(neighbor)

    neighbors_to_add = set(neighbors_to_add)

    neighbor_rows_to_add = pd.DataFrame(columns=df_tmp.columns, index=list(neighbors_to_add))
    neighbor_rows_to_add.fillna(0, inplace=True)
    df_tmp= pd.concat([df_tmp, neighbor_rows_to_add], axis=0)

    print(len(df_tmp))

    return df_tmp


In [None]:
df_DD_grouped_scaled = add_missing_hex_ids(df_DD_grouped_scaled, df_DD)

In [None]:
df_FB_grouped_scaled = add_missing_hex_ids(df_FB_grouped_scaled, df_FB)

In [None]:
plot_cell_area(df_FB_grouped_scaled.index.tolist())

In [None]:
plot_cell_area(df_DD_grouped_scaled.index.tolist())

# get cell neighbors

In [None]:
# h3.grid_ring('8928308280fffff', 1)

In [None]:
# plot_cell(("8928308280fffff"))

In [None]:
# plot_cells(h3.grid_ring('8928308280fffff', 1))

In [None]:
existing_hex_ids = set(df_FB_grouped_scaled.index)

In [None]:
len(existing_hex_ids)

In [None]:
# neighbors_dict = {
#     hex_id: [cell for cell in h3.grid_ring(hex_id, 1) if cell in existing_hex_ids]  for hex_id in existing_hex_ids
# }

In [None]:
# for key in neighbors_dict.keys():
#     assert len(neighbors_dict[key]) <= 6

In [None]:
# w = W(neighbors_dict, id_order=sorted(existing_hex_ids))

In [None]:
# check_cell = "8a1f80240937fff" 

In [None]:
# check_cell in existing_hex_ids

In [None]:
# for neighbor in h3.grid_ring(check_cell, 1):
#     print(neighbor in existing_hex_ids)

In [None]:
# len(neighbors_dict)

In [None]:
# neighbors_dict[check_cell]

In [None]:
# counter=0
# for cell in neighbors_dict.keys():
#     if len(neighbors_dict[cell]) ==0:
#         counter += 1

In [None]:
# counter

In [None]:
START_NEIGHBOURS_DISTANCE=5 # from the visual analysis of maps, because there arre "isslands" of several cells

In [None]:
existing_hex_ids = set(df_FB_grouped_scaled.index)

In [None]:
neighbors_dict = {
    hex_id: [cell for cell in h3.grid_ring(hex_id, START_NEIGHBOURS_DISTANCE) if cell in existing_hex_ids]  for hex_id in existing_hex_ids
}

In [None]:
def count_islands(neighbors_dict):
    counter=0
    for cell in neighbors_dict.keys():
        if len(neighbors_dict[cell]) ==0:
            counter += 1
    return counter
    

neighbours_distance = START_NEIGHBOURS_DISTANCE
while True:
    islands = count_islands(neighbors_dict)
    print(f"{islands=}")
    if islands>0:
        neighbours_distance+=1
        print(f"{neighbours_distance=}")
        for cell in neighbors_dict.keys():
            if len(neighbors_dict[cell]) ==0:
                neighbors_dict[cell] = [cell for cell in h3.grid_ring(cell, neighbours_distance) if cell in existing_hex_ids]
    else: 
        break

        
        


In [None]:
w = W(neighbors_dict)

In [None]:
w = w.symmetrize()

In [None]:
w.n_components

In [None]:
len(neighbors_dict)

In [None]:
from spopt.region import RegionKMeansHeuristic

In [None]:
model = RegionKMeansHeuristic(data=df_FB_grouped_scaled, n_clusters=25, w = w, drop_islands=True)

In [None]:
model.solve()

In [None]:
# df_DD_grouped = df_DD.groupby(["small_hex_id_rent", 'hour']).size()
# df_DD_grouped = df_DD_grouped.reset_index(name='count_rent')
# df_DD_grouped = df_DD_grouped.pivot(index='small_hex_id_rent', columns='hour', values='count_rent').fillna(0).astype(int)

In [None]:
# grouped = df_DD.groupby("hour").size().sort_values()
# grouped = grouped.reset_index(name='count_rent')
# grouped

In [None]:
# grouped.sort_values("hour").plot(x='hour', y='count_rent', kind='bar', figsize=(10, 5), title='Rentals per hour')

In [None]:
# defined intervals are: [0-6] [7-11] [12-13-14] [15,16,17,18, 19] [20-23]

In [None]:
# tmp = df_DD.groupby(["hour_interval", "hour"]).size().reset_index(name='count_rent')

In [None]:
# tmp.query("count_rent>0")

In [None]:
# df_DD_grouped_scaled