In [None]:
import torch
from hydra import initialize, compose

from kitten.nn import Value, ClassicalValue

from cats.agent.policy import ExplorationPolicy
from cats.online_experiment import OnlineExperiment
from cats.evaluation import *


class ClassicControlDiscreteExperiment(OnlineExperiment):
    def _build_policy(self) -> None:
        super()._build_policy()
        self._policy = ExplorationPolicy(
            self.env, self.rng.build_generator(), repeat_probability=self.cfg.policy.p
        )
    
    @property
    def policy(self):
        return self._policy

    def _build_value(self) -> Value:
        return ClassicalValue(self.env).to(self.device)
        
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:

seed = 0
steps = 10000
with initialize(version_base=None, config_path="cats/config"):
        cfg = compose(
            config_name="defaults_online_classic_control.yaml",
            overrides=[
                f"seed={seed}",
                f"train.total_frames={steps}",
                "env.max_episode_steps=10000",
                "cats.fixed_reset=true",
                "cats.teleport.enable=true",
            ],
        )
experiment = ClassicControlDiscreteExperiment(cfg, device=DEVICE)
experiment.run()

visualise_classic_control_results(experiment)

In [None]:
seed = 0
steps = 10000
with initialize(version_base=None, config_path="cats/config"):
        cfg = compose(
            config_name="defaults_online_classic_control.yaml",
            overrides=[
                f"seed={seed}",
                f"train.total_frames={steps}",
                "env.max_episode_steps=10000",
                "cats.fixed_reset=true",
                "cats.teleport.enable=true",
            ],
        )
        cfg.cats.teleport.type = "boltzmann"
        cfg.cats.teleport.kwargs = {"alpha": 2}
experiment = ClassicControlDiscreteExperiment(cfg, device=DEVICE)
experiment.run()

visualise_classic_control_results(experiment)