In [None]:
from attrs import define
import networkx as nx
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib as mpl

## Get lake and basin geometries

In [None]:
basins = gpd.read_file("/home/chris/Temp/basins.gpkg")
basins.head(1)

In [None]:
lake_geoms = gpd.read_file("./h2ox-dash/geoms/h2ox-reservoirs.geojson")
lake_geoms.head(1)

## Spatial join basin for each lake

In [None]:
lakes_basins = (
    lake_geoms.sjoin(basins, how="left", predicate="intersects")
    # get only the firthest downstream joined basin
    .sort_values(by="UP_AREA", ascending=False).drop_duplicates(
        subset="name", keep="first"
    )
)

## Create the graph

In [None]:
# Filter out unneeded basins (it is 1 million+)
pfaf3 = lakes_basins.PFAF_3.unique()
basins_india = basins.loc[basins.PFAF_3.isin(pfaf3)]

In [None]:
G = nx.from_pandas_edgelist(
    basins_india,
    source="HYBAS_ID",
    target="NEXT_DOWN",
    create_using=nx.DiGraph,
)

## Get upstream for each lake and filter already captured
Eg if a lake is downstream for another, remove the catchment of the upstream from what we record for the downstream.

In [None]:
@define
class Lake:
    name: str
    hybas_id: int
    up_area: int
    basins: list[int]

    @classmethod
    def from_row(cls, row):
        return cls(
            name=row["name"],
            hybas_id=row["HYBAS_ID"],
            up_area=row["UP_AREA"],
            basins=[],
        )

In [None]:
lakes = [Lake.from_row(row) for idx, row in lakes_basins.iterrows()]

In [None]:
for lake in lakes:
    lake.basins = [n for n in nx.traversal.bfs_tree(G, lake.hybas_id, reverse=True)]

In [None]:
for lake_a in lakes:  # downstream lake
    for lake_b in lakes:  # upstream lake
        if lake_a.up_area > lake_b.up_area:
            lake_a.basins = list(set(lake_a.basins) - set(lake_b.basins))

## Get geoms back into a DataFrame

In [None]:
geoms = []
for lake in lakes:
    geom = basins_india.loc[basins_india.HYBAS_ID.isin(lake.basins)].unary_union
    geoms.append(geom)

In [None]:
output_basins = gpd.GeoDataFrame(
    [lake.name for lake in lakes],
    columns=["name"],
    geometry=geoms,
    crs=4326,
)
output_basins["WKT"] = output_basins.geometry.to_wkt()
# output_basins.to_file("done.geojson")

In [None]:
colors = [mpl.colors.rgb2hex(c) for c in plt.get_cmap("tab20", 29).colors]
output_basins.explore(color=colors)

In [None]:
output_basins[["name", "WKT"]].to_csv("wkt.csv")