https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html  
https://github.com/rtqichen/torchdiffeq/blob/master/examples/cnf.py

In [54]:
import torch
import numpy as np
import argparse
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torchdiffeq import odeint, odeint_adjoint
from torchdiffeq import odeint_event

torch.set_default_dtype(torch.float64)

In [64]:
class soil_moisture_flux_ode(nn.Module):
    
    def __init__(self, cfe_state=None, reservoir=None):
        super().__init__()
        self.cfe_state = cfe_state
        self.reservoir = reservoir

    def forward(self, t, states):
        
        S = states[0]
            
        storage_above_threshold_m = S - reservoir['storage_threshold_primary_m']
        storage_diff = reservoir['storage_max_m'] - reservoir['storage_threshold_primary_m']
        storage_ratio = torch.minimum(storage_above_threshold_m / storage_diff, torch.tensor([1.0]))

        perc_lat_switch = torch.multiply(S - reservoir['storage_threshold_primary_m'] > 0, 1)
        ET_switch = torch.multiply(S - reservoir['wilting_point_m'] > 0, 1)

        storage_above_threshold_m_paw = S - reservoir['wilting_point_m']
        storage_diff_paw = reservoir['storage_threshold_primary_m'] - reservoir['wilting_point_m']
        storage_ratio_paw = torch.minimum(storage_above_threshold_m_paw / storage_diff_paw, torch.tensor([0.3])) # Equation 11 (Ogden's document)
        dS_dt = cfe_state['infiltration_depth_m'] -1 * perc_lat_switch * (reservoir['coeff_primary'] + reservoir['coeff_secondary']) * storage_ratio - ET_switch * cfe_state['reduced_potential_et_m_per_timestep'] * storage_ratio_paw
        
        return (dS_dt)

In [65]:

# Initialization
y0 = torch.tensor([0.3])
t = torch.tensor([0, 0.05, 0.15, 0.3, 0.6, 1.0]) # ODE time descritization of one time step
cfe_state = {}
cfe_state['infiltration_depth_m'] = torch.tensor([0.1])
cfe_state['reduced_potential_et_m_per_timestep'] = torch.tensor([0.003])
reservoir = {}
reservoir['storage_threshold_primary_m'] = torch.tensor([0.2])
reservoir['storage_max_m'] = torch.tensor([0.4])
reservoir['wilting_point_m'] = torch.tensor([0.1])
reservoir['coeff_primary'] = torch.tensor([0.4])
reservoir['coeff_secondary'] = torch.tensor([0.4])

# Pass parameters beforehand
device = 'cpu'
func = soil_moisture_flux_ode(cfe_state=cfe_state, reservoir=reservoir).to(device)

# Solve and ODE
sol = odeint(
    func,
    y0,
    t,
    atol=1e-5,
    rtol=1e-5,
    # adjoint_params=()
)



In [85]:
sol.shape

torch.Size([6, 1])

In [93]:
# Finalize results
ts_concat = t
ys_concat = sol.squeeze()
t_proportion = torch.diff(ts_concat, dim=0) # ts_concat[1:] - ts_concat[:-1]
ys_concat

tensor([0.3000, 0.2864, 0.2661, 0.2474, 0.2316, 0.2262])

In [112]:
import torch.nn.functional as F

# Create the kernel tensor with torch.ones
kernel = torch.ones(2)

# Get the moving average y values in between the time intervals
convolved = F.conv1d(ys_concat.unsqueeze(0).unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), padding=1).squeeze()
# Divide by 2 to match np.convolve
ys_avg_ = convolved / 2
ys_avg = ys_avg_[1:-1]
ys_avg
# Original numpy method
# ys_avg = np.convolve(ys_concat, np.ones(2), 'valid') / 2

tensor([0.2932, 0.2762, 0.2567, 0.2395, 0.2289])

In [111]:
import math
lateral_flux = torch.zeros(ys_avg.shape)
perc_lat_switch = ys_avg - reservoir['storage_threshold_primary_m'] > 0
lateral_flux[perc_lat_switch] = reservoir['coeff_secondary'] * torch.minimum(
    (ys_avg[perc_lat_switch] - reservoir['storage_threshold_primary_m']) / (
                reservoir['storage_max_m'] - reservoir['storage_threshold_primary_m']), torch.tensor([1.0]))
lateral_flux_frac = lateral_flux * t_proportion

perc_flux = torch.zeros(ys_avg.shape)
perc_flux[perc_lat_switch] = reservoir['coeff_primary'] * torch.minimum(
    (ys_avg[perc_lat_switch] - reservoir['storage_threshold_primary_m']) / (
                reservoir['storage_max_m'] - reservoir['storage_threshold_primary_m']), torch.tensor([1.0]))
perc_flux_frac = perc_flux * t_proportion

et_from_soil = torch.zeros(ys_avg.shape)
ET_switch = ys_avg - reservoir['wilting_point_m'] > 0
et_from_soil[ET_switch] = cfe_state['reduced_potential_et_m_per_timestep'] * torch.minimum(
    (ys_avg[ET_switch] - reservoir['wilting_point_m']) / (reservoir['storage_threshold_primary_m'] - reservoir['wilting_point_m']), torch.tensor([1.0]))
et_from_soil_frac = et_from_soil * t_proportion

infilt_to_soil = cfe_state['infiltration_depth_m'].repeat(ys_avg.shape)
infilt_to_soil_frac = infilt_to_soil * t_proportion

# Scale fluxes (Since the sum of all the estimated flux above usually exceed the input flux because of calculation errors, scale it
# The more finer ODE time descritization you use, the less errors you get, but the more calculation time it takes 
sum_outflux = lateral_flux_frac + perc_flux_frac + et_from_soil_frac
if sum_outflux.any() == 0:
    flux_scale = 0
else:
    flux_scale = torch.zeros(infilt_to_soil_frac.shape)
    flux_scale[sum_outflux != 0] = (torch.diff(-ys_concat, dim=0)[sum_outflux != 0] + infilt_to_soil_frac[
        sum_outflux != 0]) / sum_outflux[sum_outflux != 0]
    flux_scale[sum_outflux == 0] = 0
scaled_lateral_flux = lateral_flux_frac * flux_scale
scaled_perc_flux = perc_flux_frac * flux_scale
scaled_et_flux = et_from_soil_frac * flux_scale

# Pass the results
# ? Do these all gets tracked? 
primary_flux_m = math.fsum(scaled_perc_flux)
secondary_flux_m = math.fsum(scaled_lateral_flux)
actual_et_from_soil_m_per_timestep = math.fsum(scaled_et_flux)
reservoir['storage_m'] = ys_concat[-1]

print(f'primary_flux_m: {primary_flux_m}')
print(f'secondary_flux_m: {secondary_flux_m}')
print(f'actual_et_from_soil_m_per_timestep: {actual_et_from_soil_m_per_timestep}')
print(f'reservoir["storage_m"]: {reservoir["storage_m"]}')

# cfe_state.primary_flux_m = math.fsum(scaled_perc_flux)
# cfe_state.secondary_flux_m = math.fsum(scaled_lateral_flux)
# cfe_state.actual_et_from_soil_m_per_timestep = math.fsum(scaled_et_flux)
# # reservoir['storage_m'] = ys_concat[-1]
# cfe_state.soil_reservoir['storage_m'] = ys_concat[-1]

primary_flux_m: 0.08548009648611235
secondary_flux_m: 0.08548009648611235
actual_et_from_soil_m_per_timestep: 0.0028869109589852553
reservoir["storage_m"]: 0.22615289606879002
