# Visualising MCC Exploration

This notebook logs exploratory results on adding teleportation on MCC with state coverage visualisation

21/01/2024
- Naive teleportation to argmax works
- Longer episodes are better than shorter
- Different intrinsic rewards show significantly different behavior
- Even naively, general improvement over pure intrinsic
- Fails to beat intrinsic + extrinsic: perhaps this is due to negative extrinsic reward revealing data on target? Not comparable, and I think fully explored in that reward shifting paper
- Keeps teleporting to same target
- This may be a problem with DDPG


28/01/2024
- Probabilistic teleportation work well
- Environment reset stochasticity is important
- Time limit aware Q functions are difficult to train!
- Proposal: Dynamic Truncation!


TODO:
- Confidence Bounds
- Termination as an action
- Epsilon greedy
- Time aware exploration

In [None]:
# Define Imports and shared training information

import copy
import math
from typing import Optional

from omegaconf import DictConfig
import numpy as np
import torch
from tqdm import tqdm

import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.cm as cm

import gymnasium as gym
from gymnasium.wrappers import TimeAwareObservation
from curiosity.experience import Transition
from curiosity.experience.collector import GymCollector
from curiosity.policy import ColoredNoisePolicy
from curiosity.experience.memory import ReplayBuffer
from curiosity.experience.util import build_replay_buffer
from curiosity.util.util import global_seed, build_intrinsic, build_rl

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
COLLECTION_STEPS = 10000
ENV = "MountainCarContinuous-v0"
#ENV = "Pendulum-v1"

cfg = DictConfig({
    "seed": 1,
    "env": {
        "name": ENV,
    },
    "memory": {
    #    "type": "experience_replay"
        "type": "prioritized_experience_replay",
        "alpha": 0.6,
        "epsilon": 0.1,
        "beta_0": 0.4,
        "beta_annealing_steps": COLLECTION_STEPS
    },
    "train": {
        "initial_collection_size": 500,
        "total_frames": COLLECTION_STEPS,
        "minibatch_size": 128   
    },
    "algorithm": {
        "type": "ddpg",
        "gamma": 0.99,
        "tau": 0.005,
        "lr": 0.001,
        "update_frequency": 1,
        "clip_grad_norm": 1,
        "actor": {
            "features": 128
        },
        "critic": {
            "features": 128
        }
    },
    "intrinsic": {
        "type": "rnd",
        "encoding_size": 32,
        "lr": 0.0003,
        "int_coef": 1, 
        "ext_coef": 2,
        "obs_normalisation": True,
        "reward_normalisation": True,
        "normalised_obs_clip": 5
    },
    "noise": {
        "scale": 0.1,
        "beta": 0
    }
})


In [None]:
def visualise_memory_mcc(env: gym.Env, *memories: ReplayBuffer):
    
    fig, ax = plt.subplots()
    ax.set_title("State Space Coverage")

    if env.spec.id == "MountainCarContinuous-v0":
        ax.set_xlim(env.observation_space.low[0], env.observation_space.high[0])
        ax.set_xlabel("Position")
        ax.set_ylim(env.observation_space.low[1], env.observation_space.high[1])
        ax.set_ylabel("Velocity")
    elif env.spec.id == "Pendulum-v1":
        ax.set_xlim(-math.pi, math.pi)
        ax.set_xlabel("Theta")
        ax.set_ylim(env.observation_space.low[2], env.observation_space.high[2])
        ax.set_ylabel("Angular Velocity")

    for memory, name in memories:
        batch = Transition(*memory.storage)
        s = batch.s_0.cpu().numpy()

        # Colors based on time
        norm = mpl.colors.Normalize(vmin=0, vmax=len(s)-1)
        cmap = cm.viridis
        m = cm.ScalarMappable(norm=norm, cmap=cmap)

        colors = m.to_rgba(np.linspace(0, len(s)-1, len(s)))


        if env.spec.id == "MountainCarContinuous-v0":
            ax.scatter(s[:, 0], s[:, 1], s=1, label=name, c=colors)
        elif env.spec.id == "Pendulum-v1":
            ax.scatter(np.arctan2(s[:, 1], s[:, 0]), s[:, 2], s=1, label=name, c=colors)

    
    fig.colorbar(m, ax=ax)
    ax.legend()

