In [None]:
import gymnasium as gym
import numpy as np
from datetime import datetime, timedelta
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Config():
    def __init__(self):
        # env
        self.env_name = "CartPole-v1"
        self.gamma = 0.99
        self.num_action = 2
        self.state_dim = 4

        # replay
        self.buffer_size = 100000
        self.batch_size = 64

        # training
        self.total_episodes = 2000
        self.learning_rate = 2.3e-3
        self.weight_decay = 1e-4
        self.start_training_step = 1000
        self.train_frequency = 4
        self.epochs = 1
        self.test_frequency = 10
        self.num_test_episodes = 10
        self.save_frequency = 1000
        self.save_path = 'best_model.pth'
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        ## target network
        self.use_soft_update = True
        self.update_frequency = 100
        self.tau = 0.005

        ## episode
        self.init_epsilon = 1.
        self.end_epsilon = 0.04
        self.exploration_fraction = 0.16
        self.decay_step = self.total_episodes * self.exploration_fraction

config = Config()

In [None]:
class Replay_Buffer():
    def __init__(self, buffer_size, state_dim):
        self.buffer_size = buffer_size
        self.real_size = 0
        self.index = 0

        self.states = np.zeros((buffer_size, state_dim))
        self.actions = np.zeros((buffer_size,))
        self.rewards = np.zeros((buffer_size,))
        self.dones = np.zeros((buffer_size,), dtype = bool)
        self.next_states = np.zeros((buffer_size, state_dim))

    def add(self, state, action, reward, next_state, done):
        self.states[self.index] = state
        self.actions[self.index] = action
        self.rewards[self.index] = reward
        self.next_states[self.index] = next_state
        self.dones[self.index] = done

        self.real_size = min(self.real_size+1, self.buffer_size)
        self.index = (self.index+1) % self.buffer_size

    def sample(self, batchsize):
        idxs = np.random.choice(self.real_size, batchsize, replace=False).astype(np.int64)
        return torch.tensor(self.states[idxs]).float(), \
                torch.tensor(self.actions[idxs]).long().reshape(-1), \
                    torch.tensor(self.rewards[idxs]).reshape(-1), \
                        torch.tensor(self.next_states[idxs]).float(), torch.tensor(self.dones[idxs].astype(np.float32)).reshape(-1)


In [None]:
class Model(nn.Module):
    def __init__(self, state_dim=4, num_action=2):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.V = nn.Linear(64, 1)
        self.A = nn.Linear(64, num_action)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        A = self.A(x)
        V = self.V(x)
        Q = V + A - A.mean(-1, keepdims=True)
        return Q

In [None]:
class Agent():
    def __init__(self, model, target_model, env, env_name, replay_buffer, num_action, batch_size, learning_rate, weight_decay, gamma, total_episodes,
                 train_frequency, epochs, test_frequency, save_frequency, num_test_episodes, start_training_step, save_path, use_soft_update, 
                 tau, init_epsilon, decay_step, end_epsilon, update_frequency, device):
        self.model = model
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        self.target_model = target_model
        self.use_soft_update = use_soft_update
        self.tau = tau
        self.batch_size = batch_size
        self.update_frequency = update_frequency

        self.replay_buffer = replay_buffer

        self.env = env
        self.env_name = env_name
        self.gamma = gamma
        self.num_action = num_action

        self.total_episodes = total_episodes
        self.train_frequency = train_frequency
        self.epochs = epochs
        self.test_frequency = test_frequency
        self.save_frequency = save_frequency
        self.num_test_episodes = num_test_episodes
        self.save_path = save_path
        self.start_training_step = start_training_step

        self.init_epsilon = init_epsilon
        self.decay_step = decay_step
        self.end_epsilon = end_epsilon

        self.device = device

    def update_weights(self):
        if self.use_soft_update:
            with torch.no_grad():
                for target_param, online_param in zip(self.target_model.parameters(), self.model.parameters()):
                    target_param.data.copy_(self.tau * online_param.data + (1.0 - self.tau) * target_param.data)
        else:
            with torch.no_grad():
                for target_param, online_param in zip(self.target_model.parameters(), self.model.parameters()):
                    target_param.data.copy_(online_param.data)

    def select_action(self, state, epsilon):
        state = torch.tensor(np.array(state)).unsqueeze(0).to(self.device)
        if np.random.rand() > epsilon:
            with torch.no_grad():
                Qs = self.model(state)[0]
                action = Qs.argmax().item()
        else:
            action = np.random.choice(self.num_action, 1)[0]
        return action
    
    def decay_epsilon(self, step):
        if step < self.decay_step:
            epsilon = self.init_epsilon - (step / self.decay_step) * (self.init_epsilon - self.end_epsilon)
        else:
            epsilon = self.end_epsilon
        return epsilon
    
    def train(self):
        self.optimizer.zero_grad()
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        idxs = torch.arange(0, self.batch_size)
        with torch.no_grad():
            next_Qs_online = self.model(next_states.to(self.device))
            next_actions = next_Qs_online.argmax(-1)
            
            next_Qs = self.target_model(next_states.to(self.device))
            next_Qs = next_Qs[idxs, next_actions].reshape(-1)
        targets = rewards.to(self.device) + self.gamma * (1 - dones.to(self.device)) * next_Qs.to(self.device)
        Qs = self.model(states.to(self.device))[idxs, actions].reshape(-1)
        loss = ((targets - Qs)**2).mean()
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def test(self):
        env = gym.make(self.env_name, render_mode="rgb_array")
        state, info = env.reset()
        total_rewards = 0.
        step = 0.
        
        done = False
        while not done:
            action = self.select_action(state, -1)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            total_rewards += reward
            state = next_state
            step += 1
        if total_rewards > self.max_test_rewards:
            self.max_test_rewards = total_rewards
            torch.save(self.model.state_dict(), self.save_path)
        
        env.close()
        return total_rewards, step
    
    def learn(self):
        self.update_weights()
        self.current_step = 0
        epsilon = self.init_epsilon

        self.max_test_rewards = 0.
        self.episode_rewards = []
        self.episode_steps = []
        self.max_episode_rewards = 0.
        self.test_episode_rewards = []
        self.test_episode_steps = []
        self.losses = []

        for episode in range(1, self.total_episodes+1):
            start_time = datetime.now()
            state, info = self.env.reset()
            done = False
            total_rewards = 0.
            episode_step = 0.
            while not done:
                self.current_step += 1
                episode_step += 1
                action = self.select_action(state, epsilon)
                next_state, reward, terminated, truncated, info = self.env.step(action)
                total_rewards += reward
                done = terminated or truncated
                self.replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                if self.current_step % self.train_frequency == 0 and self.current_step >= self.start_training_step:
                    for _ in range(self.epochs):
                        loss = self.train()
                        self.losses.append(loss)
                if self.current_step % self.update_frequency == 0 or self.use_soft_update:
                    self.update_weights()

            self.episode_rewards.append(total_rewards)
            self.episode_steps.append(episode_step)
            self.max_episode_rewards = max(self.max_episode_rewards, total_rewards)
            if episode % 100 == 0:
                print(f"Step {self.current_step}, Episode {episode}: Steps: {episode_step}, Rewards: {total_rewards}, "
                      f"Mean_Rewards: {np.array(self.episode_rewards[-min(100, len(self.episode_rewards)):]).mean():.4f}, "
                      f"Max_Rewards: {self.max_episode_rewards}, Loss: {np.array(self.losses[-min(1000, len(self.losses)):]).mean():.4f}, "
                      f"Duration: {datetime.now() - start_time}, epsilon: {epsilon:.6f}")        

            if episode % self.test_frequency == 0:
                start_time = datetime.now()
                for _ in range(self.num_test_episodes):
                    test_rewards, test_steps = self.test()
                    self.test_episode_rewards.append(test_rewards)
                    self.test_episode_steps.append(test_steps)
                if episode % 100 == 0:
                    print(f"Test Episode: {len(self.test_episode_rewards)} Mean Rewards: {np.array(self.test_episode_rewards[-self.num_test_episodes:]).mean():.4f}, "
                          f"Max_Test_Rewards: {self.max_test_rewards}, Duration: {datetime.now() - start_time}")
            if episode % self.save_frequency == 0:
                torch.save(self.model.state_dict(), f"{episode}.pth")
            epsilon = self.decay_epsilon(episode)

