In [None]:
import gym
import random
import numpy as np
import os
import pylab
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
import cv2
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
from google.colab.patches import cv2_imshow
import pygame

class OurModel(nn.Module):
    def __init__(self, input_shape, action_space):
        super(OurModel, self).__init__()
        self.fc1 = nn.Linear(input_shape[0], 24)
        self.fc2 = nn.Linear(24, 24)
        self.fc3 = nn.Linear(24, action_space)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

class DQNAgent:
    def __init__(self, env_name):
        self.env_name = env_name
        self.env = gym.make(env_name)
        #self.env.seed(0)
        self.env._max_episode_steps = 4000
        self.state_size = self.env.observation_space.shape[0]
        self.action_size = self.env.action_space.n

        self.EPISODES = 100
        self.memory = deque(maxlen=2000)

        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.999
        self.batch_size = 32
        self.train_start = 1000

        self.ROWS = 160
        self.COLS = 240
        self.REM_STEP=4
        self.image_memory = np.zeros((self.REM_STEP, self.ROWS, self.COLS))


        pygame.init()
        self.screen = pygame.display.set_mode((640, 480))
        self.clock = pygame.time.Clock()
        self.metadata = {"render_fps": 30}

        self.ddqn = True
        self.Soft_Update = False

        self.TAU = 0.1

        self.Save_Path = 'Models'
        self.scores, self.episodes, self.average = [], [], []

        if self.ddqn:
            print("----------Double DQN--------")
            self.Model_name = os.path.join(self.Save_Path, "DDQN_"+self.env_name+".pt")
        else:
            print("-------------DQN------------")
            self.Model_name = os.path.join(self.Save_Path, "DQN_"+self.env_name+".pt")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = OurModel(input_shape=(self.state_size,), action_space=self.action_size).float()
        self.target_model = OurModel(input_shape=(self.state_size,), action_space=self.action_size).float()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.criterion = nn.MSELoss()

    def update_target_model(self):
        if not self.Soft_Update and self.ddqn:
            self.target_model.load_state_dict(self.model.state_dict())
        elif self.Soft_Update and self.ddqn:
            for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - self.TAU) + param.data * self.TAU)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        if len(self.memory) > self.train_start and self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def act(self, state):
        if np.random.random() <= self.epsilon:
            return random.randrange(self.action_size)
        else:
            state = torch.FloatTensor(state).float()
            with torch.no_grad():
                return np.argmax(self.model(state).numpy())

    def replay(self):
        if len(self.memory) < self.train_start:
            return

        minibatch = random.sample(self.memory, min(len(self.memory), self.batch_size))

        state = np.zeros((self.batch_size, self.state_size))
        next_state = np.zeros((self.batch_size, self.state_size))
        action, reward, done = [], [], []

        for i in range(self.batch_size):
            state[i] = minibatch[i][0]
            action.append(minibatch[i][1])
            reward.append(minibatch[i][2])
            next_state[i] = minibatch[i][3]
            done.append(minibatch[i][4])

        state = torch.FloatTensor(state).float()
        next_state = torch.FloatTensor(next_state).float()
        action = torch.LongTensor(action).unsqueeze(1)
        reward = torch.FloatTensor(reward)
        done = torch.FloatTensor(done)

        q_values = self.model(state).gather(1, action).squeeze(1)
        next_q_values = self.model(next_state).max(1)[0]
        next_q_state_values = self.target_model(next_state).max(1)[0]

        target = reward + (1 - done) * self.gamma * next_q_state_values

        if self.ddqn:
            next_q_action = self.model(next_state).max(1)[1]
            next_q_state_values = self.target_model(next_state).gather(1, next_q_action.unsqueeze(1)).squeeze(1)
            target = reward + (1 - done) * self.gamma * next_q_state_values

        loss = self.criterion(q_values, target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def load(self, name):
        self.model.load_state_dict(torch.load(name))

    def save(self, name):
        torch.save(self.model.state_dict(), name)

    def step(self,action):
        next_state, reward, done, info = self.env.step(action)
        #next_state = self.GetImage()#action,next_state
        return next_state, reward, done, info

    def reset(self,action,next_state):
        self.env.reset()
        for i in range(self.REM_STEP):
            state = self.GetImage()#action,next_state
        return state


    def imshow(self, image, rem_step=0):
        cv2_imshow(image[rem_step,...])
        if cv2.waitKey(25) & 0xFF == ord("q"):
            cv2.destroyAllWindows()
            return

    def GetImage(self):
        img = self.env.render(mode='rgb_array')

        img_rgb = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        img_rgb_resized = cv2.resize(img_rgb, (self.COLS, self.ROWS), interpolation=cv2.INTER_CUBIC)
        img_rgb_resized[img_rgb_resized < 255] = 0
        img_rgb_resized = img_rgb_resized / 255

        self.image_memory = np.roll(self.image_memory, 1, axis = 0)
        self.image_memory[0,:,:] = img_rgb_resized

        self.imshow(self.image_memory,0)

        return np.expand_dims(self.image_memory, axis=0)

        #return np.expand_dims(self.image_memory, axis=0)

    def PlotModel(self, score, episode):
        self.scores.append(score)
        self.episodes.append(episode)
        self.average.append(sum(self.scores) / len(self.scores))
        pylab.plot(self.episodes, self.average, 'r')
        pylab.plot(self.episodes, self.scores, 'b')
        pylab.ylabel('Score', fontsize=18)
        pylab.xlabel('Steps', fontsize=18)
        dqn = 'DQN_'
        softupdate = ''
        if self.ddqn:
            dqn = 'DDQN_'
        if self.Soft_Update:
            softupdate = '_soft'
        try:
            pylab.savefig(dqn+self.env_name+softupdate+".png")
        except OSError:
            pass

        return str(self.average[-1])[:5]

    def run(self):
        for e in range(self.EPISODES):
            state = self.env.reset()
            #state = self.reset()
            state = np.reshape(state, [1, self.state_size])
            done = False
            i = 0
            while not done:
                action = self.act(state)
                next_state, reward, done,info = self.step(action)#self.env
                next_state = np.reshape(next_state, [1, self.state_size])
                reward = reward if not done or i == self.env._max_episode_steps-1 else -100
                self.remember(state, action, reward, next_state, done)
                state = next_state
                i += 1
                if done:
                    self.update_target_model()
                    average = self.PlotModel(i, e)
                    print("episode: {}/{}, score: {}, e: {:.2}, average: {}".format(e, self.EPISODES, i, self.epsilon, average))
                    if e ==(self.EPISODES-1) :#(self.env._max_episode_steps-1):
                        print("Saving trained model as cartpole-ddqn.pt")
                        self.save('/content/cartpole-ddqn.pt')
                        #break
                self.replay()

    def test(self):
        self.load("/content/cartpole-ddqn.pt")
        for e in range(self.EPISODES):
            state = self.env.reset()
            state = np.reshape(state, [1, self.state_size])
            state = torch.FloatTensor(state).to(self.device)  # Convert state to PyTorch tensor
            done = False
            i = 0
            img = plt.imshow(self.env.render(mode='rgb_array'))
            text_action = plt.text(0, -10, '', fontsize=12, color='red')
            text_reward = plt.text(100, -10, '', fontsize=12, color='blue')
            text_step= plt.text(300, -10, '', fontsize=12, color='green')
            text_score = plt.text(500, -10, '', fontsize=12, color='blue')

            while not done:
                self.env.render()
                img.set_data(self.env.render(mode='rgb_array'))  # Update the data
                #display.display(plt.gcf())
                #display.clear_output(wait=True)

                action = np.argmax(self.model(state).detach().numpy())  # Detach the tensor and convert it back to numpy
                next_state, reward, done, _ = self.env.step(action)
                text_action.set_text(f'Action: {action}')
                text_reward.set_text(f'Reward: {reward}')
                text_step.set_text(f'Step: {i}')
                text_score.set_text(f'Score: {e}')

                state = np.reshape(next_state, [1, self.state_size])
                state = torch.FloatTensor(state)  # Convert next_state to PyTorch tensor
                i += 1

                display.display(plt.gcf())
                display.clear_output(wait=True)
                if done:
                    print("episode: {}/{}, score: {}".format(e, self.EPISODES, i))
                    break


In [None]:
agent = DQNAgent('CartPole-v1')
agent.run()


In [None]:
agent.test()