In [None]:
import gymnasium as gym
import jax.numpy as jnp
import numpy as np
import jax.random as jrandom
from utils_v2 import KalmanFilter
from stable_baselines3 import PPO
from jax import grad, test_util

In [8]:
delta_t = 0.1
m = 0.1
M = 1
g = 9.8
le = 0.5
A = jnp.array(
    [
        [0, 1, 0, 0],
        [0, 0, m * g * delta_t / M, 0],
        [0, 0, 0, 1],
        [0, 0, (m + M) * g * delta_t / (M * le), 0],
    ]
)

# A=jnp.array([[1,delta_t,0,0],
#             [0,1,m*g*delta_t/M,0],
#             [0,0,1,delta_t],
#             [0,0,(m+M)*g*delta_t/(M*l),1]])

B = jnp.array([[0], [delta_t / M], [0], [delta_t / (M * le)]])

B_0 = jnp.ones((4, 1))


H = jnp.eye(4)

R = jnp.eye(H.shape[0]) * 0.005
C = H
Q = jnp.eye(A.shape[0]) * 0.005
mean = 0
std_dev = 0.005
# Q=jrandom.normal(key, shape=(A.shape[0],1)) * std_dev + mean     #FAILED
P = jnp.ones(A.shape[0])
key = jrandom.PRNGKey(42)
Z = jrandom.normal(key, shape=(C.shape[0], 1)) * std_dev + mean
w_k = jnp.ones((A.shape[0], 1)) * 0.005
# w_k=jrandom.normal(key, shape=(A.shape[0],1)) * std_dev + mean    #FAILED

In [None]:
class KFOptimiser(gym.Env):
    def __init__(self, control_model, lr=0.01, lr_decay=0.9995, min_lr=0.001):
        super(KFOptimiser, self).__init__()

        self.env = gym.make("CartPole-v1")
        self.env._max_episode_steps = 1000

        self.B_0 = jnp.zeros((4, 1))

        self.kf = KalmanFilter(
            jnp.expand_dims(self.env.reset()[0], axis=-1),
            A,
            self.B_0,
            H,
            R,
            C,
            Q,
            Z,
            w_k,
            P,
        )

        self.action_space = gym.spaces.Box(
            low=-1, high=1, shape=self.kf.B.flatten().shape
        )

        low = np.append(
            jnp.append(self.env.observation_space.low, self.env.observation_space.low),
            jnp.full(self.B_0.flatten().shape, -3),
        )
        self.observation_space = gym.spaces.Box(low, -low)

        self.control_model = control_model
        self.lr = lr
        self.lr_decay = lr_decay
        self.min_lr = min_lr

    def reset(self, seed=None, **kwargs):
        state = self.env.reset()[0]  # 4-dim
        self.kf.reset()  # 4-dim

        # approach 1: reset B at each episode
        self.kf.B = self.B_0  # 2

        # approach 2: don't reset at each episode
        # do nothing  #1

        # approach 3: Perturb B slightly at the start of each episode
        self.kf.B += np.random.normal(
            0, 0.3, size=self.kf.B.shape
        )  # 3 very bad results

        self.state = np.append(
            jnp.append(state, self.kf.x_0), self.kf.B
        )  # 12-dim (4x true state, 4x estimated state, 4x B)

        return self.state, {}  # 12-dim

    def loss_fn(self, B, true_state, kf_state):
        error = true_state - kf_state
        return jnp.mean(error**2)

    def step(self, action):
        action = jnp.expand_dims(action, axis=-1)  # + np.random.normal(0, 0.1)  # 4-dim

        u = self.control_model.predict(self.state[:4])[0]  # first 4x of true states
        true_state, _, done, truncuated, _ = self.env.step(u)

        u_k = jnp.array([[u if u == 1 else -1]])

        # true_B=true_state-self.kf.A @ self.state[:4] / u_k

        true_B = (
            true_state
            - self.state[:4]
            - delta_t * self.kf.A @ self.state[:4]
            - self.kf.w_k
        ) / (u_k * delta_t)

        self.kf.predict(u_k)

        self.state = np.append(
            jnp.append(true_state, self.kf.update(true_state)), self.kf.B
        )

        error = true_state - self.state[4:8]

        grad_loss_fn = grad(self.loss_fn)

        gradient = grad_loss_fn(self.kf.B, true_state, self.state[4:8])

        # b_error=b_error*jnp.array([[1],[1],[1],[5]])

        if self.lr > self.min_lr:
            self.lr *= self.lr_decay
            self.lr = max(self.min_lr, self.lr)

        # self.kf.B+=action*self.lr
        # + jnp.expand_dims(error,axis=-1)*self.beta
        # + b_error*self.beta
        previous_b = self.kf.B

        self.kf.B += self.lr * action - (1 - self.lr) * gradient

        self.kf.B = jnp.clip(
            self.kf.B,
            jnp.expand_dims(self.observation_space.low[8:], axis=-1),
            jnp.expand_dims(self.observation_space.high[8:], axis=-1),
        )
        B_change_penalty = jnp.linalg.norm(self.kf.B - previous_b)
        b_error = true_B - self.kf.B

        # error=error*jnp.array([[1],[1],[1],[3]])

        error = jnp.mean(error**2)

        b_error = jnp.mean(b_error.flatten() ** 2)

        return self.state, -error - 0.1 * B_change_penalty, done, truncuated, {}

