# PT-JPL-SM Sensitivity Analysis with ECOv002 Cal-Val

This notebook performs a sensitivity analysis of the PT-JPL-SM model using ECOSTRESS Collection 2 Cal-Val data. It loads input data, processes it through the PT-JPL-SM model, and visualizes the impact of input perturbations on latent heat flux. The notebook also generates figures for publication and explores the relationship between surface temperature and latent heat flux.

## Import Required Libraries and Functions

This cell imports all necessary libraries and functions for data processing, model execution, statistical analysis, and plotting. It includes custom modules for the PT-JPL-SM model, sensitivity analysis, and net radiation calculations, as well as standard scientific Python libraries.

In [1]:
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname("__file__"), '..')))

In [None]:
from typing import Callable
from os import makedirs
from os.path import join
import numpy as np
import pandas as pd
from verma_net_radiation import verma_net_radiation_table
from PTJPLSM import process_PTJPLSM_table, load_ECOv002_calval_PTJPLSM_inputs
from monte_carlo_sensitivity import perturbed_run, sensitivity_analysis, divide_absolute_by_unperturbed
import matplotlib.pyplot as plt
from scipy.stats import mstats
import seaborn as sns
from matplotlib.ticker import FuncFormatter

## Set Normalization Function

This cell assigns the normalization function used to compare perturbed model outputs to the unperturbed baseline. The function `divide_absolute_by_unperturbed` is used for normalization in the sensitivity analysis.

In [None]:
normalization_function = divide_absolute_by_unperturbed

## Load and Filter Input Data

This cell loads the ECOSTRESS Cal-Val input data using a custom loader function and filters out rows where `fAPARmax` is zero. The resulting DataFrame is displayed for inspection.

In [None]:
input_df = load_ECOv002_calval_PTJPLSM_inputs()
input_df = input_df[input_df.fAPARmax != 0]
input_df

## Check Minimum fAPARmax Value

This cell computes the minimum value of `fAPARmax` in the filtered input data to verify the filtering step and ensure no zero values remain.

In [None]:
np.nanmin(input_df.fAPARmax)

## Define Processing Function for Model Table

This cell defines a helper function that processes the input DataFrame through the Verma net radiation calculation and then the PT-JPL-SM model. This function is used in later analysis steps.

In [None]:
def process_verma_PTJPLSM_table(input_df: pd.DataFrame) -> pd.DataFrame:
    return process_PTJPLSM_table(verma_net_radiation_table(input_df), upscale_to_daily=True)

## Process Input Data Through Model

This cell applies the processing function to the filtered input data, running it through the Verma net radiation and PT-JPL-SM model, and displays the resulting DataFrame.

In [None]:
processed = process_verma_PTJPLSM_table(input_df)
processed

In [None]:
# import matplotlib.pyplot as plt
# import matplotlib.lines as mlines
# import numpy as np
# from scipy.stats import linregress, pearsonr

# colors = {
#     'CRO': '#FFEC8B', 'CSH': '#AB82FF', 'CVM': '#8B814C', 
#     'DBF': '#98FB98', 'EBF': '#7FFF00', 'ENF': '#006400', 
#     'GRA': '#FFA54F', 'MF': '#8FBC8F', 'OSH': '#FFE4E1', 
#     'SAV': '#FFD700', 'WAT': '#98F5FF', 'WET': '#4169E1', 
#     'WSA': '#CDAA7D'
# }

# one2one = np.arange(-250, 1200, 5)

# # Use the processed DataFrame for plotting
# df = processed.copy()
# mean_LE = df['ETcorr50daily'].mean()

# # New y positions for the text to shift it down
# x_pos = 0.05  # Position the text just inside the right edge of the plot
# y_pos = [0.9, 0.85, 0.8, 0.75, 0.7]  # Different y positions for each line of text

