In [6]:
import copy
import os
import random

import numpy as np
import torch
from gym import make
from torch import nn
from torch.optim import Adam
from tqdm import tqdm

from train import set_seed, DQN, evaluate_policy, DeepQNetworkModel

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
INITIAL_STEPS = 1024
TRANSITIONS = 500_000
LEARNING_RATE = 5e-4
HID_DIM = 64

# Train

In [2]:
set_seed(SEED)

env = make("LunarLander-v2")
dqn = DQN(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n, hid_dim=HID_DIM)
eps = 0.1
state = env.reset()

for _ in range(INITIAL_STEPS):
    action = env.action_space.sample()

    next_state, reward, done, *_ = env.step(action)
    dqn.consume_transition((state, action, next_state, reward, done))

    state = next_state if not done else env.reset()

best_avg_rewards = -np.inf
# pbar = tqdm(total=TRANSITIONS)
for i in range(TRANSITIONS):
    # Epsilon-greedy policy
    if random.random() < eps:
        action = env.action_space.sample()
    else:
        action = dqn.act(state)

    next_state, reward, done, *_ = env.step(action)
    dqn.update((state, action, next_state, reward, done))

    state = next_state if not done else env.reset()

    # pbar.update(1)

    if (i + 1) % (TRANSITIONS // 100) == 0:
        rewards = evaluate_policy(dqn, 5)
        avg_reward = np.mean(rewards)
        # pbar.set_description(
        #     f"Best reward mean: {best_avg_rewards:.2f}, Reward mean: {avg_reward:.2f}, Reward std: {np.std(rewards):.2f}"
        # )
        print(f"Step: {i + 1}/{TRANSITIONS}, Best reward mean: {best_avg_rewards:.2f}, Reward mean: {avg_reward:.2f}, Reward std: {np.std(rewards):.2f}")
        if avg_reward > best_avg_rewards:
            best_avg_rewards = avg_reward
            dqn.save()

  deprecation(
  deprecation(


Step: 5000/500000, Best reward mean: -inf, Reward mean: -108.50, Reward std: 64.67
Step: 10000/500000, Best reward mean: -108.50, Reward mean: -186.47, Reward std: 53.22
Step: 15000/500000, Best reward mean: -108.50, Reward mean: -204.06, Reward std: 142.76
Step: 20000/500000, Best reward mean: -108.50, Reward mean: 3.61, Reward std: 136.35
Step: 25000/500000, Best reward mean: 3.61, Reward mean: -164.20, Reward std: 158.91
Step: 30000/500000, Best reward mean: 3.61, Reward mean: -163.38, Reward std: 66.36
Step: 35000/500000, Best reward mean: 3.61, Reward mean: -86.18, Reward std: 37.95
Step: 40000/500000, Best reward mean: 3.61, Reward mean: -43.23, Reward std: 78.89
Step: 45000/500000, Best reward mean: 3.61, Reward mean: -44.61, Reward std: 154.63
Step: 50000/500000, Best reward mean: 3.61, Reward mean: -98.33, Reward std: 138.40
Step: 55000/500000, Best reward mean: 3.61, Reward mean: -58.85, Reward std: 6.04
Step: 60000/500000, Best reward mean: 3.61, Reward mean: -205.78, Reward

# Inference

In [7]:
class Agent:
    def __init__(self):
        self.model = DeepQNetworkModel(8, 4, 64)
        weights = torch.load("agent.pth")
        self.model.load_state_dict(weights)
        self.model.to(DEVICE)
        self.model.eval()

    def act(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            action = np.argmax(self.model(state).cpu().numpy())
        return action

In [22]:
set_seed(SEED)

rewards = evaluate_policy(Agent(), 50)
print(np.mean(rewards))

  deprecation(
  deprecation(


241.78670184807845


# Video

In [21]:
import glob
import io
import base64
from gym.wrappers.monitoring import video_recorder
from IPython import display 

def show_video(env_name):
    mp4list = glob.glob('*.mp4')
    if len(mp4list) > 0:
        mp4 = f'{env_name}.mp4'
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        display.display(display.HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
    else:
        print("Could not find video")
        
def render_video_of_model(agent, env_name):
    env = make(env_name)
    vid = video_recorder.VideoRecorder(env, path=f"{env_name}.mp4")
    state = env.reset()
    done = False
    while not done:
        frame = env.render(mode='rgb_array')
        vid.capture_frame()
        
        action = agent.act(state)

        state, reward, done, _ = env.step(action)        
    env.close()


set_seed(SEED)
agent = Agent()
render_video_of_model(agent, 'LunarLander-v2')
show_video('LunarLander-v2')

  deprecation(
  deprecation(
  logger.deprecation(
  logger.deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
