In [None]:
import os
import numpy as np
from CMF_GPU.utils.Preprocesser import binread, read_map, process_basins,verify_basin_integrity, assign_basins_to_gpus, find_indices_in, compute_runoff_id
from omegaconf import OmegaConf
config = OmegaConf.load(r"../configs/test1-glb_15min.yaml")
runtime_flags = config.runtime_flags

In [None]:
nx, ny = 1440, 720
map_dir=config.map_dir
nextxy_path = os.path.join(map_dir, "nextxy.bin")
nextxy_data = binread(nextxy_path, (nx, ny, 2), dtype_str="<i4")

In [None]:
catchment_x, catchment_y = np.where(nextxy_data[:, :, 0] != -9999)
next_catchment_x, next_catchment_y = nextxy_data[catchment_x, catchment_y, 0] - 1, nextxy_data[catchment_x, catchment_y, 1] - 1
catchment_id = np.ravel_multi_index((catchment_x, catchment_y), nextxy_data.shape[:2])
next_catchment_id = np.full_like(next_catchment_x, -1, dtype=int)
valid_next = (next_catchment_x >= 0) & (next_catchment_y >= 0)
next_catchment_id[valid_next] = np.ravel_multi_index(
    (next_catchment_x[valid_next], next_catchment_y[valid_next]),
    nextxy_data.shape[:2]
)
NSEQMAX = len(np.where(next_catchment_x != -9999)[0])
is_river_mouth = next_catchment_id < 0

In [None]:
# Topological sorting to get all basins for potential distribution to multiple workers
basins = process_basins(catchment_id, next_catchment_id, is_river_mouth)
verification_passed = verify_basin_integrity(
    basins,
    catchment_id,
    next_catchment_id
)
assert len(basins) == len(set(rivermouth for rivermouth, _ in basins)), "Duplicate river mouths found!"

gpu_basin_ids, split_indices = assign_basins_to_gpus(basins, num_gpus=len(runtime_flags["device_indices"]))

runtime_flags["split_indices"] = split_indices

In [None]:
loc =find_indices_in(gpu_basin_ids, catchment_id)
assert (loc != -1).all()
catchment_x = catchment_x[loc]
catchment_y = catchment_y[loc]
catchment_id = catchment_id[loc]
next_catchment_id = next_catchment_id[loc]
is_river_mouth = is_river_mouth[loc]

In [None]:
VarList=["river_length","river_width","river_height","river_manning","catchment_elevation","catchment_area","downstream_distance"]
FileName=["rivlen","rivwth_gwdlr","rivhgt","rivman","elevtn","ctmare","nxtdst","fldhgt"]
Precision="<f4"
params = {}
for var, fname in zip(VarList, FileName):
    params[var] = read_map(os.path.join(map_dir, f"{fname}.bin"), (nx, ny), precision=Precision)[catchment_x, catchment_y]
NLFP = 10
flood_depth_table = read_map(os.path.join(map_dir, "fldhgt.bin"), (nx, ny, NLFP), precision=Precision)[catchment_x, catchment_y, :]
flood_depth_table = np.hstack([
    -params["river_height"].reshape(-1, 1),
    np.zeros((NSEQMAX,1)).astype(np.float32),
    flood_depth_table,
    np.full((NSEQMAX, 1), np.inf).astype(np.float32)
    ])


# Initialize river depth

In [None]:

river_depth_init = np.zeros(NSEQMAX)
next_id_map = find_indices_in(next_catchment_id, catchment_id)

river_height = params["river_height"]
river_length = params["river_length"]
river_width = params["river_width"]
catchment_area = params["catchment_area"]
catchment_elevation = params["catchment_elevation"]
downstream_distance = params["downstream_distance"]
river_elevation = catchment_elevation - river_height

