In [1]:
import random
import time
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from scipy.optimize import bisect
from torchdiffeq import odeint

from lib import AnalyticRetardation, load_data, Initialize, Flux_Kernels

In [2]:
def caps_calculation(network_preds: dict[str, Any], c_up, c_down, Y, verbose=0):
    """Caps calculations for single quantile"""

    if verbose > 0:
        print("--- Start caps calculations for SINGLE quantile ---")
        print("**************** For Training data *****************")

    if len(Y.shape) == 2:
        Y = Y.flatten()

    bound_up = (network_preds["mean"] + c_up * network_preds["up"]).numpy().flatten()
    bound_down = (
        (network_preds["mean"] - c_down * network_preds["down"]).numpy().flatten()
    )

    y_U_cap = bound_up > Y  # y_U_cap
    y_L_cap = bound_down < Y  # y_L_cap

    y_all_cap = np.logical_or(y_U_cap, y_L_cap)  # y_all_cap
    PICP = np.count_nonzero(y_all_cap) / y_L_cap.shape[0]  # 0-1
    MPIW = np.mean(
        (network_preds["mean"] + c_up * network_preds["up"]).numpy().flatten()
        - (network_preds["mean"] - c_down * network_preds["down"]).numpy().flatten()
    )
    if verbose > 0:
        print(f"Num of train in y_U_cap: {np.count_nonzero(y_U_cap)}")
        print(f"Num of train in y_L_cap: {np.count_nonzero(y_L_cap)}")
        print(f"Num of train in y_all_cap: {np.count_nonzero(y_all_cap)}")
        print(f"np.sum results(train): {np.sum(y_all_cap)}")
        print(f"PICP: {PICP}")
        print(f"MPIW: {MPIW}")

    return (
        PICP,
        MPIW,
    )


def optimize_bound(
    *,
    mode: str,
    y_train: np.ndarray,
    pred_mean: np.ndarray,
    pred_std: np.ndarray,
    num_outliers: int,
    c0: float = 0.0,
    c1: float = 1e5,
    maxiter: int = 1000,
    verbose=0,
):
    def count_exceeding_upper_bound(c: float):
        bound = pred_mean + c * pred_std
        f = np.count_nonzero(y_train >= bound) - num_outliers
        return f

    def count_exceeding_lower_bound(c: float):
        bound = pred_mean - c * pred_std
        f = np.count_nonzero(y_train <= bound) - num_outliers
        return f

    objective_function = (
        count_exceeding_upper_bound if mode == "up" else count_exceeding_lower_bound
    )

    if verbose > 0:
        print(f"Initial bounds: [{c0}, {c1}]")

    try:
        optimal_c = bisect(objective_function, c0, c1, maxiter=maxiter)
        if verbose > 0:
            final_count = objective_function(optimal_c)
            print(f"Optimal c: {optimal_c}, Final count: {final_count}")
        return optimal_c
    except ValueError as e:
        if verbose > 0:
            print(f"Bisect method failed: {e}")
        raise e


def compute_boundary_factors(
    *, y_train: np.ndarray, network_preds: dict[str, Any], quantile: float, verbose=0
):
    n_train = y_train.shape[0]
    num_outlier = int(n_train * (1 - quantile) / 2)

    if verbose > 0:
        print(
            "--- Start boundary optimizations for SINGLE quantile: {}".format(quantile)
        )
        print(
            "--- Number of outlier based on the defined quantile: {}".format(
                num_outlier
            )
        )

    c_up, c_down = [
        optimize_bound(
            y_train=y_train,
            pred_mean=network_preds["mean"],
            pred_std=network_preds[mode],
            mode=mode,
            num_outliers=num_outlier,
        )
        for mode in ["up", "down"]
    ]

    if verbose > 0:
        print("--- c_up: {}".format(c_up))
        print("--- c_down: {}".format(c_down))

    return c_up, c_down