In [None]:
class Experiment:
    """ Baseline control class for intrinsic exploration experiment

    Core assumptions
    - DDPG (Or similar offline )

    This notebook primarily explores the concept of "teleportation"
    And the edits involve data collection
    """
    def __init__(self,
                 cfg,
                 max_episode_steps: Optional[int] = None,
                 death_is_not_the_end: bool = True,
                 fixed_reset: bool = True):
        self.cfg = copy.deepcopy(cfg)
        self.death_is_not_the_end = death_is_not_the_end
        self.fixed_reset = fixed_reset

        self._build_env(max_episode_steps)
        self._build_policy()
        self._build_data()
        self.intrinsic = build_intrinsic(self.env, self.cfg.intrinsic, device=DEVICE)

        self.log = {
            "total_intrinsic_reward": 0 # An ideal exploration algorithm maximises the recieved intrinsic reward
        }

    def _build_env(self, max_episode_steps=None):
        self.env = gym.make(
            self.cfg.env.name,
            render_mode="rgb_array",
            max_episode_steps=max_episode_steps
        )
        self.env.reset()
        self.rng = global_seed(self.cfg.seed, self.env)

    def _build_policy(self):
        self.ddpg = build_rl(self.env, self.cfg.algorithm, device=DEVICE)
        self.policy = ColoredNoisePolicy(
            self.ddpg.actor,
            self.env.action_space,
            self.env.spec.max_episode_steps,
            rng=self.rng,
            device=DEVICE,
            **self.cfg.noise
        )

    def _build_data(self):
        self.memory = build_replay_buffer(self.env,capacity=COLLECTION_STEPS, device=DEVICE)
            # Remove automatic memory addition for more control
        self.collector = GymCollector(self.policy, self.env, device=DEVICE)

    def _update_memory(self, obs, action, reward, n_obs, terminated, truncated):
        self.memory.append((obs, action, reward, n_obs, terminated))

    def early_start(self):
        early_start_transitions = self.collector.early_start(cfg.train.initial_collection_size)
        for t in early_start_transitions:
            self._update_memory(*t)
    
    def V(self, s) -> torch.Tensor:
        """Calculates the value function for states s
        """
        target_action = self.ddpg.actor.target(s)
        V = self.ddpg.critic.target(torch.cat((s, target_action), 1))
        return V

    def run(self):
        """Default Experiment run
        """
        self.early_start()
        
        batch, aux = self.memory.sample(cfg.train.initial_collection_size)
        self.intrinsic.initialise(Transition(*batch), aux)

        for step in tqdm(range(1, self.cfg.train.total_frames+1)):
            # Collect Data
            obs, action, reward, n_obs, terminated, truncated = self.collector.collect(n=1)[-1]
            if (terminated or truncated) and self.fixed_reset:
                self.collector.env.reset(seed=self.cfg.seed)
            self._update_memory(obs, action, reward, n_obs, terminated, truncated)                

            batch, aux = self.memory.sample(self.cfg.train.minibatch_size)
            batch = Transition(*batch)
            if self.death_is_not_the_end:
                batch = Transition(batch.s_0, batch.a, batch.r, batch.s_1, torch.zeros(batch.d.shape, device=DEVICE).bool())
            # Intrinsic Reward Calculation
            r_t, r_e, r_i = self.intrinsic.reward(batch)
            self.intrinsic.update(batch, aux, step=step)
            # RL Update            
            batch = Transition(batch.s_0, batch.a, r_i, batch.s_1, batch.d)
            self.ddpg.update(batch, aux, step=step)

            # Log
            self.log["total_intrinsic_reward"] += r_i.mean().item()

In [None]:
# Intrinsic Only

    # Build RL Structure
baseline = Experiment(
    cfg,
    max_episode_steps=300,
    death_is_not_the_end=True,
    fixed_reset=False
)
baseline.run()

    # Visualise Training Information
