In [None]:
'''
WORKING IN GOOGLE COLAB

'''

import gym
import cv2
from collections import deque, namedtuple
import numpy as np
import random
from keras.models import Sequential
from keras.layers import Conv2D, Flatten, Dense
from keras.optimizers import Adam

conv = namedtuple('Conv', 'filter kernel stride')


class Buffer:
    def __init__(self, size):
        self.size = size
        self.buffer = deque()

    def add(self, s, a, r, s2, t):
        s = np.stack((s[0], s[1], s[2], s[3]), axis=2)
        s2 = np.stack((s2[0], s2[1], s2[2], s2[3]), axis=2)
        if len(self.buffer) < self.size:
            self.buffer.appendleft((s, a, r, s2, t))
        else:
            self.buffer.pop()
            self.buffer.appendleft((s, a, r, s2, t))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)


class DQN:
    def __init__(self, buff, batch_size=32, min_buff=10000, gamma=0.99, learning_rate=2.5e-4):
        self.buffer = buff
        self.min_buffer = min_buff
        self.batch_size = batch_size
        self.gamma = gamma

        self.model = create_network(learning_rate)
        self.target_model = create_network(learning_rate)
        self.copy_network()

    def train(self):
        if len(self.buffer.buffer) < self.min_buffer:
            return
        states, actions, rewards, next_states, terminal = map(np.array, zip(*self.buffer.sample(self.batch_size)))
        next_state_action_values = np.max(self.target_model.predict(next_states), axis=1)
        targets = self.model.predict(states)
        targets[range(self.batch_size), actions] = rewards + self.gamma * next_state_action_values * np.invert(terminal)
        self.model.train_on_batch(states, targets)

    def copy_network(self):
        frm = self.model
        to = self.target_model
        for l_tg, l_sr in zip(to.layers, frm.layers):
            wk = l_sr.get_weights()
            l_tg.set_weights(wk)

    def predict(self, x):
        s = np.stack((x[0], x[1], x[2], x[3]), axis=2)
        return self.model.predict(np.array([s]))


def create_network(learning_rate, conv_info=[conv(32, 8, 4), conv(64, 4, 2), conv(64, 3, 1)], dense_info=[512],
                   input_size=(80, 80, 4)):
    model = Sequential()
    for i, cl in enumerate(conv_info):
        if i == 0:
            model.add(Conv2D(cl.filter, cl.kernel, padding="same", strides=cl.stride, activation="relu",
                             input_shape=input_size))
        else:
            model.add(Conv2D(cl.filter, cl.kernel, padding="same", strides=cl.stride, activation="relu"))
    model.add(Flatten())
    for dl in dense_info:
        model.add(Dense(dl, activation="relu"))
    model.add(Dense(6))
    adam = Adam(lr=learning_rate)
    model.compile(loss='mse', optimizer=adam)
    return model


class Pong:
    def __init__(self):
        self.env = gym.make('Pong-v0')
        self.epsilon = 1
        self.buffer = Buffer(50000)
        self.dqn = DQN(self.buffer)
        self.copy_period = 40000
        self.itr = 0
        self.eps_step = 0.0000009

    def sample_action(self, s):
        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        return np.argmax(self.dqn.predict(s)[0])

    def play_one_episode(self):
        observation = self.env.reset()
        done = False
        state = []
        update_state(state, observation)
        prv_state = []
        total_reward = 0

        while not done:
            if len(state) < 4:
                action = self.env.action_space.sample()
            else:
                action = self.sample_action(state)

            prv_state.append(state[-1])
            if len(prv_state) > 4:
                prv_state.pop(0)

        # Update the state using the step result
            step_result = self.env.step(action)
            observation, reward, done, _ = step_result[:4]

            update_state(state, observation)
            if len(state) == 4 and len(prv_state) == 4:
                self.buffer.add(prv_state, action, reward, state, done)
            total_reward += reward

            self.itr += 1
            if self.itr % 4 == 0:
                self.dqn.train()
            self.epsilon = max(0.1, self.epsilon - self.eps_step)
            if self.itr % self.copy_period == 0:
                self.dqn.copy_network()

            return total_reward



def downsample(observation):
    s = cv2.cvtColor(observation[30:, :, :], cv2.COLOR_BGR2GRAY)
    s = cv2.resize(s, (80, 80), interpolation=cv2.INTER_AREA)
    s = s / 255.0
    return s




def update_state(state, observation):
    ds_observation = downsample(observation[0])
    state.append(ds_observation)
    if len(state) > 4:
        state.pop(0)


p = Pong()
for i in range(100000):
    total_reward = p.play_one_episode()
    print("episode total reward:", total_reward)
    if i % 100 == 0:
        print("Saving the model")
        p.dqn.model.save("model-{}.h5".format(i))