# Hybridizing process-based models with ML using Jax and Flax 

Arpit Kapoor 2023

In this notebook, we will first train a process-based hydrological model (Simple AWBM) on synthetic streamflow data and then use a Multi-layered perceptron to train hybrid streamflow prediction model.

You can pull the docker image from docker hub with the following command:
```bash
docker pull jsimdare/darenumpyro
```

We start import by necessary libraries

In [None]:
# Utility packages
import os
import pandas as pd
import matplotlib.pyplot as plt

from typing import Tuple

import jax
import jax.numpy as jnp                # For numpy operations in jax
import optax                           # Optimization package for jax

from tqdm import tqdm                  # To print progress bar

from flax import linen as nn
from clu import metrics                # To keep track of training metrics
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses


## 1. Read synthetic streamflow data from file

In [None]:
raw_data = pd.read_csv(
    os.path.join(
        'project', 'data', 'rr_Data.csv'
    ), index_col=0, parse_dates=True
)
raw_data.head()

We extract features and targets as jax device arrays from dataframe

In [None]:
prec = jnp.array(raw_data['prec'].values) # precipitation
etp = jnp.array(raw_data['etp'].values) # evaporation
q_obs = jnp.array(raw_data['qobs'].values).reshape(-1, 1) # observed discharge
q_date = raw_data.index.values # date

In [None]:
#Checks
print('Precipitation: ', prec.shape)
print('Evaporation: ', etp.shape)
print('Observed discharge: ', q_obs.shape)

Create covariate and target vectors

In [None]:
covariates = jnp.stack([prec,etp], axis=1)
targets = q_obs

In [None]:
# Key Generators
root_key = jax.random.PRNGKey(seed=1)
main_key, params_key = jax.random.split(key=root_key, num=2)

## 2. Instantiate and train the Process-based model (AWBM)

In [None]:
from model.awbm import SimpleAWBM

awbm = SimpleAWBM(
    S_init=10.0,
    B_init=10.0
)

Define dataclass for keeping track of metrics during the model training

