In [1]:
from torchrl.data.replay_buffers import (
    TensorDictReplayBuffer,
    LazyTensorStorage,
    PrioritizedSampler,
)
from tensordict import TensorDict

import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import gymnasium as gym
import numpy as np

from rlarcworld.arc_dataset import ArcDataset, ArcSampleTransformer
from rlarcworld.enviroments.arc_batch_grid_env import ArcBatchGridEnv
from rlarcworld.enviroments.wrappers.rewards import PixelAwareRewardWrapper
from rlarcworld.agent.actor import ArcActorNetwork
from rlarcworld.agent.critic import ArcCriticNetwork


import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [2]:
grid_size = 30
color_values = 11
batch_size = 2
n_steps = 100  # For testing

## Dataset
dataset = ArcDataset(
    "./dataset/training",
    keep_in_memory=True,
    transform=ArcSampleTransformer((grid_size, grid_size), examples_stack_dim=10),
)
train_samples = DataLoader(dataset=dataset, batch_size=batch_size)
## Environment
env = ArcBatchGridEnv(size=grid_size, color_values=color_values)
env = PixelAwareRewardWrapper(env)

## The atoms are essentially the "bins" or "categories"
## into which the possible range of returns is divided.
## Consequently depends of the reward behaviour
n_atoms = {"pixel_wise": 100, "binary": 3}

## Networks
actor = ArcActorNetwork(size=grid_size, color_values=color_values)
critic = ArcCriticNetwork(n_atoms=n_atoms)
## Target Networks
target_actor = actor
target_critic = critic

In [3]:
rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(batch_size),
    sampler=PrioritizedSampler(max_capacity=batch_size, alpha=1.0, beta=1.0),
    priority_key="priority",
)

In [4]:
for episode, samples in enumerate(train_samples):
    observation, information = env.reset(
        options={"batch": samples["task"], "examples":samples["examples"]}, seed=episode
    )
    for step in range(n_steps):
        init_state = env.get_wrapper_attr("state")
        actions = actor.predict(init_state)
        obs, reward, terminated, truncated, info = env.step(actor.get_discrete_actions(actions))
        final_state = env.get_wrapper_attr("state")
        final_state.update({"actions": actions})
        critic.predict(final_state)
        if terminated or truncated:
            break