In [None]:
import os
import pickle as pkl
from copy import deepcopy

# Training
import numpy as np
import torch
from hydra import initialize, compose
from tqdm import tqdm

# Evaluation
import seaborn as sns
from matplotlib import pyplot as plt

sns.set_theme()

from cats.evaluation import *
from cats.run import run

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MASTER_SEED = 235790
TOTAL_FRAMES = 10000

def generate_random_seeds(n: int):
    rng = np.random.default_rng(MASTER_SEED)
    return list(map(int, rng.integers(0, 2**32-1, size=(n, ))))

seeds = generate_random_seeds(20)
print(seeds)

ENV = "MountainCarContinuous-v0"
# ENV = "Pendulum-v1"

with initialize(version_base=None, config_path="cats/config"):
    base_cfg = compose(
        config_name="defaults_off_policy.yaml",
        overrides=[
            "intrinsic=disagreement",
            f"env.name={ENV}",                               # Environment Selection
            "env.max_episode_steps=200",
            f"train.total_frames={TOTAL_FRAMES}",            # Collection frames
        ],
    )   

path = "evaluate/data/cats-reset"
path = os.path.join(path, f"{ENV}.pkl")

In [None]:
# Baseline

death_end_baseline = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = seed
    cfg.cats.death_not_end = False
    experiment = run(cfg, save=False)
    death_end_baseline.append(experiment)

# Death is not the end

death_cont_baseline = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = seed
    cfg.cats.death_not_end = True
    experiment = run(cfg, save=False)
    death_cont_baseline.append(experiment)

# Reset Action (Standard)

reset_1 = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = seed
    cfg.cats.death_not_end = True
    cfg.cats.reset_action.enable = True
    cfg.cats.reset_inject_critic = False
    cfg.noise.scale = [0.1, 0.01]
    experiment = run(cfg, save=False)
    reset_1.append(experiment)

# Reset Action (Injection)
reset_2 = []
for seed in tqdm(seeds):
    cfg = deepcopy(base_cfg)
    cfg.seed = seed
    cfg.cats.death_not_end = True
    cfg.cats.reset_action.enable = True
    cfg.cats.reset_inject_critic = True
    cfg.noise.scale = [0.1, 0.01]
    experiment = run(cfg, save=False)
    reset_2.append(experiment)

In [None]:
# data = {
#     "death_end_baseline": death_end_baseline,
#     "death_cont_baseline": death_cont_baseline,
#     "reset_1": reset_1,
#     "reset_2": reset_2
# }

# with open(path, "wb") as f:
#     pkl.dump(data, f)

# Figures

In [None]:
with open(path, "rb") as f:
    data = pkl.load(f)

key_map = {
    "death_end_baseline": "Death End",
    "death_cont_baseline": "Death Cont.",
    "reset_1": "Reset Action Learnt",
    "reset_2": "Reset Action Injection",
}

In [None]:
# Final Disagreement Value
final_disagreement = {}
final_entropy = {}

for k, v in data.items():
    final_disagreement[k] = np.array([evaluate_disagreement(x) for x in v])
    final_entropy[k] = np.array([entropy_memory(x.memory.rb) for x in v])

In [None]:
def mu_var(data: dict):
    for k, v in data.items():
        mu = v.mean()
        n = len(v)
        confidence_bound = (((v-mu)**2).sum() / (n-1))**0.5 / (n**0.5) * 1.96
        print(k, mu, confidence_bound)

print("Entropy")
mu_var(final_entropy)
print("Disagreement")
mu_var(final_disagreement)

In [None]:
sns.barplot(final_disagreement)

In [None]:
import matplotlib
matplotlib.__version__