In [23]:
from typing import Dict, Tuple

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

## Replay buffer

Reference: [OpenAI spinning-up](https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py#L10)

In [24]:
class ReplayBuffer:
    """A simple FIFO experience replay buffer for SAC agents."""

    def __init__(self, obs_dim: int, act_dim: int, size: int, batch_size: int = 32):
        self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size, self.batch_size = 0, 0, size, batch_size

    def store(self, obs: np.ndarray, act: int, rew: float,
              next_obs: np.ndarray, done: bool):
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self) -> Dict[str, np.ndarray]:
        idxs = np.random.choice(self.size, size=self.batch_size, replace=False)
        return dict(obs1=self.obs1_buf[idxs],
                    obs2=self.obs2_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

## Device

In [25]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Environment (CartPole)

In [26]:
env_id = "CartPole-v0"
env = gym.make(env_id)

## Network

In [27]:
class Network(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(env.observation_space.shape[0], 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.n)
        )
        
    def forward(self, x):
        return self.layers(x)

## DQN Agent

In [29]:
class DQNAgent:
    """DQN Agent interacting with environment.
    
    Attribute:
        memory (PrioritizedReplayBuffer): replay memory
        dqn (nn.Module): actor model to select actions
        dqn_target (nn.Module): target actor model to select actions
        dqn_optimizer (Optimizer): optimizer for training actor
        epsilon (float): parameter for epsilon greedy policy
    """

    def __init__(self, env: gym.Env, memory: ReplayBuffer, update_period: int, 
                 epsilon: float = 1.0, gamma: float = 0.99, lr: float = 1e-4):
        """Initialization.
        Args:
            env (gym.Env): openAI Gym environment
        """
        self.env = env
        self.memory = memory
        self.epsilon = epsilon
        self.update_period = update_period
        self.gamma = gamma

        # networks: dqn, dqn_target
        self.dqn = Network().to(device)
        self.dqn_target = Network().to(device)
        self.dqn_target.load_state_dict(self.dqn.state_dict())

        # optimizer
        self.optimizer = optim.Adam(self.dqn.parameters(), lr=LR)

        # counter
        self.cnt_update = 0

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input space."""
        # epsilon greedy policy
        if not self.args.test and self.epsilon > np.random.random():
            selected_action = self.env.action_space.sample()
        else:
            state = torch.FloatTensor(state).to(device)
            selected_action = self.dqn(state).argmax()
            selected_action = selected_action.detach().cpu().numpy()
        return selected_action

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool]:
        """Take an action and return the response of the env."""
        next_state, reward, done, _ = self.env.step(action)

        return next_state, reward, done

    def compute_dqn_loss(self, samples: Dict[str, np.ndarray]) -> torch.Tensor:
        """Return dqn loss."""
        state = torch.FloatTensor(samples["obs1"])
        next_state = torch.FloatTensor(samples["obs2"])
        action = samples["act"]
        reward = samples["rews"]
        done = samples["done"]

        q_value = self.dqn(states)
        next_q_value = self.dqn_target(next_states)

        curr_q_value = q_value.gather(1, action.unsqueeze(1))
        next_q_value = next_q_value.gather(1, next_q_values.argmax(1).unsqueeze(1))

        # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        masks = 1 - dones
        target = rewards + self.gamma * next_q_value * masks
        target = target.to(device)

        # calculate dq loss
        dq_loss = F.smooth_l1_loss(curr_q_value, target.detach())

        return dq_loss

    def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Train the model after each episode."""
        self.cnt_update += 1
        samples = self.memory.sample_batch(self.beta)

        dq_loss = self.compute_dqn_loss(samples)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target networks
        if self.cnt_update % self.update_period == 0:
            self.update_target_n

        return loss.item()

    def update_target_net(self):
        """Update target network's weights."""
        self.dqn_target.load_state_dict(self.dqn.state_dict())

## Training

In [30]:
def plot(frame_idx, rewards, losses):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s' % (frame_idx, np.mean(rewards[-10:])))
    plt.plot(rewards)
    plt.subplot(132)
    plt.title('loss')
    plt.plot(losses)
    plt.show()

In [None]:
num_frames = 10000
batch_size = 32
gamma      = 0.99

losses = []
all_rewards = []
episode_reward = 0

state = env.reset()
for frame_idx in range(1, num_frames + 1):
    epsilon = epsilon_by_frame(frame_idx)
    action = model.act(state, epsilon)
    
    next_state, reward, done, _ = env.step(action)
    replay_buffer.push(state, action, reward, next_state, done)
    
    state = next_state
    episode_reward += reward
    
    if done:
        state = env.reset()
        all_rewards.append(episode_reward)
        episode_reward = 0
        
    if len(replay_buffer) > batch_size:
        loss = compute_td_loss(batch_size)
        losses.append(loss.data[0])
        
    if frame_idx % 200 == 0:
        plot(frame_idx, all_rewards, losses)