# IQN (Implicit Quantile Network)

IQN (Implicit Quantile Network) is an algorithm that combines distributional RL with deep neural networks. It aims to learn the entire distribution of return rather than just the expected return. By estimating the quantile function of the return distribution, IQN provides more information about the uncertainty and risk associated with different actions. This allows for more robust decision-making in uncertain environments and can lead to better performance, especially in tasks with nonstationary and complex reward structures.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, state, action, reward, terminated, next_state):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = (state, action, reward, terminated, next_state)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class IQN(nn.Module):
    def __init__(self, obs_size, hidden_size1, hidden_size2, n_tau_samples, n_actions):
        super(IQN, self).__init__()
        self.obs_size = obs_size
        self.n_tau_samples = n_tau_samples
        self.n_actions = n_actions
        
        self.fc1 = nn.Linear(obs_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, hidden_size2)
        self.fc4 = nn.Linear(hidden_size2, n_actions * n_tau_samples)

    def forward(self, x, taus):
        batch_size = x.size(0)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)

        # Reshape to (batch_size, n_tau_samples, n_actions)
        x = x.view(batch_size, self.n_tau_samples, self.n_actions)

        # Compute the quantile values
        taus = taus.view(batch_size, self.n_tau_samples, 1)
        quantiles = torch.bmm(x, taus).squeeze(2)

        return quantiles

class IQN_Agent:
    def __init__(self,
                action_space,
                observation_space,
                n_tau_samples,
                gamma,
                batch_size,
                buffer_capacity,
                update_target_every, 
                epsilon_start, 
                decrease_epsilon_factor, 
                epsilon_min,
                learning_rate,
                ):
        self.action_space = action_space
        self.observation_space = observation_space
        self.n_tau_samples = n_tau_samples
        self.gamma = gamma
        
        self.batch_size = batch_size
        self.buffer_capacity = buffer_capacity
        self.update_target_every = update_target_every
        
        self.epsilon_start = epsilon_start
        self.decrease_epsilon_factor = decrease_epsilon_factor # larger -> more exploration
        self.epsilon_min = epsilon_min
        
        self.learning_rate = learning_rate
        
        self.reset()

    def get_action(self, state):
        """
        Return action according to an epsilon-greedy exploration policy
        """
        if np.random.rand() < self.epsilon: 
            return self.action_space.sample()
            
        return self.get_best_action(state)

    def get_best_action(self, state):
        state = torch.tensor(state).float().flatten().unsqueeze(0)  # Flatten the state and add batch dimension
        taus = torch.rand(self.n_tau_samples).unsqueeze(0)  # Generate random quantiles
        quantiles = self.q_net(state, taus)
        q_values = torch.mean(quantiles, dim=1)  # Average over quantiles
        return torch.argmax(q_values).item()

    def update(self, state, action, reward, terminated, next_state):
        state_tensor = torch.tensor(state).float().unsqueeze(0)
        action_tensor = torch.tensor([[action]], dtype=torch.int64)  # action should be long for gather
        reward_tensor = torch.tensor([reward]).float()
        terminated_tensor = torch.tensor([terminated], dtype=torch.int64).float()  # make sure terminated is also float for consistency
        next_state_tensor = torch.tensor(next_state).float().unsqueeze(0)

        self.buffer.push(state_tensor, action_tensor, reward_tensor, terminated_tensor, next_state_tensor)

        if len(self.buffer) < self.batch_size:
            return np.inf
        
        transitions = self.buffer.sample(self.batch_size)
        state_batch, action_batch, reward_batch, terminated_batch, next_state_batch = tuple([torch.cat(data) for data in zip(*transitions)])

        taus = torch.rand(self.batch_size, self.n_tau_samples)  # Generate random quantiles
        current_quantiles = self.q_net(state_batch, taus)
        with torch.no_grad():
            next_quantiles = self.target_net(next_state_batch, taus)
            best_next_actions = torch.argmax(torch.mean(next_quantiles, dim=1), dim=1)  # Get best actions from target net

            # Compute target quantiles
            target_quantiles = reward_batch.unsqueeze(1) + (1 - terminated_batch.unsqueeze(1)) * self.gamma * next_quantiles[np.arange(self.batch_size), best_next_actions]

        # Compute Huber loss
        elementwise_loss = self.huber_loss(current_quantiles, target_quantiles.unsqueeze(1)).mean(dim=1).mean(dim=1)

        loss = elementwise_loss.mean()

        # Optimize the model 
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if not((self.n_steps+1) % self.update_target_every): 
            self.target_net.load_state_dict(self.q_net.state_dict())
            
        self.decrease_epsilon()
            
        self.n_steps += 1
        if terminated: 
            self.n_eps += 1

        return loss.detach().numpy()

    def get_q(self, state):
        """
        Compute Q function for a state
        """
        state_tensor = torch.tensor(state).unsqueeze(0)
        taus = torch.rand(self.n_tau_samples).unsqueeze(0)  # Generate random quantiles
        quantiles = self.q_net(state_tensor, taus)
        q_values = torch.mean(quantiles, dim=1)  # Average over quantiles
        return q_values.squeeze().detach().numpy()

    def decrease_epsilon(self):
        self.epsilon = self.epsilon_min + (self.epsilon_start - self.epsilon_min) * (
                        np.exp(-1. * self.n_eps / self.decrease_epsilon_factor ) )

    def reset(self):
        hidden_size = 128
        obs_size = self.observation_space.shape[0] * self.observation_space.shape[1]  # Assuming 5x5 observation space
        n_actions = self.action_space.n
        self.buffer = ReplayBuffer(self.buffer_capacity)
        self.q_net = IQN(obs_size, hidden_size, hidden_size, self.n_tau_samples, n_actions)
        self.target_net = IQN(obs_size, hidden_size, hidden_size, self.n_tau_samples, n_actions)
        self.huber_loss = nn.SmoothL1Loss()
        self.optimizer = optim.Adam(params=self.q_net.parameters(), lr=self.learning_rate)
        self.epsilon = self.epsilon_start
        self.n_steps = 0
        self.n_eps = 0
