In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import gymnasium as gym
import numpy as np
import random
import os
import pandas as pd

In [24]:
# Test if interpolaiton is differentiable in jax.numpy
t = np.arange(10)
y = t**2 + 3*t

interp = jnp.interp(9.5, t, y) 

f = lambda x: jnp.interp(x, t, y)
grad_fcn = jax.value_and_grad(f)
grad_fcn(0.)
grad_fcn(0.5)


(Array(2., dtype=float32), Array(4., dtype=float32, weak_type=True))

In [None]:

class QNetwork(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(2)(x)
        return x

class EnvModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(3)(x)  # 3 outputs: 2 states [Tz for next step and power for next step] (although we have simple relationship between power and control action), and reward
        return x

class LSSM(nn.Module):
    pass 

In [None]:
import env
# RC model parameters
rc_params = [6.9789902e+03, 2.1591113e+04, 1.8807944e+05, 3.4490612e+00, 4.9556872e-01, 9.8289281e-02, 4.6257420e+00]
x0 = np.array([20, 35.8, 26.])
x_high = np.array([40., 80., 40.])
x_low = np.array([10., 10., 10.])
n_actions = 101
u_high = [0]
u_low = [-10.0] # -12

# load disturbances
file_path = os.path.abspath('')
parent_path = os.path.dirname(file_path)
data_path = os.path.join(parent_path, 'data/disturbance_1min.csv')
data = pd.read_csv(data_path, index_col=[0])
# assign time index
t_base = 181*24*3600 # 7/1
n = len(data)
index = range(t_base, t_base + n*60, 60)
data.index = index

# sample
dt = 900
data = data.groupby([data.index // dt]).mean()
index_dt = range(t_base, t_base + len(data)*dt, dt)
data.index = index_dt 

# get disturbances for lssm
t_d = index_dt
disturbance_names = ['out_temp', 'qint_lump', 'qwin_lump', 'qradin_lump']
disturbance = data[disturbance_names].values

# RC Gym envionment
ts = 195*24*3600
ndays = 7
te = ndays*24*3600 + ts
weights = [100., 1., 0.] # for energy cost, dT, du

env = gym.make("R4C3Discrete-v0",
            rc_params = rc_params,
            x0 = x0,
            x_high = x_high,
            x_low = x_low,
            n_actions = n_actions,
            u_high = u_high,
            u_low = u_low,
            disturbances = (t_d, disturbance),
            ts = ts,
            te = te,
            dt = dt,
            weights = weights).env

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

In [34]:
random.seed(41)

# Hyperparameters
learning_rate = 1e-3
gamma = 0.99
epsilon = 1.0
epsilon_decay = 0.99
episodes = 500
batch_size = 64
planning_steps = 5

q_network = QNetwork()
env_model = EnvModel()

params = q_network.init(jax.random.PRNGKey(0), jnp.zeros((state_dim,)))
env_model_params = env_model.init(jax.random.PRNGKey(1), jnp.zeros((state_dim + 1,)))

optimizer = optax.adam(learning_rate)
env_model_optimizer = optax.adam(learning_rate)

opt_state = optimizer.init(params)
env_model_opt_state = env_model_optimizer.init(env_model_params)

@jax.jit
def q_learning_update(params, opt_state, state, action, reward, next_state, done):
    def loss_fn(params):
        q_values = q_network.apply(params, state)
        next_q_values = q_network.apply(params, next_state)
        target = reward + gamma * jnp.max(next_q_values, axis=1) * (1 - done)
        loss = jnp.mean((q_values[jnp.arange(q_values.shape[0]), action] - target) ** 2)
        return loss

    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# env model update 
@jax.jit
def env_model_update(env_model_params, env_model_opt_state, state, action, next_state, reward):

    def env_model_loss_fn(env_model_params):

        state_action = jnp.hstack([state, action.reshape(-1,1)])#, axis=1
        # predictions of Tz and Power
        predictions = env_model.apply(env_model_params, state_action)
        # target is next Tz and power
        Tz_target = next_state[:,1]
        power_target = next_state[:,4]
        
        target = jnp.stack([Tz_target, power_target, reward], axis=1)
        print(target.shape, predictions.shape)
        print("in model update")
        env_model_loss = jnp.mean(jnp.square(predictions - target))
        return env_model_loss

    env_model_loss, env_model_grads = jax.value_and_grad(env_model_loss_fn)(env_model_params)
    env_model_updates, env_model_opt_state = env_model_optimizer.update(env_model_grads, env_model_opt_state)
    env_model_params = optax.apply_updates(env_model_params, env_model_updates)
    return env_model_params, env_model_opt_state

memory = []
reward_history = []
reward_threshold=175 # env.spec.reward_threshold
solved_window = 100

for episode in range(episodes):
    state, _ = env.reset(seed=1)
    done = False
    total_reward = 0
    step_in_episode = 0

    while not done:
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            q_values = q_network.apply(params, jnp.expand_dims(jnp.array(state), axis=0))
            action = jnp.argmax(q_values).item()

        next_state, reward, done, _, _ = env.step(action)

        memory.append((state, action, reward, next_state, done))
        state = next_state
        total_reward += reward
        step_in_episode += 1

        if len(memory) >= batch_size:
            batch = random.sample(memory, batch_size)

            state_batch = jnp.array([s for (s, _, _, _, _) in batch])
            action_batch = jnp.array([a for (_, a, _, _, _) in batch])
            reward_batch = jnp.array([r for (_, _, r, _, _) in batch])
            next_state_batch = jnp.array([ns for (_, _, _, ns, _) in batch])
            done_batch = jnp.array([d for (_, _, _, _, d) in batch], dtype=jnp.float32)

            params, opt_state = q_learning_update(params, opt_state, state_batch, action_batch, reward_batch, next_state_batch, done_batch)
            env_model_params, env_model_opt_state = env_model_update(env_model_params, env_model_opt_state, state_batch, action_batch, next_state_batch, reward_batch)

            for _ in range(planning_steps):
                planning_batch = random.sample(memory, batch_size)

                state_batch = jnp.array([s for (s, _, _, _, _) in planning_batch])
                action_batch = jnp.array([a for (_, a, _, _, _) in batch])

                state_action_batch = jnp.concatenate([state_batch, action_batch[:, np.newaxis]], axis=1)
                predictions = env_model.apply(env_model_params, state_action_batch)

                # replace Tz and power with predicted values for the next state
                Tz_next = predictions[:, 0]
                power_next = predictions[:, 1]
                next_state_batch = jnp.concatenate([Tz_next[:, np.newaxis], next_state_batch[:, 1:3], power_next[:, np.newaxis], next_state_batch[:, 4:]], axis=1)
                reward_batch = predictions[:, 2]
                done_batch = jnp.full(reward_batch.shape, False, dtype=jnp.float32)
                #done_batch = (jnp.abs(jnp.sum(next_state_batch - state_batch, axis=1)) > 0.5).astype(jnp.float32)

                params, opt_state = q_learning_update(params, opt_state, state_batch, action_batch, reward_batch, next_state_batch, done_batch)
        
        # episode stopping: NOT IMPLEMENTED.

    epsilon = max(epsilon * epsilon_decay, 0.01)
    print(f"Episode {episode + 1}, Total Reward: {total_reward}")

    # outputs
    reward_history.append(total_reward)

    # stop training if average reward reaches requirement
    # Calculate the average reward over the last 'solved_window' episodes
    if episode >= solved_window:
        avg_reward = np.mean(reward_history[-solved_window:])
        print(f'Episode: {episode}, Average Reward: {avg_reward}')

        if avg_reward >= reward_threshold:
            print(f"R4C3Discrete-v0 solved in {episode} episodes!")
            break



env is reset!
(64, 3) (64, 3)
in model update
Episode 1, Total Reward: -14.245113476420832
env is reset!


In [None]:
import matplotlib.pyplot as plt
# Plot the historical rewards
plt.plot(reward_history)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Historical Rewards for CartPole-v1")
plt.show()

In [None]:
# plot training 
def plot_moving_average_reward(episode_rewards, window_size=100):
    cumsum_rewards = np.cumsum(episode_rewards)
    moving_avg_rewards = (cumsum_rewards[window_size:] - cumsum_rewards[:-window_size]) / window_size

    plt.plot(moving_avg_rewards)
    plt.xlabel('Episode')
    plt.ylabel('Moving Average Reward')
    plt.title('Moving Average Reward over Episodes')
    plt.show()

plot_moving_average_reward(reward_history)

In [None]:
# need a virtual display for rendering in docker
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()
from IPython import display as ipythondisplay

# Test the trained agent

print("\nTesting the trained agent...")
env = gym.make("CartPole-v1",render_mode='rgb_array').env
state, _ = env.reset()
state = jnp.array(state, dtype=jnp.float32)

total_reward = 0
done = False
pre_screen = env.render()
step_in_episode = 0

while not done:
    q_values = q_network.apply(params, jnp.expand_dims(jnp.array(state), axis=0))
    action = jnp.argmax(q_values).item()
    #action = agent.act(state)
    next_state, reward, done, _, _ = env.step(action)
    next_state = jnp.array(next_state, dtype=jnp.float32)
    screen = env.render()
    state = next_state
    total_reward += reward
    step_in_episode += 1

    plt.imshow(screen)
    ipythondisplay.clear_output(wait=True)
    ipythondisplay.display(plt.gcf())

    # check if the max_episode_steps are met. if so, terminate this episode
    if step_in_episode >= max_episode_steps:
        print(f"Agent reached max_episode_steps in test.")
        break

    ipythondisplay.clear_output(wait=True)
    
print(f"Total Reward: {total_reward}")

env.close()


In [None]:

print(pre_screen)