In [None]:
from collections import namedtuple

import jax
import jax.numpy as jnp
import numpy as np

from scipy.integrate import odeint
import matplotlib.pyplot as plt

from cgc.graph import ComputationalGraph, derivative
from cgc.types import LearnableParameter, ConstantParameter


# Mass-Spring System

## Data Generation

In [None]:
def ms_system_ode(pq, t):
    p, q = pq
    h_grad = [
        -2 * q,
        2 * p
    ]

    return h_grad

t = np.linspace(0, 80, 400)
pq = odeint(ms_system_ode, [-0.1, -0.1], t)

In [None]:
p, q = pq.T
H = 0.5 * (p ** 2 + q ** 2)

plt.plot(t, p, label="$p$")
plt.plot(t, q, label="$q$")
plt.plot(t, H, label="H")
plt.legend()

In [None]:
X_true = np.concatenate((
    t[:, np.newaxis],
    pq,
    H[:, np.newaxis],
    #np.full_like(H, fill_value=0.03)[:, np.newaxis]
), axis=1)

M = np.ones_like(X_true).astype(bool)
M[200:, 1:3] = False
M[:, 3] = False

X = np.full_like(X_true, fill_value=0)
X[M] = X_true[M]
X[:, 3] = 0.1

In [None]:
ms_graph = ComputationalGraph(observables_order=["t", "p", "q", "H"])

ms_graph.add_observable("t")
ms_graph.add_unknown_fn("t", "p", alpha=0.01, gamma=LearnableParameter(1.0))
ms_graph.add_unknown_fn("t", "q", alpha=0.01, gamma=LearnableParameter(1.0))

ms_graph.add_known_fn("p", "p_dot", derivative)
ms_graph.add_known_fn("q", "q_dot", derivative)
ms_graph.add_known_fn("p_dot", "-p_dot", lambda p_dot: -p_dot)

ms_graph.add_aggregator(["q_dot", "-p_dot"], "qp_dot")

ms_graph.add_aggregator(["p", "q"], "pq")
ms_graph.add_unknown_fn("pq", "H", linear_functional=jax.jacobian, observations="qp_dot", alpha=0.01, gamma=1.0)
ms_graph.add_known_fn("H", "grad_H", derivative)

ms_graph.add_aggregator(["q_dot", "grad_H"], "(q_dot, grad_H)")
ms_graph.add_aggregator(["p_dot", "grad_H"], "(p_dot, grad_H)")

def p_dot_constraint(p_dot_grad_H):
    p_dot, grad_H = p_dot_grad_H[:, 0], p_dot_grad_H[:, 1:]
    return p_dot + grad_H[:, 1]

def q_dot_constraint(q_dot_grad_H):
    q_dot, grad_H = q_dot_grad_H[:, 0], q_dot_grad_H[:, 1:]
    return q_dot - grad_H[:, 0]

ms_graph.add_constraint("(p_dot, grad_H)", "W1", p_dot_constraint)
ms_graph.add_constraint("(q_dot, grad_H)", "W2", q_dot_constraint)

In [None]:
ms_graph.set_loss_multipliers(constraints_loss_multiplier=10000)

In [None]:
Z = ms_graph.complete(X, M, optimizer="l-bfgs-b", learn_parameters=True, n_rounds=10)

In [None]:
ms_graph.report_kernel_params()

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(30, 20))

axes[0].plot(t, Z[:, 1], label="Prediction")
axes[0].plot(t, p, label="Truth")
axes[0].set_title("$p$")
axes[0].axvline(39, label="End-of-Observation", linestyle="--", c='black')
axes[0].legend()

axes[1].plot(t, Z[:, 2], label="Prediction")
axes[1].plot(t, q, label="Truth")
axes[1].set_title("$q$")
axes[1].axvline(39, label="End-of-Observation", linestyle="--", c='black')
axes[1].legend()

# Two-Mass-Three-Springs System

## Data Generation

In [None]:
def m2s3_system_ode(pq, t):
    p1, p2, q1, q2 = pq
    h_grad = [
        -q1 + (q2 - q1),
        -q2 - (q2 - q1),
        p1,
        p2
    ]
    
    return h_grad

t = np.linspace(0, 80, 400)
pq = odeint(m2s3_system_ode, [0.1, -0.1, 0.2, -0.1], t)

In [None]:
p1, p2, q1, q2 = pq.T
H = 0.5 * (q1 ** 2 + q2 ** 2 + (q2 - q1) ** 2 + p1 ** 2 + p2 ** 2)