visualise_memory_mcc(
     baseline.env,
    (baseline.memory, "intrinsic"),
)

In [None]:
class CATS(Experiment):

    def __init__(self,
                 cfg,
                 max_episode_steps: Optional[int] = None,
                 death_is_not_the_end: bool = True,
                 fixed_reset: bool = False,
                 epsilon: float = 0.2):
        super().__init__(cfg, max_episode_steps, death_is_not_the_end, fixed_reset)

        # Recently explored trajectory
        self.trajectory = torch.zeros((self.env.spec.max_episode_steps, *self.env.observation_space.shape), device=DEVICE)
        # Current time step
        self.trajectory_index = 0
        # Target Timestep
        self.teleport_index = 0
        # Reset epsilon
        self.epsilon = epsilon
        
        # Environment deepcopies
        self.state = None
        self.quicksaves = [None for _ in range(self.env.spec.max_episode_steps)]

        # RNG
        self.np_rng = np.random.default_rng(self.cfg.seed)

        self.log["teleport_targets"] = []
        self.log["teleport_targets_observations"] = []
        self.log["latest_trajectory"] = []

    def _build_env(self, max_episode_steps=None):
        super()._build_env(max_episode_steps)
        
        # Meta-RL 

            # Wrap Time Aware Observations
            # This adds complexity to the environments, hard to learn?
        # self.env = TimeAwareObservation(self.env)
        #    # Adjust observation space limit
        # igh = self.env.observation_space.high
        # igh[-1] = self.env.spec.max_episode_steps
        # elf.env.observation_space = gym.spaces.Box(
        #    self.env.observation_space.low,
        #    high,

            # Learn to Truncate
        # self.env = self.env.env

    def _teleport_selection(self, V):
        V = V**2
            # Argmax
        # teleport_index = torch.argmax(V).item()
            # Probability Matching
        p = V / V.sum()
        pt = self.np_rng.random()
        pc = 0
        for i, pi in enumerate(p):
            pc += pi
            if pc >= pt or i == len(p) - 1:
                teleport_index = i
                break
            # TODO: Upper Confidence bound
        return teleport_index
        
    def _reset(self, V):
        if self.fixed_reset:
            return
            # Epsilon Greedy Reset
        #if torch.rand(1) < epsilon:
        #    obs, infos = collector.env.reset()
        #    resets[0] = collector.env
        #    trajectory[0] = torch.tensor(obs, device=DEVICE)
        #    teleport_index = 0
            # Reset Buffer
        reset_buffer = []
        reset_buffer_obs = []
        for i in range(10):
            obs, info = self.collector.env.reset(seed=int(self.np_rng.integers(65536)))
            reset_buffer.append(copy.deepcopy(self.collector.env))
            reset_buffer_obs.append(obs)
        reset_buffer_obs = torch.tensor(np.array(reset_buffer_obs, dtype=np.float32), device=DEVICE)
        V_r = self.V(reset_buffer_obs)
        best_reset_index = torch.argmax(V_r).item()
        if self.np_rng.random() < self.epsilon or V_r[best_reset_index] >= V[self.teleport_index]:
        #if V_r[best_reset_index] >= V[self.teleport_index]:
            self.collector.env = reset_buffer[best_reset_index]
            self.quicksaves[0] = self.collector.env
            self.trajectory[0] = reset_buffer_obs[best_reset_index]
            self.teleport_index = 0

    def run(self):

        early_start_transitions = self.collector.early_start(cfg.train.initial_collection_size)
        for t in early_start_transitions:
            self._update_memory(*t)

        batch, aux = self.memory.sample(self.cfg.train.initial_collection_size)
        self.intrinsic.initialise(Transition(*batch), aux)

        # Main Loop
        if self.fixed_reset:
            self.collector.env.reset(seed=self.cfg.seed)
        else:
            self.collector.env.reset()
        self.state = copy.deepcopy(self.collector.env)

        for step in tqdm(range(1, cfg.train.total_frames+1)):
            obs, action, reward, n_obs, terminated, truncated = self.collector.collect(n=1)[-1]

            # Time Limit Normalisation
            # obs, n_obs = copy.deepcopy(obs), copy.deepcopy(n_obs)
            # obs[-1] = obs[-1] / self.env.spec.max_episode_steps
            # n_obs[-1] = n_obs[-1] / self.env.spec.max_episode_steps

            # Update trajectory
            self.quicksaves[self.trajectory_index] = self.state
            self.trajectory[self.trajectory_index] = torch.from_numpy(obs).to(DEVICE)
            self.state = copy.deepcopy(self.collector.env)
            self.trajectory_index += 1

            # Manage Teleportation
            if truncated or terminated:
                # Calculate Value
                V = self.V(self.trajectory[:self.trajectory_index])
                # Teleportation Selection
                with torch.no_grad():
                    self.teleport_index = self._teleport_selection(V)
                # Resets
                self._reset(V)

                # Teleport
                self.collector.env = self.quicksaves[self.teleport_index]
                self.collector.obs = self.trajectory[self.teleport_index].cpu().numpy()
                self.collector.env.np_random = np.random.default_rng(self.np_rng.integers(65536))
                self.trajectory_index = self.teleport_index
                self.state = copy.deepcopy(self.collector.env)

                # Log
                self.log["teleport_targets"].append(self.teleport_index)
                self.log["teleport_targets_observations"].append(self.collector.obs)
                self.log["latest_trajectory"] = self.trajectory.cpu().numpy()

                # Account for time
                # terminated = True
                # n_obs = self.collector.obs


            # Update Memory
            self._update_memory(obs, action, reward, n_obs, terminated, truncated)
            
            # Remaining RL Update
            batch, aux = self.memory.sample(self.cfg.train.minibatch_size)
            batch = Transition(*batch)
            if self.death_is_not_the_end:
                batch = Transition(batch.s_0, batch.a, batch.r, batch.s_1, torch.zeros(batch.d.shape, device=DEVICE).bool())


            r_t, r_e, r_i = self.intrinsic.reward(batch)
            self.intrinsic.update(batch, aux, step=step)

            batch = Transition(batch.s_0, batch.a, r_i, batch.s_1, batch.d)
            self.ddpg.update(batch, aux, step=step)

            # Log
            self.log["total_intrinsic_reward"] += r_i.mean().item()
        

