# Bayesian hierarchical model Implementation

This notebook demonstrates the implementation of a Bayesian hierarchical model (BHM) for extreme annual and seasonal precipitation.
The code related to extrapolation at ungauged basins is not included here but can be requested by contacting atakallou@crimson.ua.edu
.
In addition, this notebook presents the implementation of the stationary BHM; for the non-stationary BHM implementation, please contact the author (Ali Takallou).

### Import Libraries

In [1]:
import argparse
import os
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as random
from scipy.stats import genextreme
from scipy.spatial.distance import pdist, squareform
import numpyro
import numpyro.distributions as dist
from numpyro.infer import (
    MCMC,
    NUTS,
    init_to_feasible,
    init_to_median,
    init_to_sample,
    init_to_uniform,
    init_to_value,
)
import matplotlib.pyplot as plt
import seaborn as sns
from numpyro.diagnostics import hpdi, summary
import arviz as az
from scipy.stats import pearsonr
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from tensorflow_probability.substrates.jax import distributions as tfd
from numpyro.infer import Predictive
import sys

#### Load the data

In [2]:
## Get data function
def get_data(season , sigma_obs=0.15, seed=44, test_frac=0.2):
    np.random.seed(seed)
    data_folder = "/mh1/Atakallou/HBM/Covariates"
    csv_data = pd.read_csv(os.path.join(data_folder, "COV.csv"))
    csv_data = csv_data.sample(frac=1, random_state=seed).reset_index(drop=True)
    N =  len(csv_data)
    X_dist   =  get_distance_matrix(csv_data[["Est", "Nor"]].to_numpy())
    X         = csv_data[["LONGITUDE", "LATITUDE", "ELEVATION"]].to_numpy()
    X = (X - X.min(0)) / (X.max(0) - X.min(0))
    ones = np.ones((X.shape[0], 1))
    X = np.hstack([X, ones])
    # 1. Load PRCP data for selected season
    y_data = pd.read_csv(f"/mh1/Atakallou/HBM/Ind_data/{season}_data_Ind.csv")
    # 2. Filter years between 1951 and 2020
    y_data = y_data[(y_data['Year'] >= 1951) & (y_data['Year'] <= 2020)]
    # 3. Keep only stations present in csv_data
    y_data = y_data[y_data['STATION'].isin(csv_data['STATION'])]
    # 4. Pivot to station-year PRCP matrix
    prcp_matrix = y_data.pivot_table(index='STATION', columns='Year', values='PRCP')
    # 5. Reorder columns (years) and rows (stations)
    prcp_matrix = prcp_matrix.reindex(columns=sorted(prcp_matrix.columns))
    prcp_matrix = prcp_matrix.reindex(csv_data['STATION'])
    # 6. Interpolate missing values across years (axis=1)
    prcp_matrix_filled = prcp_matrix.interpolate(method='linear', axis=1, limit_direction='both')
    # 7. Convert to NumPy array
    prcp_array = prcp_matrix_filled.to_numpy() 
    Y =  prcp_array.copy()
    ML_parm =  pd.read_csv(f"/mh1/Atakallou/HBM/ML_est/{season}_Parm_ML.csv")
    ML_parm = ML_parm.set_index("station").reindex(csv_data["STATION"]).reset_index()
    loc   =  ML_parm['location_est'].to_numpy()
    scale =  ML_parm['scale_est'].to_numpy()
    shape =  ML_parm['shape_est'].to_numpy()
    
    # --- Generate dynamic features --- 
    ci_data = pd.read_csv(os.path.join(data_folder, "CI",f"{season}.csv"))
    X_T = ci_data.iloc[:, 1:].to_numpy().T
    N_test = X_T.shape[1]
    X_T = (X_T - X_T.mean(axis = 1, keepdims = True)) / (X_T.std(axis = 1, keepdims = True))
    time_row = np.linspace(0, 1, X_T.shape[1])
    X_T = np.vstack([X_T, time_row])
    
    # --- Split over gauge axis (spatial) --- We will split the data into two part (training and test)
    split_idx = int((1 - test_frac) * N)
    X_train, X_test         = X[:split_idx],     X[split_idx:]
    Y_train, Y_test         = Y[:split_idx],     Y[split_idx:]
    Xds_train = X_dist[:split_idx, :split_idx]
    Xds_test  = X_dist[split_idx:, split_idx:]
    Xds_join  = X_dist[split_idx: , :split_idx]
    cov_train, cov_test = csv_data[:split_idx], csv_data[split_idx:]
    ML_train, ML_test   =  ML_parm[:split_idx], ML_parm[split_idx:]
    loc_train, loc_test     = loc[:split_idx],   loc[split_idx:]
    scale_train, scale_test = scale[:split_idx], scale[split_idx:]
    shape_train, shape_test = shape[:split_idx], shape[split_idx:]

    
    return (X_train, X_test, X_T, Y_train, Y_test, 
            loc_train, loc_test,
            scale_train, scale_test,
            shape_train, shape_test,
            X_dist, Xds_train, Xds_test, Xds_join,
            cov_train, cov_test,
             ML_train, ML_test)







