# Visualising MCC Exploration

This notebook logs exploratory results on adding teleportation on MCC with state coverage visualisation

21/01/2024
- Naive teleportation to argmax works
- Longer episodes are better than shorter
- Different intrinsic rewards show significantly different behavior
- Even naively, general improvement over pure intrinsic
- Fails to beat intrinsic + extrinsic: perhaps this is due to negative extrinsic reward revealing data on target? Not comparable, and I think fully explored in that reward shifting paper
- Keeps teleporting to same target
- This may be a problem with DDPG


28/01/2024
- Probabilistic teleportation work well
- Environment reset stochasticity is important
- Time limit aware Q functions are difficult to train!
- Proposal: Dynamic Truncation!

4/02/2024
- ICM and RND leads to inherently different results - RND should be prioritised
- CATS fails to improve over baseline on RND with fixed reset, but does in ICM. After reset, the new trajectory follows the previous trajectory too closely, while resetting from the start leads to more divergence across the entire episode (and hence more exploration)
- Fixing the reset states leads to improved analysis
- Policy function gets stuck in the local minima of the Q function
- Analyse DQN instead? Skip parametrized policy function and use an approximator?? Maybe implement QT-opt https://arxiv.org/pdf/1806.10293.pdf. This may be important to obtain interesting experiment results, since on MCC the policy generally fails to follow the critic even on large learning rates (why??)


TODO:
- Confidence Bounds (How? Without latent density estimator?)
- Termination as an action
- Epsilon greedy
- Time aware exploration

Known Failure Modes
- Teleporting to the end of the episode, and immediately truncating
- 

Ideas
- Bootstrapped Q value estimate for confidence bound guided estimation?

Interesting observations
- Qt_opt directly on critic, rather than target network explores faster??

In [None]:
# Define Imports and shared training information

# Std
import copy
import math
from typing import Optional

# Training
import numpy as np
import torch
from tqdm import tqdm
from omegaconf import DictConfig
import gymnasium as gym
from gymnasium.wrappers import TimeAwareObservation

# Evaluation
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.cm as cm
from sklearn.neighbors import KernelDensity

# Curiosity
from curiosity.experience import Transition
from curiosity.experience.collector import GymCollector
from curiosity.policy import ColoredNoisePolicy, Policy
from curiosity.experience.memory import ReplayBuffer
from curiosity.experience.util import build_replay_buffer
from curiosity.util.util import *
from curiosity.rl.qt_opt import cross_entropy_method, QTOpt

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
COLLECTION_STEPS = 4096
MAX_EPISODE_STEPS = 400
ENV = "MountainCarContinuous-v0"
#ENV = "Pendulum-v1"

cfg = DictConfig({
    "seed": 6230,
    "env": {
        "name": ENV,
    },
    "memory": {
        "type": "experience_replay",
    #    "type": "prioritized_experience_replay",
    #    "alpha": 0.6,
    #    "epsilon": 0.1,
    #    "beta_0": 0.4,
    #    "beta_annealing_steps": COLLECTION_STEPS,
        "normalise_observation": True
    },
    "train": {
        "initial_collection_size": 256,
        "total_frames": COLLECTION_STEPS,
        "minibatch_size": 128
    },
    "algorithm": {
        "type": "qt_opt",
        "gamma": 0.99,
        "tau": 0.005,
        "lr": 0.01,
        "update_frequency": 1,
        "clip_grad_norm": 1,
        "ensemble_number": 5,
        "actor": {
            "features": 128
        },
        "critic": {
            "features": 128
        }
    },
    "intrinsic": {
        "type": "rnd",
        "encoding_size": 32,
        "lr": 0.0003,
        "int_coef": 1, 
        "ext_coef": 2,
        "reward_normalisation": True,
        "normalised_obs_clip": 5
    },
    "noise": {
        "scale": 0.1,
        "beta": 0
    }
})


In [None]:
from typing import Callable, Union, List, Tuple
import random

import numpy as np
from torch import Tensor
import torch.nn as nn

from curiosity.rl import Algorithm
from curiosity.rl.qt_opt import cross_entropy_method
from curiosity.experience import AuxiliaryMemoryData
from curiosity.nn import Critic, AddTargetNetwork