def create_PI_training_data(
    network_mean, X, Y
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
    """Generate up and down training data"""
    threshold = 40  # TODO
    with torch.no_grad():
        Y_pred = network_mean(X)
        diff_train = torch.sum((Y - Y_pred) ** 2, dim=[1, 2, 3])
        up_idx = diff_train > threshold
        down_idx = diff_train < threshold

        X_up = X[up_idx]
        Y_up = diff_train[up_idx]

        X_down = X[down_idx]
        Y_down = -1.0 * diff_train[down_idx]
        print(X_down.shape, X_up.shape)
        print(Y_down.shape, Y_up.shape)

    return ((X_up, Y_up), (X_down, Y_down))


def eval_networks(
    networks: dict[str, Any], x, as_numpy: bool = False
) -> dict[str, Any]:
    with torch.no_grad():
        d = {k: network(x) for k, network in networks.items()}
    if as_numpy:
        d = {k: v.numpy() for k, v in d.items()}
    return d


class UQ_Net_mean(nn.Module):
    def __init__(
        self, num_neurons: list[int], num_inputs, num_outputs, activation="relu"
    ):
        super(UQ_Net_mean, self).__init__()
        self.activation_fun = torch.relu if activation == "relu" else torch.tanh

        self.inputLayer = nn.Linear(num_inputs, num_neurons[0])
        self.fcs = nn.ModuleList()
        for i in range(len(num_neurons) - 1):
            self.fcs.append(nn.Linear(num_neurons[i], num_neurons[i + 1]))
        self.outputLayer = nn.Linear(num_neurons[-1], num_outputs)

        # Initialize weights with a mean of 0.1 and stddev of 0.1
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.1, std=0.1)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.activation_fun(self.inputLayer(x))
        for i in range(len(self.fcs)):
            x = self.activation_fun(self.fcs[i](x))
        x = self.outputLayer(x)
        # TODO: Maybe use sigmoid here since we learn inv ret
        return x


class UQ_Net_std(nn.Module):
    def __init__(
        self,
        num_neurons: list[int],
        num_inputs,
        num_outputs,
        net=None,
        bias=None,
        activation="relu",
    ):
        super(UQ_Net_std, self).__init__()

        self.activation_fun = torch.relu if activation == "relu" else torch.tanh
        self.inputLayer = nn.Linear(num_inputs, num_neurons[0])
        self.fcs = nn.ModuleList()
        for i in range(len(num_neurons) - 1):
            self.fcs.append(nn.Linear(num_neurons[i], num_neurons[i + 1]))
        self.outputLayer = nn.Linear(num_neurons[-1], num_outputs)

        # Initialize weights with a mean of 0.1 and stddev of 0.1
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.1, std=0.1)
                nn.init.zeros_(m.bias)

        # Custom bias
        if bias is None:
            self.custom_bias = torch.nn.Parameter(torch.tensor([3.0]))
        else:
            self.custom_bias = torch.nn.Parameter(torch.tensor([bias]))

    def forward(self, x):
        x = self.activation_fun(self.inputLayer(x))
        for i in range(len(self.fcs)):
            x = self.activation_fun(self.fcs[i](x))
        x = self.outputLayer(x)
        x = x + self.custom_bias
        x = torch.sqrt(torch.square(x) + 0.2)
        return x

