# Visualising MCC Exploration

This notebook logs exploratory results on adding teleportation on MCC with state coverage visualisation

21/01/2024
- Naive teleportation to argmax works
- Longer episodes are better than shorter
- Different intrinsic rewards show significantly different behavior
- Even naively, general improvement over pure intrinsic
- Fails to beat intrinsic + extrinsic: perhaps this is due to negative extrinsic reward revealing data on target? Not comparable, and I think fully explored in that reward shifting paper
- Keeps teleporting to same target
- This may be a problem with DDPG


TODO:
- Confidence Bounds
- Termination as an action
- Epsilon greedy
- Time aware exploration

In [None]:
# Define Imports and shared training information

import copy
import math

import torch
from tqdm import tqdm
from omegaconf import DictConfig
import numpy as np
from matplotlib import pyplot as plt
import gymnasium as gym

from curiosity.experience import Transition
from curiosity.experience.collector import GymCollector
from curiosity.policy import ColoredNoisePolicy
from curiosity.experience.memory import ReplayBuffer
from curiosity.experience.util import build_replay_buffer, build_collector
from curiosity.util.util import global_seed, build_intrinsic, build_rl

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
COLLECTION_STEPS = 20000
ENV = "MountainCarContinuous-v0"
#ENV = "Pendulum-v1"

cfg = DictConfig({
    "seed": 0,
    "env": {
        "name": ENV,
    },
    "memory": {
        "type": "experience_replay"
    #    "type": "prioritized_experience_replay",
    #    "alpha": 0.6,
    #    "epsilon": 0.1,
    #    "beta_0": 0.4,
    #    "beta_annealing_steps": COLLECTION_STEPS
    },
    "train": {
        "initial_collection_size": 500,
        "total_frames": COLLECTION_STEPS,
        "minibatch_size": 128   
    },
    "algorithm": {
        "type": "ddpg",
        "gamma": 0.99,
        "tau": 0.005,
        "lr": 0.001,
        "update_frequency": 1,
        "clip_grad_norm": 1,
        "actor": {
            "features": 128
        },
        "critic": {
            "features": 128
        }
    },
    "intrinsic": {
        "type": "rnd",
        "encoding_size": 32,
        "lr": 0.0003,
        "int_coef": 1, 
        "ext_coef": 2,
        "obs_normalisation": True,
        "reward_normalisation": True,
        "normalised_obs_clip": 5
    },
    "noise": {
        "scale": 0.1,
        "beta": 0
    }
})

def visualise_memory_mcc(env: gym.Env, *memories: ReplayBuffer):
    
    fig, ax = plt.subplots()
    ax.set_title("State Space Coverage")

    if env.spec.id == "MountainCarContinuous-v0":
        ax.set_xlim(env.observation_space.low[0], env.observation_space.high[0])
        ax.set_xlabel("Position")
        ax.set_ylim(env.observation_space.low[1], env.observation_space.high[1])
        ax.set_ylabel("Velocity")
    elif env.spec.id == "Pendulum-v1":
        ax.set_xlim(-math.pi, math.pi)
        ax.set_xlabel("Theta")
        ax.set_ylim(env.observation_space.low[2], env.observation_space.high[2])
        ax.set_ylabel("Angular Velocity")

    for memory, name in memories:
        batch = Transition(*memory.storage)
        s = batch.s_0.cpu().numpy()

        if env.spec.id == "MountainCarContinuous-v0":
            ax.scatter(s[:, 0], s[:, 1], s=1, label=name)
        elif env.spec.id == "Pendulum-v1":
            ax.scatter(np.arctan2(s[:, 1], s[:, 0]), s[:, 2], s=1, label=name)

    ax.legend()