class QTOptCats(Algorithm):
    """ Q-learning for continuous actions with Cross-Entropy Maximisation

    With ensemble uncertainty estimation and dual critic

    Kalashnikov et al. QT-Opt: Scalable Deep Reinforcement Learning for Vision-Based Robotic Manipulation
    """
    def __init__(self,
                 build_critic: Callable[[], Critic],
                 obs_space: gym.spaces.Box,
                 action_space: gym.spaces.Box,
                 ensemble_number: int = 5,
                 gamma: float = 0.99,
                 lr: float = 1e-3,
                 tau: float = 0.005,
                 cem_n: int = 64,
                 cem_m: int = 6,
                 cem_n_iterations: int = 2,
                 clip_grad_norm: Optional[float] = 1,
                 update_frequency: int = 1,
                 device: str = "cpu",
                 **kwargs) -> None:
        super().__init__()

        self.critics = [AddTargetNetwork(build_critic(), device=device) for _ in range(ensemble_number)]

        self.device=device

        self._ensemble_number = ensemble_number
        self._gamma = gamma
        self._tau = tau
        self._env_action_scale = torch.tensor(action_space.high-action_space.low, device=device) / 2.0
        self._env_action_min = torch.tensor(action_space.low, dtype=torch.float32, device=device)
        self._env_action_max = torch.tensor(action_space.low, dtype=torch.float32, device=device)
        self._obs_space = obs_space
        self._action_space = action_space
        self._clip_grad_norm = clip_grad_norm
        self._update_frequency = update_frequency
        self._n = cem_n
        self._m = cem_m
        self._n_iterations = cem_n_iterations
        self._chosen_critic = 0

        self._optim_critic = torch.optim.Adam(
            params=torch.nn.ModuleList(self.critics).parameters(),
            lr=lr
        )

        self.loss_critic_value = 0
    
    @property
    def critic(self):
        return self.critics[self._chosen_critic]

    def policy_fn(self, s: Union[Tensor, np.ndarray], critic: Optional[Critic] = None) -> Tensor:
        if isinstance(s, np.ndarray):
            s = torch.tensor(s, device=self.device)
        if critic is None:
            critic = self.critic
        squeeze = False
        if self._obs_space.shape == s.shape:
            squeeze = True
            s = s.unsqueeze(0)
        result = cross_entropy_method(
            s_0=s,
            critic_network=critic,
            action_space=self._action_space,
            n=self._n,
            m=self._m,
            n_iterations=self._n_iterations,
            device=self.device
        )
        if squeeze:
            result = result.squeeze()
        return result

    def reset_critic(self):
        self._chosen_critic = random.randint(0,self._ensemble_number-1)

    def _critic_update(self, batch: Transition, aux: AuxiliaryMemoryData):
        x = [critic.q(batch.s_0, batch.a).squeeze() for critic in self.critics]        
        with torch.no_grad():
            # Implement a variant of Clipped Double-Q Learning
            # Randomly sample two networks
            sampled_critics = random.sample(self.critics, 2)
            a_1 = self.policy_fn(batch.s_1, critic=sampled_critics[0].target)
            a_2 = self.policy_fn(batch.s_1, critic=sampled_critics[1].target)
            target_max_1 = sampled_critics[0].target.q(batch.s_1, a_1).squeeze()
            target_max_2 = sampled_critics[1].target.q(batch.s_1, a_2).squeeze()
            y = batch.r + (~batch.d) * torch.minimum(target_max_1, target_max_2) * self._gamma

        losses = []
        for x_i in x:
            losses.append(torch.mean((aux.weights * (y-x_i)) ** 2))
        loss_critic = sum(losses)

        loss_value = loss_critic.item()
        self._optim_critic.zero_grad()
        loss_critic.backward()
        if not self._clip_grad_norm is None:
            for critic in self.critics:
                nn.utils.clip_grad_norm_(critic.net.parameters(), self._clip_grad_norm)
        self._optim_critic.step()

        return loss_value

    def update(self, batch: Transition, aux: AuxiliaryMemoryData, step: int):
        if step % self._update_frequency == 0:
            self.loss_critic_value =  self._critic_update(batch, aux)
            for critic in self.critics:
                critic.update_target_network(tau = self._tau)
        return self.loss_critic_value

    def get_log(self):
        return {
            "critic_loss": self.loss_critic_value,
        }

    def get_models(self) -> List[Tuple[nn.Module, str]]:
        return list(zip(self.critics, [f"critic_{i}" for i in range(len(self.critics))]))
 

In [None]:
def entropy_memory(memory: ReplayBuffer):
    # Construct a density estimator
    s = Transition(*memory.sample(len(memory))[0]).s_0.cpu().numpy()
    kde = KernelDensity(kernel="gaussian", bandwidth="scott").fit(s)
    log_likelihoods = kde.score_samples(kde.sample(n_samples=10000))
    return -log_likelihoods.mean()