# # Function to calculate metrics using standard libraries
# def calculate_metrics(y, x):
#     # Remove NaNs for paired arrays
#     mask = ~np.isnan(x) & ~np.isnan(y)
#     x_valid = x[mask]
#     y_valid = y[mask]
#     n_points = len(x_valid)
#     # RMSE
#     rmse = np.sqrt(np.mean((y_valid - x_valid) ** 2)) if n_points > 0 else np.nan
#     # R2
#     r2 = pearsonr(y_valid, x_valid)[0] ** 2 if n_points > 1 else np.nan
#     # Linear regression
#     slope, intercept, _, _, _ = linregress(x_valid, y_valid) if n_points > 1 else (np.nan, np.nan, np.nan, np.nan, np.nan)
#     # Bias
#     bias = np.mean(y_valid - x_valid) if n_points > 0 else np.nan
#     # Absolute bias
#     abs_bias = np.mean(np.abs(y_valid - x_valid)) if n_points > 0 else np.nan
#     return rmse, r2, slope, intercept, bias, n_points, abs_bias

# # Function to plot PT-JPL subplot only
# def plot_ptjpl_subplot(ax, x, y, title, metrics):
#     rmse, r2, slope, intercept, bias, n_points, abs_bias = metrics
#     scatter_colors = [colors.get(veg, 'gray') for veg in df['vegetation']] if 'vegetation' in df else 'black'
#     ax.scatter(x, y, c=scatter_colors, marker='o', s=14, zorder=4)
#     ax.set_title(title, fontsize=16)
#     ax.set_ylim([-1, 12.5])
#     ax.set_xlim([-1, 12.5])
#     ax.set_ylabel('Model ET [mm day$^-$$^1$]', fontsize=14)
#     ax.set_xlabel('Flux Tower ET [mm day$^-$$^1$]', fontsize=14)
#     ax.plot(one2one, one2one, '--', c='k')
#     ax.plot(one2one, one2one * slope + intercept, '--', c='gray')
#     ax.text(x_pos, y_pos[0], f'y = {slope:.2f}x + {intercept:.2f}', transform=ax.transAxes, fontsize=12, color='black')
#     ax.text(x_pos, y_pos[1], f'RMSE: {rmse:.2f} mm day$^-$$^1$', transform=ax.transAxes, fontsize=12, color='black')
#     ax.text(x_pos, y_pos[2], f'bias: {bias:.2f} mm day$^-$$^1$', transform=ax.transAxes, fontsize=12, color='black')
#     ax.text(x_pos, y_pos[3], f'R$^2$: {r2:.2f}', transform=ax.transAxes, fontsize=12, color='black')
#     ax.text(x_pos, y_pos[4], f'N= {n_points}', transform=ax.transAxes, fontsize=12, color='black')
#     ax.text(-0.1, 1.05, 'a)', transform=ax.transAxes, fontsize=16, fontweight='bold', va='top', ha='right')

# # Prepare data and metrics for PT-JPL only
# x20 = df['ETcorr50daily'].to_numpy()
# metrics_pt_jpl = calculate_metrics(df['ETdaily_L3_ET_PT-JPL'].to_numpy(), x20)

# # Set up the figure and single subplot
# fig, ax = plt.subplots(figsize=(6, 6))
# plot_ptjpl_subplot(ax, x20, df['ETdaily_L3_ET_PT-JPL'].to_numpy(), 'PT-JPL ET$_{daily}$ C1', metrics_pt_jpl)

# # Create legend for vegetation types if available
# if 'vegetation' in df and 'colors' in globals():
#     scatter_handles = [mlines.Line2D([0], [0], marker='o', color='w', label=veg, markerfacecolor=color, markersize=8) 
#                        for veg, color in colors.items()]
#     fig.legend(handles=scatter_handles, loc='lower center', bbox_to_anchor=(0.5, 0), ncol=7, title='Vegetation Type', fontsize=10)

# plt.tight_layout()
# fig.subplots_adjust(bottom=0.15)  # Increase the bottom margin to make room for the legend

# print(f'PT-JPL bias = {metrics_pt_jpl[4]}')
# print(f'PT-JPL abs bias = {metrics_pt_jpl[6]}')
# print(f'PT-JPL n = {metrics_pt_jpl[5]}')

## Plot Unperturbed Comparison of Surface Temperature to Latent Heat Flux

This cell creates a scatter plot comparing ECOSTRESS surface temperature to PT-JPL-SM latent heat flux for the unperturbed data. The plot is saved as both JPEG and SVG files for publication or further analysis.