#### Distance Matrix Function Generator

In [3]:
#Calculate the distance matrix from eastings and northings
#We also scale the distances for nummerical stabiliy of BHM to  10km

def get_distance_matrix(X):
    # Transform to UTM
    eastings = X[:,0]
    northings = X[:,1]
    utm_coords = np.column_stack((eastings, northings))
    distance_matrix = jnp.array(squareform(pdist(utm_coords, metric='euclidean')))/ 10000
    return distance_matrix

#### Exponential Kernel for GP Covariance

In [4]:
#exponential kernel 
def kernel(dist, pho, alpha):
    deltaX = dist / pho
    k = jnp.power(alpha, 2) * jnp.exp(-deltaX)
    return k


#### Bayesian Hierarchical Model

In [5]:
#Stationary model for BHM
#Non-Stationary model can be requested by contacting atakallou@crimson.ua.edu

def ST_model(X, X_T, D, Y = None):
        
    """
    X      : (N, 3)   static covariates  (east, north, height)
    X_T    : (2, T)   dynamic covariates (climate indices)
    Y      : (N, T)   block-maxima observations
    """
    N = X.shape[0]
    T = X_T.shape[1]
    # ----- Estimation of Location -------------------------------------------
    WsL1 = numpyro.sample("WsL1", dist.Normal(0.0, 15.0).expand([X.shape[1]]))     # static weights for loc
    mu_static = jnp.dot(X, WsL1) 
    numpyro.deterministic("m_q_loc", mu_static)
    alpha = numpyro.sample("alpha_loc", dist.Gamma(1.25, 0.25))   # amplitude
    rho   = numpyro.sample("rho_loc",   dist.Gamma(30.0, 0.2))   # length-scale
    K = kernel(D, rho, alpha) 
    numpyro.deterministic("K_loc", K)
    loc_gp    = numpyro.sample("u_q_loc",
                           dist.MultivariateNormal(mu_static, K))
    loc_total = jnp.tile(loc_gp[:, None], (1, T)) 
    numpyro.deterministic("loc", loc_total)
    
    #Estimation of Scale Parameter
    WsL2 = numpyro.sample("WsL2", dist.Normal(0.0, 1.0).expand([X.shape[1]])) 
    alpha_sig = numpyro.sample("alpha_scale", dist.Gamma(1.25, 0.25))   # amplitude
    rho_sig   = numpyro.sample("rho_scale",   dist.Gamma(30.0, 0.2))   # length-scale
    K_sig = kernel(D, rho_sig, alpha_sig)  
    numpyro.deterministic("K_sig", K_sig)
    # apply the weights to it 
    mu_static_sig = jnp.dot(X, WsL2)
    numpyro.deterministic("m_q_scale", mu_static_sig)
    scale_gp = numpyro.sample("u_q_scale",
                               dist.MultivariateNormal(mu_static_sig, K_sig))
    scale_total = jnp.tile(jnp.exp(scale_gp[:, None]), (1, T)) 
    numpyro.deterministic("scale", scale_total)
    shape = numpyro.sample("shape", dist.Beta(2.5, 7.5))
    #GEV distribution
    gev_dist = tfd.GeneralizedExtremeValue(loc_total, scale_total, shape)
    # ----- Likelihood: GEV (TFP-JAX) ------------------------------
    with numpyro.plate("gauges", N, dim=-2):
        with numpyro.plate("samples", T, dim=-1):
            if Y is not None:
                # Conditioning on observed data
                numpyro.sample("Y_obs", gev_dist, obs=Y)
            else:
                # Generating posterior-predictive samples
                numpyro.sample("Y_pred", gev_dist)

