In [13]:
import gymnasium
import numpy as np
import random
import tensorflow as tf

In [None]:
env = gymnasium.make("ALE/Seaquest-v5", render_mode="human")

### Seeing the agent in action

First, lets take a moment and visualize how the model performs at a completely untrained state. This simulation will demonstrate just how much the agent improved over time.

##### Untrained Model

In [None]:

observation, info = env.reset(seed=1)
for _ in range(1000):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)

    if terminated or truncated:
      observation, info = env.reset()
        

env.close()

##### Trained Model

Lets create a model and upload the weights from our trained model. We can bake these in with Tensorflow's save_weights() during training, then we can bring them back with load_weights(), which we will do now.

In [None]:
def create_model(input_shape, output_classes):
    """
    - Network
        - input: 84x84x4
        - conv1: 32 filters of 8x8 with stride 4 rectifier nonlinearly
        - conv2: 64 filters of 4x4 with stride 2 rectifier nonlinearly
        - conv3: 64 filters of 3x3 with stride 1 rectifier
        - dense 512 rectifier units
        - output: action space
    """

    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (8,8), strides=4, activation='relu', input_shape=input_shape),
        tf.keras.layers.Conv2D(64, (4,4), strides=2, activation='relu'),
        tf.keras.layers.Conv2D(64, (3,3), strides=1, activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(output_classes)
    ])

    return model

In [None]:
model = create_model((84,84,1), env.action_space.n)
model.load_weights('./saved_models/double_dqn/sequest_target_action_value_network_ep40')

##### Visualising the trained agent

In [None]:


# observation, info = env.reset(seed=1)
# for _ in range(1000):
#     action = env.action_space.n
#     next_state, reward, terminated, truncated, info = env.step(action)
    
#     if terminated or truncated:
#         observation, info = env.reset()
        
# env.close