In [2]:
from torchrl.data.replay_buffers import (
    TensorDictReplayBuffer,
    LazyTensorStorage,
    PrioritizedSampler,
)
from tensordict import TensorDict

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np

from rlarcworld.arc_dataset import ArcDataset, ArcSampleTransformer
from rlarcworld.enviroments.arc_batch_grid_env import ArcBatchGridEnv
from rlarcworld.enviroments.wrappers.rewards import PixelAwareRewardWrapper
from rlarcworld.agent.actor import ArcActorNetwork
from rlarcworld.agent.critic import ArcCriticNetwork

from rlarcworld.algorithms.d4pg import D4PG

In [2]:
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [3]:
grid_size = 30
color_values = 11
batch_size = 2
n_steps = 100  # For testing

## Dataset
dataset = ArcDataset(
    "./dataset/training",
    keep_in_memory=True,
    transform=ArcSampleTransformer((grid_size, grid_size), examples_stack_dim=10),
)
train_samples = DataLoader(dataset=dataset, batch_size=batch_size)
## Environment
env = ArcBatchGridEnv(size=grid_size, color_values=color_values)
env = PixelAwareRewardWrapper(env)

## The atoms are essentially the "bins" or "categories"
## into which the possible range of returns is divided.
## Consequently depends of the reward behaviour
n_atoms = {"pixel_wise": 100, "binary": 3}

## Networks
actor = ArcActorNetwork(size=grid_size, color_values=color_values)
critic = ArcCriticNetwork(size=grid_size, color_values=color_values,n_atoms=n_atoms)
## Target Networks
actor_target = actor
critic_target = critic

In [7]:
D4PG().train_d4pg(
    actor,
    critic,
    actor_target,
    critic_target,
    replay_buffer=rb,
    actor_optimizer=torch.optim.Adam(actor.parameters(), lr=1e-4),
    critic_optimizer=torch.optim.Adam(critic.parameters(), lr=1e-4),
    gamma=0.99,
    num_atoms=100,
    v_min=-1.0,
    v_max=1.0,
    batch_size=batch_size,
    target_update_freq=10,
    steps=n_steps,
)

torch.Size([10, 2, 30, 30])


In [5]:
rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(batch_size),
    sampler=PrioritizedSampler(max_capacity=batch_size, alpha=1.0, beta=1.0),
    priority_key="priority",
)

**Pseudocode Q-Learning Algorithm for Categorical Actions**

- Get the current observations in the state (s); Includes the different kind of grids and values managed by the environment.
- Pass the *s* to the actor to get the action *a*, reward *r*, done flag *d* and next state *s'*.
- Compute the categorical distribution over returns *Z_pi* for the next state *s'* using the critic network.
- Get the best action *a'\** using the target actor on next state *s'*


In [9]:
for episode, samples in enumerate(train_samples):
    observation, information = env.reset(
        options={"batch": samples["task"], "examples":samples["examples"]}, seed=episode
    )
    for step in range(n_steps):
        init_state = env.get_wrapper_attr("state")
        actions = actor(init_state)
        obs, reward, terminated, truncated, info = env.step(actor.get_discrete_actions(actions))
        final_state = env.get_wrapper_attr("state")
        final_state.update({"actions": actions})
        critic.predict(final_state)
        if terminated or truncated:
            break

RuntimeError: Input type (long int) and bias type (float) should be the same

In [7]:
from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
rb = TDRB(
    storage=LazyTensorStorage(10),
    sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
    priority_key="priority",  # This kwarg isn't present in regular RBs
)
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
data = torch.stack([data_0, data_1])
rb.extend(data)
rb.update_priority(data)  # Reads the "priority" key as indicated in the constructor
sample, info = rb.sample(10, return_info=True)
print(sample['index'])  # The index is packed with the tensordict
torch.tdensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

TypeError: ReplayBuffer.update_priority() missing 1 required positional argument: 'priority'

In [None]:
import torch
import torch.nn.functional as F

