#**Drug_model_nn_JAX**

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib as mpl
import time
import pandas as pd
# import jaxopt
from scipy.integrate import odeint


def drug_model(
    t,
    kg = 0.72,
    kb = 0.15,
    G0 = 0.1,
):
    def func(y, t):
        G, B, U = y[0], y[1], y[2]

        return [
            - kg * G,
            kg * G - kb * B,
            kb * B,
            ]

    y0 = [G0, 0, 0]
    return odeint(func, y0, t)

t_dense = jnp.linspace(0, 50, 501)[:, None]
y_dense = drug_model(np.ravel(t_dense))

In [3]:
#**10 data points**

import pandas as pd
import matplotlib.pyplot as plt


# -------Load the CSV file into a Pandas DataFrame
df_10 = pd.read_csv("pred_10.csv")
# t = df_10['t']
ft_10 = df_10['ft']

df_20 = pd.read_csv("pred_20.csv")
ft_20 = df_20['ft']

df_50 = pd.read_csv("pred_50.csv")
ft_50 = df_50['ft']

df_100 = pd.read_csv("pred_100.csv")
ft_100 = df_100['ft']
#-------------------------------------------------
kg = 0.72
kb = 0.15
G0 = 0.1

f_t_analytical = kg * y_dense[:, 0] - kb * y_dense[:, 1]

#-------------------------------------------------

def compute_errors(true_values, predicted_values):
    # Compute Mean Absolute Error (MAE)
    mae = np.abs(true_values - predicted_values).mean()

    # Compute Root Mean Squared Error (RMSE)
    rmse = np.sqrt(((true_values - predicted_values) ** 2).mean())

    # Compute Relative Error (RE)
    numerator = np.sqrt(np.sum((true_values - predicted_values) ** 2))
    denominator = np.sqrt(np.sum(true_values ** 2))
    re = numerator / denominator

    return mae, rmse, re

# Extract the true values and predicted values from the CSV data

# Calculate L1, L2, L∞, and L2 relative errors

print("_____________10 data points______________")

mae, rmse, re  = compute_errors(f_t_analytical, ft_10)
print(f"(MAE): {mae:.2e}")
print(f"(RMSE): {rmse:.2e}")
print(f"(RE): {re:.2e}")

print("_____________20 data points______________")

mae, rmse, re  = compute_errors(f_t_analytical, ft_20)
print(f"(MAE): {mae:.2e}")
print(f"(RMSE): {rmse:.2e}")
print(f"(RE): {re:.2e}")

print("_____________50 data points______________")

mae, rmse, re  = compute_errors(f_t_analytical, ft_50)
print(f"(MAE): {mae:.2e}")
print(f"(RMSE): {rmse:.2e}")
print(f"(RE): {re:.2e}")

print("_____________100 data points_____________")

mae, rmse, re  = compute_errors(f_t_analytical, ft_100)
print(f"(MAE): {mae:.2e}")
print(f"(RMSE): {rmse:.2e}")
print(f"(RE): {re:.2e}")

_____________10 data points______________
(MAE): 1.26e-04
(RMSE): 5.57e-04
(RE): 6.92e-02
_____________20 data points______________
(MAE): 1.09e-04
(RMSE): 4.82e-04
(RE): 5.99e-02
_____________50 data points______________
(MAE): 6.59e-05
(RMSE): 2.26e-04
(RE): 2.80e-02
_____________100 data points_____________
(MAE): 6.54e-05
(RMSE): 1.84e-04
(RE): 2.29e-02