In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Define metrics to track during training"""
    loss: metrics.Average.from_output('loss')
    nse: metrics.Average.from_output('nse')

Flax is designed to keep the model states independent of the module object. This is done via TrainState object.

In [None]:
class TrainState(train_state.TrainState):
    """Train state stores the current state of model parameter and optimizer
       We also store the current metrics and random number generator keys
    """
    metrics: Metrics
    key: jax.random.KeyArray

def create_train_state(module: nn.Module, main_key, params_key, lr, n_features):
    """Creates an initial `TrainState`."""

    # TODO: Initialise model parameters

    # TODO: Create optimizer instance

    # TODO: Create TrainState instance

    return train_state


In [None]:
model_state = create_train_state(awbm, 
                                 main_key=main_key, 
                                 params_key=params_key, 
                                 lr=1e-2,
                                 n_features=covariates.shape[-1])

model_state.params

Next we define the `train_step` function to execute one training step on single batch of data

In [None]:
@jax.jit
def train_step(state, batch, targets):
    """Train for a single step."""
    
    def loss_fn(params):
        # TODO: Define the function to compute MSE loss
        loss = 
        return loss
    
    # Function to compute gradients of the loss_fn
    grad_fn = jax.grad(loss_fn)

    # TODO: Compute and apply gradient to the train state
    
    return state

We use Nash–Sutcliffe Efficiency (NSE) coefficient to evaluate the streamflow predictions

$$
NSE = 1 - \frac{\sum_j (Q_j - \hat{Q}_j)^2}{\sum_j (Q_j - \bar{Q})^2}
$$

where $Q$ is the observed streamflow and $\hat{Q}$ is the predicted stream flow


In [None]:
@jax.jit
def nse(targets: jnp.ndarray, predictions: jnp.ndarray):
    """Function to compute Nash–Sutcliffe Efficiency (NSE) coefficient"""
    numer = jnp.sum(jnp.square(targets-predictions))
    denom = jnp.sum(jnp.square(targets-jnp.mean(targets)))
    nse_score = 1 - numer/denom
    return nse_score

@jax.jit
def compute_metrics(*, state: TrainState, 
                    batch: jnp.ndarray, 
                    targets: jnp.ndarray):
    """Function to compute training metrics at each epoch"""

    # Generate model prediction from state
    preds = state.apply_fn({'params': state.params}, batch)
    
    # Compute loss and other metrics
    loss = optax.l2_loss(preds, targets).mean()
    nse_score = nse(targets, preds)
    
    # TODO: Compute metric updates and merge the updates
    metric_updates = state.metrics.single_from_model_output(loss=loss,
                                                            nse=nse_score)
    
    # Update state
    state = state.replace(metrics=metrics)
    
    return state

In [None]:

def plot_metrics(metrics_history: dict, figsize: Tuple[int, int]=(10, 4)):
    """Utility function to plot performance metrics after training"""

    # Extract the list of metrics from dict
    metric_list = metrics_history.keys()

    # Crreate subplots
    fig, ax = plt.subplots(1, len(metric_list), figsize=figsize)

    for idx, metric in enumerate(metric_list):
        
        # Fetch  value of current metric
        metric_val = jnp.array(metrics_history[metric])
        
        # Plot the metric
        ax[idx].plot(metric_val, color='black')
        ax[idx].set_xlabel('Epoch')
        ax[idx].set_ylabel(metric)

    return fig

In [None]:
def train(state: TrainState, 
          covariates: jnp.ndarray, 
          targets: jnp.ndarray, 
          n_epoch: int):
    """Train function for training Flax modules"""
    
    # Progress bar to monitor training progress
    pbar = tqdm(range(1, n_epoch+1))

    # Dictionary to store training metrics at each epoch
    metrics_history = {'loss': [], 'nse': []}

    # Train Module
    for epoch in pbar:

        # Take one training step
        state = train_step(state, covariates, targets)

        # Compute metrics and update state
        state = compute_metrics(state=state, batch=covariates, targets=targets)
        
        # Store metrics
        for metric, value in state.metrics.compute().items():
            metrics_history[metric].append(value)
        
        # Reset metrics of the state
        state = state.replace(metrics=state.metrics.empty())

        # Print progress
        pbar.set_description(f"""Epoch {epoch}/{n_epoch} loss: {metrics_history['loss'][-1]:.4f} NSE: {metrics_history['nse'][-1]:.4f}""")

    return state, metrics_history


In [None]:
def evaluate(module: nn.Module, 
             covariates: jnp.ndarray, 
             targets: jnp.ndarray, 
             state: TrainState=None,
             params: nn.FrozenDict=None):
    """Function to evaluate the model predictions"""
    
    # Set params and state variables for param  values to use
    if params is None:
        if state is None:
            raise("No params provided!")
        params = state.params
    
    # Forward pass through model
    preds = module.apply({'params': params}, covariates)

    # Compute the nse score
    nse_score = nse(preds, targets)

    # Plot hydrograph
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.plot(covariates[:, 0], 'g--', label='precip', alpha=0.40)
    ax.plot(covariates[:, 1], 'y--', label='etp', alpha=0.30)
    ax.plot(targets, color='black', label='obs', alpha=1.0)
    ax.plot(preds, color='red', label='pred', alpha=0.75)

    ax.set_xlabel('Timestep')
    ax.set_ylabel('Flow (mm/day)')

    ax.annotate(f'NSE: {nse_score:.4f}',
            xy=(0.8, 0.95), xycoords='figure fraction',
            horizontalalignment='right', verticalalignment='top',
            fontsize=12)
    ax.set_title('Streamflow prediction')

    plt.legend()

## 3. Train the AWBM model on synthetic data

In [None]:
trained_state, metrics_history = train(model_state, 
                                       covariates, targets, 
                                       n_epoch=150)

### Evaluate model performance (NSE) and plot hydrograph

In [None]:
plot_metrics(metrics_history)

In [None]:
evaluate(awbm, covariates, q_obs, state=trained_state)

## 4. Train Hybrid Model

Now that we have optimised an AWBM model for this data, we will now look at integrating this with an MLP to create a hybrid model. this model takes the output of AWBM and combines it with the covariates to predict streamflow at each timestep. We will optimise for the MLP coefficients and the AWBM coefficients simultaneously.

In [None]:
from model.hybrid import HybridAWBM

hybrid_awbm = HybridAWBM(S_init=10.,
                         B_init=10.,
                         n_layers=2,
                         n_features=[8, 1])

Create train_state and train the model

In [None]:
hybrid_state = create_train_state(hybrid_awbm, 
                                  main_key=main_key, 
                                  params_key=params_key, 
                                  lr=1e-2,
                                  n_features=covariates.shape[-1])

trained_hybrid_state, hybrid_metrics_history = train(hybrid_state, 
                                       covariates, targets, 
                                       n_epoch=200)

In [None]:
plot_metrics(hybrid_metrics_history)

In [None]:
evaluate(hybrid_awbm, covariates, targets, state=trained_hybrid_state)

### Compare the performance of the hybrid model with the AWBM model (NSE)

In [None]:
# Print the overall performance of the awbm vs hybrid model
print('Compare Nash-Sutcliffe Efficiency (NSE) scores')
print(f"AWBM NSE: {metrics_history['nse'][-1]:.4f}")
print(f"Hybrid NSE: {hybrid_metrics_history['nse'][-1]:.4f}")

### View the parameters of the model 
Here we can see that the to see how they have been optimised to different values in the AWBM only and the Hybrid model.

In [None]:
# AWBM model optimisation
trained_state.params

In [None]:
# Hybrid model optimisation
trained_hybrid_state.params

## That's it, folks!