def visualise_memory(env: gym.Env, *memories: ReplayBuffer):
    """ Visualise state space for given environmentss
    """
    
    fig, ax = plt.subplots()
    ax.set_title("State Space Coverage")

    if env.spec.id == "MountainCarContinuous-v0":
        ax.set_xlim(env.observation_space.low[0], env.observation_space.high[0])
        ax.set_xlabel("Position")
        ax.set_ylim(env.observation_space.low[1], env.observation_space.high[1])
        ax.set_ylabel("Velocity")
    elif env.spec.id == "Pendulum-v1":
        ax.set_xlim(-math.pi, math.pi)
        ax.set_xlabel("Theta")
        ax.set_ylim(env.observation_space.low[2], env.observation_space.high[2])
        ax.set_ylabel("Angular Velocity")

    for memory, name in memories:
        batch = Transition(*memory.storage)
        s = batch.s_0.cpu().numpy()

        # Colors based on time
        norm = mpl.colors.Normalize(vmin=0, vmax=len(s)-1)
        cmap = cm.viridis
        m = cm.ScalarMappable(norm=norm, cmap=cmap)

        colors = m.to_rgba(np.linspace(0, len(s)-1, len(s)))

        if env.spec.id == "MountainCarContinuous-v0":
            ax.scatter(s[:, 0], s[:, 1], s=1, label=name, c=colors)
        elif env.spec.id == "Pendulum-v1":
            ax.scatter(np.arctan2(s[:, 1], s[:, 0]), s[:, 2], s=1, label=name, c=colors)

    
    fig.colorbar(m, ax=ax)
    ax.legend()

# Visualise V
# For mountain car

def visualise_experiment_v(experiment):

    X = torch.linspace(experiment.env.observation_space.low[0], experiment.env.observation_space.high[0], 100)
    Y = torch.linspace(experiment.env.observation_space.low[1], experiment.env.observation_space.high[1], 100)
    grid_X, grid_Y = torch.meshgrid((X, Y))

    # Time Aware
    #states = torch.stack((grid_X.flatten(), grid_Y.flatten(), 1+torch.zeros_like(grid_X.flatten())) ).T

    fig, axs = plt.subplots(1,3)
    fig.set_size_inches(20,5)
    ax = axs[0]
    states = torch.stack((grid_X.flatten(), grid_Y.flatten())).T
    states = states.to(DEVICE)
        # Observation Normalisation
    states_cpu = states.cpu()
    states = (states - experiment.policy.normalise_obs.mean) / experiment.policy.normalise_obs.std

    # V 
    values = experiment.V(states)
    norm = mpl.colors.Normalize(vmin=values.min(), vmax=values.max())
    cmap = cm.viridis
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    colors = m.to_rgba(values.detach().cpu())

    ax.set_title("Value Function Visualisation")
    ax.scatter(states_cpu[:,0], states_cpu[:, 1], c=colors)
    fig.colorbar(m, ax=ax)

    # Policy

    # Train new actor
    #
    # new_actor = build_actor(experiment.env, cfg.algorithm.actor.features).to(DEVICE)
    # optim_actor = torch.optim.Adam(params=new_actor.net.parameters(), lr=0.1)
    # X = torch.linspace(experiment.env.observation_space.low[0], experiment.env.observation_space.high[0], 10)
    # Y = torch.linspace(experiment.env.observation_space.low[1], experiment.env.observation_space.high[1], 10)
    # grid_X, grid_Y = torch.meshgrid((X, Y))
    # pseudo_states = torch.stack((grid_X.flatten(), grid_Y.flatten())).to(DEVICE)
    # pseudo_states = pseudo_states.T
    # pseudo_states = (pseudo_states - experiment.policy.normalise_obs.mean) / experiment.policy.normalise_obs.std
# 
    # for _ in range(200):
    #     # batch, aux = experiment.memory.sample(16)
    #     # batch = Transition(*batch)
    #     # s_0 = batch.s_0
    #     s_0 = pseudo_states
    #     desired_action = new_actor(s_0)
    #     loss = -experiment.algorithm.critic(torch.cat((s_0, desired_action), 1))
    #     loss = torch.mean(torch.mean(loss, dim=list(range(1, len(loss.shape)))))
    #     optim_actor.zero_grad()
    #     loss.backward()
    #     torch.nn.utils.clip_grad_norm_(new_actor.net.parameters(), 1)
    #     optim_actor.step()