plt.plot(t, p1, label="$p_1$")
plt.plot(t, p2, label="$p_2$")
plt.plot(t, q1, label="$q_1$")
plt.plot(t, q2, label="$q_2$")
plt.plot(t, H, label="H")
plt.legend()


In [None]:
X_true = np.concatenate((
    t[:, np.newaxis],
    pq,
    H[:, np.newaxis],
), axis=1)

M = np.ones_like(X_true).astype(bool)
M[200:, 1:5] = False
M[:, 5] = False

X = np.zeros_like(X_true)
X[M] = X_true[M]

In [None]:
m2s3_graph = ComputationalGraph(observables_order=["t", "p1", "p2", "q1", "q2", "H"])

m2s3_graph.add_observable("t")
m2s3_graph.add_unknown_fn("t", "p1", alpha=0.01, gamma=LearnableParameter(1.5))
m2s3_graph.add_unknown_fn("t", "p2", alpha=0.01, gamma=LearnableParameter(1.5))
m2s3_graph.add_unknown_fn("t", "q1", alpha=0.01, gamma=LearnableParameter(1.5))
m2s3_graph.add_unknown_fn("t", "q2", alpha=0.01, gamma=LearnableParameter(1.5))

m2s3_graph.add_known_fn("p1", "p1_dot", derivative)
m2s3_graph.add_known_fn("p2", "p2_dot", derivative)
m2s3_graph.add_known_fn("q1", "q1_dot", derivative)
m2s3_graph.add_known_fn("q2", "q2_dot", derivative)

m2s3_graph.add_aggregator(["q1_dot", "q2_dot"], "q_dot")
m2s3_graph.add_aggregator(["p1_dot", "p2_dot"], "p_dot")
m2s3_graph.add_known_fn("p_dot", "-p_dot", lambda p_dot: -p_dot)

m2s3_graph.add_aggregator(["q_dot", "-p_dot"], "qp_dot")
m2s3_graph.add_aggregator(["p1", "p2", "q1", "q2"], "pq")

m2s3_graph.add_unknown_fn("pq", "H", linear_functional=jax.jacobian, observations="qp_dot", alpha=0.01, gamma=1)
m2s3_graph.add_known_fn("H", "grad_H", derivative)

m2s3_graph.add_aggregator(["p_dot", "grad_H"], "(p_dot, grad_H)")
def p_dot_constraint(p_dot_grad_H):
    p_dot, grad_H = p_dot_grad_H[:, :2], p_dot_grad_H[:, 2:]
    return p_dot + grad_H[:, 2:]

m2s3_graph.add_aggregator(["q_dot", "grad_H"], "(q_dot, grad_H)")
def q_dot_constraint(q_dot_grad_H):
    q_dot, grad_H = q_dot_grad_H[:, :2], q_dot_grad_H[:, 2:]
    return q_dot - grad_H[:, :2]

m2s3_graph.add_constraint("(p_dot, grad_H)", "W1", p_dot_constraint)
m2s3_graph.add_constraint("(q_dot, grad_H)", "W2", q_dot_constraint)

In [None]:
m2s3_graph.set_loss_multipliers(constraints_loss_multiplier=10000)

In [None]:
Z = m2s3_graph.complete(X, M, optimizer="l-bfgs-b", learn_parameters=True, n_rounds=20)

In [None]:
m2s3_graph.report_kernel_params()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(30, 30))

PlotData = namedtuple("PlotData", ["axes", "index", "truth", "label"])

plots_data = [
    PlotData((0, 0), 1, p1, "$p_1$"),
    PlotData((0, 1), 2, p2, "$p_2$"),
    PlotData((1, 0), 3, q1, "$q_1$"),
    PlotData((1, 1), 4, q2, "$q_2$")
]

for data in plots_data:
    i,j = data.axes

    axes[i, j].plot(t, Z[:, data.index], label="Predictions")
    axes[i, j].plot(t, data.truth, label="Truth")
    axes[i, j].axvline(39, label="End-of-Observations", linestyle='--', c='black')
    axes[i, j].set_title(data.label)
    axes[i, j].legend()

In [None]:
plt.plot(t, Z[:, 5])

# Hénon–Heiles System

## Data Generation

In [None]:
def hh_system_ode(pq, t):
    p1, p2, q1, q2 = pq
    h_grad = [
        -q1 - 2 * q1 * q2,
        -q2 - q1 ** 2 + q2 ** 2,
        p1,
        p2
    ]

    return h_grad