In [None]:
plt.xticks(range(int(min(processed.ST_C)), int(max(processed.ST_C)) + 1, 5))
plt.scatter(x=processed.ST_C, y=processed.ET_daily_kg, color='black', s=10, zorder=5)
plt.grid(True, zorder=0)
plt.xlabel("ECOSTRESS Surface Temperature (°C)")
plt.ylabel("PT-JPL-SM Daily ET (mm/day)")
plt.title("Unperturbed Comparison\nof ECOSTRESS Surface Temperature\nto PT-JPL-SM Evapotranspiration")

plt.savefig("Unperturbed Comparison of ECOSTRESS Surface Temperature to PT-JPL-SM Evapotranspiration.jpeg", format='jpeg', bbox_inches='tight')
plt.savefig("Unperturbed Comparison of ECOSTRESS Surface Temperature to PT-JPL-SM Evapotranspiration.svg", format='svg', bbox_inches='tight')

plt.show()

## Further Filter Input Data and Check Temperature Range

This cell applies additional filters to the input data, ensuring only valid `fAPARmax` and `NDVI` values are included. It also checks the minimum and maximum surface temperature values in the filtered dataset.

In [None]:
# input_df = pd.read_csv(input_filename)
input_df = load_ECOv002_calval_PTJPLSM_inputs()

if "Ta" in input_df and "Ta_C" not in input_df:
    # input_df.rename({"Ta": "Ta_C"}, inplace=True)
    input_df["Ta_C"] = input_df["Ta"]

input_df = input_df[input_df.fAPARmax.apply(lambda fAPARmax: fAPARmax > 0.001)]
input_df = input_df[input_df.NDVI.apply(lambda NDVI: NDVI > 0.05)]

np.nanmin(input_df.ST_C), np.nanmax(input_df.ST_C)

## Check Number of Valid Input Rows

This cell displays the number of rows remaining in the input DataFrame after all filtering steps, confirming the size of the dataset used for analysis.

In [None]:
len(input_df)

In [None]:
input_df.columns

In [None]:
input_df["SWin_Wm2"] = input_df["SW_IN"]
input_df["emissivity"] = input_df["EmisWB"]

In [None]:
# Ensure all relevant columns are numeric before sensitivity analysis
for col in ["albedo", "SWin_Wm2", "ST_C", "emissivity", "Ta_C", "RH"]:
    if col in input_df.columns:
        input_df[col] = pd.to_numeric(input_df[col], errors="coerce")
# Drop rows with NaN in any relevant column
input_df = input_df.dropna(subset=["albedo", "SWin_Wm2", "ST_C", "emissivity", "Ta_C", "RH"])

In [None]:
input_df = input_df.dropna()
input_df

## Run Perturbed Model Analysis

This cell sets up the input and output variables for the sensitivity analysis and runs the `perturbed_run` function, which perturbs the input variable and observes the effect on the output variable using the PT-JPL-SM model. The results are displayed for further analysis.

In [None]:
input_variable = "ST_C"
output_variable = "ET_daily_kg"

results = perturbed_run(
    input_df=input_df, 
    input_variable=input_variable, 
    output_variable=output_variable, 
    forward_process=process_verma_PTJPLSM_table,
    normalization_function=normalization_function
)

results

## Filter Out NaN Results

This cell removes any rows with missing values from the perturbed results to ensure only valid data points are used in subsequent analysis and plotting.

In [None]:
filtered_results = results.dropna()
filtered_results

## Plot Change in Surface Temperature vs Evapotranspiration

This cell generates a scatter plot showing the relationship between changes in surface temperature and changes in PT-JPL-SM evapotranspiration due to input perturbations. The plot is saved as JPEG and SVG files.

In [None]:
# plt.xticks(range(int(min(processed.ST_C)), int(max(processed.ST_C)) + 1, 5))
# plt.ylim(-350, 150)
plt.scatter(x=results.input_perturbation, y=results.output_perturbation, color='black', s=10, zorder=5)  # Adjust the 's' parameter to make dots thinner
plt.grid(True, zorder=0)
plt.xlabel("Change in Surface Temperature (°C)")
plt.ylabel("Change in PT-JPL-SM Evapotranspiration (mm/day)")
plt.title("Comparison of Change\nin Surface Temperature\nto PT-JPL-SM Evapotranspiration")

