In [1]:
# import all the modules
import sympy
import argparse
import numpy as np
import abc
from scipy.interpolate import BSpline

import equations
import data
from gp_utils import run_gp_ode
import pickle
import os
import time
import basis
from sympy import *

from sklearn.linear_model import LassoCV
from scipy.linalg import lstsq
from tvregdiff.tvregdiff import TVRegDiff
import gppca
from config import get_interpolation_config
from integrate import generate_grid
from derivative import dxdt


In [2]:
# Find the n'th expansion of the function f(x)
def function_expansion(derivative):
    x = symbols('x')
    f = sympy.Function("f")
    expr = f(x)**derivative
    print(expr.diff(x, derivative))


# Noise free updated
def get_ode_data_noise_free_updated(yt, x_id, dg, ode, derivative):
    t = dg.solver.t
    freq = dg.freq
    t_new = t

    weight = np.ones_like(t_new)
    weight[0] /= 2
    weight[-1] /= 2
    weight = weight / weight.sum() * dg.T

    X_sample = yt
    config = get_interpolation_config(ode, 0)
    n_basis = config['n_basis']
    basis = config['basis']

    basis_func = basis(dg.T, n_basis)

    derivative_list =  []

    # append all the derivatives to one list
    for i in range(derivative+1):
        derivative_list.append(basis_func.design_matrix(t_new, i))
    
    Xi = X_sample[:, :, x_id]

    # now compute c
    g = derivative_list[0]
    g_dot1 = derivative_list[1]
    g_dot2 = derivative_list[2]
    g_dot3 = derivative_list[3]
    
    # using the 3rd order expansion
    # if you want to use the first order just comment out everything after @ and use g_dot1
    c = derivative*(Xi * weight[:, None]).T @ (g*g*g_dot3 + 6*g*g_dot1*g_dot2 + 2*g_dot1**3)
    
    ode_data = {
        'x_hat': X_sample,
        'g': g,
        'c': c,
        'weights': weight
    }
    X_ph = np.zeros((X_sample.shape[1], X_sample.shape[2]))
    y_ph = np.zeros(X_sample.shape[1])

    return ode_data, X_ph, y_ph, t_new


# Noise updated
def get_ode_data_updated(yt, x_id, dg, ode, config_n_basis, config_basis, derivative):
    t = dg.solver.t
    noise_sigma = dg.noise_sigma
    freq = dg.freq

    if noise_sigma == 0:
        return get_ode_data_noise_free_updated(yt, x_id, dg, ode, derivative)

    X_sample_list = list()
    pca_list = []
    # for each dimension
    assert yt.shape[-1] > 0
    for d in range(yt.shape[-1]):
        config = get_interpolation_config(ode, d)
        Y = yt[:, :, d]
        r = config['r']
        if r < 0:
            r = Y.shape[1]

        if 'sigma_in_mul' in config.keys():
            sigma_in_mul = config['sigma_in_mul']
            sigma_in = sigma_in_mul / freq
        else:
            sigma_in = config['sigma_in']
        freq_int = config['freq_int']

        pca = gppca.GPPCA0(r, Y, t, noise_sigma, sigma_out=ode.std_base, sigma_in=sigma_in)

        t_new, weight = generate_grid(dg.T, freq_int)
        X_sample = pca.get_predictive(new_sample=1, t_new=t_new)
        X_sample = X_sample.reshape(len(t_new), X_sample.size // len(t_new), 1)
        X_sample_list.append(X_sample)
        pca_list.append(pca)

    X_sample = np.concatenate(X_sample_list, axis=-1)
    # check smaller than zero
    if ode.positive:
        X_sample[X_sample <= 1e-6] = 1e-6
    config = get_interpolation_config(ode, x_id)
    if config_n_basis is None:
        n_basis = config['n_basis']
    else:
        n_basis = config_n_basis
    if config_basis is None:
        basis = config['basis']
    else:
        basis = config_basis
    
    basis_func = basis(dg.T, n_basis)

    # compute c using a much larger grid
    t_new_c, weight_c = generate_grid(dg.T, 1000)
    Xi = pca_list[x_id].get_predictive(new_sample=1, t_new=t_new_c)
    Xi = Xi.reshape(len(t_new_c), Xi.size // len(t_new_c))

    # get all the derivatives
    derivative_list =  []

    for i in range(derivative+1):
        derivative_list.append(basis_func.design_matrix(t_new_c, i))
    
    # now compute c
    g = derivative_list[0]
    g_dot1 = derivative_list[1]
    g_dot2 = derivative_list[2]
    g_dot3 = derivative_list[3]
    
    # using the 3rd order expansion
    # if you want to use the first order just comment out everything after @ and use g_dot1
    c = derivative*(Xi * weight_c[:, None]).T @ (g*g*g_dot3 + 6*g*g_dot1*g_dot2 + 2*g_dot1**3)

    ode_data = {
        'x_hat': X_sample,
        'g': g,
        'c': c,
        'weights': weight
    }
    X_ph = np.zeros((X_sample.shape[1], X_sample.shape[2]))
    y_ph = np.zeros(X_sample.shape[1])

    return ode_data, X_ph, y_ph, t_new

In [3]:
# main function
def run(ode_name, ode_param, x_id, freq, n_sample, noise_ratio, seed, n_seed, n_basis, basis_str, derivative):
    np.random.seed(999)

    ode = equations.get_ode(ode_name, ode_param)
    T = ode.T
    init_low = 0
    init_high = ode.init_high

    if basis_str == 'sine':
        basis_obj = basis.FourierBasis
    else:
        basis_obj = basis.CubicSplineBasis

    noise_sigma = ode.std_base * noise_ratio

    dg = data.DataGenerator(ode, T, freq, n_sample, noise_sigma, init_low, init_high)
    yt = dg.generate_data()

	# just pass in the derivative
    ode_data, X_ph, y_ph, t_new = get_ode_data_updated(yt, x_id, dg, ode, n_basis, basis_obj, derivative)

    for s in range(seed, seed+n_seed):
        print(' ')
        print('Running with seed {}'.format(s))
        start = time.time()

        f_hat, est_gp = run_gp_ode(ode_data, X_ph, y_ph, ode, x_id, s)

        f_true = ode.get_expression()[x_id]
        if not isinstance(f_true, tuple):
            correct = sympy.simplify(f_hat - f_true) == 0
        else:
            correct_list = [sympy.simplify(f_hat - f) == 0 for f in f_true]
            correct = max(correct_list) == 1

        end = time.time()


        print("Model found", f_hat)
        print("True model", f_true)
        print("Correct", correct)

In [None]:
ode_name = 'LogisticODE'
ode_param = None
x_id = 0
freq = 10
n_sample = 100
noise_ratio = 0.0
seed = 0
n_seed = 10
n_basis = 50
basis_str = 'sine'
derivative = 3

run(ode_name, ode_param, x_id, freq, n_sample, noise_ratio, seed, n_seed, n_basis, basis_str, derivative)

In [None]:
ode_name = 'GompertzODE'
ode_param = None
x_id = 0
freq = 10
n_sample = 100
noise_ratio = 0.0
seed = 0
n_seed = 10
n_basis = 50
basis_str = 'sine'
derivative = 3

run(ode_name, ode_param, x_id, freq, n_sample, noise_ratio, seed, n_seed, n_basis, basis_str, derivative)