In [1]:
import gym
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque
import pygame
from IPython import display
import matplotlib


class OurModel(nn.Module):
    def __init__(self, input_shape, action_space, dueling):
        super(OurModel, self).__init__()
        self.dueling = dueling

        self.fc1 = nn.Linear(input_shape[0], 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 64)

        if dueling:
            self.state_value = nn.Linear(64, 1)
            self.action_advantage = nn.Linear(64, action_space)
        else:
            self.output = nn.Linear(64, action_space)

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight)
                m.bias.data.fill_(0.01)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))

        if self.dueling:
            state_value = self.state_value(x)
            action_advantage = self.action_advantage(x)
            action_advantage_mean = action_advantage.mean(dim=1, keepdim=True)
            q_value = state_value + action_advantage - action_advantage_mean
        else:
            q_value = self.output(x)

        return q_value

class D3QN:
    def __init__(self, env_name):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.env.seed(0)
        self.env._max_episode_steps = 4000
        self.state_size = self.env.observation_space.shape[0]
        self.action_size = self.env.action_space.n

        self.EPISODES = 500
        memory_size = 10000
        self.memory = deque(maxlen=2000)
        self.gamma = 0.95

        pygame.init()
        self.screen = pygame.display.set_mode((640, 480))
        self.clock = pygame.time.Clock()
        self.metadata = {"render_fps": 30}

        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.0005
        self.batch_size = 32

        self.ddqn = True
        self.Soft_Update = False
        self.dueling = True
        self.epsilot_greedy = False

        self.TAU = 0.1

        self.Save_Path = 'Models'
        if not os.path.exists(self.Save_Path): os.makedirs(self.Save_Path)
        self.scores, self.episodes, self.average = [], [], []

        self.Model_name = os.path.join(self.Save_Path, self.env_name + "_e_greedy.pth")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = OurModel(input_shape=(self.state_size,), action_space=self.action_size, dueling=self.dueling).to(self.device)
        self.target_model = OurModel(input_shape=(self.state_size,), action_space=self.action_size, dueling=self.dueling).to(self.device)
        self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=0.01)
        self.update_target_model()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state, decay_step):
        if self.epsilot_greedy:
            explore_probability = self.epsilon_min + (self.epsilon - self.epsilon_min) * np.exp(-self.epsilon_decay * decay_step)
        else:
            if self.epsilon > self.epsilon_min:
                self.epsilon *= (1 - self.epsilon_decay)
            explore_probability = self.epsilon

        if explore_probability > np.random.rand():
            return random.randrange(self.action_size), explore_probability
        else:
            state = torch.FloatTensor(state).to(self.device)
            with torch.no_grad():
                action_values = self.model(state)
            return torch.argmax(action_values).item(), explore_probability

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        minibatch = random.sample(self.memory, self.batch_size)
        state = torch.FloatTensor(np.vstack([e[0] for e in minibatch])).to(self.device)
        action = torch.LongTensor(np.array([e[1] for e in minibatch])).to(self.device)
        reward = torch.FloatTensor(np.array([e[2] for e in minibatch])).to(self.device)
        next_state = torch.FloatTensor(np.vstack([e[3] for e in minibatch])).to(self.device)
        done = torch.FloatTensor(np.array([e[4] for e in minibatch])).to(self.device)

        target = self.model(state).to(self.device)
        target_next = self.model(next_state).to(self.device)
        target_val = self.target_model(next_state).to(self.device)

        for i in range(self.batch_size):
            if done[i]:
                target[i][action[i]] = reward[i]
            else:
                if self.ddqn:
                    a = torch.argmax(target_next[i]).item()
                    target[i][action[i]] = reward[i] + self.gamma * target_val[i][a]
                else:
                    target[i][action[i]] = reward[i] + self.gamma * torch.max(target_val[i]).item()

        self.optimizer.zero_grad()
        loss = nn.MSELoss()(self.model(state), target)
        loss.backward()
        self.optimizer.step()

    def load(self, name):
        self.model.load_state_dict(torch.load(name))

    def save(self, name):
        torch.save(self.model.state_dict(), name)

    def plot_model(self, score, episode):
        self.scores.append(score)
        self.episodes.append(episode)
        self.average.append(sum(self.scores[-50:]) / len(self.scores[-50:]))
        plt.plot(self.episodes, self.average, 'r')
        plt.plot(self.episodes, self.scores, 'b')
        plt.ylabel('Score', fontsize=18)
        plt.xlabel('Steps', fontsize=18)
        dqn = 'DQN_'
        softupdate = ''
        dueling = ''
        greedy = ''
        if self.ddqn: dqn = 'DDQN_'
        if self.Soft_Update: softupdate = '_soft'
        if self.dueling: dueling = '_Dueling'
        if self.epsilot_greedy: greedy = '_Greedy'
        plt.savefig(dqn + self.env_name + softupdate + dueling + greedy + ".png")
        plt.close()
        return str(self.average[-1])[:5]

    def run(self):
        decay_step = 0
        for e in range(self.EPISODES):
            state = self.env.reset()
            state = np.reshape(state, [1, self.state_size])
            done = False
            i = 0
            while not done:
                decay_step += 1
                action, explore_probability = self.act(state, decay_step)
                next_state, reward, done, _ = self.env.step(action)
                next_state = np.reshape(next_state, [1, self.state_size])
                if not done or i == self.env._max_episode_steps-1:
                    reward = reward
                else:
                    reward = -100
                self.remember(state, action, reward, next_state, done)
                state = next_state
                i += 1
                if done:
                    self.update_target_model()
                    average = self.plot_model(i, e)
                    print(f"episode: {e}/{self.EPISODES}, score: {i}, e: {explore_probability:.2}, average: {average}")
                    if e == (self.EPISODES - 1):
                        print("Saving trained model")
                        self.save("/content/D3QN.pt")
                        break
                self.replay()
        self.env.close()

    def test(self):
        self.load("/content/D3QN.pt")
        for e in range(self.EPISODES):
            state = self.env.reset()
            state = np.reshape(state, [1, self.state_size])
            state = torch.FloatTensor(state).to(self.device)  # Convert state to PyTorch tensor and move to device
            done = False
            i = 0
            img = plt.imshow(self.env.render(mode='rgb_array'))
            text_action = plt.text(0, -10, '', fontsize=12, color='red')
            text_reward = plt.text(100, -10, '', fontsize=12, color='blue')
            text_step= plt.text(300, -10, '', fontsize=12, color='green')
            text_score = plt.text(500, -10, '', fontsize=12, color='blue')

            while not done:
                self.env.render()
                img.set_data(self.env.render(mode='rgb_array'))  # Update the data
                # display.display(plt.gcf())
                # display.clear_output(wait=True)

                action = np.argmax(self.model(state).detach().cpu().numpy())  # Detach the tensor and convert it back to numpy
                next_state, reward, done, _ = self.env.step(action)
                text_action.set_text(f'Action: {action}')
                text_reward.set_text(f'Reward: {reward}')
                text_step.set_text(f'Step: {e}')
                text_score.set_text(f'Score: {int}')

                state = np.reshape(next_state, [1, self.state_size])
                state = torch.FloatTensor(state).to(self.device)  # Convert next_state to PyTorch tensor and move to device
                i += 1

                display.display(plt.gcf())
                display.clear_output(wait=True)
                if done:
                    print("episode: {}/{}, score: {}".format(e, self.EPISODES, i))
                    break


In [None]:
agent = D3QN('CartPole-v1')
agent.run()


In [None]:
agent.test()