In [1]:
from torchrl.data.replay_buffers import (
    TensorDictReplayBuffer,
    LazyTensorStorage,
    PrioritizedSampler,
)
from tensordict import TensorDict

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np

from rlarcworld.arc_dataset import ArcDataset, ArcSampleTransformer
from rlarcworld.enviroments.arc_batch_grid_env import ArcBatchGridEnv
from rlarcworld.enviroments.wrappers.rewards import PixelAwareRewardWrapper
from rlarcworld.agent.actor import ArcActorNetwork
from rlarcworld.agent.critic import ArcCriticNetwork

In [2]:
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [4]:
grid_size = 30
color_values = 11
batch_size = 2
n_steps = 100  # For testing

## Dataset
dataset = ArcDataset(
    "./dataset/training",
    keep_in_memory=True,
    transform=ArcSampleTransformer((grid_size, grid_size), examples_stack_dim=10),
)
train_samples = DataLoader(dataset=dataset, batch_size=batch_size)
## Environment
env = ArcBatchGridEnv(size=grid_size, color_values=color_values)
env = PixelAwareRewardWrapper(env)

## The atoms are essentially the "bins" or "categories"
## into which the possible range of returns is divided.
## Consequently depends of the reward behaviour
n_atoms = {"pixel_wise": 100, "binary": 3}

## Networks
actor = ArcActorNetwork(size=grid_size, color_values=color_values)
critic = ArcCriticNetwork(size=grid_size, color_values=color_values,n_atoms=n_atoms)
## Target Networks
actor_target = actor
critic_target = critic

In [7]:
for i in dataset:
    print(i["examples"].shape)
    break

torch.Size([10, 2, 30, 30])


In [5]:
rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(batch_size),
    sampler=PrioritizedSampler(max_capacity=batch_size, alpha=1.0, beta=1.0),
    priority_key="priority",
)

**Pseudocode Q-Learning Algorithm for Categorical Actions**

- Get the current observations in the state (s); Includes the different kind of grids and values managed by the environment.
- Pass the *s* to the actor to get the action *a*, reward *r*, done flag *d* and next state *s'*.
- Compute the categorical distribution over returns *Z_pi* for the next state *s'* using the critic network.
- Get the best action *a'\** using the target actor on next state *s'*


In [9]:
for episode, samples in enumerate(train_samples):
    observation, information = env.reset(
        options={"batch": samples["task"], "examples":samples["examples"]}, seed=episode
    )
    for step in range(n_steps):
        init_state = env.get_wrapper_attr("state")
        actions = actor(init_state)
        obs, reward, terminated, truncated, info = env.step(actor.get_discrete_actions(actions))
        final_state = env.get_wrapper_attr("state")
        final_state.update({"actions": actions})
        critic.predict(final_state)
        if terminated or truncated:
            break

RuntimeError: Input type (long int) and bias type (float) should be the same

In [7]:
from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
from tensordict import TensorDict
rb = TDRB(
    storage=LazyTensorStorage(10),
    sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
    priority_key="priority",  # This kwarg isn't present in regular RBs
)
priority = torch.tensor([0, 1000])
data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
data = torch.stack([data_0, data_1])
rb.extend(data)
rb.update_priority(data)  # Reads the "priority" key as indicated in the constructor
sample, info = rb.sample(10, return_info=True)
print(sample['index'])  # The index is packed with the tensordict
torch.tdensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

TypeError: ReplayBuffer.update_priority() missing 1 required positional argument: 'priority'

In [None]:
import gymnasium as gym
size = 30
color_values = 11
action_space = gym.spaces.Dict(
            {
                "x_location": gym.spaces.Sequence(
                    gym.spaces.Discrete(size), stack=True
                ),
                "y_location": gym.spaces.Sequence(
                    gym.spaces.Discrete(size), stack=True
                ),
                "color_values": gym.spaces.Sequence(
                    gym.spaces.Discrete(color_values), stack=True
                ),
                "submit": gym.spaces.Sequence(gym.spaces.Discrete(2), stack=True),
            }

        )