In [1]:
from torchrl.data.replay_buffers import (
    TensorDictReplayBuffer,
    LazyTensorStorage,
    PrioritizedSampler,
)
from tensordict import TensorDict

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np

from rlarcworld.arc_dataset import ArcDataset, ArcSampleTransformer
from rlarcworld.enviroments.arc_batch_grid_env import ArcBatchGridEnv
from rlarcworld.enviroments.wrappers.rewards import PixelAwareRewardWrapper
from rlarcworld.agent.actor import ArcActorNetwork
from rlarcworld.agent.critic import ArcCriticNetwork


import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [2]:
grid_size = 30
color_values = 11
batch_size = 2
n_steps = 100  # For testing

## Dataset
dataset = ArcDataset(
    "./dataset/training",
    keep_in_memory=True,
    transform=ArcSampleTransformer((grid_size, grid_size), examples_stack_dim=10),
)
train_samples = DataLoader(dataset=dataset, batch_size=batch_size)
## Environment
env = ArcBatchGridEnv(size=grid_size, color_values=color_values)
env = PixelAwareRewardWrapper(env)

## The atoms are essentially the "bins" or "categories"
## into which the possible range of returns is divided.
## Consequently depends of the reward behaviour
n_atoms = {"pixel_wise": 100, "binary": 3}

## Networks
actor = ArcActorNetwork(size=grid_size, color_values=color_values)
critic = ArcCriticNetwork(n_atoms=n_atoms)
## Target Networks
target_actor = actor
target_critic = critic

In [3]:
rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(batch_size),
    sampler=PrioritizedSampler(max_capacity=batch_size, alpha=1.0, beta=1.0),
    priority_key="priority",
)

In [4]:
for episode, samples in enumerate(train_samples):
    observation, information = env.reset(
        options={"batch": samples["task"], "examples":samples["examples"]}, seed=episode
    )
    for step in range(n_steps):
        init_state = env.get_wrapper_attr("state")
        actions = actor.predict(init_state)
        obs, reward, terminated, truncated, info = env.step(actor.get_discrete_actions(actions))
        rb.
        final_state = env.get_wrapper_attr("state")
        final_state.update({"actions": actions})
        critic.predict(final_state)
        if terminated or truncated:
            break

In [None]:
>>> from torchrl.data.replay_buffers import TensorDictReplayBuffer as TDRB, LazyTensorStorage, PrioritizedSampler
>>> from tensordict import TensorDict
>>> rb = TDRB(
...     storage=LazyTensorStorage(10),
...     sampler=PrioritizedSampler(max_capacity=10, alpha=1.0, beta=1.0),
...     priority_key="priority",  # This kwarg isn't present in regular RBs
... )
>>> priority = torch.tensor([0, 1000])
>>> data_0 = TensorDict({"reward": 0, "obs": [0], "action": [0], "priority": priority[0]}, [])
>>> data_1 = TensorDict({"reward": 1, "obs": [1], "action": [2], "priority": priority[1]}, [])
>>> data = torch.stack([data_0, data_1])
>>> rb.extend(data)
>>> rb.update_priority(data)  # Reads the "priority" key as indicated in the constructor
>>> sample, info = rb.sample(10, return_info=True)
>>> print(sample['index'])  # The index is packed with the tensordict
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [None]:
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.objectives.utils import ValueEstimators
from torchrl.objectives.utils import default_value_kwargs
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


def _init(
    self,
    actor_network: TensorDictModule,
    value_network: TensorDictModule,
) -> None:
    super(type(self), self).__init__()

    self.convert_to_functional(
        actor_network,
        "actor_network",
        create_target_params=True,
    )
    self.convert_to_functional(
        value_network,
        "value_network",
        create_target_params=True,
        compare_against=list(actor_network.parameters()),
    )

    self.actor_in_keys = actor_network.in_keys

    # Since the value we'll be using is based on the actor and value network,
    # we put them together in a single actor-critic container.
    actor_critic = ActorCriticWrapper(actor_network, value_network)
    self.actor_critic = actor_critic
    self.loss_function = "l2"

def make_value_estimator(self, value_type: ValueEstimators, **hyperparams):
    hp = dict(default_value_kwargs(value_type))
    if hasattr(self, "gamma"):
        hp["gamma"] = self.gamma
    hp.update(hyperparams)
    value_key = "state_action_value"
    if value_type == ValueEstimators.TD1:
        self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
    elif value_type == ValueEstimators.TD0:
        self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
    elif value_type == ValueEstimators.GAE:
        raise NotImplementedError(
            f"Value type {value_type} it not implemented for loss {type(self)}."
        )
    elif value_type == ValueEstimators.TDLambda:
        self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp)
    else:
        raise NotImplementedError(f"Unknown value type {value_type}")
    self._value_estimator.set_keys(value=value_key)