In [3]:
class ConcentrationPredictor(nn.Module):
    def __init__(self, u0: torch.Tensor, cfg: Initialize, ret_inv_funs=None):
        """TODO: Docstring

        Args:
            u0 (tensor): initial condition, dim: [num_features, Nx]
            cfg (_type_): _description_
        """
        super(ConcentrationPredictor, self).__init__()
        if ret_inv_funs is None:
            ret_inv_funs = [None] * len(u0)

        self.cfg = cfg
        self.u0 = u0
        self.dudt_fun = ConcentrationChangeRatePredictor(
            u0, cfg, ret_inv_funs=ret_inv_funs
        )

    def forward(self, t):
        """Predict the concentration profile at given time steps from an initial condition using the FINN method.

        Args:
            t (tensor): time steps

        Returns:
            tensor: Full field solution of concentration at given time steps.
        """

        ode_pred = odeint(self.dudt_fun, self.u0, t, rtol=1e-5, atol=1e-6)
        return ode_pred

    def run_training(self, t: torch.Tensor, u_full_train: torch.Tensor):
        """Train to predict the concentration from the given full field training data.

        Args:

            t (tensor): time steps for integration, dim: [Nt,]
            x_train (tensor): full field solution at each time step, dim: [Nt, num_features, Nx]
        """
        out_dir = Path("data_out")
        out_dir.mkdir(exist_ok=True, parents=True)

        optimizer = torch.optim.LBFGS(self.parameters(), lr=0.1)

        u_ret = torch.linspace(0.0, 1.0, 100).view(-1, 1).to(self.cfg.device)
        # TODO: Should not be here
        ret_linear = AnalyticRetardation.linear(
            u_ret, por=self.cfg.por, rho_s=self.cfg.rho_s, Kd=self.cfg.Kd
        )
        ret_freundlich = AnalyticRetardation.freundlich(
            u_ret,
            por=self.cfg.por,
            rho_s=self.cfg.rho_s,
            Kf=self.cfg.Kf,
            nf=self.cfg.nf,
        )
        ret_langmuir = AnalyticRetardation.langmuir(
            u_ret,
            por=self.cfg.por,
            rho_s=self.cfg.rho_s,
            smax=self.cfg.smax,
            Kl=self.cfg.Kl,
        )
        np.save(out_dir / "u_ret.npy", u_ret)
        np.save(out_dir / "retardation_linear.npy", ret_linear)
        np.save(out_dir / "retardation_freundlich.npy", ret_freundlich)
        np.save(out_dir / "retardation_langmuir.npy", ret_langmuir)

        # Define the closure function that consists of resetting the
        # gradient buffer, loss function calculation, and backpropagation
        # The closure function is necessary for LBFGS optimizer, because
        # it requires multiple function evaluations
        # The closure function returns the loss value
        def closure():
            self.train()
            optimizer.zero_grad()
            ode_pred = self.forward(t)  # aka. y_pred
            # TODO: mean instead of sum?
            loss = self.cfg.error_mult * torch.sum((u_full_train - ode_pred) ** 2)

            # Physical regularization: value of the retardation factor should decrease with increasing concentration
            ret_inv_pred = self.retardation_inv_scaled(u_ret)
            loss += self.cfg.phys_mult * torch.sum(
                torch.relu(ret_inv_pred[:-1] - ret_inv_pred[1:])
            )  # TODO: mean instead of sum?

            loss.backward()

            return loss

        # Iterate until maximum epoch number is reached
        for epoch in range(1, self.cfg.epochs + 1):
            dt = time.time()
            optimizer.step(closure)
            loss = closure()
            dt = time.time() - dt

            print(
                f"Training: Epoch [{epoch + 1}/{self.cfg.epochs}], "
                f"Training Loss: {loss.item():.4f}, Runtime: {dt:.4f} secs"
            )

            ret_pred_path = self.cfg.model_path / f"retPred_{epoch}.npy"
            np.save(ret_pred_path, self.retardation(u_ret).detach().numpy())

    def retardation_inv_scaled(self, u):
        return self.dudt_fun.flux_modules[0].ret_inv_fun(u)

    def retardation(self, u):
        return (
            1.0
            / self.dudt_fun.flux_modules[0].ret_inv_fun(u)
            / 10 ** self.dudt_fun.flux_modules[0].p_exp
        )


class ConcentrationChangeRatePredictor(nn.Module):
    def __init__(self, u0, cfg, ret_inv_funs=None):
        """
        Constructor
        Inputs:
            u0      : initial condition, dim: [num_features, Nx]
            cfg     : configuration object of the model setup, containing boundary condition types, values, learnable parameter settings, etc.
        """
        if ret_inv_funs is None:
            ret_inv_funs = [None] * len(u0)

        super(ConcentrationChangeRatePredictor, self).__init__()

        self.flux_modules = nn.ModuleList()
        self.num_vars = u0.size(0)
        self.cfg = cfg

        # Create flux kernel for each variable to be calculated
        for var_idx in range(self.num_vars):
            self.flux_modules.append(
                Flux_Kernels(
                    u0[var_idx], self.cfg, var_idx, ret_inv_fun=ret_inv_funs[var_idx]
                )
            )

    def forward(self, t, u):
        """Computes du/dt to be put into the ODE solver

        Args:
            t (float): time point
            u (tensor): the unknown variables to be calculated taken from the previous time step, dim: [num_features, Nx]

        Returns:
            tensor: the time derivative of u (du/dt), dim: [num_features, Nx]
        """
        flux = []

        # Use flux and state kernels to calculate du/dt for all unknown variables
        for var_idx in range(self.num_vars):
            flux.append(
                self.flux_modules[var_idx](
                    u[self.cfg.flux_calc_idx[var_idx]],
                    u[self.cfg.flux_couple_idx[var_idx]],
                    t,
                )
            )

        du = torch.stack(flux)

        return du

In [4]:
configs = {}
configs["quantile"] = 0.95
configs["Max_iter"] = 5000
SEED = 10
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


cfg = Initialize()
u0 = torch.zeros(cfg.num_vars, cfg.Nx, 1)
net_mean = ConcentrationPredictor(
    u0=u0.clone(),
    cfg=cfg,
    ret_inv_funs=[
        (
            UQ_Net_mean(
                num_neurons=[15, 15, 15],
                num_inputs=1,
                num_outputs=1,
                activation="tanh",
            ).to(cfg.device)
            if is_fun
            else None
        )
        for is_fun in cfg.is_retardation_a_func
    ],
)