In [None]:
class Experiment:
    """ Baseline control class for intrinsic exploration experiment

    Core assumptions
    - DDPG (Or similar offline )

    This notebook primarily explores the concept of "teleportation"
    And the edits involve data collection
    """
    def __init__(self, cfg, max_episode_steps=None, death_is_not_the_end = True):
        self.cfg = copy.deepcopy(cfg)
        self.death_is_not_the_end = death_is_not_the_end

        self._build_env(max_episode_steps)
        self._build_policy()
        self._build_data()
        self.intrinsic = build_intrinsic(self.env, self.cfg.intrinsic, device=DEVICE)

        self.log = {}

    def _build_env(self, max_episode_steps=None):
        self.env = gym.make(
            self.cfg.env.name,
            render_mode="rgb_array",
            max_episode_steps=max_episode_steps
        )
        self.env.reset()
        self.rng = global_seed(self.cfg.seed, self.env)

    def _build_policy(self):
        self.ddpg = build_rl(self.env, self.cfg.algorithm, device=DEVICE)
        self.policy = ColoredNoisePolicy(
            self.ddpg.actor,
            self.env.action_space,
            self.env.spec.max_episode_steps,
            rng=self.rng,
            device=DEVICE,
            **self.cfg.noise
        )

    def _build_data(self):
        self.memory = build_replay_buffer(self.env, capacity=COLLECTION_STEPS, device=DEVICE)
        self.collector = build_collector(self.policy, self.env, self.memory, device=DEVICE)

    def run(self):
        """Default Experiment run
        """

        self.collector.early_start(cfg.train.initial_collection_size)
        batch, aux = self.memory.sample(cfg.train.initial_collection_size)
        self.intrinsic.initialise(Transition(*batch), aux)

        for step in tqdm(range(1, self.cfg.train.total_frames+1)):
            # Collect Data
            self.collector.collect(n=1)
            batch, aux = self.memory.sample(self.cfg.train.minibatch_size)
            batch = Transition(*batch)
            if self.death_is_not_the_end:
                batch = Transition(batch.s_0, batch.a, batch.r, batch.s_1, torch.zeros(batch.d.shape, device=DEVICE).bool())
            # Intrinsic Reward Calculation
            r_t, r_e, r_i = self.intrinsic.reward(batch)
            self.intrinsic.update(batch, aux, step=step)
            # RL Update            
            batch = Transition(batch.s_0, batch.a, r_i, batch.s_1, batch.d)
            self.ddpg.update(batch, aux, step=step)

In [None]:
# Intrinsic Only

    # Build RL Structure
baseline = Experiment(cfg, max_episode_steps=300, death_is_not_the_end=False)
baseline.run()

    # Visualise Training Information
visualise_memory_mcc(
     baseline.env,
    (baseline.memory, "intrinsic"),
)

