In [None]:
import torch
from torch import Tensor, nn
from torchvision.transforms import ToTensor

import gymnasium as gym
from minigrid.wrappers import FullyObsWrapper, ImgObsWrapper, ReseedWrapper, RGBImgObsWrapper

from kitten.common.rng import global_seed, Generator
from kitten.nn import Value
from kitten.policy import Policy
from kitten.experience.collector import GymCollector
from kitten.experience.util import build_transition_from_list, build_replay_buffer
from kitten.experience import Transitions, AuxiliaryMemoryData
from kitten.intrinsic.rnd import RandomNetworkDistillation
from kitten.rl.common import td_lambda, monte_carlo_return

import numpy as np

from tqdm import tqdm
from matplotlib import pyplot as plt

from cats.teleport import *
from cats.cats import TeleportationResetModule
from cats.reset import *

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_env():
    env = gym.make("MiniGrid-FourRooms-v0")
    env = FullyObsWrapper(env)
    #env = RGBImgObsWrapper(env)
    env = ImgObsWrapper(env)
    env = ReseedWrapper(env)
    env.reset()[0].shape
    return env

def visualise(env, batch: Transitions):
    obs, _ = env.reset()

    draw_map = np.zeros(obs.shape)
    white = np.ones(obs.shape)
    red = np.zeros(obs.shape)
    red[:, :, 0] = 1
    green = np.zeros(obs.shape)
    green[:, :, 1] = 1
    blue = np.zeros(obs.shape)
    blue[:, :, 2] = 1

    red

    # Draw walls
    has_wall = np.expand_dims((obs[:, :, 0] == 2), axis=-1)
    draw_map += white * has_wall * 0.8
    # Draw goal
    has_wall = np.expand_dims((obs[:, :, 0] == 8), axis=-1)
    draw_map += green * has_wall * 1
    
    # Draw agent exploration
    states = batch.s_0.detach().cpu().numpy() # [batch, a, b, c]
    has_agent = states[:, :, :, 0] == 10
    has_agent = np.sum(has_agent, axis=0)
    has_agent = np.maximum(0,np.log(has_agent))
    has_agent = np.expand_dims(has_agent / has_agent.max(), axis=-1)
    draw_map += red * has_agent

    plt.imshow(draw_map)



In [None]:
# Define Value Module, Transformations, Policy
class MinigridValue(Value):

    def __init__(self) -> None:
        super().__init__()
        # We use the CNN from https://minigrid.farama.org/content/training/
        # With ReLu changed to LeakyReLu

        self.cnn = nn.Sequential(
            nn.Conv2d(3, 16, (2, 2)),
            nn.LeakyReLU(),
            nn.Conv2d(16, 32, (2, 2)),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, (2, 2)),
            nn.LeakyReLU(),
            nn.Flatten(),
        )

        # Followed by a simple MLP
        self.linear = nn.Sequential(
            nn.LazyLinear(128),
            nn.LeakyReLU(),
            nn.LazyLinear(128),
            nn.LeakyReLU(),
            nn.LazyLinear(1)
        )
    def forward(self, x) -> Tensor:
        return self.linear(self.cnn(x))
    
    @property
    def value(self) -> Value:
        return self

    def v(self, s: Tensor) -> Tensor:
        d = len(s.shape) == 3
        if d:
            s = s.unsqueeze(0)
        if s.shape[-1] == 3:
            s = s.permute(0,3,1,2)
        v = self.forward(s).squeeze()
        return v

def build_rnd() -> RandomNetworkDistillation:
    def build_net():
        return nn.Sequential(
        nn.Conv2d(3, 8, (2, 2)),
        nn.LeakyReLU(),
        nn.Conv2d(8, 16, (2, 2)),
        nn.LeakyReLU(),
        nn.Flatten(),
        nn.LazyLinear(128)
    )

    return RandomNetworkDistillation(
        build_net(),
        build_net(),
        reward_normalisation=True
    )

class ExplorationPolicy(Policy):
    """ Purely Random Exploration
    """
    def __init__(self, env: gym.Env, rng: Generator, repeat_probability: float=0.9) -> None:
        super().__init__(fn=None)
        self.env = env
        self.rng = rng
        self.p = repeat_probability
        self.previous_action = None
    
    def __call__(self, obs):
        if self.previous_action is None or self.rng.numpy.random() > self.p:
            self.previous_action = self.env.action_space.sample()
            return self.previous_action
        else:
            return self.previous_action
    
    def reset(self) -> None:
        self.previous_action = None

In [None]:
def experiment_random(steps: int = 1000, seed: int = 0):
    env = build_env()
    rng = global_seed(seed=seed)
    policy = ExplorationPolicy(env, rng=rng.build_generator(), repeat_probability=0.5)
    collector = GymCollector(policy, env)
    batch = build_transition_from_list(collector.collect(steps), DEVICE)
    return env, batch

def experiment_teleport(steps: int = 10000,  seed: int = 0):
    lmbda, gamma = 0.9, 0.99
    env = build_env()
    rng = global_seed(seed=seed)
    policy = ExplorationPolicy(env, rng.build_generator(), repeat_probability=0.5)
    intrinsic = build_rnd().to(DEVICE)
    value = MinigridValue().to(DEVICE)
    optim_value = torch.optim.Adam(value.parameters())
    memory, _ = build_replay_buffer(env, capacity=steps)
    collector = GymCollector(policy, env, memory=memory)

    tm = LatestEpisodeTeleportMemory(rng.build_generator(), device=DEVICE)
    rm = ResetMemory(env, capacity=1, rng=rng.build_generator(), device=DEVICE)
    ts = EpsilonGreedyTeleport(value, rng.build_generator(), e=0)
    trm = TeleportationResetModule(rm, tm, ts)

    log = {"value_loss": [], "return": []}

    tm.reset(collector.env, collector.obs)
    step = 0
    while step < steps:
        # Collection Phase
        batch = []
        while True:
            data = list(collector.collect(n=1)[-1])
            memory.append(data)
            batch.append(data)
            tm.update(collector.env, obs=data[0])
            if data[-2] or data[-1]:
                # Terminated or truncated
                break

        batch = build_transition_from_list(batch, device=DEVICE)
        step += batch.shape[0]

        # Manually reshape and scale
        batch.s_0 = batch.s_0.permute(0,3,1,2)
        batch.s_1 = batch.s_1.permute(0,3,1,2)

        # Reset
        trm.select(collector)

        # Intrinsic
            # Update
        intrinsic.update(batch, aux=AuxiliaryMemoryData.placeholder(batch), step=step)
            # Override reward

        r_t, r_e, r_i = intrinsic.reward(batch)
        batch.r = r_i
        # Update Value Function
        value_targets = td_lambda(batch, lmbda, gamma, value)
        total_value_loss = 0
        for _ in range(8):
            optim_value.zero_grad()
            value_loss = ((value_targets - value.v(batch.s_0))**2).mean()
            total_value_loss += value_loss.item()
            value_loss.backward()
            optim_value.step()

        # Logging
        log["return"].append(value_targets[0].item())
        log["value_loss"].append(total_value_loss)

    return env, memory.sample(steps)[0], value, log

In [None]:
env, batch, value, log = experiment_teleport(10000)

In [None]:
visualise(env, batch)