In [None]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install --upgrade setuptools 2>&1
!pip install ez_setup > /dev/null 2>&1
!pip install gym[atari] > /dev/null 2>&1
!pip install wandb

Requirement already up-to-date: setuptools in /usr/local/lib/python3.7/dist-packages (56.1.0)


Imports

In [None]:
import torch
import math
import glob
import shutil
import io
import base64
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import gym
import os
import numpy as np
import collections
import cv2

from tqdm import trange
from gym import logger as gymlogger
from gym.wrappers import Monitor
from torch.nn.utils import clip_grad_value_
from collections import namedtuple, deque
from copy import deepcopy
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display

#### Utils for showing atari games

In [None]:
"""
Utility functions to enable video recording of gym environment and displaying it
To enable video, just do "env = wrap_env(env)""
"""
display = Display(visible=0, size=(1400, 900))
display.start()

def show_video():
    mp4list = glob.glob('video/*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                    loop controls style="height: 400px;">
                    <source src="data:video/mp4;base64,{0}" type="video/mp4" />
                    </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Could not find video")
    

def wrap_env(enviroment):
    enviroment = Monitor(enviroment, './video', force=True)
    return enviroment

In [None]:
def play_breakout_and_show_video(model, enviroment):
    enviroment = wrap_env(enviroment)
    state = enviroment.reset()
    total = 0
    while True:
        enviroment.render()
        action = model.choose_action(state)
        next_state, reward, done, _ = enviroment.step(action)
        state = next_state
        total += reward
        if done:
            enviroment.close()
            show_video()
            return total

In [None]:
import wandb

run_name = 'Dead lives=5, [Huber loss], buffer=80k, lr=1e-4'

wandb.init(project='Atari RL', entity='danielto1404', name=run_name)

[34m[1mwandb[0m: Currently logged in as: [33mdanielto1404[0m (use `wandb login --relogin` to force relogin)


#### Selector interface for choosing actions

In [None]:
class ActionSelector:
    def __init__(self, model, atari_mode=False, device=None):
        super(ActionSelector, self).__init__()
        self.model = model
        self.device = device
        self.atari_mode = atari_mode

    @torch.no_grad()
    def choose_action(self, state):
        """
        :return: (best action, Q-value for best action)
        """
        tensor_state = torch.tensor(state).to(self.device)

        if self.atari_mode:
            tensor_state = tensor_state.unsqueeze(0)

        q_values = self.model(tensor_state)
        action = torch.argmax(q_values).item()

        return action

def load_selector(path, atari_mode=True) -> ActionSelector:
    return ActionSelector(model=torch.load(path, map_location=torch.device('cpu')),
                          atari_mode=atari_mode)

#### Epsilon-gready strategy class

In [None]:
class EpsilonStrategy:
    def __init__(self, start=1, decay=.999985, min_eps=0.02):
        self.eps = start
        self.decay = decay
        self.start = start
        self.min_eps = min_eps

    def eps(self):
        return self.eps

    def decrease(self):
        self.eps = max(self.eps * self.decay, self.min_eps)

    def check_random_prob(self):
        return np.random.random() < self.eps

#### Experience buffer class

In [None]:
Transition = namedtuple(typename='Transition',
                        field_names=['state', 'next_state', 'action', 'reward', 'terminal'])

class ExperienceBuffer:
    def __init__(self, capacity=10_000, batch_size=32, start_sample_from=10_000):
        capacity = int(capacity)
        batch_size = int(batch_size)
        start_sample_from = int(start_sample_from)

        if batch_size > capacity:
            raise AssertionError('random sample size should be <= size')

        if batch_size > start_sample_from:
            raise AssertionError('start sample from should be >= batch_size')

        self.buffer = deque(maxlen=capacity)
        self.capacity = capacity
        self.batch_size = batch_size
        self.batch_indices = np.arange(self.batch_size)
        self.start_sample_from = start_sample_from

    def __len__(self):
        return len(self.buffer)

    def sample_batch(self):
        if not self.is_ready_for_sample():
            raise AssertionError('Buffer have is not ready for sample')

        indices = np.random.choice(len(self.buffer), self.batch_size, replace=False)
        states, next_states, actions, rewards, terminals = zip(*[self.buffer[i] for i in indices])
        return (np.array(states),
                np.array(next_states),
                np.array(actions, dtype=np.int64),
                np.array(rewards, dtype=np.float32),
                np.array(terminals, dtype=np.uint8))

    def store_transition(self, transition: Transition):
        self.buffer.append(transition)

    def is_ready_for_sample(self):
        return len(self.buffer) >= self.start_sample_from

#### Enviroment wrappers

In [None]:
class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None, terminal_on_life_loss=False, lives=5):
        super(FireResetEnv, self).__init__(env)
        self.terminal_on_life_loss = terminal_on_life_loss
        self.FIRE_ACTION = 1
        self.lives = lives
        self.initial_lives = lives
        assert env.unwrapped.get_action_meanings()[self.FIRE_ACTION] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        state, reward, done, info = self.env.step(action)
        lives = info.get('ale.lives', self.initial_lives)
        if lives != self.lives:
            self.lives = lives
            self.env.step(self.FIRE_ACTION)

        return state, reward, done or (lives != self.initial_lives and self.terminal_on_life_loss), info

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(self.FIRE_ACTION)
        if done:
            self.reset()

        way = np.random.choice([2, 3])
        for _ in range(np.random.randint(6)):
            obs, _, done, _ = self.step(way)
        if done:
            self.reset()
        return obs


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4, use_for_pool=2):
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=use_for_pool)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done, info = None, None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0,
                                                high=255,
                                                shape=(84, 84, 1),
                                                dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, frames, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        self.frames = frames
        self.buffer = None
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(frames, axis=0),
                                                old_space.high.repeat(frames, axis=0),
                                                dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        height, width, frames = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0,
                                                high=1.0,
                                                shape=(frames, height, width),
                                                dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, observation):
        return np.array(observation, dtype=np.float32) / 255.0


def make_env(env_name, terminal_on_life_loss=True):
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env, skip=4, use_for_pool=2)
    env = FireResetEnv(env, terminal_on_life_loss=terminal_on_life_loss)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, frames=4)
    env = ScaledFloatFrame(env)
    return env