# Initialize river depth (traverse topological sequence from back to front)
for ii, jj in zip(reversed(range(NSEQMAX)), reversed(next_id_map)):
    if ii == jj or jj < 0:
        river_depth_init[ii] = river_height[ii]
    else:
        river_depth_init[ii] = max(
            river_depth_init[jj] + river_elevation[jj] - river_elevation[ii],
            0.0
        )
    river_depth_init[ii] = min(river_depth_init[ii], river_height[ii])
river_mouth_distance = 10000.0 
downstream_distance[is_river_mouth] = river_mouth_distance
river_storage = river_width * river_depth_init * river_length

# Create runoff input matrix 

In [None]:
hires_map_dir = config.hires_map_dir
location_file = os.path.join(hires_map_dir, "location.txt")

with open(location_file, "r") as f:
    lines = f.readlines()

data = lines[2].split()
Nx, Ny = int(data[6]), int(data[7])
West, East = float(data[2]), float(data[3])
South, North = float(data[4]), float(data[5])
Csize = float(data[8])

hires_lon = np.linspace(West  + 0.5 * Csize, East  - 0.5 * Csize, Nx)
hires_lat = np.linspace(North - 0.5 * Csize, South + 0.5 * Csize, Ny)
lon2D, lat2D = np.meshgrid(hires_lon, hires_lat)  
hires_lon_2D = lon2D.T
hires_lat_2D = lat2D.T

HighResGridArea = read_map(os.path.join(map_dir, hires_map_dir, "1min.grdare.bin"), (Nx, Ny), precision="<f4") * 1E6

HighResCatchmentId = read_map(os.path.join(map_dir, hires_map_dir, "1min.catmxy.bin"), (Nx, Ny, 2), precision="<i2")

valid_mask = HighResCatchmentId[:, :, 0] > 0
x_indices, y_indices = np.where(valid_mask)
    
valid_x = HighResCatchmentId[x_indices, y_indices, 0] - 1  # 1-based to 0-based
valid_y = HighResCatchmentId[x_indices, y_indices, 1] - 1
valid_areas = HighResGridArea[x_indices, y_indices]

catchment_id_hires = np.ravel_multi_index((valid_x, valid_y), (nx, ny))

ro_lon = np.arange(-179.5, 179.5 + 1, 1)
ro_lat = np.arange(89.5, -89.5 - 1, -1)
valid_lon = hires_lon_2D[x_indices, y_indices]
valid_lat = hires_lat_2D[x_indices, y_indices]

runoff_ids = compute_runoff_id(ro_lon, ro_lat, valid_lon, valid_lat)


In [None]:
import importlib
import torch
ds_cls = getattr(importlib.import_module("CMF_GPU.utils.Dataloader"), config.runoff_dataset.class_name)
example_ds = ds_cls(
    **config.runoff_dataset.params
)
runoff_mask = example_ds.get_mask()
runoff_matrix_list = [None] * len(split_indices) # num_device
runoff_mask_list = [np.zeros(len(example_ds.lat) * len(example_ds.lon), dtype=np.bool)] * len(split_indices)

valid_count = 0
total_count = 0
for i, gpu_id in enumerate(np.split(gpu_basin_ids, split_indices[:-1])):
    row_indices = find_indices_in(catchment_id_hires, gpu_id)
    if runoff_mask is not None:
        runoff_mask_row = np.ravel(runoff_mask, order="C")
    else :
        runoff_mask_row = np.ones(len(example_ds.lat)*len(example_ds.lon), dtype=bool)
    row_mask = row_indices != -1 & runoff_mask_row[runoff_ids] # pixels in this gpu and have valid runoff.
    # remap runoff_ids
    unique_ids = np.unique(runoff_ids[row_mask])
    id_map = {old_id: new_id for new_id, old_id in enumerate(unique_ids)}
    remapped_runoff_ids = np.array([id_map[id_val] for id_val in runoff_ids[row_mask]], dtype=np.int32)
    runoff_mask_list[i][unique_ids] = True
    runoff_matrix_list[i] = torch.sparse_coo_tensor(np.vstack((row_indices[row_mask], remapped_runoff_ids)), valid_areas[row_mask], (NSEQMAX, len(unique_ids))).coalesce()
    valid_count += len(np.unique(row_indices[row_mask]))
    total_count += len(np.unique(row_indices[row_indices != -1]))
