In [None]:
import torch
import copy
from hydra import initialize, compose
from omegaconf import DictConfig
import gymnasium as gym

from kitten.common.rng import global_seed, Generator
from kitten.policy import Policy
from kitten.experience.collector import GymCollector
from kitten.experience.util import build_transition_from_list, build_replay_buffer
from kitten.experience import Transitions, AuxiliaryMemoryData
from kitten.rl.common import td_lambda

import numpy as np

from matplotlib import pyplot as plt

from cats.agent.minigrid.value import MinigridValue
from cats.agent.experiment import ExperimentBase, build_env

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def visualise(env, batch: Transitions | list[np.ndarray]):
    obs, _ = env.reset()

    draw_map = np.zeros(obs.shape)
    white = np.ones(obs.shape)
    red = np.zeros(obs.shape)
    red[:, :, 0] = 1
    green = np.zeros(obs.shape)
    green[:, :, 1] = 1
    blue = np.zeros(obs.shape)
    blue[:, :, 2] = 1

    # Draw walls
    has_wall = np.expand_dims((obs[:, :, 0] == 2), axis=-1)
    draw_map += white * has_wall * 0.8
    # Draw goal
    has_wall = np.expand_dims((obs[:, :, 0] == 8), axis=-1)
    draw_map += green * has_wall * 1

    # Draw agent exploration
    if isinstance(batch, Transitions):
        states = batch.s_0.detach().cpu().numpy()  # [batch, a, b, c]
    else:
        states = np.array(batch)
    has_agent = states[:, :, :, 0] == 10
    has_agent = np.sum(has_agent, axis=0)
    #has_agent = np.maximum(0, np.log(has_agent))
    has_agent = np.expand_dims(has_agent / has_agent.max(), axis=-1)
    draw_map += red * has_agent

    plt.imshow(draw_map)

In [None]:
class ExplorationPolicy(Policy):
    """Purely Random Exploration"""

    def __init__(
        self, env: gym.Env, rng: Generator, repeat_probability: float = 0.9
    ) -> None:
        super().__init__(fn=None)
        self.action_space = copy.deepcopy(env.action_space)
        self.action_space.seed(int(rng.numpy.integers(2**32 - 1)))
        self.rng = rng
        self.p = repeat_probability
        self.previous_action = None

    def __call__(self, obs):
        if self.previous_action is None or self.rng.numpy.random() > self.p:
            self.previous_action = self.action_space.sample()
            return self.previous_action
        else:
            return self.previous_action

    def reset(self) -> None:
        self.previous_action = None


class TeleportStrategyExperiment(ExperimentBase):

    def __init__(
        self,
        cfg: DictConfig,
        deprecated_testing_flag: bool = False,
        device: str = "cpu",
    ) -> None:
        super().__init__(cfg,
                         normalise_obs=False,
                         deprecated_testing_flag=False,
                         device=device)
        self._gamma: float = cfg.algorithm.gamma
        self._lmbda: float = cfg.algorithm.lmbda


    def _build_policy(self) -> None:
        # Random Policy
        self._policy = ExplorationPolicy(
            self.env, self.rng.build_generator(), repeat_probability=self.cfg.policy.p
        )
        # Build Value Estimator
        self.value = MinigridValue().to(DEVICE)
        self.optim_v = torch.optim.Adam(params=self.value.parameters()) 

    @property
    def policy(self):
        return self._policy
    
    @property
    def value_container(self):
        return self.value

    def run(self):
        self.tm.reset(self.collector.env, self.collector.env)
        step, steps = 0, self.cfg.train.total_frames
        while step < steps:
            
            # Collect batch
            batch_ = []
            goal_step = step + 100 # TODO
            while step < goal_step:
                step += 1
                data = self.collector.collect(n=1)[-1]
                batch_.append(data)
                self.tm.update(self.collector.env, obs=data[0])
                if data[-2] or data[-1]: # Terminated or truncated
                    break
            batch = build_transition_from_list(batch_, device=DEVICE)

            # Reshape for ConvNet
            batch.s_0 = batch.s_0.permute(0, 3, 1, 2)
            batch.s_1 = batch.s_1.permute(0, 3, 1, 2)

            # Intrinsic Update
            for _ in range(4):
                self.intrinsic.update(batch, aux=AuxiliaryMemoryData.placeholder(batch), step=step)
            
            # Override rewards
            _, _, r_i = self.intrinsic.reward(batch)
            batch.r = r_i
            value_targets = td_lambda(batch, self._lmbda, self._gamma, self.value)
            total_value_loss = 0
            for _ in range(4):
                self.optim_v.zero_grad()
                value_loss = ((value_targets - self.value.v(batch.s_0)) ** 2).mean()
                total_value_loss += value_loss.item()
                value_loss.backward()
                self.optim_v.step()

            # Always reset? Or only reset on truncation?
            self._reset(batch_[-1][0], batch_[-1][4])

In [None]:
def experiment_random(steps: int = 1000, seed: int = 0):
    with initialize(version_base=None, config_path="cats/config"):
        cfg = compose(
            config_name="defaults_online.yaml",
            overrides=[
                f"seed={seed}",
                "cats.fixed_reset=true",
                "cats.teleport.enable=true",
            ],
        )
    env = build_env(cfg)
    rng = global_seed(seed=seed)
    policy = ExplorationPolicy(env, rng=rng.build_generator(), repeat_probability=0.5)
    collector = GymCollector(policy, env)
    batch = build_transition_from_list(collector.collect(steps), DEVICE)
    return env, batch


def experiment_teleport(steps: int = 10000, seed: int = 0):
    with initialize(version_base=None, config_path="cats/config"):
        cfg = compose(
            config_name="defaults_online.yaml",
            overrides=[
                f"seed={seed}",
                f"train.total_frames={steps}",
                "cats.fixed_reset=true",
                "cats.teleport.enable=true",
                f"env.max_steps={steps}"
            ],
        )
    experiment = TeleportStrategyExperiment(cfg, device=DEVICE)
    experiment.run()    
    return experiment

In [None]:
experiment = experiment_teleport(steps=10000)
env, batch = experiment.env, experiment.memory.sample(len(experiment.memory))[0]
visualise(env, batch)

In [None]:
visualise(experiment.env, batch=experiment.logger._engine.results['reset_obs'])

In [None]:
def grid(env):
    obs, _ = env.reset()
    has_agent = obs[:, :, 0] == 10
    obs[:, :, 0] -= (has_agent * 9).astype(np.uint8)
    grid = np.reshape(obs, (1,1, *obs.shape))
    grid = np.repeat(grid, len(obs), 0)
    grid = np.repeat(grid, len(obs), 1)
    for x in range(len(obs)):
        for y in range(len(obs)):
            grid[x,y,x,y,0] = 10
    return grid

experiment_grid = grid(experiment.env)
n = 19
values = np.zeros((n,n))
for x in range(n):
    for y in range(n):
        s = torch.tensor(experiment_grid[x,y], device=DEVICE).to(torch.float32)
        v = experiment.value.v(s)
        #v = experiment.intrinsic(s.unsqueeze(0).permute(0, 3, 1, 2))
        values[x,y] = v.item()
fig, ax = plt.subplots()
im = ax.imshow(values)
fig.colorbar(im, ax=ax)