# Train the model
concentration_data = load_data(cfg)
t = torch.linspace(0.0, cfg.T, cfg.Nt)
# x = torch.linspace(0.0, cfg.X, cfg.Nx)

train_split_index = 51
x_train = t[:train_split_index]
y_train = concentration_data[:train_split_index]
x_valid = t[train_split_index:]
y_valid = concentration_data[train_split_index:]

In [5]:
# _ = create_PI_training_data(net_mean, X=x_train, Y=y_train)

### Train Mean Network

In [6]:
net_mean.run_training(x_train, y_train)

Training: Epoch [2/10], Training Loss: 1953.3463, Runtime: 0.8493 secs
Training: Epoch [3/10], Training Loss: 1953.3463, Runtime: 0.7534 secs
Training: Epoch [4/10], Training Loss: 1953.3463, Runtime: 0.7622 secs
Training: Epoch [5/10], Training Loss: 1953.3463, Runtime: 0.8253 secs
Training: Epoch [6/10], Training Loss: 1953.3463, Runtime: 0.7721 secs
Training: Epoch [7/10], Training Loss: 1953.3463, Runtime: 0.7673 secs
Training: Epoch [8/10], Training Loss: 1953.3463, Runtime: 0.8218 secs
Training: Epoch [9/10], Training Loss: 1953.3463, Runtime: 0.7712 secs
Training: Epoch [10/10], Training Loss: 1953.3463, Runtime: 0.7689 secs
Training: Epoch [11/10], Training Loss: 1953.3463, Runtime: 0.8819 secs


### Create Up/Down Training Data

In [7]:
data_train_up, data_train_down = create_PI_training_data(net_mean, X=x_train, Y=y_train)
# TODO: Not used yet
# data_val_up, data_val_down = create_PI_training_data(
#     net_mean, X=x_valid, Y=y_valid
# )

torch.Size([24]) torch.Size([27])
torch.Size([24]) torch.Size([27])


### Train Up/Down Networks

In [8]:
net_up = ConcentrationPredictor(
    u0=u0.clone(),
    cfg=cfg,
    ret_inv_funs=[
        (
            UQ_Net_std(
                num_neurons=[15, 15, 15],
                num_inputs=1,
                num_outputs=1,
                net="up",
                activation="tanh",
            ).to(cfg.device)
            if is_fun
            else None
        )
        for is_fun in cfg.is_retardation_a_func
    ],
)
net_down = ConcentrationPredictor(
    u0=u0.clone(),
    cfg=cfg,
    ret_inv_funs=[
        (
            UQ_Net_std(
                num_neurons=[15, 15, 15],
                num_inputs=1,
                num_outputs=1,
                net="down",
                activation="tanh",
            ).to(cfg.device)
            if is_fun
            else None
        )
        for is_fun in cfg.is_retardation_a_func
    ],
)

In [9]:
net_up.run_training(*data_train_up)
net_down.run_training(*data_train_down)

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
networks = {
    "mean": net_mean,
    "up": net_up,
    "down": net_down,
}

c_up, c_down = compute_boundary_factors(
    y_train=y_train.numpy(),
    network_preds=eval_networks(networks, x_train, as_numpy=True),
    quantile=configs["quantile"],
    verbose=1,
)

pred_train = eval_networks(networks, x_train)
pred_valid = eval_networks(networks, x_valid)

PICP_train, MPIW_train = caps_calculation(pred_train, c_up, c_down, y_train.numpy())
PICP_valid, MPIW_valid = caps_calculation(pred_valid, c_up, c_down, y_valid.numpy())

fig, ax = plt.subplots()
ax.plot(x_train, y_train, ".")
y_U_PI_array_train = (pred_train["mean"] + c_up * pred_train["up"]).numpy().flatten()
y_L_PI_array_train = (
    (pred_train["mean"] - c_down * pred_train["down"]).numpy().flatten()
)
y_mean = pred_train["mean"].numpy().flatten()
sort_indices = np.argsort(x_train.flatten())
ax.plot(x_train.flatten()[sort_indices], y_mean[sort_indices], "-")
ax.plot(x_train.flatten()[sort_indices], y_U_PI_array_train[sort_indices], "-")
ax.plot(x_train.flatten()[sort_indices], y_L_PI_array_train[sort_indices], "-")
plt.show()