In [None]:
from helpers.util_functions import central_difference
from lodegp.LODEGP import LODEGP
import torch
import gpytorch


In [None]:
train_x = torch.linspace(1r, 2r, 1r)
noise_level = 0.0r
y0_func = lambda x: float(781/8000)*torch.sin(x)/x - float(1/20)*torch.cos(x)/x**2 + float(1/20)*torch.sin(x)/x**3
y1_func = lambda x: float(881/8000)*torch.sin(x)/x - float(1/40)*torch.cos(x)/x**2 + float(1/40)*torch.sin(x)/x**3
y2_func = lambda x: float(688061/800000)*torch.sin(x)/x - float(2543/4000)*torch.cos(x)/x**2 + float(1743/4000)*torch.sin(x)/x**3 - float(3/5)*torch.cos(x)/x**4 + float(3/5)*torch.sin(x)/x**5 
y0 = y0_func(train_x) 
y1 = y1_func(train_x)
y2 = y2_func(train_x)
y0 = y0 + torch.randn_like(train_x)*(torch.max(y0)*noise_level)
y1 = y1 + torch.randn_like(train_x)*(torch.max(y1)*noise_level)
y2 = y2 + torch.randn_like(train_x)*(torch.max(y2)*noise_level)
train_y = torch.stack([y0, y1, y2], dim=-1r)



In [None]:
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=3r)
model = LODEGP(train_x, train_y, likelihood, 3, ODE_name="Heating")



In [None]:
def extract_differential_polynomial_terms(expr, diff_var, var_dict):
    """
    Parses a polynomial expression in the differential operator variable (e.g., x),
    and returns a dict mapping derivative order to coefficient.

    Parameters:
    - expr: Sage symbolic expression (e.g., x^2 + 981/100)
    - diff_var: the Sage variable representing the differential operator (e.g., x)

    Returns:
    - dict: {order: coefficient}  (e.g., {0: 981/100, 2: 1})
    """
    expr = sage_eval(str(expr), locals=var_dict)
    if type(expr) is not sage.symbolic.expression.Expression:
        return {0: expr}
    terms = expr.coefficients(diff_var)
    result = {}

    for term in terms:
        # each term is a tuple [coefficient, degree]
        coeff = term[0]
        degree = term[1]
        result[degree] = coeff

    return result


In [None]:
# Verify that the functions satisfy the given differential equation
def calculate_differential_equation_error_symbolic(functions, differential_eq, sage_locals, **kwargs):
    # We know we that the channel count is equal to the number of tasks
    dx = kwargs.get("diff_var", var("x"))
    differential_equation_error = 0
    for i, column in enumerate(differential_eq):
        # Each channel contains the polynom of differentials that is used on the respective channel
        # Dictionary of the form {order: coeff}
        coeff_dict = extract_differential_polynomial_terms(column, dx, sage_locals)
        for order, coeff in coeff_dict.items():
            differential_equation_error += coeff*functions[i].diff(dx, int(order))
    return differential_equation_error


In [None]:
# Verify that the given data satisfies the given differential equation
def calculate_differential_equation_error_numeric(differential_eq, sage_locals, data_generating_functions, data, **kwargs):
    dx = kwargs.get("diff_var", var("x"))
    locals_values = kwargs.get("locals_values", {sage_locals[var_name] : 1.0 for var_name in sage_locals})
    # We know we that the channel count is equal to the number of tasks
    channel_values = [[] for _ in range(len(differential_eq))]
    differential_equation_error = None 
    functions = data_generating_functions
    for i, column in enumerate(differential_eq):
        # Each channel contains the polynom of differentials that is used on the respective channel
        # Dictionary of the form {order: coeff}
        coeff_dict = extract_differential_polynomial_terms(column, dx, sage_locals)
        for order, coeff in coeff_dict.items():
            try:
                coeff = coeff.subs(locals_values)
                diff_approx = central_difference(functions, data, order=order)[:, i]
                if differential_equation_error is None:
                    differential_equation_error = float(coeff)*diff_approx
                else:
                    differential_equation_error += float(coeff)*diff_approx
            except Exception as e:
                print(coeff)
                print(e)
    return differential_equation_error


# Verify that the models output satisfies the given differential equation
target_row = 0
target_col = 1
model.eval()
likelihood.eval()
model_mean_generator = lambda x: model(x).mean
#locals_values = {var("a") : model.model_parameters["a"].item(), var("b") : model.model_parameters["b"].item(), var("signal_variance_2") : model.model_parameters["signal_variance_2"].item(), var("lengthscale_2") : model.model_parameters["lengthscale_2"].item()}
#locals_values = {name: torch.exp(torch.tensor(locals_values[name])).item() for name in locals_values}
locals_values = model.prepare_numeric_ode_satisfaction_check()
calculate_differential_equation_error_numeric(model.A[target_row], model.sage_locals, model_mean_generator, train_x, locals_values=locals_values)

# Verify that the symbolic covariance functions satisfy the given differential equation
model_diffed_kernel_col = model.prepare_symbolic_ode_satisfaction_check(target_col) 
diff_var = var("t1")
var("x")
differential_equation = [term.substitute(x=diff_var) for term in sage_eval(str(model.A[target_row]), locals=model.sage_locals)]
calculate_differential_equation_error_symbolic(model_diffed_kernel_col, differential_equation, model.sage_locals, diff_var=diff_var)(t1=1, t2=1, signal_variance_2=1.0, lengthscale_2=1.0, a=1.0, b=1.0)

In [None]:
model.model_parameters["a"]

In [None]:
model.sage_locals

In [None]:
model = LODEGP(train_x, train_y, ODE_name="Bipendulum", likelihood=likelihood, num_tasks=3r)
var("x")
for row in model.A:
    for column in row:
        print(column)
        print(extract_differential_polynomial_terms(column, x, model.sage_locals))