### GPU Memory optimization

DiffRoute accelerates LTI routing computations by factoring computations in an embarassingly parallel formulation.
This formulation of the routing procedure induces memory overhead and redundant computations.

In technical terms, DiffRoute explictly instantiates a routing kernel along the transitive closure of the river network.
For large river network, the memory footprint of the kernel becomes a bottleneck.
To handle such large workloads within memory bounds, a number of helper functionalities are proposed.

This notebook discusses the GPU memory bottleneck that arises in routing through large river networks and presents the different techniques to navigate this limitation.

% Actually there are two limitations: 
(i) the kernel
(2) the input/output time series.

In [1]:
from pathlib import Path

import pandas as pd
import xarray as xr

from diffhydro import (LTIRouter, LTIStagedRouter, DataTensor, StagedCatchmentInterpolator)
from diffhydro.io import read_rapid_graph

### Parameters

In [2]:
# Routing model parameterss
max_delay = 32
dt = 1/24
# Experiment pathes and variables
device = "cuda:1"
root = Path("./data") # Set a data root path with enough disk space (>100MB)
root = Path("../../../DiffHydro/examples/data") # Set a data root path with enough disk space (>100MB)
rapid_path = root / "geoglows" / "rapid_config" / "305"

### Download data if necessary

Downloads can take time, only execute this cell if you have not previously downloaded the data

In [3]:
from utils.download import download_single_vpu_data
download_single_vpu_data(root) # Should format to RAPID.

Downloading 305_daily_sparse_runoff.feather:   0%|          | 0.00/46.0M [00:00<?, ?B/s]

Downloading 305_interp_weight.feather:   0%|          | 0.00/465k [00:00<?, ?B/s]

AttributeError: 'DataTensor' object has no attribute 'dtype'

### Load input runoffs

In [None]:
runoff_xr = xr.open_dataarray(rapid_path / "runoff.nc")
runoff_xr = runoff_xr.rename({"river_id":"spatial"}).expand_dims("batch").transpose("batch", "spatial", "time")
runoff = DataTensor.from_dataarray(runoff_xr).to(device)

### 0. Basic Routing

This is a recap of the previous notebook for "normal" routing

In [None]:
# Load graph
g = read_rapid_graph(rapid_path).to(device)
# Instantiate the routing model
model = LTIRouter(max_delay=max_delay, dt=dt).to(device)
# Execute routing
discharges = model(runoff, g)

### 1. Kernel memory footprint reduction with staged routing

In [None]:
# graph partitioning parameters
plength_thr = 10**4
node_thr = 10**4

In [None]:
g = read_rapid_graph(rapid_path, 
                     plength_thr=plength_thr, 
                     node_thr=node_thr).to(device)
g

In [None]:
model = LTIStagedRouter(max_delay=max_delay, dt=dt).to(device)
discharges = model(runoff, g)

### 2. Time series footprint reduction

#### 2.1 Chunk routing in time

In [None]:
def time_chunk_generator(datatensor, chunk_size, init_window):
    """
        
    """
    pass

def coalesce_time_chunks(chunk_seq, init_window):
    """
        
    """
    pass

#### 2.2 Chunk routing in space

#### 2.3 On-the-fly catchment aggregation

#### 2.4 Manage memory transfers

### Kernel memory reduction: Route with sub-cluster stages

In [16]:
# Load the routing graph defined as a RAPID project into a RivTree structure
g = read_rapid_graph(rapid_config_path, 
                     plength_thr=plength_thr, 
                     node_thr=node_thr).to(device)

# Load input runoff.
# Here the data is provided as pixel-wise runoffs.
pixel_runoff = pd.read_feather(runoff_path)  / (3600. * 24) # Convert values in m3 / s
pixel_runoff = TimeSeriesThDF.from_pandas(pixel_runoff).to(device) # Convert pandas DataFrame to TimeSeriesThDF

# Interpolate the pixel-wise runoffs onto the graph catchments 
interp_df = pd.read_feather(interp_weight_path)
cat = StagedCatchmentInterpolator(g, pixel_runoff, interp_df).to(device) 

model = LTIStagedRouter(max_delay=max_delay, dt=dt).to(device)

#### Upstream stats computations ... ####


Computing breakpoints:   0%|          | 0/15562 [00:00<?, ?it/s]

#### Segmentation into subgraphs ... ####
Removing edges...


  0%|          | 0/15562 [00:00<?, ?it/s]

Segment graph into connected components....
Build subgraphs for each cluster and node-cluster map...


  0%|          | 0/207 [00:00<?, ?it/s]

Establish dependencies between clusters...


  0%|          | 0/153 [00:00<?, ?it/s]

#### Grouping subgraphs to cluster and infering dependencies ... ####
Initialize dependencies...
Associate clusters for remaining subgraphs...


0it [00:00, ?it/s]

Merging graphs...


  0%|          | 0/19 [00:00<?, ?it/s]

Computing merged graphs node idxs...


  0%|          | 0/19 [00:00<?, ?it/s]

Match breakpoint nodes across clusters...


  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:00<?, ?it/s]

In [17]:
runoff = cat.interpolate_all_runoff(pixel_runoff)
discharges = model(runoff, g)

### Time series memory reduction: Route with sub-cluster stages Iterative routing of sub-cluster with on-the fly catchment interpolation

In [32]:
# Here we create a generator of per-cluster input runoff time series
runoff = cat.yield_all_runoffs(pixel_runoff)
type(runoff)

generator

In [33]:
# Here we create a generator of per-cluster output discharge time series
discharges = model.route_all_clusters_yield(runoff, g)
type(discharges)

generator

In [34]:
# Looping over the output discharges triggers both the catchment interpolation and routing execution.
for i,d in enumerate(discharges):
    print(f"Cluster {i} contains river channels {d.columns[:5]}")

Cluster 0 contains river channels [360038401 360049008 360056494 360060030 360060238]
Cluster 1 contains river channels [360038400 360042353 360053793 360054001 360054209]
Cluster 2 contains river channels [360050178 360046227 360049556 360054758 360053511]
Cluster 3 contains river channels [360046093 360046301 360046509 360052124 360052332]
Cluster 4 contains river channels [360029711 360032415 360028048 360028256 360025138]
Cluster 5 contains river channels [360028673 360024096 360030126 360030334 360036574]
Cluster 6 contains river channels [360023043 360028868 360035526 360035734 360035942]
Cluster 7 contains river channels [360030211 360044355 360052051 360052259 360055796]
Cluster 8 contains river channels [360069121 360069329 360071203 360071411 360071619]
Cluster 9 contains river channels [360030208 360030416 360029585 360056418 360047683]
Cluster 10 contains river channels [360061444 360061652 360073092 360071221 360062070]
Cluster 11 contains river channels [360061455 3600616

### Further time-series memory reduction by keeping runoff time series on GPU.

### Further time-series memory reduction: tiling in time

Finally, if the cost of routing remains too high, even through small clusters and dedicated CPU-GPU memory transfers, it is possible to chunk the time series in time and iterate over these chunks.

We shall provide an example for doing so shortly.

If you need this functionality, please reach out to us so we can adapt an implementation that suits your needs.