# 
    #     # Evaluate differences between old and new actor
    # batch, aux = experiment.memory.sample(1024)
    # print("Experiment policy value", experiment.V(batch[0]).mean().detach().item())
# 
    # target_action = new_actor(batch[0])
    # V = experiment.algorithm.critic.target(torch.cat((batch[0], target_action), 1))
    # print("Retrained policy value ", V.mean().detach().item())
# 


    ax = axs[1]
    actions = experiment.algorithm.policy_fn(states)
    #actions = new_actor(states)
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)
    cmap = cm.viridis
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    colors = m.to_rgba(actions.detach().cpu())
    ax.scatter(states_cpu[:,0], states_cpu[:, 1], c=colors)
    ax.set_title("Policy Actions")
    fig.colorbar(m, ax=ax)

    # Q Function Optimal
    ax = axs[2]

    batch_size = 1000
    actions = []
    for i in range(100*100 // batch_size): 
        batch = states[i*batch_size : (i+1)*batch_size]
        a=cross_entropy_method(
                batch,
                experiment.algorithm.critic,
                experiment.collector.env.action_space,
                device=DEVICE
            )
        actions.append(a)

    actions = torch.cat(actions)
    print(actions.shape)
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)
    cmap = cm.viridis
    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    colors = m.to_rgba(actions.detach().cpu())
    ax.scatter(states_cpu[:,0], states_cpu[:, 1], c=colors)
    ax.set_title("Cross Entropy Maximal Actions")
    fig.colorbar(m, ax=ax)

    print(f"Minimum Value {values.min()} | Maximum Value {values.max()}")

In [None]:
class Experiment:
    """ Baseline control class for intrinsic exploration experiment

    Core assumptions
    - DDPG (Or similar offline )

    This notebook primarily explores the concept of "teleportation"
    And the edits involve data collection
    """
    def __init__(self,
                 cfg,
                 max_episode_steps: Optional[int] = None,
                 death_is_not_the_end: bool = True,
                 fixed_reset: bool = True):
        self.cfg = copy.deepcopy(cfg)
        self.death_is_not_the_end = death_is_not_the_end
        self.fixed_reset = fixed_reset

        self._build_env(max_episode_steps)
        self._build_policy()
        self._build_data()
        self.intrinsic = build_intrinsic(self.env, self.cfg.intrinsic, device=DEVICE)

        self.log = {
            "total_intrinsic_reward": 0 # An ideal exploration algorithm maximises the recieved intrinsic reward
        }

    def _build_env(self, max_episode_steps=None):
        self.env = gym.make(
            self.cfg.env.name,
            render_mode="rgb_array",
            max_episode_steps=max_episode_steps
        )
        self.env.reset()
        self.rng = global_seed(self.cfg.seed, self.env)

    def _build_policy(self):
        self.algorithm = QTOptCats(
            build_critic=lambda : build_critic(self.env, **self.cfg.algorithm.critic).to(DEVICE),
            action_space=self.env.action_space,
            obs_space=self.env.observation_space,
            device=DEVICE,
            **self.cfg.algorithm
        )
        self.policy = ColoredNoisePolicy(
            self.algorithm.policy_fn,
            self.env.action_space,
            self.env.spec.max_episode_steps,
            rng=self.rng,
            device=DEVICE,
            **self.cfg.noise
        )

    def _build_data(self):
        self.memory = build_replay_buffer(self.env,capacity=COLLECTION_STEPS, device=DEVICE, normalise_observation=True)
            # Remove automatic memory addition for more control
        self.collector = GymCollector(self.policy, self.env, device=DEVICE)
        self.policy.normalise_obs = self.memory.rmv[0]

    def _update_memory(self, obs, action, reward, n_obs, terminated, truncated):
        self.memory.append((obs, action, reward, n_obs, terminated))

    def V(self, s) -> torch.Tensor:
        """Calculates the value function for states s
        """
        target_action = self.algorithm.policy_fn(s)
        V = self.algorithm.critic.target.q(s, target_action)
        return V
    
    def env_reset(self):
        #self.algorithm.reset_critic()
        if self.fixed_reset:
            o, i = self.collector.env.reset(seed=self.cfg.seed)
        else:
            o, i = self.collector.env.reset()
        self.collector.obs = o

    def early_start(self, n: int):
        """ Overrides early start for tighter control

        Args:
            n (int): number of steps
        """
        self.env_reset()    
        policy = self.collector.policy
        self.collector.set_policy(Policy(lambda _: self.env.action_space.sample(), transform_obs=False))
        for i in range(n):
            obs, action, reward, n_obs, terminated, truncated = self.collector.collect(n=1, early_start=True)[-1]
            if terminated or truncated:
                self.env_reset()
            self._update_memory(obs, action, reward, n_obs, terminated, truncated)
        self.collector.policy = policy

    def run(self):
        """Default Experiment run
        """
        self.early_start(cfg.train.initial_collection_size)
        self.env_reset()

        batch, aux = self.memory.sample(self.cfg.train.initial_collection_size)
        self.intrinsic.initialise(Transition(*batch), aux)

        self.newly_collected_intrinsic_reward = []

        for step in tqdm(range(1, self.cfg.train.total_frames+1)):
            # Collect Data
            obs, action, reward, n_obs, terminated, truncated = self.collector.collect(n=1)[-1]

            self.newly_collected_intrinsic_reward.append(self.intrinsic.reward(Transition(
                torch.tensor(obs, device=DEVICE).unsqueeze(0),
                torch.tensor(action, device=DEVICE).unsqueeze(0),
                torch.tensor(reward, device=DEVICE).unsqueeze(0),
                torch.tensor(n_obs, device=DEVICE).unsqueeze(0),
                torch.tensor(terminated, device=DEVICE).unsqueeze(0)
            ))[2].item())


            if terminated or truncated:
                self.env_reset() 

            self._update_memory(obs, action, reward, n_obs, terminated, truncated)                

            batch, aux = self.memory.sample(self.cfg.train.minibatch_size)
            batch = Transition(*batch)
            if self.death_is_not_the_end:
                batch = Transition(batch.s_0, batch.a, batch.r, batch.s_1, torch.zeros(batch.d.shape, device=DEVICE).bool())
            # Intrinsic Reward Calculation
            r_t, r_e, r_i = self.intrinsic.reward(batch)
            r_i = r_i
            self.intrinsic.update(batch, aux, step=step)
            # RL Update            
            batch = Transition(batch.s_0, batch.a, r_i, batch.s_1, batch.d)
            self.algorithm.update(batch, aux, step=step)

            # Log
            self.log["total_intrinsic_reward"] += r_i.mean().item()