#### Action wrappers

In [None]:
class ActionWrapper:
    def action(self, a):
        return a

class BreakoutFireDropActionWrapper(ActionWrapper):
    def action(self, a):
        return 0 if a == 0 else a + 1

#### DQN Agent class

In [None]:
class Agent:
    def __init__(self,
                 state_dim,
                 action_dim,
                 env,
                 lr=0.0003,
                 gamma=0.99,
                 loss_function=nn.MSELoss(),
                 update_model_frequency=1000,
                 experience_buffer=ExperienceBuffer(),
                 eps_strategy=EpsilonStrategy(decay=5e-2),
                 device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.env = env

        # Buffer
        self.experience_buffer = experience_buffer

        # Constants
        self.eps_strategy = eps_strategy
        self.gamma = gamma

        self.update_model_frequency = update_model_frequency
        self.lr = lr

        # Network
        self.loss_function = loss_function
        self.device = device

        # Default values
        self.optimizer = None
        self.policy_model, self.target_model = None, None
        self.policy_action_selector = None

        self.set_policy_model(self.__get_model__())
        self.set_target_model(self.__get_model__())

    def choose_eps_action(self, state):
        """
        Represents eps-greedy selection according to given eps-strategy.
        """
        if self.eps_strategy.check_random_prob():
            return np.random.randint(0, self.action_dim)
        else:
            return self.choose_action(state)

    def choose_action(self, state, model=None):
        """
        :return: best action
        """
        if model is None:
            model = self.policy_model

        numpy_state  = np.array([state], copy=False)
        tensor_state = torch.tensor(numpy_state).to(self.device)
        q_values     = model(tensor_state)
        action       = torch.argmax(q_values).item()
        return action

    def store_transition(self, state, next_state, action, reward, terminal):
        """
        Stores environment transition in buffer.
        """
        transition = Transition(state=state,
                                next_state=next_state,
                                action=action,
                                reward=reward,
                                terminal=terminal)
        self.experience_buffer.store_transition(transition)

    def learn(self, episode, test_games=(0, 0)):
        """
        Trains Q-function net.
        Using fixed target-model to predict target Q-function and q-model to choose actions while training.

        test_games -- see doc for 'Trainer'
        """
        test_games_freq, n_test_games = test_games

        if test_games_freq != 0 and episode % test_games_freq == 0:
            mean_reward = self.play_games(n_test_games)
            wandb.log({'test games mean reward' : mean_reward})


        if not self.experience_buffer.is_ready_for_sample():
            return

        states, next_states, actions, rewards, terminals = self.__sample_batch__()

        q_eval = self.policy_model(states)[self.experience_buffer.batch_indices, actions]

        with torch.no_grad():
            q_future = self.target_model(next_states).max(dim=1).values
            q_future[terminals] = 0.0
            q_target = rewards + self.gamma * q_future

        self.__fit_network__(q_eval, q_target)

        if episode % self.update_model_frequency == 0:
            self.target_model.load_state_dict(self.policy_model.state_dict())

    def play_games(self, n_games):
        selector = ActionSelector(self.policy_model, 
                                  atari_mode=True, 
                                  device=self.device)
        rs = np.zeros(n_games)
        for i in range(n_games):
            e = make_env('BreakoutNoFrameskip-v4')
            s = e.reset()
            t = 0
            while True:
                a = selector.choose_action(s)
                next_s, r, done, _ = e.step(a)
                s = next_s
                t += r
                if done:
                    rs[i] = t
                    break

        return rs.mean()

    def set_policy_model(self, model):
        self.policy_model = model.to(self.device)  
        self.optimizer = torch.optim.Adam(self.policy_model.parameters(), 
                                          lr=self.lr)

    def set_target_model(self, model):
        self.target_model = model.to(self.device)

    def __get_model__(self):
        pass

    def __sample_batch__(self):
        """
        :return: ( [states], [next_states], [actions], [rewards], [is_done] )
        """
        states, next_states, actions, rewards, terminals = self.experience_buffer.sample_batch()
        return [
            torch.FloatTensor(states).to(self.device),
            torch.FloatTensor(next_states).to(self.device),
            torch.LongTensor(actions).to(self.device),
            torch.FloatTensor(rewards).to(self.device),
            torch.BoolTensor(terminals).to(self.device)
        ]

    def __fit_network__(self, q_eval, q_target):
        loss = self.loss_function(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_value_(self.policy_model.parameters(), clip_value=1)
        self.optimizer.step()

#### Atari CNN

In [None]:
class AtariCNN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(AtariCNN, self).__init__()

        frames, width, height = input_shape

        self.conv1  = nn.Conv2d(frames, 32, kernel_size=(8, 8), stride=(4, 4))
        self.conv2  = nn.Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
        self.conv3  = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))

        self.dense1 = nn.Linear(self.__get_dense_shape__(input_shape), 512)
        self.dense2 = nn.Linear(512, n_actions)

        torch.nn.init.kaiming_normal_(self.conv1.weight)
        torch.nn.init.kaiming_normal_(self.conv2.weight)
        torch.nn.init.kaiming_normal_(self.conv3.weight)
        torch.nn.init.kaiming_normal_(self.dense1.weight)
        torch.nn.init.kaiming_normal_(self.dense2.weight)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.02)
        x = F.leaky_relu(self.conv2(x), 0.02)
        x = F.leaky_relu(self.conv3(x), 0.02)
        x = F.leaky_relu(self.dense1(x.view(x.shape[0], -1)), 0.02)
        x = self.dense2(x)
        return x

    def conv(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

    def __get_dense_shape__(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size())) 

