# Import Dependency

In [None]:
# if need use colab TPU
from jax.tools import colab_tpu
colab_tpu.setup_tpu()

In [None]:
!pip install diffrax optax
import os
from functools import partial 

import jax
from jax import jit, lax
import jax.numpy as jnp
import jax.random as jr
from diffrax import diffeqsolve, ODETerm, Euler, ImplicitEuler, Kvaerno3, Dopri5, SaveAt, PIDController, NewtonNonlinearSolver

import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
import time 
import pandas as pd 

# enable 64 bit. This only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
#config.update("jax_debug_nans", True)
#config.update("jax_disable_jit", True)

# Limit ourselves to single-threaded jax/xla operations to avoid thrashing. See
# https://github.com/google/jax/issues/743.
#os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
#                           "intra_op_parallelism_threads=1")

n_devices = jax.local_device_count()
print(n_devices)

# Define ODE System

RC models are typically represented as an explicit ODE system, which follows the general format as below:

$
\dot x = Ax + Bu \\
y = Cx 
$

For the RC model used here, the detailed definitions are shown as:

*State Variables*

$
x = \begin{bmatrix}
    T_{a,i} \\
    T_{w,e} \\
    T_{w,i}
\end{bmatrix}
$

*Disturbances*

$
u = \begin{bmatrix}
    T_{a,o} \\
    q_{con,i} \\
    q_{hvac} \\
    q_{rad,e} \\
    q_{rad,i}
    \end{bmatrix}
$

*Output*

$
y = \begin{bmatrix}
    T_{a,i}
    \end{bmatrix}
$

The time-invariant coefficients $A$, $B$, and $C$ are calculated based on energy conservation laws from given building parameters such as heating resistance $R$ and capacitor $C$.

$
A = \begin{bmatrix}
    \frac{-1}{C_{a,i}(\frac{1}{R_g} + \frac{1}{R_i})} & 0 & \frac{1}{C_{a,i}R_{i}} \\
    0 & 0 & 0 \\
    0 & 0 & 0
    \end{bmatrix}
$

$
B = \begin{bmatrix}
    \frac{1}{C_{a,i}R_{g}} & \frac{1}{C_{a,i}} & \frac{1}{C_{a,i}} & 0 & 0 \\
    \frac{1}{C_{w,e}R_{e}} & 0 & 0 & \frac{1}{C_{w,e}} & 0 \\
    0 & 0 & 0 & 0 & \frac{1}{C_{w,i}}  
    \end{bmatrix}
$

$
C = \begin{bmatrix}
    1 & 0 & 0
    \end{bmatrix}
$

The following codes are used to build such a model.

In [None]:
@jax.jit
def get_ABCD(Cai, Cwe, Cwi, Re, Ri, Rw, Rg):
    A = jnp.zeros((3, 3))
    B = jnp.zeros((3, 5))
    C = jnp.zeros((1, 3))
    A = A.at[0, 0].set(-1/Cai*(1/Rg+1/Ri))
    A = A.at[0, 2].set(1/(Cai*Ri))
    A = A.at[1, 1].set(-1/Cwe*(1/Re+1/Rw))
    A = A.at[1, 2].set(1/(Cwe*Rw))
    A = A.at[2, 0].set(1/(Cwi*Ri))
    A = A.at[2, 1].set(1/(Cwi*Rw))
    A = A.at[2, 2].set(-1/Cwi*(1/Rw+1/Ri))

    B = B.at[0, 0].set(1/(Cai*Rg))
    B = B.at[0, 1].set(1/Cai)
    B = B.at[0, 2].set(1/Cai)
    B = B.at[1, 0].set(1/(Cwe*Re))
    B = B.at[1, 3].set(1/Cwe)
    B = B.at[2, 4].set(1/Cwi)

    C = C.at[0, 0].set(1)

    D = 0

    return A, B, C, D

Together with Kalman Filter, the ODE has somehow changed to the following format. 

Need add more documentation about this.

In [None]:
@jit
def continuous_kmf(t, xP, A, B, C, Q, R, u, z):
    """
    TODO: Kalman filter needs significant tuning otherwise it would lead to unrealistic big dxdt if K is too big.
    """
    # extract states
    x, P = xP
    
    # eq 3.22 of Ref [1]
    K = P @ C.transpose() @ jnp.linalg.inv(R)

    # eq 3.21 of Ref [1]
    dPdt = (
        A @ P
        + P @ A.transpose()
        + Q
        - P @ C.transpose() @ jnp.linalg.inv(R) @ C @ P
    )

    # eq 3.23 of Ref [1]
    dxdt = A @ x + B @ u + K @ (z - C @ x)
    #jax.debug.print("{x}", x=x)
    #jax.debug.print("{dxdt}", dxdt=A@x + B@u)
    #jax.debug.print("{K}", K=K)
    #jax.debug.print("{e}", e=z - C@x)
    #jax.debug.print("{dxdt}", dxdt=dxdt)
    return (dxdt, dPdt)

# Define ODE System Forward Calls
Now lets define a fowrad simulation function to solve the above ODE system over time using numerical integrators.

In [None]:
# Using for loop to update the disturbance every time step
# forward function that simulates ODE given time period and inputs
@partial(jit, static_argnums=(0, 1, 2, 3, 5,))
def forward(func, ts, te, dt, xP0, solver, args):
    # unpack args:
    # A, B, C: system dynamics coefficents
    # Q, R: variance of modeling errors and measurement noise
    # u: control inputs and disturbances
    # z: measurements
    A, B, C, Q, R, u, z = args

    # ode formulation
    term = ODETerm(func)

    # initial step
    t = ts
    tnext = t + dt
    u0 = u[0, :]
    z0 = z[0, :]
    args = (A, B, C, Q, R, u0, z0)
    
    # helper function
    def step_at_t(carryover, t, term, dt, te, A, B, C, Q, R, u, z):
        # the lax.scannable function to computer ODE/DAE systems
        xP, state, i = carryover
        args = (A, B, C, Q, R, u[i, :], z[i, :])
        tnext = jnp.minimum(t + dt, te)

        xPnext, _, _, state, _ = solver.step(
            term, t, tnext, xP, args, state, made_jump=False)
        i += 1

        return (xPnext, state, i), xP
    
    # main loop
    state0 = solver.init(term, t, tnext, xP0, args)
    i = 0
    carryover_init = (xP0, state0, i)
    step_func = partial(step_at_t, term=term, dt=dt, te=te, A=A, B=B, C=C, Q=Q, R=R, u=u, z=z)
    time_steps = jnp.arange(ts, te+1, dt)
    carryover_final, xP_all = lax.scan(
        step_func, init=carryover_init, xs=time_steps)

    return time_steps, xP_all

# Prepare Data