In [1]:
from pathlib import Path
import sys

sys.path.append(str(Path.cwd().parent))

import uuid
from typing import Tuple, List, Type
from datetime import datetime
import gc

from bandit_lib.runner import batch_train
from bandit_lib.agents import (
    BaseRewardStates,
    Metrics,
    MetricsConfig,
    GreedyConfig,
    GreedyAgent,
    AlgorithmConfig,
)
from bandit_lib.env import EnvConfig
from bandit_lib.agents.base import Agent_T
from bandit_lib.utils import (
    save_process_data,
    ProcessDataDump,
    save_meta_data,
    MetaDataDump,
    plot_metrics_history,
    get_metric_labels,
    plot_comparison,
)

In [2]:
# constants
ROOT = Path.cwd().parent  # Root Path
DATE = datetime.now().strftime("%Y%m%d_%H%M%S")
EXP_ID = uuid.uuid4().hex  # You can find logs and figures in the results folders, path: results/logs, results/figures
METRIC_LABELS = get_metric_labels()

# states
# environment
ENABLE_DYNAMIC = True
RANDOM_WALK_ARM_NUM = 1
RANDOM_WALK_INTERNAL = 1
RANDOM_WALK_STD = 0.01
ARM_NUM = 10

# experiment
REPEAT_TIMES = 500  # The number of repititions for each test run should be set to 500 to guarantee consistent and reliable outcomes
STEP_NUM = 100_000
WORKER_NUM = 10
BASE_SEED = 42  # The base seed for multiple parallel replications under different variables, which increments by one for each run, hence called base seed

# algorithm
OPTIMISTIC_INITIALIZATION_ENABLED = True
OPTIMISTIC_INITIALIZATION_VALUE = 1

# configs
ENV_CONFIG = EnvConfig(
    enable_dynamic=ENABLE_DYNAMIC,
    random_walk_internal=RANDOM_WALK_INTERNAL,
    random_walk_arm_num=RANDOM_WALK_ARM_NUM,
    random_walk_std=RANDOM_WALK_STD,
)
METRICS_CONFIG = MetricsConfig(metrics_history_size=500)
GREEDY_CONFIG = GreedyConfig(
    optimistic_initialization_enabled=OPTIMISTIC_INITIALIZATION_ENABLED,
    optimistic_initialization_value=OPTIMISTIC_INITIALIZATION_VALUE,
)

# results
RESULTS_LIST: List[Tuple[ProcessDataDump, List[GreedyAgent]]] = []

# experiment variables
LEARN_RATE = [
    # 0.001,
    # 0.005,
    # 0.01,
    0.05,
    # 0.1,
    # 0.3,
    # 0.5,
    # 0.7,
    # 0.9,
    # 0.99,
    # 0.999,
    # 0.9999,
    # 0.99999,
]

# dynamic environment, remove convergence_rate from metric labels
METRIC_LABELS.remove("convergence_rate")

In [3]:
def get_run_id(name: str) -> str:
    return f"{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"


