In [None]:
from copy import deepcopy

import numpy as np
from hydra import initialize, compose
from tqdm import tqdm

from kitten.nn import Value, ClassicalValue

from cats.agent.policy import ExplorationPolicy
from cats.on_policy_experiment import OnlineExperiment
from cats.evaluation import *

DEVICE = "cpu"
MASTER_SEED = 235790
def generate_random_seeds(n: int):
    rng = np.random.default_rng(MASTER_SEED)
    return rng.integers(0, 2**32-1, size=(n, ))

seeds = generate_random_seeds(1)
print(seeds)

class ClassicControlDiscreteExperiment(OnlineExperiment):
    def _build_policy(self) -> None:
        super()._build_policy()
        self._policy = ExplorationPolicy(
            self.env, self.rng.build_generator(), repeat_probability=self.cfg.policy.p
        )
    
    @property
    def policy(self):
        return self._policy

    def _build_value(self) -> Value:
        return ClassicalValue(self.env).to(self.device)


In [None]:
TOTAL_FRAMES = 50000

with initialize(version_base=None, config_path="cats/config"):
    base_cfg = compose(
        config_name="defaults_on_policy_classic_control.yaml",
        overrides=[
            f"train.total_frames={TOTAL_FRAMES}",   # Collection frames
            "env.max_episode_steps=10000",          # Disable environment truncation
            "algorithm.collection_batch=100",       # Truncation / Teleport on 100 steps
            "policy.p=0",                           # Uncorrelated random actions for policy
        ],
    )

In [None]:
# UCB
ucb = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = int(seed)
    cfg.cats.teleport.enable=True
    cfg.cats.teleport.type = "ucb"
    cfg.cats.teleport.kwargs = {"c": 1}
    experiment = ClassicControlDiscreteExperiment(cfg, device=DEVICE)
    experiment.run()
    ucb.append(experiment)

In [None]:
fig, ax = plt.subplots()
ax.plot(experiment.logger._engine.results['train/value_loss'])
ax.set_yscale("log")

In [None]:
visualise_classic_control_results(ucb[0])

In [None]:
# Disabled
disabled = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = int(seed)
    cfg.cats.teleport.enable=False
    experiment = ClassicControlDiscreteExperiment(cfg, device=DEVICE)
    experiment.run()
    disabled.append(experiment)

In [None]:
# Random
random = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = int(seed)
    cfg.cats.teleport.enable=True
    cfg.cats.teleport.type = "e_greedy"
    cfg.cats.teleport.kwargs.e = 1.0
    experiment = ClassicControlDiscreteExperiment(cfg, device=DEVICE)
    experiment.run()
    random.append(experiment)

In [None]:
# e_greedy
e_greedy = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = int(seed)
    cfg.cats.teleport.enable=True
    cfg.cats.teleport.type = "e_greedy"
    cfg.cats.teleport.kwargs.e = 0.1
    experiment = ClassicControlDiscreteExperiment(cfg, device=DEVICE)
    experiment.run()
    e_greedy.append(experiment)

In [None]:
# boltzmann
boltzmann = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = int(seed)
    cfg.cats.teleport.enable=True
    cfg.cats.teleport.type = "boltzmann"
    cfg.cats.teleport.kwargs = {"alpha": 2}
    experiment = ClassicControlDiscreteExperiment(cfg, device=DEVICE)
    experiment.run()
    boltzmann.append(experiment)

In [None]:
def entropy_memory_with_confidence(experiments: list[ClassicControlDiscreteExperiment]):
    experiment_entropy = [entropy_memory(experiment.memory.rb) for experiment in experiments]
    mu_hat = experiment_entropy.mean()
    n = len(experiment_entropy)
    var_hat = ((experiment_entropy-mu_hat)**2).sum() / (n-1)
    std_hat = var_hat**0.5
    mu_hat, (std_hat / n**0.5) * 1.96

In [None]:
# Mountain Car Visual Confirmation
 
fig, axs = plt.subplots(1,5)
fig.set_size_inches(24,3)
fig.subplots_adjust(wspace=0.5)
visualise_memory(disabled[0], fig, axs[0])
axs[0].set_title("Disabled")
visualise_memory(random[0], fig, axs[1])
axs[1].set_title("Random")
visualise_memory(e_greedy[0], fig, axs[2])
axs[2].set_title("e_greedy")
visualise_memory(boltzmann[0], fig, axs[3])
axs[3].set_title("boltzmann")
visualise_memory(ucb[0], fig, axs[4])
axs[4].set_title("ucb")