In [11]:
model = PPO.load("ppo_cartpole_100k_noKF")

In [12]:
# cpu: 50mins
# cuda: 58mins
try:
    env = KFOptimiser(model)
    log_dir = "logs"
    model1 = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_dir, device="cpu")
    model1.learn(total_timesteps=300_000, tb_log_name="PPO_grad_b_change_pen")
except KeyboardInterrupt:
    print("cut")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to logs\PPO_grad_b_change_pen_1


---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1e+03    |
|    ep_rew_mean     | -17.4    |
| time/              |          |
|    fps             | 104      |
|    iterations      | 1        |
|    time_elapsed    | 19       |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 1e+03        |
|    ep_rew_mean          | -16          |
| time/                   |              |
|    fps                  | 92           |
|    iterations           | 2            |
|    time_elapsed         | 44           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0050049694 |
|    clip_fraction        | 0.033        |
|    clip_range           | 0.2          |
|    entropy_loss         | -5.69        |
|    explained_variance   | -5.36        |
|    learning_r

In [13]:
from tqdm import tqdm

episodes = 10
x = []
for ep in tqdm(range(1, episodes + 1), ascii=True, unit="episodes"):
    state = env.reset()[0]
    while True:
        # env.render()
        action = model1.predict(state)[0]
        state, reward, done, truncuated, info = env.step(action)
        if done or truncuated:
            break
    h = env.kf.B - B
    x.append(h)
jnp.mean(jnp.array(x), axis=0)

100%|##########| 10/10 [01:46<00:00, 10.66s/episodes]


Array([[ 0.08263567],
       [-0.04436978],
       [-0.17011417],
       [-0.5063559 ]], dtype=float32)