def run_agent(
    run_id: str,
    agent_type: Type[Agent_T],
    name: str,
    arm_num: int,
    env_config: EnvConfig,
    algorithm_config: AlgorithmConfig,
    repeat_times: int,
    step_num: int,
    base_seed: int,
    worker_num: int,
    metrics_config: MetricsConfig,
    metrics_to_plot: List[str],
) -> Tuple[ProcessDataDump, List[Agent_T]]:
    _results: Tuple[List[Agent_T], BaseRewardStates, List[Metrics]] = batch_train(
        run_id=run_id,
        agent_type=agent_type,
        name=name,
        arm_num=arm_num,
        env_config=env_config,
        algorithm_config=algorithm_config,
        repeat_times=repeat_times,
        step_num=step_num,
        base_seed=base_seed,
        worker_num=worker_num,
        metrics_config=metrics_config,
    )

    save_meta_data(
        meta_data=MetaDataDump(
            experiment_date=DATE,
            experiment_id=EXP_ID,
            agent_runs_num=repeat_times,
            total_steps=step_num,
            arm_num=arm_num,
            agent_seed=base_seed,
            algorithm=algorithm_config,
            metrics_config=metrics_config,
            env_config=env_config,
        ),
        path=ROOT / "results" / "logs" / DATE / f"meta_{run_id}.json",
    )
    metrics_history: List[List[Metrics]] = []
    for agent in _results[0]:
        metrics_history.append(
            [metric.model_copy(deep=True) for metric in agent.metrics]
        )
    process_data = ProcessDataDump(
        run_id=run_id,
        create_at=datetime.now(),
        rewards=_results[1],
        metrics=_results[2][-1],
        metrics_history=metrics_history,
        metrics_history_avg=_results[2],
    )
    save_process_data(
        process_data=process_data,
        path=ROOT / "results" / "logs" / DATE / f"process_{run_id}.json",
    )
    plot_metrics_history(
        metrics_history=_results[2],
        agent_name=name,
        file_name=ROOT / "results" / "figures" / DATE / f"{run_id}.html",
        agents=_results[0],
        x_log=True,
        metrics_to_plot=metrics_to_plot,
        enable_statistical_credibility=True,
    )
    plot_metrics_history(
        metrics_history=_results[2],
        agent_name=name,
        file_name=ROOT / "results" / "figures" / DATE / f"{run_id}.jpeg",
        agents=_results[0],
        x_log=True,
        width=1500,
        height=1000,
        scale=2,
        metrics_to_plot=metrics_to_plot,
        enable_statistical_credibility=True,
    )
    results = (process_data, _results[0])
    return results

In [4]:
for learn_rate in LEARN_RATE:
    GREEDY_CONFIG.learning_rate = learn_rate
    print(f"Running {GreedyAgent.__name__} with learn_rate={learn_rate}")
    _results = run_agent(
        run_id=get_run_id(name=f"learn_rate_{learn_rate}_{GreedyAgent.__name__}"),
        agent_type=GreedyAgent,
        name=f"{GreedyAgent.__name__} Learn rate={learn_rate}",
        arm_num=ARM_NUM,
        env_config=ENV_CONFIG,
        algorithm_config=GREEDY_CONFIG,
        repeat_times=REPEAT_TIMES,
        step_num=STEP_NUM,
        base_seed=BASE_SEED,
        worker_num=WORKER_NUM,
        metrics_config=METRICS_CONFIG,
        metrics_to_plot=METRIC_LABELS,
    )
    RESULTS_LIST.append(_results)
    print(f"Finished {GreedyAgent.__name__} with learn_rate={learn_rate}")
    print(gc.collect())

Running GreedyAgent with learn_rate=0.05


Training agents: 100%|██████████| 500/500 [01:26<00:00,  5.78it/s]


Finished GreedyAgent with learn_rate=0.05
3187


In [5]:
print("run base line")
GREEDY_CONFIG.learning_rate = 0
_results = run_agent(
    run_id=get_run_id(name=f"baseLine_{GreedyAgent.__name__}"),
    agent_type=GreedyAgent,
    name=f"{GreedyAgent.__name__} BaseLine",
    arm_num=ARM_NUM,
    env_config=ENV_CONFIG,
    algorithm_config=GREEDY_CONFIG,
    repeat_times=REPEAT_TIMES,
    step_num=STEP_NUM,
    base_seed=BASE_SEED,
    worker_num=WORKER_NUM,
    metrics_config=METRICS_CONFIG,
    metrics_to_plot=METRIC_LABELS,
)
RESULTS_LIST.append(_results)

run base line


Training agents: 100%|██████████| 500/500 [01:25<00:00,  5.87it/s]


In [6]:
# plot comparison
fig_comparison = plot_comparison(
    runs_data=RESULTS_LIST,
    file_name=ROOT / "results" / "figures" / DATE / f"{EXP_ID}_x_log.html",
    metrics_to_plot=METRIC_LABELS,
    show_intersections=False,
    x_log=True,
    enable_statistical_credibility=True,
)
plot_comparison(
    runs_data=RESULTS_LIST,
    file_name=ROOT / "results" / "figures" / DATE / f"{EXP_ID}_x_log.png",
    metrics_to_plot=METRIC_LABELS,
    show_intersections=False,
    x_log=True,
    enable_statistical_credibility=True,
)
fig_comparison.show()