t = np.linspace(0, 80, 400)
pq = odeint(hh_system_ode, [0.1, -0.1, 0.2, -0.1], t)


In [None]:
p1, p2, q1, q2 = pq.T
H = 0.5 * (q1 ** 2 + q2 ** 2 + p1 ** 2 + p2 ** 2) + q2 * q1 ** 2 - (1/3) * q2 ** 3

plt.plot(t, p1, label="$p_1$")
plt.plot(t, p2, label="$p_2$")
plt.plot(t, q1, label="$q_1$")
plt.plot(t, q2, label="$q_2$")
plt.plot(t, H, label="H")
plt.legend()


In [None]:
X_true = np.concatenate((
    t[:, np.newaxis],
    pq,
    H[:, np.newaxis],
), axis=1)

M = np.ones_like(X_true).astype(bool)
M[200:, 1:5] = False
M[:, 5] = False

X = np.zeros_like(X_true)
X[M] = X_true[M]

In [None]:
hh_graph = ComputationalGraph(observables_order=["t", "p1", "p2", "q1", "q2", "H"])

hh_graph.add_observable("t")
hh_graph.add_unknown_fn("t", "p1", alpha=0.01, gamma=LearnableParameter(1.2))
hh_graph.add_unknown_fn("t", "p2", alpha=0.01, gamma=LearnableParameter(1.2))
hh_graph.add_unknown_fn("t", "q1", alpha=0.01, gamma=LearnableParameter(1.2))
hh_graph.add_unknown_fn("t", "q2", alpha=0.01, gamma=LearnableParameter(1.2))

hh_graph.add_known_fn("p1", "p1_dot", derivative)
hh_graph.add_known_fn("p2", "p2_dot", derivative)
hh_graph.add_known_fn("q1", "q1_dot", derivative)
hh_graph.add_known_fn("q2", "q2_dot", derivative)

hh_graph.add_aggregator(["q1_dot", "q2_dot"], "q_dot")
hh_graph.add_aggregator(["p1_dot", "p2_dot"], "p_dot")
hh_graph.add_known_fn("p_dot", "-p_dot", lambda p_dot: -p_dot)

hh_graph.add_aggregator(["q_dot", "-p_dot"], "qp_dot")
hh_graph.add_aggregator(["p1", "p2", "q1", "q2"], "pq")

hh_graph.add_unknown_fn("pq", "H", linear_functional=jax.jacobian, observations="qp_dot", alpha=0.01, gamma=1.0)
hh_graph.add_known_fn("H", "grad_H", derivative)

hh_graph.add_aggregator(["p_dot", "grad_H"], "(p_dot, grad_H)")
def p_dot_constraint(p_dot_grad_H):
    p_dot, grad_H = p_dot_grad_H[:, :2], p_dot_grad_H[:, 2:]
    return p_dot + grad_H[:, 2:]

hh_graph.add_aggregator(["q_dot", "grad_H"], "(q_dot, grad_H)")
def q_dot_constraint(q_dot_grad_H):
    q_dot, grad_H = q_dot_grad_H[:, :2], q_dot_grad_H[:, 2:]
    return q_dot - grad_H[:, :2]

hh_graph.add_constraint("(p_dot, grad_H)", "W1", p_dot_constraint)
hh_graph.add_constraint("(q_dot, grad_H)", "W2", q_dot_constraint)

In [None]:
hh_graph.set_loss_multipliers(constraints_loss_multiplier=10000)

In [None]:
Z = hh_graph.complete(X, M, optimizer="l-bfgs-b", learn_parameters=True, n_rounds=20)

In [None]:
hh_graph.report_kernel_params()

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(30, 30))

PlotData = namedtuple("PlotData", ["axes", "index", "truth", "label"])

plots_data = [
    PlotData((0, 0), 1, p1, "$p_1$"),
    PlotData((0, 1), 2, p2, "$p_2$"),
    PlotData((1, 0), 3, q1, "$q_1$"),
    PlotData((1, 1), 4, q2, "$q_2$")
]

for data in plots_data:
    i,j = data.axes

    axes[i, j].plot(t, Z[:, data.index], label="Predictions")
    axes[i, j].plot(t, data.truth, label="Truth")
    axes[i, j].axvline(39, label="End-of-Observations", linestyle='--', c='black')
    axes[i, j].set_title(data.label)
    axes[i, j].legend()