def categorical_projection(distribution, reward, v_min, v_max, num_atoms, done):
    """Projects a distribution onto a categorical distribution.

    Args:
        distribution (torch.Tensor): The distribution to project (batch_size, num_atoms).
        reward (torch.Tensor): The target values (batch_size,).
        v_min (float): Minimum value of the support.
        v_max (float): Maximum value of the support.
        num_atoms (int): Number of atoms in the categorical distribution.
        done (torch.Tensor): Done flags from the batch.

    Returns:
        torch.Tensor: Projected distribution (batch_size, num_atoms).
    """

    delta_z = (v_max - v_min) / (num_atoms - 1)
    z = torch.linspace(v_min, v_max, num_atoms).to(reward.device)  # Atom values

    # Compute the expected action distribution
    next_q_values = (distribution * z).sum(
        dim=-1
    )  # Shape: (batch_size, num_actions)
    print(next_q_values.shape)
    best_next_action = torch.argmax(next_q_values, dim=-1)  # Shape: (batch_size,)
    best_next_q_dist = distribution[
        torch.arange(distribution.size(0)), best_next_action
    ]  # Shape: (batch_size, num_atoms)

    # Bellman backup for distribution
    tz = reward.unsqueeze(-1) + 0.99 * (1 - done.unsqueeze(-1)) * z.unsqueeze(0)
    tz = tz.clamp(v_min, v_max)
    b = (tz - v_min) / delta_z
    l, u = b.floor().long(), b.ceil().long()
    l = l.clamp(0, num_atoms - 1)
    u = u.clamp(0, num_atoms - 1)

    # Distribute probability mass
    target_dist = torch.zeros_like(best_next_q_dist)
    target_dist.scatter_add_(
        dim=-1, index=l, src=best_next_q_dist * (u.float() - b)
    )
    target_dist.scatter_add_(
        dim=-1, index=u, src=best_next_q_dist * (b - l.float())
    )

    return target_dist

# Example 1: Rewards 0 or 1
batch_size = 32
num_atoms = 4
v_min = -1.0
v_max = 2.0
distribution = torch.randn(batch_size, num_atoms).softmax(dim=-1)  # Normalize the distribution
print("Original distribution:", distribution[0])
reward = torch.randint(int(v_min), int(v_max), (batch_size,)).float()
done = torch.randint(0, 2, (batch_size,)).float()
projected_dist = categorical_projection(distribution, reward, v_min, v_max, num_atoms,done)
print("Projected distribution shape (0-1 rewards):", projected_dist.shape)
print("Example projected distribution (0-1 rewards):\n", projected_dist[0])

# Example 2: Rewards -900, -2, -1, or 0
v_min = -1000.0
v_max = 2.0
num_atoms = 1000
reward = torch.randint(int(v_min), int(v_max), (batch_size,)).float()
distribution = torch.randn(batch_size, num_atoms).softmax(dim=-1)  # Normalize the distribution
print("Original distribution:", distribution[0])
projected_dist = categorical_projection(distribution, reward, v_min, v_max, num_atoms,done)
print("\nProjected distribution shape (-900 to 0 rewards):", projected_dist.shape)
print("Example projected distribution (-900 to 0 rewards):\n", projected_dist[0])

Original distribution: tensor([0.5012, 0.3661, 0.0283, 0.1044])
torch.Size([32])


IndexError: index 15 is out of bounds for dimension 1 with size 4

In [18]:
def categorical_projection(next_q_dist, rewards, dones, gamma, v_min, v_max, num_atoms):
    """
    Projects the target distribution using the Bellman update.

    Args:
        next_q_dist (torch.Tensor): Next state Q-distribution (batch_size, num_atoms).
        rewards (torch.Tensor): Rewards from the batch (batch_size, 1).
        dones (torch.Tensor): Done flags from the batch (batch_size, 1).
        gamma (float): Discount factor.
        v_min (float): Minimum value for value distribution.
        v_max (float): Maximum value for value distribution.
        num_atoms (int): Number of atoms in the distribution.

    Returns:
        torch.Tensor: Projected target distribution (batch_size, num_atoms).
    """
    delta_z = (v_max - v_min) / (num_atoms - 1)
    z = torch.linspace(v_min, v_max, num_atoms).to(rewards.device)  # Atom values

    # Compute the target distribution support
    tz = rewards.unsqueeze(-1) + gamma * (1 - dones.unsqueeze(-1)) * z.unsqueeze(0)
    tz = tz.clamp(v_min, v_max)

    # Map values to categorical bins
    b = (tz - v_min) / delta_z
    l, u = b.floor().long(), b.ceil().long()
    l = l.clamp(0, num_atoms - 1)
    u = u.clamp(0, num_atoms - 1)

    # Distribute probability mass
    projected_dist = torch.zeros_like(next_q_dist)
    projected_dist.scatter_add_(dim=-1, index=l, src=next_q_dist * (u.float() - b))
    projected_dist.scatter_add_(dim=-1, index=u, src=next_q_dist * (b - l.float()))
    
    return projected_dist


tensor([0.5824, 0.4319, 0.4766, 0.7789, 0.7979, 0.4295, 0.6930, 0.3343, 0.4795,
        0.7065, 0.4422, 0.4728, 0.6305, 0.3167, 0.6947, 0.3992, 0.4781, 0.6013,
        0.4806, 0.6819, 0.4991, 0.4540, 0.3671, 0.7688, 0.7215, 0.4388, 0.5435,
        0.3948, 0.5157, 0.5096, 0.4434, 0.6468])