assert total_count == len(gpu_basin_ids), "total count mismatch!"
if total_count != valid_count:
    print(
        f"Warning: {total_count - valid_count} catchment(s) will never receive valid runoff data "
        "because all their associated grid cells are invalid; their runoff input will always be 0. "
        "If there are many such catchments, this may indicate an issue with the input data or code logic."
    )

for i, mat in enumerate(runoff_matrix_list):
    print(f"(GPU{i}) Runoff Input Matrix Shape:", mat.shape)
    print(f"(GPU{i}) Nonzero Elements:", mat._nnz())

In [None]:
params["is_river_mouth"] = is_river_mouth
next_catchment_id[is_river_mouth] = catchment_id[is_river_mouth]
downstream_idx = find_indices_in(next_catchment_id, catchment_id)
params["flood_depth_table"] = flood_depth_table
params["downstream_idx"] = downstream_idx
params["river_length"]  = river_length
params["flood_manning"] = 0.1 * np.ones(NSEQMAX, dtype=np.float32)
params["log_buffer_size"]  = 500
params["adaptation_factor"] = 0.7
params["num_catchments"] = NSEQMAX
params["num_flood_levels"] = 10
params["gravity"] = 9.81


init_states = {
    "river_storage": river_storage,
    "river_depth": river_depth_init,}

for state in [
    "flood_storage",
    "river_outflow",
    "flood_depth",
    "flood_outflow",
    "river_cross_section_depth",
    "flood_cross_section_depth",
    "flood_cross_section_area",
]:
    init_states[state] = np.zeros(NSEQMAX, dtype=np.float32)

In [None]:
import pickle
from CMF_GPU.utils.utils import snapshot_to_pkl
from CMF_GPU.utils.Preprocesser import save_coo_list_to_pkl
inp_dir = config.inp_dir
os.makedirs(inp_dir, exist_ok=True)
save_coo_list_to_pkl(runoff_matrix_list, os.path.join(inp_dir, "runoff_input_matrix.pkl"))

snapshot_to_pkl(params, "param", runtime_flags["modules"], os.path.join(inp_dir, "parameters.pkl"), omit_hidden=True)
snapshot_to_pkl(init_states, "state", runtime_flags["modules"], os.path.join(inp_dir, "init_states.pkl"), omit_hidden=True)

with open(os.path.join(inp_dir, "runoff_mask.pkl"), 'wb') as f:
    pickle.dump(runoff_mask_list, f)    

# update runtime_flags
config.runtime_flags.update(runtime_flags)
OmegaConf.save(config=config, f=os.path.join(inp_dir, "config.yaml"))



In [None]:
# import matplotlib.pyplot as plt
# if runoff_mask is not None:
#     plt.figure(figsize=(10, 6))
#     plt.imshow(runoff_mask, origin='upper')
#     plt.show()

# runoff_ids_2D = np.full((Ny, Nx), np.nan) 
# runoff_ids_2D[y_indices, x_indices] = runoff_ids

# plt.figure(figsize=(10, 6))
# plt.imshow(runoff_ids_2D[1000:5000,2000:7000], cmap='viridis', origin='upper')
# plt.title('Runoff IDs on High-Resolution Map')
# plt.xlabel('Longitude')
# plt.ylabel('Latitude')
# plt.show()

In [None]:
# pythonorder=(np.ravel_multi_index((catchment_y, catchment_x), (ny, nx))+1)
# np.savetxt('PythonOrder.csv', pythonorder, delimiter=',', fmt='%d')