In [None]:
plt.plot(t, Z[:, 5])


# [WIP] Periodically driven pendulum

## Data Generation

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from scipy.integrate import odeint
import matplotlib.pyplot as plt

nu = 5
lam = 1

def H(p, q, t):
    return 0.5 * p ** 2 - (nu ** 2) * jnp.cos(q) - lam * (0.3 * p * q * jnp.sin(0.2 * t) + 0.7 * p * q * jnp.sin(0.3 * t))

def dq_H(p, q, t):
    return (nu ** 2) * np.sin(q) - lam * p * (0.3 * np.sin(2 * t) + 0.7 * np.sin(3 * t))

def dp_H(p, q, t):
    return p - lam * q * (0.3 * np.sin(2 * t) + 0.7 * np.sin(3 * t))

#dq_H = jax.vmap(
#        jax.vmap(
#            jax.vmap(jax.grad(H, argnums=1), in_axes=(None, None, 0)),
#            in_axes=(None, 0, None)
#        ),
#        in_axes=(0, None, None)
#)
#dq_H = jax.vmap(jax.grad(H, argnums=1), (0, 0, 0))
#dp_H = jax.vmap(jax.grad(H, argnums=0), (0, 0, 0))

p_dot = lambda p, q, t: -1 * dq_H(p, q, t)
q_dot = lambda p, q, t: dp_H(p, q, t)

def system_ode(pq, t):
    h_grad = [0, 0]
    p, q = pq
    h_grad[0] = p_dot(p, q, t)
    h_grad[1] = q_dot(p, q, t)
    return h_grad

t = np.linspace(0, 40, 400)
pq = odeint(system_ode, [0.1, 0.1], t=t)


In [None]:
h_values = H(pq[:, 0], pq[:, 1], t)
np.mean(h_values)

In [None]:
np.std(h_values)

In [None]:
_, axes = plt.subplots(1, 3, figsize=(30, 10))

axes[0].plot(t, pq[:, 0])
axes[1].plot(t, pq[:, 1])
axes[2].plot(t, H(pq[:, 0], pq[:, 1], t))

In [None]:
X_true = np.concatenate((t[:, np.newaxis], pq, H(pq[:, 0], pq[:, 1], t)[:, np.newaxis]), axis=1)
n, p = X_true.shape
M = np.full((n, p), fill_value=True).astype(bool)
M[:, p - 1] = False
M[200:, 1:3] = False


X = np.zeros_like(X_true)
X[M] = X_true[M]


In [None]:
from cgc.graph import ComputationalGraph, derivative

graph = ComputationalGraph(observables_order=["t", "p", "q", "H"])

graph.add_observable("t")
graph.add_unknown_fn("t", "p", alpha=0.01)
graph.add_unknown_fn("t", "q", alpha=0.01)
graph.add_known_fn("p", "p_dot", derivative)
graph.add_known_fn("q", "q_dot", derivative)
graph.add_aggregator(["p", "q", "t"], "pqt")
graph.add_unknown_fn("pqt", "H", alpha=0.01)
graph.add_known_fn("H", "grad_H", derivative)

graph.add_aggregator(["grad_H", "p_dot"], "(grad_H, p_dot)")
graph.add_aggregator(["grad_H", "q_dot"], "(grad_H, q_dot)")

def q_dot_constraint(grad_H_q_dot):
    grad_H, q_dot = grad_H_q_dot[:, :3], grad_H_q_dot[:, 3]
    Dp_H = grad_H[:, 0]
    return q_dot - Dp_H

def p_dot_constraint(grad_H_p_dot):
    grad_H, p_dot = grad_H_p_dot[:, :3], grad_H_p_dot[:, 3]
    Dq_H = grad_H[:, 1]
    return p_dot + Dq_H

graph.add_constraint("(grad_H, p_dot)", "W1", p_dot_constraint)
graph.add_constraint("(grad_H, q_dot)", "W2", q_dot_constraint)

In [None]:
graph.set_loss_multipliers(constraints_loss_multiplier=1000, data_compliance_loss_multiplier=1000, unknown_functions_loss_multiplier=10)
Z = graph.complete(X, M)

In [None]:
_, axes = plt.subplots(1, 3, figsize=(30, 10))

axes[0].plot(t, Z[:, 1])
axes[1].plot(t, Z[:, 2])
axes[2].plot(t, Z[:, 3])