In [None]:
class CATS(Experiment):

    def __init__(self,
                 cfg,
                 max_episode_steps=None,
                 death_is_not_the_end=True,
                 epsilon: float =0.1):
        super().__init__(cfg, max_episode_steps, death_is_not_the_end)

        # Recently explored trajectory
        self.trajectory = torch.zeros((self.env.spec.max_episode_steps, *self.env.observation_space.shape), device=DEVICE)
        # Current time step
        self.trajectory_index = 0
        # Target Timestep
        self.teleport_index = 0
        # Reset epsilon
        self.epsilon = epsilon
        
        # Environment deepcopies
        self.state = None
        self.quicksaves = [None for _ in range(self.env.spec.max_episode_steps)]

        # RNG
        self.np_rng = np.random.default_rng(self.cfg.seed)


    def _build_data(self):
        self.memory = build_replay_buffer(self.env,capacity=COLLECTION_STEPS, device=DEVICE)
            # Remove automatic memory addition for more control
        self.collector = GymCollector(self.policy, self.env, device=DEVICE)

    def _update_memory(self, obs, action, reward, n_obs, terminated, truncated):
        self.memory.append((obs, action, reward, n_obs, terminated))

    def _teleport_selection(self, V):
        V = V**2
            # Argmax
        # teleport_index = torch.argmax(V).item()
            # Probability Matching
        p = V / V.sum()
        pt = self.np_rng.random()
        pc = 0
        for i, pi in enumerate(p):
            pc += pi
            if pc >= pt or i == len(p) - 1:
                teleport_index = i
                break
            # TODO: Upper Confidence bound
        return teleport_index
        
    def _reset(self, V):
            # Epsilon Greedy Reset
        #if torch.rand(1) < epsilon:
        #    obs, infos = collector.env.reset()
        #    resets[0] = collector.env
        #    trajectory[0] = torch.tensor(obs, device=DEVICE)
        #    teleport_index = 0
            # Reset Buffer
        reset_buffer = []
        reset_buffer_obs = []
        for i in range(10):
            obs, info = self.collector.env.reset()
            reset_buffer.append(copy.deepcopy(self.collector.env))
            reset_buffer_obs.append(obs)
        reset_buffer_obs = torch.tensor(np.array(reset_buffer_obs), device=DEVICE)
        target_action = self.ddpg.actor.target(reset_buffer_obs)
        V_r = self.ddpg.critic.target(torch.cat((reset_buffer_obs, target_action),1))
        best_reset_index = torch.argmax(V_r).item()
        #if self.np_rng.random() < self.epsilon or V_r[best_reset_index] >= V[self.teleport_index]:
        if V_r[best_reset_index] >= V[self.teleport_index]:
            self.collector.env = reset_buffer[best_reset_index]
            self.quicksaves[0] = self.collector.env
            self.trajectory[0] = reset_buffer_obs[best_reset_index]
            self.teleport_index = 0

    def run(self):

        early_start_transitions = self.collector.early_start(cfg.train.initial_collection_size)
        for t in early_start_transitions:
            self._update_memory(*t)

        batch, aux = self.memory.sample(self.cfg.train.initial_collection_size)
        self.intrinsic.initialise(Transition(*batch), aux)

        # Main Loop
        self.collector.env.reset()
        self.state = copy.deepcopy(self.collector.env)

        for step in tqdm(range(1, cfg.train.total_frames+1)):
            obs, action, reward, n_obs, terminated, truncated = self.collector.collect(n=1)[-1]

            # Update trajectory
            self.quicksaves[self.trajectory_index] = self.state
            self.trajectory[self.trajectory_index] = torch.from_numpy(obs).to(DEVICE)
            self.state = copy.deepcopy(self.collector.env)
            self.trajectory_index += 1

            # Manage Teleportation
            if truncated or terminated:
                # Calculate Value
                t = self.trajectory[:self.trajectory_index]
                target_action = self.ddpg.actor.target(t)
                V = self.ddpg.critic.target(torch.cat((t, target_action), 1))

                # Teleportation Selection
                    # Argmax
                #teleport_index = torch.argmax(V).item()
                    # Probability matching
                with torch.no_grad():
                    self.teleport_index = self._teleport_selection(V)
                
                self._reset(V)

                # Teleport
                self.collector.env = self.quicksaves[self.teleport_index]
                self.collector.obs = self.trajectory[self.teleport_index].cpu().numpy()
                self.collector.env.np_random = np.random.default_rng(self.np_rng.integers(65536))
                self.trajectory_index = self.teleport_index
                self.state = copy.deepcopy(self.collector.env)

            
            # Update Memory
            self._update_memory(obs, action, reward, n_obs, terminated, truncated)
            
            # Remaining RL Update
            batch, aux = self.memory.sample(self.cfg.train.minibatch_size)
            batch = Transition(*batch)
            if self.death_is_not_the_end:
                batch = Transition(batch.s_0, batch.a, batch.r, batch.s_1, torch.zeros(batch.d.shape, device=DEVICE).bool())


            r_t, r_e, r_i = self.intrinsic.reward(batch)
            self.intrinsic.update(batch, aux, step=step)

            batch = Transition(batch.s_0, batch.a, r_i, batch.s_1, batch.d)
            self.ddpg.update(batch, aux, step=step)
        

In [None]:

    # Build RL Structure
cats = CATS(cfg, max_episode_steps=300, death_is_not_the_end=False)
cats.run()

    # Visualise Training Information
visualise_memory_mcc(
     cats.env,
    (cats.memory, "intrinsic"),
)