# C51 (Categorical 51)

C51 (Categorical 51) is a variant of distributional reinforcement learning that discretizes the return distribution into a fixed number of support atoms. It uses a deep neural network to estimate the probability mass function of the return distribution over these support atoms. By doing so, C51 enables the agent to learn a more accurate representation of the return distribution, which can lead to improved performance and better exploration in reinforcement learning tasks. It allows for capturing complex and multimodal distributions of returns, facilitating robust and efficient learning in challenging environments.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class CategoricalDQN(nn.Module):
    def __init__(self, obs_size, hidden_size, n_actions, n_atoms, v_min, v_max):
        super(CategoricalDQN, self).__init__()
        self.obs_size = obs_size
        self.n_actions = n_actions
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max
        
        self.fc1 = nn.Linear(obs_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, n_actions * n_atoms)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, self.n_actions, self.n_atoms)
        return F.softmax(x, dim=-1)

class CategoricalDQN_Agent:
    def __init__(self,
                 action_space,
                 observation_space,
                 hidden_size,
                 gamma,
                 epsilon_start,
                 epsilon_end,
                 epsilon_decay,
                 learning_rate,
                 batch_size,
                 n_atoms,
                 v_min,
                 v_max,
                 ):
        self.action_space = action_space
        self.observation_space = observation_space
        self.hidden_size = hidden_size
        self.gamma = gamma
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max
        
        self.q_net = CategoricalDQN(observation_space.shape[0], hidden_size, action_space.n, n_atoms, v_min, v_max)
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=learning_rate)
        self.loss_function = nn.KLDivLoss(reduction='batchmean')

    def get_action(self, state, epsilon):
        if np.random.rand() < epsilon:
            return self.action_space.sample()
        else:
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            probabilities = self.q_net(state).squeeze(0)
            q_values = torch.sum(probabilities * torch.linspace(self.v_min, self.v_max, self.n_atoms), dim=-1)
            return torch.argmax(q_values).item()

    def update(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)
        action = torch.tensor(action).unsqueeze(0)
        reward = torch.tensor([reward], dtype=torch.float32)
        done = torch.tensor([done], dtype=torch.float32)

        self.optimizer.zero_grad()
        current_probs = self.q_net(state)
        current_q_values = torch.sum(current_probs * torch.linspace(self.v_min, self.v_max, self.n_atoms), dim=-1)
        current_log_probs = F.log_softmax(current_probs, dim=-1)

        with torch.no_grad():
            next_probs = self.q_net(next_state)
            next_q_values = torch.sum(next_probs * torch.linspace(self.v_min, self.v_max, self.n_atoms), dim=-1)
            next_action = torch.argmax(next_q_values)
            next_max_probs = next_probs[0, next_action]

        target_probs = torch.zeros_like(current_probs)
        target_probs[:, action] = 1
        target_probs = target_probs * (1 - done) * self.gamma + reward

        loss = self.loss_function(current_log_probs, target_probs.detach())

        loss.backward()
        self.optimizer.step()

    def get_epsilon(self, current_step):
        return self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(-1.0 * current_step / self.epsilon_decay)