#### Get Data

In [6]:


year = ["annual", "DJF", "MAM", "JJA", "SON"]
#If you aim to parallelize the code over different cores you can use season = year[int(sys.argv[1])]
season  = year[0]
data_folder = "/mh1/Atakallou/HBM/Covariates"


(
    X_train, X_test,X_T, Y_train, Y_test,
    loc_train, loc_test,
    scale_train, scale_test,
    shape_train, shape_test,
    X_dist, Xds_train, Xds_test, Xds_join,
    cov_train, cov_test,
    ML_train, ML_test
) = get_data(season)

print(f"✓ X shape: {X_train.shape}")
print(f"✓ Y shape: {Y_train.shape}")
print(f"✓ X_T shape: {X_T.shape}")

✓ X shape: (132, 4)
✓ Y shape: (132, 70)
✓ X_T shape: (6, 70)


#### Run the NUTS inference

In [8]:
print("\n Model is running...")
nuts_kernel = NUTS(ST_model,  target_accept_prob=0.95)
mcmc_gp = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=True)
mcmc_gp.run(jax.random.PRNGKey(42), X_train, X_T, Xds_train, Y_train)
samples = mcmc_gp.get_samples()


 Model is running...


  mcmc_gp = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000, num_chains=4, progress_bar=True)
sample: 100%|██████████| 2000/2000 [01:28<00:00, 22.62it/s, 31 steps of size 1.38e-01. acc. prob=0.96] 
sample: 100%|██████████| 2000/2000 [01:28<00:00, 22.63it/s, 31 steps of size 1.48e-01. acc. prob=0.96] 
sample: 100%|██████████| 2000/2000 [01:26<00:00, 23.03it/s, 31 steps of size 1.51e-01. acc. prob=0.95] 
sample: 100%|██████████| 2000/2000 [01:24<00:00, 23.75it/s, 31 steps of size 1.42e-01. acc. prob=0.96] 


####  Model Summary

In [9]:
mcmc_gp.print_summary()


                    mean       std    median      5.0%     95.0%     n_eff     r_hat
       WsL1[0]      9.37     15.56      9.60    -14.10     37.04   8044.32      1.00
       WsL1[1]      3.11     14.97      2.96    -21.93     27.89   8996.76      1.00
       WsL1[2]    -10.48     14.80    -10.56    -32.94     15.71   8176.58      1.00
       WsL1[3]     13.85     14.91     13.76    -10.44     38.61   8485.71      1.00
       WsL2[0]      0.85      0.37      0.84      0.21      1.43   4955.81      1.00
       WsL2[1]     -0.40      0.31     -0.40     -0.90      0.07   4460.99      1.00
       WsL2[2]     -0.84      0.19     -0.84     -1.19     -0.55   3181.89      1.00
       WsL2[3]      4.72      0.34      4.74      4.19      5.28   3291.45      1.00
     alpha_loc    146.46      9.42    146.11    131.77    162.27   4918.59      1.00
   alpha_scale      0.34      0.05      0.34      0.26      0.42   1431.12      1.00
       rho_loc    215.55     23.42    214.28    176.89    252.91