In [None]:
# Intrinsic Only

    # Build RL Structure
baseline = Experiment(
    cfg,
    max_episode_steps=MAX_EPISODE_STEPS,
    death_is_not_the_end=True,
    fixed_reset=True
)
baseline.run()

    # Visualise Training Information
visualise_memory(
     baseline.env,
    (baseline.memory, "intrinsic"),
)

In [None]:
visualise_memory(
     baseline.env,
    (baseline.memory, "intrinsic"),
)

In [None]:
visualise_experiment_v(baseline)
print("Entropy: ", entropy_memory(baseline.memory))

In [None]:
visualise_experiment_v(baseline)
print("Entropy: ", entropy_memory(baseline.memory))

In [None]:
class CATS(Experiment):

    def __init__(self,
                 cfg,
                 max_episode_steps: Optional[int] = None,
                 death_is_not_the_end: bool = True,
                 fixed_reset: bool = False,
                 epsilon: float = 0.2):
        super().__init__(cfg, max_episode_steps, death_is_not_the_end, fixed_reset)

        # Recently explored trajectory
        self.trajectory = torch.zeros((self.env.spec.max_episode_steps, *self.env.observation_space.shape), device=DEVICE)
        # Current time step
        self.trajectory_index = 0
        # Target Timestep
        self.teleport_index = 0
        # Reset epsilon
        self.epsilon = epsilon
        
        # Environment deepcopies
        self.state = None
        self.quicksaves = [None for _ in range(self.env.spec.max_episode_steps)]

        # RNG
        self.np_rng = np.random.default_rng(self.cfg.seed)

        self.log["teleport_targets"] = []
        self.log["teleport_targets_observations"] = []
        self.log["latest_trajectory"] = []

    def _build_env(self, max_episode_steps=None):
        super()._build_env(max_episode_steps)
        
        # Meta-RL 

            # Wrap Time Aware Observations
            # This adds complexity to the environments, hard to learn?
        # self.env = TimeAwareObservation(self.env)
        #    # Adjust observation space limit
        # igh = self.env.observation_space.high
        # igh[-1] = self.env.spec.max_episode_steps
        # elf.env.observation_space = gym.spaces.Box(
        #    self.env.observation_space.low,
        #    high,

            # Learn to Truncate
        # self.env = self.env.env

    def _teleport_selection(self, V):
        V = V**2
            # Argmax
        #teleport_index = torch.argmax(V).item()
            # Probability Matching
        p = V / V.sum()
        pt = self.np_rng.random()
        pc = 0
        for i, pi in enumerate(p):
            pc += pi
            if pc >= pt or i == len(p) - 1:
                teleport_index = i
                break
            # TODO: Upper Confidence bound
        return teleport_index
        
    def _reset(self, V):
        if self.fixed_reset:
            return
            # Epsilon Greedy Reset
        #if torch.rand(1) < epsilon:
        #    obs, infos = collector.env.reset()
        #    resets[0] = collector.env
        #    trajectory[0] = torch.tensor(obs, device=DEVICE)
        #    teleport_index = 0
            # Reset Buffer
        reset_buffer = []
        reset_buffer_obs = []
        for i in range(10):
            obs, info = self.collector.env.reset(seed=int(self.np_rng.integers(65536)))
            reset_buffer.append(copy.deepcopy(self.collector.env))
            reset_buffer_obs.append(obs)
        reset_buffer_obs = torch.tensor(np.array(reset_buffer_obs, dtype=np.float32), device=DEVICE)
        V_r = self.V(reset_buffer_obs)
        best_reset_index = torch.argmax(V_r).item()
        #if self.np_rng.random() < self.epsilon or V_r[best_reset_index] >= V[self.teleport_index]:
        if V_r[best_reset_index] >= V[self.teleport_index]:
            self.collector.env = reset_buffer[best_reset_index]
            self.quicksaves[0] = self.collector.env
            self.trajectory[0] = reset_buffer_obs[best_reset_index]
            self.teleport_index = 0

    def run(self):
        self.early_start(cfg.train.initial_collection_size)
        self.env_reset()

        batch, aux = self.memory.sample(self.cfg.train.initial_collection_size)
        self.intrinsic.initialise(Transition(*batch), aux)

        # Main Loop
        self.state = copy.deepcopy(self.collector.env)
        self.newly_collected_intrinsic_reward = []
        for step in tqdm(range(1, cfg.train.total_frames+1)):
            obs, action, reward, n_obs, terminated, truncated = self.collector.collect(n=1)[-1]

            self.newly_collected_intrinsic_reward.append(self.intrinsic.reward(Transition(
                torch.tensor(obs, device=DEVICE).unsqueeze(0),
                torch.tensor(action, device=DEVICE).unsqueeze(0),
                torch.tensor(reward, device=DEVICE).unsqueeze(0),
                torch.tensor(n_obs, device=DEVICE).unsqueeze(0),
                torch.tensor(terminated, device=DEVICE).unsqueeze(0)
            ))[2].item())

            # Time Limit Normalisation
            # obs, n_obs = copy.deepcopy(obs), copy.deepcopy(n_obs)
            # obs[-1] = obs[-1] / self.env.spec.max_episode_steps
            # n_obs[-1] = n_obs[-1] / self.env.spec.max_episode_steps

            # Update trajectory
            self.quicksaves[self.trajectory_index] = self.state
            self.trajectory[self.trajectory_index] = torch.from_numpy(obs).to(DEVICE)
            self.state = copy.deepcopy(self.collector.env)
            self.trajectory_index += 1

            # Manage Teleportation
            if truncated or terminated:
                # Calculate Value
                V = self.V(self.trajectory[:self.trajectory_index])
                # Teleportation Selection
                with torch.no_grad():
                    self.teleport_index = self._teleport_selection(V)
                #print(self.teleport_index)
                # Resets
                self._reset(V)

                # Teleport
                self.collector.env = self.quicksaves[self.teleport_index]
                self.collector.obs = self.trajectory[self.teleport_index].cpu().numpy()
                self.collector.env.np_random = np.random.default_rng(self.np_rng.integers(65536))
                self.trajectory_index = self.teleport_index
                self.state = copy.deepcopy(self.collector.env)

                # Log
                self.log["teleport_targets"].append(self.teleport_index)
                self.log["teleport_targets_observations"].append(self.collector.obs)
                self.log["latest_trajectory"] = self.trajectory.cpu().numpy()

                # Account for time
                # terminated = True
                # n_obs = self.collector.obs


            # Update Memory
            self._update_memory(obs, action, reward, n_obs, terminated, truncated)
            
            # Remaining RL Update
            batch, aux = self.memory.sample(self.cfg.train.minibatch_size)
            batch = Transition(*batch)
            if self.death_is_not_the_end:
                batch = Transition(batch.s_0, batch.a, batch.r, batch.s_1, torch.zeros(batch.d.shape, device=DEVICE).bool())


            r_t, r_e, r_i = self.intrinsic.reward(batch)
            r_i = r_i
            self.intrinsic.update(batch, aux, step=step)

            batch = Transition(batch.s_0, batch.a, r_i, batch.s_1, batch.d)
            self.algorithm.update(batch, aux, step=step)

            # Log
            self.log["total_intrinsic_reward"] += r_i.mean().item()
        