#### Atari agent and trainer classes

In [None]:
class AtariAgent(Agent):
    def __get_model__(self):
        return AtariCNN(input_shape=self.state_dim, n_actions=self.action_dim)

class AtariAgentTrainer:
    def __init__(self,
                 agent: Agent,
                 episodes,
                 action_wrapper=ActionWrapper()):
        self.agent = agent
        self.episodes = episodes
        self.action_wrapper = action_wrapper

    def train(self, path='trained_models/atari', test_games=(0, 0), max_episodes_per_game=None):
        """
        path                 : path for saving model
        test_games           : (test_games_frequency, amount of games)
        max_episodes_per_game: amount of episodes for stoping long game.
        """
        rewards, test_rewards = [], []
        game_reward, best_mean_reward, game_episode = 0, 0, 0
        state = self.agent.env.reset()
        progress = trange(self.episodes, desc='epochs')

        def check_for_stop_game(game_episode_):
            if max_episodes_per_game is None:
                return False
            return game_episode_ > max_episodes_per_game

        for episode in progress:
            self.agent.eps_strategy.decrease()
            game_episode += 1

            action  = self.agent.choose_eps_action(state)
            wrapped_action = self.action_wrapper.action(action)
            next_state, reward, done, _ = self.agent.env.step(wrapped_action)

            game_reward += reward

            self.agent.store_transition(state=state,
                                        next_state=next_state,
                                        action=action,
                                        terminal=done,
                                        reward=reward)

            self.agent.learn(episode=episode, test_games=test_games)

            if done or check_for_stop_game(game_episode):
                rewards.append(game_reward)
                game_reward, game_episode = 0, 0

                mean_reward = np.mean(rewards[-100:])

                if mean_reward > best_mean_reward:
                    best_mean_reward = mean_reward
                    torch.save(self.agent.policy_model, '{}_best.pt'.format(path))

                wandb.log({
                    'mean reward': mean_reward,
                    'epsilon' : self.agent.eps_strategy.eps,
                    'last game reward': rewards[-1]
                })

                progress_status = "last game reward: {} | games: {} | mean reward: {:02f} | epsilon: {:02f}".format(
                    rewards[-1], len(rewards), mean_reward, self.agent.eps_strategy.eps
                )
        
                progress.set_postfix_str(progress_status)

                state = self.agent.env.reset()
            else:
                state = next_state

        torch.save(self.agent.policy_model, '{}.pt'.format(path))
        return rewards