In [None]:
class KFOptimiser_A(gym.Env):
    def __init__(self, control_model, lr=0.01, beta=0.001):
        super(KFOptimiser_A, self).__init__()

        self.env = gym.make("CartPole-v1")
        self.env._max_episode_steps = 1000

        self.A_0 = jnp.zeros((4, 4))

        self.kf = KalmanFilter(
            jnp.expand_dims(self.env.reset()[0], axis=-1),
            self.A_0,
            B,
            H,
            R,
            C,
            Q,
            Z,
            w_k,
            P,
        )

        self.action_space = gym.spaces.Box(
            low=-1, high=1, shape=self.A_0.flatten().shape
        )

        low = np.append(
            jnp.append(self.env.observation_space.low, self.env.observation_space.low),
            jnp.full(self.A_0.flatten().shape, -3),
        )
        self.observation_space = gym.spaces.Box(low, -low)

        self.control_model = control_model
        self.lr = lr
        self.beta = beta

    def reset(self, seed=None, **kwargs):
        state = self.env.reset()[0]  # 4-dim
        self.kf.reset()  # 4-dim

        # approach 1: reset B at each episode
        self.kf.A = self.A_0  # 2

        # approach 2: don't reset at each episode
        # do nothing  #1

        # approach 3: Perturb B slightly at the start of each episode
        self.kf.A += np.random.normal(0, 0.1, size=self.kf.A.shape)

        self.state = np.append(
            jnp.append(state, self.kf.x_0), self.kf.A.flatten()
        )  # 24-dim (4x true state, 4x estimated state, 16x A)

        return self.state, {}  # 24-dim

    def step(self, action):
        action = jnp.expand_dims(action, axis=-1) + np.random.normal(0, 0.1)  # 4-dim

        u = self.control_model.predict(self.state[:4])[0]  # fist 4x of true states

        true_state, _, done, truncuated, _ = self.env.step(u)

        u_k = jnp.array([[u if u == 1 else -1]])

        true_A = (true_state - self.state[:4] - delta_t * self.kf.B @ u_k) / self.state[
            :4
        ]

        self.kf.predict(u_k)

        self.state = np.append(
            jnp.append(true_state, self.kf.update(true_state)), self.kf.A.flatten()
        )

        error = self.state[:4] - self.state[4:8]

        a_error = true_A - self.kf.A

        # b_error=b_error*jnp.array([[1],[1],[1],[5]])

        update = self.kf.A.reshape((16, 1)) + action * self.lr
        +jnp.expand_dims(error, axis=-1) * self.beta
        +a_error.flatten() * self.beta

        self.kf.A = jnp.clip(
            update,
            jnp.expand_dims(self.observation_space.low[8:], axis=-1),
            jnp.expand_dims(self.observation_space.high[8:], axis=-1),
        ).reshape((4, 4))

        total_error = jnp.mean(error**2)

        # total_b_error=jnp.mean(b_error.flatten()**2) #B matrix error didn't work well

        return self.state, -total_error, done, truncuated, {}

In [None]:
env = KFOptimiser_A(model)
model2 = PPO("MlpPolicy", env, verbose=1, tensorboard_log=log_dir)
model2.learn(total_timesteps=300_000, tb_log_name="PPO_A")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to logs\PPO_A_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 1e+03    |
|    ep_rew_mean     | -11      |
| time/              |          |
|    fps             | 235      |
|    iterations      | 1        |
|    time_elapsed    | 8        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1e+03       |
|    ep_rew_mean          | -10.7       |
| time/                   |             |
|    fps                  | 221         |
|    iterations           | 2           |
|    time_elapsed         | 18          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008822694 |
|    clip_fraction        | 0.0828      |
|    clip_range           | 0.2       

<stable_baselines3.ppo.ppo.PPO at 0x1d427b3dbe0>

In [None]:
episodes = 10
x = []
for ep in range(episodes):
    rewards = 0
    state = env.reset()[0]
    while True:
        # env.render()
        action = model2.predict(state)[0]
        state, reward, done, truncuated, info = env.step(action)
        rewards += reward
        if done or truncuated:
            break
    h = env.kf.A - A
    x.append(h)

Array([[-1.5120112 , -3.0536647 , -2.5878775 ,  2.1101022 ],
       [-0.64700586,  2.9988332 , -2.6241379 , -1.1773218 ],
       [-0.66021204, -2.5274637 , -2.9931839 ,  0.14226285],
       [ 2.9931576 ,  0.360365  , -3.6396477 , -1.1661142 ]],      dtype=float32)