# Setup

from src.agents.trajectory_constrained_ql_agent_v2 import TrajectoryConstrainedQLV2
from src.agents.succ_feature_ql_agent import SuccessorFeatureQL
from src.agents.safety_feature_ql_agent import SafetyFeatureQL
from src.agents.ql_agent import QL
from src.features.tabular_succ_features import TabularSuccessorFeatures
from src.features.tabular_safety_features import TabularSafetyFeatures
from src.utils import create_dirs, init_logger, read_config, generate_four_rooms, MeanVar, plot_mean_var

In [4]:
# Read configs
training_params = read_config("training.cfg")
env_params = read_config("four_rooms.cfg")
agent_params = read_config("agents.cfg")
feature_params = read_config("features.cfg")

n_trials = training_params["general"]["n_trials"]
n_samples = training_params["general"]["n_samples"]
n_tasks = training_params["general"]["n_tasks"]

# Safety Feature QL vs. Successor Feature QL

In [None]:
# Initialize agents
succ_feature_ql = SuccessorFeatureQL(TabularSuccessorFeatures(**feature_params["tabular_sf"]), **agent_params["succ_feat_ql"], **agent_params["agent"])
safe_feature_ql = SafetyFeatureQL(TabularSafetyFeatures(**feature_params["tabular_safety"]), **agent_params["safe_feat_ql"], **agent_params["agent"])
ql = QL(**agent_params["ql"], **agent_params["agent"])

agents = {
    "SuccessorFeatureQL": succ_feature_ql,
    "SafetyFeatureQL": safe_feature_ql,
    "QL": ql,
}

# Initialize performance stats
task_return_hist = [MeanVar() for _ in agents]
task_cost_hist = [MeanVar() for _ in agents]
constrained_agents = {name:a for (name, a) in agents.items() if hasattr(a, "threshold")}
task_constraint_violate_hist = [MeanVar() for _ in constrained_agents]

# Train agents
for trial in range(n_trials):
    for name in agents:
        agents[name].initialize()
    for i in range(n_tasks):
        task = generate_four_rooms(env_params["env"]["maze"], i)
        for name in agents:
            info_str = "trial {}, task {}, solving with {}".format(trial, i, name)
            print(info_str)
            agents[name].train_on_task(task, n_samples)
    # Update stats
    for i, name in enumerate(agents):
        task_return_hist[i].update(agents[name].reward_hist)
        task_cost_hist[i].update(agents[name].cost_hist)
    for i, name in enumerate(constrained_agents):
        task_constraint_violate_hist[i].update(constrained_agents[name].violation_hist)

# Plot performance stats
plot_mean_var(task_return_hist, agents.keys(), n_samples, n_tasks, "reward", "return_comparison", save_fig=False, show_fig=True)
plot_mean_var(task_cost_hist, agents.keys(), n_samples, n_tasks, "cost", "cost_comparison", save_fig=False, show_fig=True)
plot_mean_var(task_constraint_violate_hist, constrained_agents.keys(), n_samples, n_tasks, "violations", "violation_comparison",save_fig=False, show_fig=True)


trial 0, task 0, solving with SuccessorFeatureQL
trial 0, task 0, solving with SafetyFeatureQL
trial 0, task 0, solving with QL
trial 0, task 1, solving with SuccessorFeatureQL
trial 0, task 1, solving with SafetyFeatureQL
trial 0, task 1, solving with QL
trial 0, task 2, solving with SuccessorFeatureQL
trial 0, task 2, solving with SafetyFeatureQL
trial 0, task 2, solving with QL
trial 0, task 3, solving with SuccessorFeatureQL
trial 0, task 3, solving with SafetyFeatureQL
trial 0, task 3, solving with QL
trial 0, task 4, solving with SuccessorFeatureQL
trial 0, task 4, solving with SafetyFeatureQL
trial 0, task 4, solving with QL
trial 0, task 5, solving with SuccessorFeatureQL
trial 0, task 5, solving with SafetyFeatureQL
trial 0, task 5, solving with QL
trial 0, task 6, solving with SuccessorFeatureQL
trial 0, task 6, solving with SafetyFeatureQL
trial 0, task 6, solving with QL
trial 0, task 7, solving with SuccessorFeatureQL
trial 0, task 7, solving with SafetyFeatureQL
trial 0, 

trial 2, task 13, solving with SafetyFeatureQL
trial 2, task 13, solving with QL
trial 2, task 14, solving with SuccessorFeatureQL
trial 2, task 14, solving with SafetyFeatureQL
trial 2, task 14, solving with QL
trial 2, task 15, solving with SuccessorFeatureQL
trial 2, task 15, solving with SafetyFeatureQL
trial 2, task 15, solving with QL
trial 2, task 16, solving with SuccessorFeatureQL
trial 2, task 16, solving with SafetyFeatureQL
trial 2, task 16, solving with QL
trial 2, task 17, solving with SuccessorFeatureQL
trial 2, task 17, solving with SafetyFeatureQL
trial 2, task 17, solving with QL
trial 2, task 18, solving with SuccessorFeatureQL
trial 2, task 18, solving with SafetyFeatureQL


# Safety Feature QL vs. Trajectory Constrained QL

In [None]:
# Initialize agents
trajectory_constrained_ql_v2 = TrajectoryConstrainedQLV2(**agent_params["trajectory_constrained_ql_v2"], **agent_params["agent"])
safe_feature_ql = SafetyFeatureQL(TabularSafetyFeatures(**feature_params["tabular_safety"]), **agent_params["safe_feat_ql"], **agent_params["agent"])

agents = {
    "TrajectoryConstrainedQL": trajectory_constrained_ql_v2,
    "SafetyFeatureQL": safe_feature_ql,
}

# Initialize performance stats
task_return_hist = [MeanVar() for _ in agents]
task_cost_hist = [MeanVar() for _ in agents]
task_constraint_violate_hist = [MeanVar() for _ in agents]

# Train agents
for trial in range(n_trials):
    for name in agents:
        agents[name].initialize()
    for i in range(n_tasks):
        task = generate_four_rooms(env_params["env"]["maze"], i)
        for name in agents:
            info_str = "trial {}, task {}, solving with {}".format(trial, i, name)
            print(info_str)
            agents[name].train_on_task(task, n_samples)
    # Update stats
    for i, name in enumerate(agents):
        task_return_hist[i].update(agents[name].reward_hist)
        task_cost_hist[i].update(agents[name].cost_hist)
        task_constraint_violate_hist[i].update(agents[name].violation_hist)

# Plot performance stats
plot_mean_var(task_return_hist, agents.keys(), n_samples, n_tasks, "reward", "return_comparison", save_fig=False, show_fig=True)
plot_mean_var(task_cost_hist, agents.keys(), n_samples, n_tasks, "cost", "cost_comparison", save_fig=False, show_fig=True)
plot_mean_var(task_constraint_violate_hist, agents.keys(), n_samples, n_tasks, "violations", "violation_comparison",save_fig=False, show_fig=True)