In [None]:
cats = CATS(
    cfg,
    max_episode_steps=MAX_EPISODE_STEPS,
    death_is_not_the_end=True,
    fixed_reset=True
)
cats.run()

    # Visualise Training Information
visualise_memory(
      cats.env,
     (cats.memory, "teleport"),
)

In [None]:
plt.scatter(
    cats.log['latest_trajectory'][:,0],
    cats.log['latest_trajectory'][:,1]
)

In [None]:
plt.plot(cats.log['teleport_targets'])

In [None]:
visualise_experiment_v(cats)
print("Entropy: ", entropy_memory(cats.memory))

In [None]:
# Learn To Truncate

class CATS_Truncate(Experiment):

    def __init__(self,
                 cfg,
                 max_episode_steps: Optional[int] = None,
                 death_is_not_the_end: bool = True,
                 fixed_reset: bool = False,
                 storage_size: float = 100,
                 epsilon: float = 0.1):
        super().__init__(cfg, max_episode_steps, death_is_not_the_end, fixed_reset)

        self.storage_size = storage_size
        # Recently explored trajectory
        self.trajectory = []

        # Current time step
        self.trajectory_index = 0
        self.trajectory_counter = 0
        # Target Timestep
        self.teleport_index = 0
        # Reset epsilon
        self.epsilon = epsilon
        
        # Environment deepcopies
        self.state = None
        self.quicksaves = []

        # RNG
        self.np_rng = np.random.default_rng(self.cfg.seed)

        self.log["teleport_targets"] = []
        self.log["teleport_targets_observations"] = []
        self.log["latest_trajectory"] = []

    def _build_env(self, max_episode_steps=None):
        super()._build_env(max_episode_steps)

        # Learn to Truncate
        self.env = self.env.env

    def _teleport_selection(self, V):
        #V = V**2
            # Argmax
        teleport_index = torch.argmax(V).item()
            # Probability Matching
        #p = V / V.sum()
        # pt = self.np_rng.random()
        # pc = 0
        # for i, pi in enumerate(p):
        #     pc += pi
        #     if pc >= pt or i == len(p) - 1:
        #         teleport_index = i
        #         break
        #     # TODO: Upper Confidence bound
        return teleport_index

    def _reset(self, V):
        if self.fixed_reset:
            return
            # Epsilon Greedy Reset
        #if torch.rand(1) < epsilon:
        #    obs, infos = collector.env.reset()
        #    resets[0] = collector.env
        #    trajectory[0] = torch.tensor(obs, device=DEVICE)
        #    teleport_index = 0
            # Reset Buffer
        reset_buffer = []
        reset_buffer_obs = []
        for i in range(10):
            obs, info = self.collector.env.reset(seed=int(self.np_rng.integers(65536)))
            reset_buffer.append(copy.deepcopy(self.collector.env))
            reset_buffer_obs.append(obs)
        reset_buffer_obs = torch.tensor(np.array(reset_buffer_obs, dtype=np.float32), device=DEVICE)
        V_r = self.V(reset_buffer_obs)
        best_reset_index = torch.argmax(V_r).item()
        #if self.np_rng.random() < self.epsilon or V_r[best_reset_index] >= V[self.teleport_index]:
        if V_r[best_reset_index] >= V[self.teleport_index]:
            self.collector.env = reset_buffer[best_reset_index]
            self.quicksaves[0] = self.collector.env
            self.trajectory[0] = reset_buffer_obs[best_reset_index]
            self.teleport_index = 0

    def run(self):
        if self.fixed_reset:
            self.collector.env.reset(seed=self.cfg.seed)

        early_start_transitions = self.collector.early_start(cfg.train.initial_collection_size)
        for t in early_start_transitions:
            self._update_memory(*t)

        batch, aux = self.memory.sample(self.cfg.train.initial_collection_size)
        self.intrinsic.initialise(Transition(*batch), aux)

        # Main Loop
        if self.fixed_reset:
            self.collector.env.reset(seed=self.cfg.seed)
        else:
            self.collector.env.reset()

        self.state = copy.deepcopy(self.collector.env)
        newly_collected_intrinsic_reward = 0
        for step in tqdm(range(1, cfg.train.total_frames+1)):
            obs, action, reward, n_obs, terminated, _ = self.collector.collect(n=1)[-1]

            newly_collected_intrinsic_reward += self.intrinsic.reward(Transition(
                torch.tensor(obs, device=DEVICE).unsqueeze(0),
                torch.tensor(action, device=DEVICE).unsqueeze(0),
                torch.tensor(reward, device=DEVICE).unsqueeze(0),
                torch.tensor(n_obs, device=DEVICE).unsqueeze(0),
                torch.tensor(terminated, device=DEVICE).unsqueeze(0)
            ))[2].item()

            if self.trajectory_counter == self.storage_size - 1:
                truncated = True
            else:
                truncated = False
            # Time Limit Normalisation
            # obs, n_obs = copy.deepcopy(obs), copy.deepcopy(n_obs)
            # obs[-1] = obs[-1] / self.env.spec.max_episode_steps
            # n_obs[-1] = n_obs[-1] / self.env.spec.max_episode_steps

            # Update trajectory
            self.quicksaves.append(self.state)
            self.trajectory.append(obs)
            self.state = copy.deepcopy(self.collector.env)
            self.trajectory_index += 1
            self.trajectory_counter += 1

            # Manage Teleportation
            if truncated or terminated:
                # Calculate Value
                V = self.V(torch.tensor(self.trajectory,device=DEVICE))
                # Teleportation Selection
                with torch.no_grad():
                    self.teleport_index = self._teleport_selection(V)
                # Resets
                self._reset(V)

                # Teleport
                self.collector.env = self.quicksaves[self.teleport_index]
                self.collector.obs = self.trajectory[self.teleport_index]
                self.collector.env.np_random = np.random.default_rng(self.np_rng.integers(65536))

                self.trajectory_index = self.teleport_index
                self.trajectory_counter = 0
                self.trajectory = self.trajectory[:self.teleport_index+1]
                self.quicksaves = self.quicksaves[:self.teleport_index+1]
                self.state = copy.deepcopy(self.collector.env)

                # Log
                self.log["teleport_targets"].append(self.teleport_index)
                self.log["teleport_targets_observations"].append(self.collector.obs)
                self.log["latest_trajectory"] = self.trajectory

            # Update Memory
            self._update_memory(obs, action, reward, n_obs, terminated, truncated)
            
            # Remaining RL Update
            batch, aux = self.memory.sample(self.cfg.train.minibatch_size)
            batch = Transition(*batch)
            if self.death_is_not_the_end:
                batch = Transition(batch.s_0, batch.a, batch.r, batch.s_1, torch.zeros(batch.d.shape, device=DEVICE).bool())


            r_t, r_e, r_i = self.intrinsic.reward(batch)
            self.intrinsic.update(batch, aux, step=step)

            batch = Transition(batch.s_0, batch.a, r_i, batch.s_1, batch.d)
            self.algorithm.update(batch, aux, step=step)

            # Log
            self.log["total_intrinsic_reward"] += r_i.mean().item()

        print(newly_collected_intrinsic_reward)
        

In [None]:
cats_truncate = CATS_Truncate(
    cfg,
    max_episode_steps=9999,
    death_is_not_the_end=True,
    storage_size=400,
    fixed_reset=True
)
cats_truncate.run()

    # Visualise Training Information
visualise_memory(
      cats_truncate.env,
     (cats_truncate.memory, "teleport"),
)

In [None]:
visualise_experiment_v(cats_truncate)
print("Entropy: ", entropy_memory(cats_truncate.memory))

In [None]:
print(cats.log['total_intrinsic_reward'])
print(baseline.log['total_intrinsic_reward'])
print(cats_truncate.log['total_intrinsic_reward'])

In [None]:
fig, ax = plt.subplots()
ax.plot(baseline.newly_collected_intrinsic_reward[MAX_EPISODE_STEPS:], label="baseline")
ax.plot(cats.newly_collected_intrinsic_reward[MAX_EPISODE_STEPS:], label="cats")
ax.plot(cats_truncate.newly_collected_intrinsic_reward[MAX_EPISODE_STEPS:], label="cats_truncate")
fig.legend()