In [None]:
def compute_critic_target_distribution(critic_target, actor_target, reward, next_state, done, gamma, num_atoms, v_min, v_max):
    """
    Computes the target distribution for the critic network.

    Args:
        critic_target (nn.Module): Target critic network.
        actor_target (nn.Module): Target actor network.
        reward (torch.Tensor): Rewards from the batch.
        next_state (TensorDict): TensorDict of next states.
        done (torch.Tensor): Done flags from the batch.
        gamma (float): Discount factor.
        num_atoms (int): Number of atoms for the categorical distribution.
        v_min (float): Minimum value for value distribution.
        v_max (float): Maximum value for value distribution.

    Returns:
        torch.Tensor: Projected target distribution (batch_size, num_atoms).
    """
    z = torch.linspace(v_min, v_max, num_atoms).to(reward.device)  # Atom values

    # Get next-action probabilities from the target actor
    next_action_probs = actor_target(next_state).get("action_probs")  # Shape: (batch_size, num_actions)
    next_q_dist = critic_target(next_state).get("q_dist")  # Shape: (batch_size, num_actions, num_atoms)
    
    # Compute the expected action distribution
    next_q_values = (next_q_dist * z).sum(dim=-1)  # Shape: (batch_size, num_actions)
    best_next_action = torch.argmax(next_q_values, dim=-1)  # Shape: (batch_size,)
    best_next_q_dist = next_q_dist[torch.arange(next_q_dist.size(0)), best_next_action]  # Shape: (batch_size, num_atoms)
    
    # Use the categorical projection function
    target_dist = categorical_projection(best_next_q_dist, reward, done, gamma, v_min, v_max, num_atoms)
    return target_dist


In [21]:
z = torch.linspace(0, 1, 10)
q = torch.randint(0, 10, (3,4,10)).float()
q_v = (q*z)
print(q_v.shape)
print(q_v)
print(q_v.sum(dim=-1).shape)
q_v.sum(dim=-1)

torch.Size([3, 4, 10])
tensor([[[0.0000, 0.4444, 1.7778, 2.0000, 1.3333, 3.3333, 4.0000, 0.0000,
          7.1111, 8.0000],
         [0.0000, 0.4444, 2.0000, 2.0000, 1.7778, 0.0000, 1.3333, 3.8889,
          2.6667, 2.0000],
         [0.0000, 0.1111, 2.0000, 2.3333, 2.2222, 1.1111, 0.0000, 0.0000,
          6.2222, 4.0000],
         [0.0000, 0.0000, 1.1111, 1.6667, 1.3333, 2.2222, 2.0000, 3.8889,
          3.5556, 5.0000]],

        [[0.0000, 0.7778, 1.3333, 2.6667, 3.5556, 1.6667, 5.3333, 2.3333,
          8.0000, 4.0000],
         [0.0000, 0.1111, 0.8889, 1.6667, 0.8889, 2.2222, 3.3333, 0.0000,
          6.2222, 8.0000],
         [0.0000, 0.2222, 0.8889, 2.0000, 0.0000, 3.8889, 2.6667, 3.8889,
          8.0000, 4.0000],
         [0.0000, 0.0000, 2.0000, 0.6667, 1.3333, 3.3333, 3.3333, 0.0000,
          2.6667, 1.0000]],

        [[0.0000, 0.4444, 2.0000, 3.0000, 1.3333, 0.0000, 6.0000, 6.2222,
          4.4444, 2.0000],
         [0.0000, 0.6667, 1.7778, 2.0000, 2.6667, 1.1111, 4.0000

tensor([[28.0000, 16.1111, 18.0000, 20.7778],
        [29.6667, 23.3333, 25.5556, 14.3333],
        [25.4444, 27.6667, 21.6667, 25.4444]])

In [11]:
import torch
from tensordict import TensorDict
actions = TensorDict(
    {
                "x_location": torch.randn(10, 30),
                "y_location": torch.randn(10, 30),
                "color_values": torch.randn(10, 11),
                "submit": torch.randn(10, 2),
            }
)

print(actions)

TensorDict(
    fields={
        color_values: Tensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
        submit: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        x_location: Tensor(shape=torch.Size([10, 30]), device=cpu, dtype=torch.float32, is_shared=False),
        y_location: Tensor(shape=torch.Size([10, 30]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


In [24]:
torch.cat([torch.argmax(x, dim=-1).unsqueeze(-1) for x in actions.values()], dim=-1)

tensor([[24,  7,  5,  0],
        [24, 13,  7,  0],
        [22,  8,  1,  1],
        [ 4, 12, 10,  0],
        [17, 14,  9,  0],
        [29, 17,  0,  0],
        [ 0,  6,  0,  1],
        [17, 28,  0,  1],
        [ 9, 17,  3,  1],
        [ 0, 24,  1,  0]])