In [None]:
# %%capture
# !apt-get update
# !apt-get install -y xvfb python-opengl ffmpeg
# !pip install pyglet==1.3.2
# !pip install gym pyvirtualdisplay
# !pip install torch
# !pip install xvfbwrapper
# pip install moviepy
import gym
from gym.wrappers.record_video import RecordVideo
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import math
import time
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

Construimos el agente DQN y algunas funciones auxiliares

In [None]:
env = gym.make('CartPole-v0')
num_features = env.observation_space.shape[0]
num_actions = env.action_space.n
print('Number of state features: {}'.format(num_features))
print('Number of possible actions: {}'.format(num_actions))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class DQN(nn.Module):
    """Dense neural network class."""
    def __init__(self, num_inputs, num_actions):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, 32)
        self.fc2 = nn.Linear(32, 32)
        self.out = nn.Linear(32, num_actions)

    def forward(self, states):
        """Forward pass."""
        x = F.relu(self.fc1(states))
        x = F.relu(self.fc2(x))
        return self.out(x)

main_nn = DQN(num_features, num_actions).to(device)
target_nn = DQN(num_features, num_actions).to(device)

optimizer = torch.optim.Adam(main_nn.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

In [None]:
class ReplayBuffer(object):
    """Experience replay buffer that samples uniformly."""
    def __init__(self, size, device="cpu"):
        """Initializes the buffer."""
        self.buffer = deque(maxlen=size)
        self.device = device

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def __len__(self):
        return len(self.buffer)

    def sample(self, num_samples):
        states, actions, rewards, next_states, dones = [], [], [], [], []
        idx = np.random.choice(len(self.buffer), num_samples)
        for i in idx:
            elem = self.buffer[i]
            state, action, reward, next_state, done = elem
            states.append(np.array(state, copy=False))
            actions.append(np.array(action, copy=False))
            rewards.append(reward)
            next_states.append(np.array(next_state, copy=False))
            dones.append(done)
        states = torch.as_tensor(np.array(states), device=self.device)
        actions = torch.as_tensor(np.array(actions), device=self.device)
        rewards = torch.as_tensor(
            np.array(rewards, dtype=np.float32), device=self.device)
        next_states = torch.as_tensor(np.array(next_states), device=self.device)
        dones = torch.as_tensor(np.array(dones, dtype=np.float32), device=self.device)
        return states, actions, rewards, next_states, dones

In [None]:
def select_epsilon_greedy_action(state, epsilon):
    """Take random action with probability epsilon, else take best action."""
    result = np.random.uniform()
    if result < epsilon:
        return env.action_space.sample() 
    else:
        qs = main_nn(state).cpu().data.numpy()
        return np.argmax(qs)

In [None]:
def train_step(states, actions, rewards, next_states, dones):
    """Perform a training iteration on a batch of data sampled from the experience
    replay buffer."""
    max_next_qs = target_nn(next_states).max(-1).values
    target = rewards + (1.0 - dones) * discount * max_next_qs
    qs = main_nn(states)
    action_masks = F.one_hot(actions, num_actions)
    masked_qs = (action_masks * qs).sum(dim=-1)
    loss = loss_fn(masked_qs, target.detach())
    #nn.utils.clip_grad_norm_(loss, max_norm=10)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

Corremos el algoritmo DQN y observamos como aprende el algoritmo

In [None]:
# Hyperparametros.
num_episodes = 1000
epsilon = 1.0
batch_size = 32
discount = 0.99
buffer = ReplayBuffer(100000, device=device)
cur_frame = 0

# Empieza a entrenar. Juega una vez y luego entrena con un batch.
last_100_ep_rewards = []
for episode in range(num_episodes+1):
    state = env.reset()[0].astype(np.float32)
    ep_reward, done = 0, False
    while not done:
        state_in = torch.from_numpy(np.expand_dims(state, axis=0)).to(device)
        action = select_epsilon_greedy_action(state_in, epsilon)
        next_state, reward, done, info = env.step(action)[:4]
        next_state = next_state.astype(np.float32)
        ep_reward += reward
        # Guardamos para reproducir la experiencia.
        buffer.add(state, action, reward, next_state, done)
        state = next_state
        cur_frame += 1
        # Copiamos los pesos de main_nn a target_nn.
        if cur_frame % 2000 == 0:
            target_nn.load_state_dict(main_nn.state_dict())
    
        # Entrenamos la red neuronal.
        if len(buffer) > batch_size:
            states, actions, rewards, next_states, dones = buffer.sample(batch_size)
            loss = train_step(states, actions, rewards, next_states, dones)

    if episode < 950:
        epsilon -= 0.001

    if len(last_100_ep_rewards) == 100:
        last_100_ep_rewards = last_100_ep_rewards[1:]
    last_100_ep_rewards.append(ep_reward)

    if episode % 50 == 0:
        print(f'Episode {episode}/{num_episodes}. Epsilon: {epsilon:.3f}.'
          f' Reward in last 100 episodes: {np.mean(last_100_ep_rewards):.2f}')
    
env.close()

Mostrar el resultado del agente DQN entrenado en el entorno Cartpole

In [None]:
def show_video():
    """Enables video recording of gym environment and shows it."""
    mp4list = glob.glob('video/*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Video not found")

def wrap_env(env):
    env = RecordVideo(env, './video', episode_trigger = lambda episode_number: True)
    return env

In [None]:
env = wrap_env(gym.make('CartPole-v0'))
state = env.reset()[0]
done = False
ep_rew = 0
while not done:
    env.render()
    state = state.astype(np.float32)
    state = torch.from_numpy(np.expand_dims(state, axis=0)).to(device)
    action = select_epsilon_greedy_action(state, epsilon=0.01)
    state, reward, done, info = env.step(action)[0]
    ep_rew += reward
print('Return on this episode: {}'.format(ep_rew))
env.close()
show_video()