In [None]:
from copy import deepcopy

# Training
import numpy as np
import torch
from hydra import initialize, compose
from tqdm import tqdm

# Evaluation
import seaborn as sns
sns.set_theme()
from matplotlib import pyplot as plt

from cats.evaluation import *
from cats.run import run

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MASTER_SEED = 235790
def generate_random_seeds(n: int):
    rng = np.random.default_rng(MASTER_SEED)
    return list(map(int, rng.integers(0, 2**32-1, size=(n, ))))

In [None]:
# Optimising exploration

TOTAL_FRAMES = 5000

with initialize(version_base=None, config_path="cats/config"):
    base_cfg_baseline = compose(
        config_name="defaults_off_policy.yaml",
        overrides=[
            "env.name=MountainCarContinuous-v0", # Environment as default on gymnasium
            "env.max_episode_steps=999",
            f"train.total_frames={TOTAL_FRAMES}",
            "intrinsic=disagreement",
            "cats.fixed_reset=true",
            "cats.death_not_end=true",
            "cats.enable_policy_sampling=false",
        ]
    )

    base_cfg_teleport = deepcopy(base_cfg_baseline)
    base_cfg_teleport.env.max_episode_steps = math.inf
    base_cfg_teleport.cats.teleport.enable = True
    base_cfg_teleport.cats.teleport_interval_enable = True # No reset as an action
    base_cfg_teleport.cats.teleport.enable = True
    base_cfg_teleport.cats.teleport.memory = {
        "type": "fifo",
        "capacity": TOTAL_FRAMES
    }

    # base_cfg_teleport.cats.teleport.type = "ucb"    # UCB teleportation
    # base_cfg_teleport.cats.teleport.kwargs = {"c": 1}

seeds = generate_random_seeds(20) 

# Baseline
baseline = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg_baseline)
    experiment = run(cfg)
    baseline.append(experiment)

# Baseline
baseline_sampling = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg_baseline)
    cfg.cats.enable_policy_sampling = True
    experiment = run(cfg)
    baseline_sampling.append(experiment)

# CATS
cats = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg_teleport)
    experiment = run(cfg)
    cats.append(experiment)

# CATS
cats_sampling = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg_teleport)
    cfg.cats.enable_policy_sampling = True
    experiment = run(cfg)
    cats_sampling.append(experiment)

In [None]:
data = { 
    "baseline": baseline,
    "baseline_sampling": baseline_sampling,
    "cats": cats,
    "cats_sampling": cats_sampling
}

import os
import pickle as pkl

path = "evaluate/data/cats-exploration"
path = os.path.join(path, "mcc_policy_sampling.pkl")
with open(path, "wb") as f:
    pkl.dump(data, f)