In [None]:
replay_buffer = Replay_Buffer(config.buffer_size, config.state_dim)
model = Model(config.state_dim).to(config.device)
target_model = Model(config.state_dim).to(config.device)
env = gym.make(config.env_name, render_mode="rgb_array")

agent = Agent(model, target_model, env, config.env_name, replay_buffer, config.num_action, config.batch_size, config.learning_rate, config.weight_decay, 
              config.gamma, config.total_episodes, config.train_frequency, config.epochs, config.test_frequency, config.save_frequency, 
              config.num_test_episodes, config.start_training_step, config.save_path, config.use_soft_update, config.tau, 
              config.init_epsilon, config.decay_step, config.end_epsilon, config.update_frequency, config.device)

In [None]:
agent.learn()

In [None]:
import matplotlib.pyplot as plt

episode_rewards = agent.episode_rewards
test_episode_rewards = agent.test_episode_rewards

train_episodes = list(range(1, len(agent.episode_rewards)+1))
test_episodes = list(range(1, len(agent.test_episode_rewards)+1, 1))
test_episodes_per_test = list(range(agent.test_frequency, len(agent.test_episode_rewards)+1, agent.test_frequency))

episode_rewards = np.array(episode_rewards)
test_episode_rewards = np.array(test_episode_rewards)

mean_rewards = []
mean_test_rewards_per_test = []
mean_test_rewards_l100 = []
for i in range(len(episode_rewards)):
    mean_rewards.append(episode_rewards[max(0, i-100):i+1].mean())
for i in range(len(test_episode_rewards)):
    if (i+1) % config.num_test_episodes == 0:
        mean_test_rewards_per_test.append(test_episode_rewards[max(0, i-config.num_test_episodes):i+1].mean())
    mean_test_rewards_l100.append(test_episode_rewards[max(0, i-100):i+1].mean())

plt.figure(figsize=(10, 6))

idx_full_rewards = np.array(test_episode_rewards) == 500
idx_full_mean_rewards = np.array(mean_test_rewards_per_test) == 500

plt.plot(train_episodes, mean_rewards, label='Train Rewards [Last 100]')
plt.plot(test_episodes, mean_test_rewards_l100, label='Test Rewards [Last 100]')
plt.scatter(np.array(test_episodes)[idx_full_rewards], 
            np.array(test_episode_rewards)[idx_full_rewards], label='Test Rewards = 500', alpha=0.6, s=5, color = 'green')
plt.scatter(np.array(test_episodes_per_test)[idx_full_mean_rewards], 
            np.array(mean_test_rewards_per_test)[idx_full_mean_rewards], label='Test Rewards [Last 10] = 500', alpha=0.6, s=50, color = 'red')

plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("D3QN")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()