In [None]:
import jdc
import gc

import numpy as np
import tensorflow as tf
import gym
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
env = gym.make('CartPole-v1')
env.reset()

## Fully-Connected Q Network


In [None]:
class FCQ:
    def __init__(self, input_shape, layers, activation_fc=tf.nn.relu) -> None:
        self.activation_fc = activation_fc
        self.model = tf.keras.Sequential([
            tf.keras.layers.Dense(4, input_shape=input_shape, activation=activation_fc),
            tf.keras.layers.Dense(512, activation=activation_fc),
            tf.keras.layers.Dense(128, activation=activation_fc),
            tf.keras.layers.Dense(2)
        ])
        # self.model.compile(optimizer=tf.keras.optimizers.RMSprop(0.0005), loss=tf.keras.losses.MeanSquaredError())
        self.optimizer = tf.keras.optimizers.RMSprop(0.0005)
        self.loss = tf.keras.losses.MeanSquaredError()

    def fit(self, states, q_targets, masks):
        with tf.GradientTape() as tape:
            q_values = self.model(states)
            q_actions = tf.reduce_sum(tf.multiply(q_values, masks), axis=1, keepdims=True)
            loss_value = self.loss(q_targets, q_actions)

        grads = tape.gradient(loss_value, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
        print(loss_value)

    def predict(self, x):
        return self.model.predict(x)


In [None]:
FCQ((4,), (4,512,128,2)).model.summary()

In [None]:
class NFQ:
    def __init__(self, fcq, epsilon=0.1, gamma=1, batch_size=1024, epochs=40) -> None:
        self.fcq = fcq
        self.batch_size = batch_size
        self. epochs = epochs
        self.epsilon = epsilon
        self.gamma = gamma


In [None]:
%%add_to NFQ

def train(self, n_episodes):
    # 1. get a batch of experiences
    # loop 1->k:
    #   2. calculate target
    #   3. fit Q values with RMSprops & MSE
    # goto 1
    experiences = []
    rewards = np.zeros(n_episodes)
    for i in tqdm(range(n_episodes)):
        state = env.reset()
        while True:
            action = self.epsilon_greedy_policy(state)
            state_p, reward, done, info = env.step(action)
            rewards[i] += reward
            is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']
            is_failure = done and not is_truncated
            experiences.append((state, action, reward, state_p, float(is_failure)))

            if len(experiences) >= self.batch_size:
                experiences = np.array(experiences, dtype=object)
                batches = [np.vstack(s) for s in experiences.T]
                self.optimize(batches)
                experiences = []

            if done:
                break

            state = state_p
    return rewards


In [None]:
# for episode in range(1, max_episodes + 1):
#     episode_start = time.time()
    
#     state, is_terminal = env.reset(), False
#     self.episode_reward.append(0.0)
#     self.episode_timestep.append(0.0)
#     self.episode_exploration.append(0.0)

#     for step in count():
#         action = self.training_strategy.select_action(self.online_model, state)
#         new_state, reward, is_terminal, info = env.step(action)
#         is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']
#         is_failure = is_terminal and not is_truncated
#         experience = (state, action, reward, new_state, float(is_failure))

#         self.experiences.append(experience)
        
#         if len(self.experiences) >= self.batch_size:
#             experiences = np.array(self.experiences)
#             batches = [np.vstack(sars) for sars in experiences.T]
#             experiences = self.online_model.load(batches)
#             for _ in range(self.epochs):
#                 self.optimize_model(experiences)
#             self.experiences.clear()
        
#         if is_terminal:
#             gc.collect()
#             break

In [None]:
%%add_to NFQ

def epsilon_greedy_policy(self, state):
    q_values = self.fcq.predict(np.expand_dims(state, axis=0))[0]
    if np.random.rand() < self.epsilon: 
        return np.random.randint(len(q_values))
    else:
        return np.argmax(q_values)

def greedy_policy(self, state):
    q_values = self.fcq.predict(np.expand_dims(state, axis=0))
    return np.argmax(q_values)


In [None]:
%%add_to NFQ

def optimize(self, batches):
    states, actions, rewards, states_p, is_terminals = batches
    for _ in range(self.epochs):
        q_states_p = self.fcq.model(states_p)
        max_q_states_p = tf.stop_gradient(tf.reduce_max(q_states_p, axis=1, keepdims=True))
        
        q_targets = rewards + self.gamma * max_q_states_p * (1 - is_terminals)
        masks = tf.one_hot(actions.ravel(), 2)
        self.fcq.fit(states, q_targets, masks)

In [None]:
fcq = FCQ((4,), (4,512,128,2))
agent = NFQ(fcq, epsilon=0.2)
rewards = agent.train(1000)

In [None]:
plt.plot(rewards)
plt.show()