<a href="https://colab.research.google.com/github/JeremieGauthier/AI_Exercices/blob/master/Atari_Breakout.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import cv2
import gym
import gym.spaces
import numpy as np
import collections


class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        """For environments where the user need to press FIRE for the game to start."""
        super(FireResetEnv, self).__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        """Clear past frame buffer and init. to first obs. from inner env."""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
                                                dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
                                                old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


def make_env(env_name):
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    return ScaledFloatFrame(env)

In [0]:
import matplotlib.pyplot as plt
import numpy as np

def plot_learning_curve(x, scores, epsilons, filename):
    
    fig = plt.figure()
    ax = fig.add_subplot(111, label="1")
    ax2 = fig.add_subplot(111, label="2", frame_on=False)

    ax.plot(x, epsilons, color="C0")
    ax.set_xlabel("Training Step", color="C0")
    ax.set_ylabel("Epsilon", color="C0")
    ax.tick_params(axis="x", color="C0")
    ax.tick_params(axis="y", color="C0")

    N = len(scores)
    running_avg = np.empty(N)
    for t in range(N):
        running_avg[t] = np.mean(scores[max(0, t-100):(t+1)])

    ax2.scatter(x, running_avg, color="C1")
    ax2.axes.get_xaxis().set_visible(False)
    ax2.yaxis.tick_right()
    ax2.set_ylabel("Score", color="C1")
    ax2.yaxis.set_label_position("right")
    ax2.tick_params(axis='y', color="C1")

    plt.savefig(filename)

In [9]:
import gym
import time
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np

from collections import namedtuple, deque
from itertools import count


Experience = namedtuple('Experience',
                        ('state', 'action', 'reward', 'next_state'))

class DQN(nn.Module):
    def __init__(self, num_actions, lr, device):
        super(DQN, self).__init__()

        self.device = device
        
        self.conv1 = nn.Conv2d(
            in_channels=4, out_channels=16, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(
            in_channels=16, out_channels=32, kernel_size=4, stride=2)

        # You have to respect the formula ((W-K+2P/S)+1)
        self.fc = nn.Linear(in_features=32*9*9, out_features=256)
        self.out = nn.Linear(in_features=256, out_features=num_actions)


    def forward(self, state):
        # (1) Hidden Conv. Layer
        self.layer1 = F.relu(self.conv1(state.to(device)))

        # (2) Hidden Conv. Layer
        self.layer2 = F.relu(self.conv2(self.layer1))
        
        # (3) Hidden Linear Layer
        input_layer3 = self.layer2.reshape(-1, 32*9*9)
        self.layer3 = self.fc(input_layer3)

        # (4) Output
        actions = self.out(self.layer3)

        return actions


class ReplayMemory():
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.count = 0

    def add_to_memory(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.capacity % self.count] = experience
        self.count += 1

    def extract_tensor(self, experiences):
        batch = Experience(*zip(*experiences))

        states = torch.cat(batch.state)
        actions = torch.cat(batch.action)
        rewards = torch.tensor(batch.reward)
        #rewards = torch.cat(batch.reward)
        next_actions = torch.cat(batch.next_state)

        return (states, actions, rewards, next_actions)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size


class EpsilonGreedyStrategy():
    def __init__(self, eps_start, eps_end, eps_decay):
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay

    def get_exploration_rate(self, current_step):
        return self.eps_end + (self.eps_start - self.eps_end) * \
            math.exp(-1 * current_step * self.eps_decay)


class Agent():
    def __init__(self, num_actions, strategy, device):
        self.strategy = strategy
        self.num_actions = num_actions
        self.current_step = 0
        self.device = device

    def choose_action(self, state, policy_net):
        self.epsilon = self.strategy.get_exploration_rate(self.current_step)
        self.current_step += 1

        if np.random.random() < self.epsilon:  # Explore
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device)
        else:  # Exploit
            with torch.no_grad():
                return policy_net(state).argmax(dim=1).to(self.device)