In [None]:

    # Build RL Structure
cats = CATS(
    cfg,
    max_episode_steps=300,
    death_is_not_the_end=True,
    fixed_reset=False
)
cats.run()

    # Visualise Training Information
visualise_memory_mcc(
      cats.env,
     (cats.memory, "teleport"),
)

In [None]:
cats.env.env

In [None]:
# Visualise V
# For mountain car

# states = []
# for _ in range(10000):
#     obs = cats.env.observation_space.sample()
#     states.append(obs)
# states = torch.tensor(states, device=DEVICE)

X = torch.linspace(cats.env.observation_space.low[0], cats.env.observation_space.high[0], 100)
Y = torch.linspace(cats.env.observation_space.low[1], cats.env.observation_space.high[1], 100)
grid_X, grid_Y = torch.meshgrid((X, Y))

# Time Aware

states = torch.stack((grid_X.flatten(), grid_Y.flatten(), 1+torch.zeros_like(grid_X.flatten())) ).T
#states = torch.stack((grid_X.flatten(), grid_Y.flatten())).T
states = states.to(DEVICE)
values = cats.V(states)

states = states.cpu()
fig, ax = plt.subplots()

norm = mpl.colors.Normalize(vmin=values.min(), vmax=values.max())
cmap = cm.viridis
m = cm.ScalarMappable(norm=norm, cmap=cmap)
colors = m.to_rgba(values.detach().cpu())

ax.set_title("Value Function Visualisation")
ax.scatter(states[:,0], states[:, 1], c= colors)
fig.colorbar(m, ax=ax)

print(f"Minimum Value {values.min()} | Maximum Value {values.max()}")

In [None]:
cats.log