# DQN

In [1]:
import collections
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import gym
import math
import matplotlib.pyplot as plt
import numpy as np
import random
import torch

In [None]:
a = torch.tensor([1], dtype=torch.float32)
a.type(torch.long)


In [None]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)
        self.capacity = capacity
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return zip(*batch)
    
    def __len__(self):
        return len(self.buffer)
    
    @property
    def threshold(self):
        return len(self.buffer) > 300
        # return len(self.buffer) > int(0.3 * self.capacity)


class TriMLP(nn.Module):
    def __init__(self, n_inputs, n_hiddens, n_outputs):
        super().__init__()
        self.fc1 = nn.Linear(n_inputs, n_hiddens)
        self.fc2 = nn.Linear(n_hiddens, n_hiddens)
        self.fc3 = nn.Linear(n_hiddens, n_outputs)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


class DQN:
    def __init__(self, n_states, n_hiddens, n_actions, lr=3e-4,
                 gamma=0.99, device='cuda', epsilon_decay=3e-4):
        self.n_states = n_states
        self.n_hiddens = n_hiddens
        self.n_actions = n_actions
        self.frame_cnt = 0

        self.policy_net = TriMLP(n_states, n_hiddens, n_actions).to(device)
        self.target_net = TriMLP(n_states, n_hiddens, n_actions).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.loss = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr)

        self.lr = lr
        self.gamma = gamma
        self.device = device
        self.epsilon_decay = epsilon_decay

    def take_action(self, state):
        self.frame_cnt += 1
        if np.random.random() > self.epsilon:
            state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
            return self.policy_net(state).max(1)[1].item()
        return random.randrange(0, self.n_actions)
    
    def update(self, batch):
        states, actions, rewards, next_states, dones = map(lambda x: torch.tensor(np.asarray(x), dtype=torch.float32, device=self.device),
                                                           batch)
        actions = actions.type(torch.int64).unsqueeze(1)

        q_values = self.policy_net(states).gather(1, actions).squeeze()
        next_q_values = self.target_net(next_states).max(1)[0].detach()
        q_target = rewards + self.gamma * next_q_values * (1 - dones)

        loss = self.loss(q_values, q_target)
        self.optimizer.zero_grad()
        loss.backward()

        for param in self.policy_net.parameters():
            param.grad.data.clamp_(-1, 1)

        self.optimizer.step()

    @property
    def epsilon(self):
        if self.frame_cnt * self.epsilon_decay > 5.3:
            return 0.005
        return min(0.5, math.exp(-self.frame_cnt * self.epsilon_decay))


def train(env, agent, buffer, batch_size, epochs, target_update):
    rewards, ma_rewards = [], []

    for i in range(epochs):
        state, _ = env.reset()
        done = False
        ep_reward = 0

        with tqdm(total=500) as pbar:
            pbar.set_description(f'Round {i+1:3d}')
            while not done:
                action = agent.take_action(state)
                next_state, reward, done, truncated, _ = env.step(action)
                ep_reward += reward
                done = done or truncated
                buffer.push(state, action, reward, next_state, done)
                state = next_state

                if buffer.threshold:
                    batch = buffer.sample(batch_size)
                    agent.update(batch)
                
                pbar.update(1)
                pbar.set_postfix({'Reward': f'{ep_reward:3.0f}'})

            if (i + 1) % target_update == 0:
                agent.target_net.load_state_dict(agent.policy_net.state_dict())

            rewards.append(ep_reward)
            if ma_rewards:
                ma_rewards.append(ma_rewards[-1] * 0.9 + ep_reward * 0.1)
            else:
                ma_rewards.append(ep_reward)

    plt.plot(rewards, label='Reward')
    plt.plot(ma_rewards, label='Moving Average Reward')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.legend()
    plt.show()
    

In [None]:
batch_size = 64
capacity = 10000
epochs = 300
target_update = 10

env = gym.make("CartPole-v1")
agent = DQN(4, 128, 2)
buffer = ReplayBuffer(capacity)

In [None]:
train(env, agent, buffer, batch_size, epochs, target_update)

In [None]:
def test_policy(env, policy_net, device, episodes=5, render=True):
    for i in range(episodes):
        state, _ = env.reset()
        done = False
        total_reward = 0

        while not done:
            if render:
                env.render()

            state_tensor = torch.tensor([state], dtype=torch.float32, device=device)
            with torch.no_grad():
                q_values = policy_net(state_tensor)
                action = q_values.max(1)[1].item()

            next_state, reward, done, truncated, _ = env.step(action)
            total_reward += reward
            done = done or truncated
            state = next_state

        print(f"Episode {i + 1}: Total Reward = {total_reward}")
    env.close()

env = gym.make('CartPole-v1', render_mode="human")
test_policy(env, agent.policy_net, 'cuda', episodes=5, render=True)


# Double DQN