In [None]:
import gymnasium as gym
import numpy as np
import random
from collections import deque


import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
env = gym.make("Taxi-v3")
n_states = env.observation_space.n      # 500
n_actions = env.action_space.n          # 6

In [None]:
class ReplayBuffer:
    def __init__(self,max_size):
        self.buffer = deque(maxlen=max_size)
    
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        return (
            torch.tensor(states, dtype=torch.float32),
            torch.tensor(actions, dtype=torch.long),
            torch.tensor(rewards, dtype=torch.float32),
            torch.tensor(next_states, dtype=torch.float32),
            torch.tensor(dones, dtype=torch.float32),
        )
    
    def __len__(self):
        return len(self.buffer)
    
def one_hot(state):
    vec = np.zeros(n_states, dtype=np.float32)
    vec[state] = 1.0
    return vec

class QNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_states, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

    def forward(self, x):
        return self.net(x)


In [31]:
# -------------------- HYPERPARAMETERS --------------------
lr = 1e-3
gamma = 0.99

epsilon = 1.0
epsilon_decay = 0.99
min_epsilon = 0.05

buffer_size = 50_000
batch_size = 64
learning_starts = 2000
target_update_episodes = 5

num_episodes = 1000

In [32]:
# -------------------- INIT --------------------
q_net = QNetwork()
target_net = QNetwork()
target_net.load_state_dict(q_net.state_dict())
target_net.eval()

optimizer = optim.Adam(q_net.parameters(), lr=lr)
loss_fn = nn.MSELoss()

replay_buffer = ReplayBuffer(buffer_size)
global_step = 0

In [33]:
# -------------------- POLICY --------------------
def epsilon_greedy(state, epsilon):
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        state_vec = torch.tensor(one_hot(state)).unsqueeze(0)
        with torch.no_grad():
            q_vals = q_net(state_vec)
        return q_vals.argmax(dim=1).item()


for episode in range(num_episodes):
    state, _ = env.reset()
    done = False
    episode_reward = 0
    losses = []
    step = 0
    while not done:
        global_step += 1

        action = epsilon_greedy(state,epsilon)
        next_state,reward,terminated,truncated,_ = env.step(action)
        done = terminated or truncated

        replay_buffer.add(one_hot(state),
        action,
        reward,
        one_hot(next_state),
        done)

        state = next_state
        # print("for episode:",episode," step:",step,"reward:",reward,end="\n")
        episode_reward += reward

        #learn
        if len(replay_buffer) >= learning_starts:
            states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

            q_values = q_net(states)
            q_selected = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

            with torch.no_grad():
                q_next = target_net(next_states)
                max_next_q = q_next.max(dim=1)[0]
                targets = rewards + gamma * max_next_q*(1-dones)

            loss = loss_fn(q_selected, targets)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(q_net.parameters(),10.0)
            optimizer.step()

            losses.append(loss.item())

    if episode % target_update_episodes == 0:
        target_net.load_state_dict(q_net.state_dict())
        print("Updated target network at episode:",episode)
    
    epsilon = max(min_epsilon, epsilon * epsilon_decay)
    avg_loss = np.mean(losses) if losses else 0.0

    print(
        f"Episode {episode:4d} | "
        f"Reward : {episode_reward:4.0f} | "
        f"Epsilon: {epsilon:.3f} | "
        f"Avg Loss: {avg_loss:.4f} | "
        f"Buffer: {len(replay_buffer)}"
    )

env.close()

Updated target network at episode: 0
Episode    0 | Reward : -668 | Epsilon: 0.990 | Avg Loss: 0.0000 | Buffer: 200
Episode    1 | Reward : -713 | Epsilon: 0.980 | Avg Loss: 0.0000 | Buffer: 400
Episode    2 | Reward : -767 | Epsilon: 0.970 | Avg Loss: 0.0000 | Buffer: 600
Episode    3 | Reward : -700 | Epsilon: 0.961 | Avg Loss: 0.0000 | Buffer: 790
Episode    4 | Reward : -866 | Epsilon: 0.951 | Avg Loss: 0.0000 | Buffer: 990
Updated target network at episode: 5
Episode    5 | Reward : -812 | Epsilon: 0.941 | Avg Loss: 0.0000 | Buffer: 1190
Episode    6 | Reward : -758 | Epsilon: 0.932 | Avg Loss: 0.0000 | Buffer: 1390
Episode    7 | Reward : -632 | Epsilon: 0.923 | Avg Loss: 0.0000 | Buffer: 1590
Episode    8 | Reward : -686 | Epsilon: 0.914 | Avg Loss: 0.0000 | Buffer: 1790
Episode    9 | Reward : -767 | Epsilon: 0.904 | Avg Loss: 0.0000 | Buffer: 1990
Updated target network at episode: 10
Episode   10 | Reward : -794 | Epsilon: 0.895 | Avg Loss: 7.5940 | Buffer: 2190
Episode   11 

In [45]:

eval_env = gym.make('Taxi-v3')
rewards = []
for i in range(100):
    state, info = eval_env.reset()
    done = False
    reward_sum = 0
    while not done:
        action = q_net(torch.tensor(one_hot(state)).unsqueeze(0)).argmax(dim=1).item()
        state, reward, terminated, truncated, info = eval_env.step(action)
        done = terminated or truncated
        reward_sum += reward
    print(reward_sum,end=" ")
    rewards.append(reward_sum)
print("Average evaluation reward over 100 episodes:", np.mean(rewards))        

9 9 8 11 7 12 13 5 6 10 7 11 4 5 5 9 7 8 8 7 4 11 12 5 10 9 14 9 7 7 8 10 8 5 11 6 7 7 10 5 11 8 9 7 9 6 7 5 5 8 4 6 7 5 4 9 7 3 8 3 4 10 9 8 8 7 4 11 9 5 12 9 7 10 14 8 4 5 6 5 13 8 13 4 7 8 6 10 5 11 11 7 7 4 7 7 11 8 10 9 Average evaluation reward over 100 episodes: 7.78