In [None]:
def preload_model(agent, path):
    agent.set_policy_model(torch.load(path))
    agent.set_target_model(torch.load(path))

In [None]:
env_name = 'BreakoutNoFrameskip-v4'

env = make_env(env_name, terminal_on_life_loss=False)

atari_agent = AtariAgent(state_dim=(4, 84, 84),
                         action_dim=3,
                         lr=1e-4,
                         env=env,   
                         update_model_frequency=1_000,
                         loss_function=nn.SmoothL1Loss(),
                         experience_buffer=ExperienceBuffer(capacity=80_000,
                                                            batch_size=32,
                                                            start_sample_from=30_000),
                         
                         eps_strategy=EpsilonStrategy(start=1, decay=.999990, min_eps=0.1)
                         )


# preload_model(atari_agent, '/content/atari-xx_best.pt')


trainer = AtariAgentTrainer(agent=atari_agent,                         
                            episodes=1_500_000,
                            action_wrapper=BreakoutFireDropActionWrapper())

#### Start training

In [None]:
test_rewards = trainer.train(path='breakout-dead-lives-colab')

epochs:  91%|█████████ | 1367321/1500000 [4:58:02<29:13, 75.66it/s, last game reward: 53.0 | games: 3659 | mean reward: 30.030000 | epsilon: 0.100000]