if __name__ == "__main__":

    lr = 0.001
    gamma = 0.99
    eps_start = 1
    eps_end = 0.01
    eps_decay = 0.001
    target_update = 10
    num_episodes = 1000
    batch_size = 256
    capacity = 1000000
    max_nb_elements = 4

    scores, eps_history = [], []


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    env = make_env("Breakout-v0")

    strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)

    agent = Agent(env.action_space.n, strategy, device)
    memory = ReplayMemory(capacity)

    policy_network = DQN(env.action_space.n, lr, device).to(device)
    target_network = DQN(env.action_space.n, lr, device).to(device)

    target_network.load_state_dict(policy_network.state_dict())
    target_network.eval()

    optimizer = optim.Adam(params=policy_network.parameters(), lr=lr)
    
    for episode in range(num_episodes):
        obs = env.reset()
        state = env.observation(obs)
        state = torch.tensor(state).unsqueeze(dim=0)
        
        score = 0
        start = time.time()

        for timestep in count():
            action = agent.choose_action(state, policy_network)
            next_state, reward, done, _ = env.step(action) 
            next_state = torch.tensor(next_state).unsqueeze(dim=0)
            memory.add_to_memory(Experience(state, action, reward, next_state))
            state = next_state
            
            score += reward

            if memory.can_provide_sample(batch_size):
                experiences = memory.sample(batch_size)
                states, actions, rewards, next_states = memory.extract_tensor(experiences)

                batch_index = np.arange(batch_size, dtype=np.int32)
                current_q_value = policy_network.forward(states)[batch_index, actions.type(torch.LongTensor)]
                next_q_value = target_network.forward(next_states)
                target_q_value = rewards.to(device) + gamma * torch.max(next_q_value, dim=1)[0]

                loss = nn.MSELoss()
                loss = loss(target_q_value, current_q_value).to(device)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if done:
                break

        scores.append(score)
        eps_history.append(agent.epsilon)

        if episode % target_update == 0:
            target_network.load_state_dict(policy_network.state_dict())

        print("episode :", episode, "epsilon :", agent.epsilon, "score", score,
                "time :", time.time()-start)

        if episode % 20 == 0:
            avg_score = np.mean(scores[-20:])
            print("episode", episode, "score %.1f average score %.1f epsilon %.2f" %
               (score, avg_score, agent.epsilon))
    
    filename = 'Atari_Breakout_DQN.png'
    x = [i+1 for i in range(num_episodes)]
    plot_learning_curve(x, scores, eps_history, filename)




episode : 0 epsilon : 0.9507758838271028 score 1.0 time : 0.1622023582458496
episode 0 score 1.0 average score 1.0 epsilon 0.95
episode : 1 epsilon : 0.9075824147163817 score 0.0 time : 0.14638876914978027
episode : 2 epsilon : 0.8402318035037033 score 2.0 time : 0.2421424388885498
episode : 3 epsilon : 0.7522639763166509 score 4.0 time : 1.7442529201507568
episode : 4 epsilon : 0.7118396387690156 score 0.0 time : 2.3819708824157715
episode : 5 epsilon : 0.6540040037760834 score 3.0 time : 3.7752678394317627
episode : 6 epsilon : 0.5944697016008592 score 3.0 time : 4.254587888717651
episode : 7 epsilon : 0.5549559376949149 score 0.0 time : 3.017728328704834
episode : 8 epsilon : 0.5211713940817868 score 2.0 time : 2.763587236404419
episode : 9 epsilon : 0.4880450348059175 score 0.0 time : 2.891465425491333
episode : 10 epsilon : 0.4486513662013461 score 2.0 time : 3.715885639190674
episode : 11 epsilon : 0.4214571745478475 score 2.0 time : 2.801426887512207
episode : 12 epsilon : 0.390

KeyboardInterrupt: ignored