plt.savefig("Comparison of Change in Surface Temperature to PT-JPL-SM Evapotranspiration.jpeg", format='jpeg', bbox_inches='tight')
plt.savefig("Comparison of Change in Surface Temperature to PT-JPL-SM Evapotranspiration.svg", format='svg', bbox_inches='tight')

plt.show()

## Compute Correlation Between Input and Output Perturbations

This cell calculates the Pearson correlation coefficient between the standardized input and output perturbations, quantifying the strength of their linear relationship.

In [None]:
correlation = mstats.pearsonr(
    np.array(filtered_results.input_perturbation_std).astype(np.float64), 
    np.array(filtered_results.output_perturbation_std).astype(np.float64)
)[0]

correlation

## Run Full Sensitivity Analysis for Multiple Inputs

This cell performs a comprehensive sensitivity analysis by perturbing several input variables (surface temperature, NDVI, albedo, air temperature, relative humidity) and measuring their effect on latent heat flux. The results are summarized in a DataFrame.

In [None]:
input_variables = ["ST_C", "NDVI", "albedo", "Ta_C", "RH"]
output_variables = ["ET_daily_kg"]

perturbation_df, sensitivity_metrics_df = sensitivity_analysis(
    input_df=input_df,
    input_variables=input_variables,
    output_variables=output_variables,
    forward_process=process_verma_PTJPLSM_table,
    normalization_function=normalization_function
)

sensitivity_metrics_df

## Plot Sensitivity Magnitude Bar Chart

This cell creates a bar chart showing the average percent change in latent heat flux for each input variable, visualizing the magnitude of model sensitivity to each input. The plot is saved as JPEG and SVG files with the PT-JPL-SM label.

In [None]:
df = sensitivity_metrics_df
df = df[(df.output_variable == "ET_daily_kg") & (df.metric == "mean_normalized_change")]
ax = sns.barplot(x=df.input_variable, y=df.value * 100, color='black')
ax.set_xticklabels(["Surface\nTemperature", "NDVI", "Albedo", "Air\nTemperature", "Relative\nHumidity"])
plt.xlabel("Input Variable")
plt.ylabel("Average Percent Change in Output Perturbation")
plt.title("PT-JPL-SM Evapotranspiration Sensitivity Magnitude")
# plt.ylim(0, 160)  # Set y-axis range from 0 to 160
plt.grid(axis='y', color='lightgray', linestyle='-', linewidth=0.5)  # Add light gray horizontal gridlines only

# Add percent sign to y-axis tick labels
ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f'{int(y)}%'))

plt.savefig("PT-JPL-SM Evapotranspiration Sensitivity Magnitude.jpeg", format='jpeg', bbox_inches='tight')
plt.savefig("PT-JPL-SM Evapotranspiration Sensitivity Magnitude.svg", format='svg', bbox_inches='tight')

plt.show()

## Plot Sensitivity Correlation Bar Chart

This cell generates a bar chart showing the correlation between input perturbations and output perturbations for each input variable, highlighting which inputs most strongly influence latent heat flux. The plot is saved as JPEG and SVG files with the PT-JPL-SM label.

In [None]:
df = sensitivity_metrics_df
df = df[(df.output_variable == "ET_daily_kg") & (df.metric == "correlation")]
ax = sns.barplot(x=df.input_variable, y=df.value, color='black')
ax.set_xticklabels(["Surface\nTemperature", "NDVI", "Albedo", "Air\nTemperature", "Relative\nHumidity"])
# plt.ylim(0, 0.8)
plt.xlabel("Input Variable")
plt.ylabel("Correlation of Input Perturbation to Output Perturbation")
plt.title("PT-JPL-SM Latent Heat Flux Sensitivity Correlation")
plt.grid(axis='y')  # Add horizontal gridlines

plt.savefig("PT-JPL-SM Latent Heat Flux Sensitivity Correlation.jpeg", format='jpeg', bbox_inches='tight')
plt.savefig("PT-JPL-SM Latent Heat Flux Sensitivity Correlation.svg", format='svg', bbox_inches='tight')

plt.show()

## Summary and Next Steps

This notebook demonstrated a full sensitivity analysis workflow for the PT-JPL-SM model using ECOSTRESS Cal-Val data. Key results include the identification of input variables with the greatest influence on latent heat flux. Next steps could include further exploration of model parameters, additional visualizations, or application to other datasets.