# 导入必要库

In [1]:
from pathlib import Path
import gc
import time

from core import RLEnv
from core import BaseAgent
from core.schemas import PiecewizeMethod  # noqa
from greedy import EpsilonDecreasingConfig, GreedyAgent, GreedyAlgorithm, GreedyType
from ucb1 import UCBAgent, UCB1Algorithm
from thompson_sampling import TSAgent, TSAlgorithm

from train import batch_train
from utils import plot_metrics_history, save_experiment_data, ProcessDataLogger

In [2]:
STEPS: int = 100_000
GRID_SIZE: int = 500

SEED: int = 42
MACHINE_COUNT: int = 10
RUN_COUNT: int = 5
CONVERGENCE_THRESHOLD: float = 0.9
CONVERGENCE_MIN_STEPS: int = 100
OPTIMISTIC_TIMES: int = 1
ENABLE_OPTIMISTIC: bool = True
EXPERIMENT_DATA_DIR: Path = Path.cwd() / "experiment_data"

ENV: RLEnv = RLEnv(
    machine_count=MACHINE_COUNT,
    random_walk_internal=1,
    random_walk_machine_num=1,
    # piecewise_internal=1,
    # piecewize_method=PiecewizeMethod.UPSIDE_DOWN,
    seed=SEED,
)
EPSILON_CONFIG: EpsilonDecreasingConfig = EpsilonDecreasingConfig()

# 工厂函数

In [3]:
def get_run_id(agent_name: str) -> str:
    return agent_name + "_" + str(time.time())

In [4]:
def create_greedy_agent(
    env: RLEnv,
    epsilon_config: EpsilonDecreasingConfig,
    optimistic_init: bool,
    optimistic_times: int,
    convergence_threshold: float,
    convergence_min_steps: int,
    seed: int,
) -> BaseAgent:
    return GreedyAgent(
        name=GreedyType.GREEDY,
        env=env,
        algorithm=GreedyAlgorithm(
            greedy_type=GreedyType.GREEDY,
            optimistic_init=optimistic_init,
            optimistic_times=optimistic_times,
        ),
        epsilon_config=epsilon_config,
        convergence_threshold=convergence_threshold,
        convergence_min_steps=convergence_min_steps,
        seed=seed,
    )


def create_epsilon_agent(
    env: RLEnv,
    epsilon_config: EpsilonDecreasingConfig,
    optimistic_init: bool,
    optimistic_times: int,
    convergence_threshold: float,
    convergence_min_steps: int,
    seed: int,
) -> BaseAgent:
    return GreedyAgent(
        name=GreedyType.EPSILON,
        env=env,
        algorithm=GreedyAlgorithm(
            greedy_type=GreedyType.EPSILON,
            optimistic_init=optimistic_init,
            optimistic_times=optimistic_times,
        ),
        epsilon_config=epsilon_config,
        convergence_threshold=convergence_threshold,
        convergence_min_steps=convergence_min_steps,
        seed=seed,
    )


def create_decreasing_agent(
    env: RLEnv,
    epsilon_config: EpsilonDecreasingConfig,
    optimistic_init: bool,
    optimistic_times: int,
    convergence_threshold: float,
    convergence_min_steps: int,
    seed: int,
) -> BaseAgent:
    return GreedyAgent(
        name=GreedyType.EPSILON_DECREASING,
        env=env,
        algorithm=GreedyAlgorithm(
            greedy_type=GreedyType.EPSILON_DECREASING,
            optimistic_init=optimistic_init,
            optimistic_times=optimistic_times,
        ),
        epsilon_config=epsilon_config,
        convergence_threshold=convergence_threshold,
        convergence_min_steps=convergence_min_steps,
        seed=seed,
    )

In [5]:
def create_ucb1_agent(
    env: RLEnv,
    convergence_threshold: float,
    convergence_min_steps: int,
    seed: int,
) -> BaseAgent:
    return UCBAgent(
        name="UCB1",
        env=env,
        algorithm=UCB1Algorithm(),
        convergence_threshold=convergence_threshold,
        convergence_min_steps=convergence_min_steps,
        seed=seed,
    )

In [6]:
def create_ts_agent(
    env: RLEnv,
    convergence_threshold: float,
    convergence_min_steps: int,
    seed: int,
) -> BaseAgent:
    return TSAgent(
        name="Thompson Sampling",
        env=env,
        algorithm=TSAlgorithm(),
        convergence_threshold=convergence_threshold,
        convergence_min_steps=convergence_min_steps,
        seed=seed,
    )

# 训练

## 普通贪婪算法

In [7]:
run_id = get_run_id(GreedyType.GREEDY)
file_name: Path = (
    EXPERIMENT_DATA_DIR
    / f"{run_id}_T={STEPS}_K={MACHINE_COUNT}_Q_0={OPTIMISTIC_TIMES}.png"
)
process_logger = ProcessDataLogger(
    run_id=run_id,
    total_steps=STEPS,
    grid_size=GRID_SIZE,
)

agents, reward, metrics = batch_train(
    count=RUN_COUNT,
    agent_factory=create_greedy_agent,
    env=ENV,
    epsilon_config=EPSILON_CONFIG,
    steps=STEPS,
    seed=SEED,
    optimistic_init=ENABLE_OPTIMISTIC,
    optimistic_times=OPTIMISTIC_TIMES,
    convergence_threshold=CONVERGENCE_THRESHOLD,
    convergence_min_steps=CONVERGENCE_MIN_STEPS,
    process_logger=process_logger,
)
print(metrics)
print(reward)

plot_metrics_history(agents, "贪婪算法", file_name, x_log=False)
plot_metrics_history(agents, "贪婪算法", file_name, x_log=True)
save_experiment_data(reward, metrics, file_name)
process_logger.save(file_name.with_stem(file_name.stem + "process"), total_steps=STEPS)
dump = process_logger.export(total_steps=STEPS)
keys = list(dump.points[0].data.keys())

print(ENV.best_reward(1000))
for m in ENV.machines:
    print(m.reward_probability)

del agents, reward, metrics, process_logger, dump
gc.collect()

达到收敛时的步数: 390
达到收敛时的步数: 280
达到收敛时的步数: 520
达到收敛时的步数: 484
达到收敛时的步数: 952
avg_regret=24927.064841829404 avg_regret_rate=0.25573600015019016 avg_total_reward=72544.8 avg_optimal_rate=4.3995600439956004e-05 avg_convergence_steps=525.2 avg_convergence_rate=1.0
values=[0.8, 2.4, 3746.4, 12537.2, 1.2, 28973.4, 0.4, 7971.4, 10395.4, 8916.2] counts=[1.8, 4.4, 5396.8, 16279.4, 2.6, 37366.6, 1.4, 10187.8, 16626.4, 14132.8]
✅ 字体文件 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/assets/微软雅黑.ttf 已加载
✅ 图表已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/greedy_1758166898.7507033_T=100000_K=10_Q_0=1.png
✅ 字体文件 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/assets/微软雅黑.ttf 已加载
✅ 图表已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/greedy_1758166898.7507033_T=100000_K=10_Q_0=1_x_log.png
✅ 实验结果数据已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/greedy_1758166898.7507033_T=100000_K=10_Q_0=1.json
✅ 过程数据已保存至 /home/Jese__Ki/

26008

## UCB1算法

In [8]:
run_id = get_run_id("ucb1")
file_name: Path = (
    EXPERIMENT_DATA_DIR
    / f"{run_id}_T={STEPS}_K={MACHINE_COUNT}_Q_0={OPTIMISTIC_TIMES}.png"
)
process_logger = ProcessDataLogger(
    run_id=run_id,
    total_steps=STEPS,
    grid_size=GRID_SIZE,
)

agents, reward, metrics = batch_train(
    count=RUN_COUNT,
    agent_factory=create_ucb1_agent,
    env=ENV,
    steps=STEPS,
    seed=SEED,
    convergence_threshold=CONVERGENCE_THRESHOLD,
    convergence_min_steps=CONVERGENCE_MIN_STEPS,
    process_logger=process_logger,
)
print(metrics)
print(reward)

plot_metrics_history(agents, "UCB1 算法", file_name, x_log=False)
plot_metrics_history(agents, "UCB1 算法", file_name, x_log=True)
save_experiment_data(reward, metrics, file_name)
process_logger.save(file_name.with_stem(file_name.stem + "process"), total_steps=STEPS)
dump = process_logger.export(total_steps=STEPS)
keys = list(dump.points[0].data.keys())

print(ENV.best_reward(1000))
for m in ENV.machines:
    print(m.reward_probability)

del agents, reward, metrics, process_logger, dump
gc.collect()

达到收敛时的步数: 94622
达到收敛时的步数: 300
达到收敛时的步数: 9019
达到收敛时的步数: 1594
达到收敛时的步数: 2352
avg_regret=45172.51529205487 avg_regret_rate=0.45925736269624184 avg_total_reward=53187.4 avg_optimal_rate=0.0 avg_convergence_steps=21577.4 avg_convergence_rate=1.0
values=[53187.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] counts=[100000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
✅ 字体文件 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/assets/微软雅黑.ttf 已加载
✅ 图表已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/ucb1_1758166903.0016177_T=100000_K=10_Q_0=1.png
✅ 字体文件 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/assets/微软雅黑.ttf 已加载
✅ 图表已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/ucb1_1758166903.0016177_T=100000_K=10_Q_0=1_x_log.png
✅ 实验结果数据已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/ucb1_1758166903.0016177_T=100000_K=10_Q_0=1.json
✅ 过程数据已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data

23782

# Thompson Sampling 算法

In [9]:
run_id = get_run_id("thompson_sampling")
file_name: Path = (
    EXPERIMENT_DATA_DIR
    / f"{run_id}_T={STEPS}_K={MACHINE_COUNT}_Q_0={OPTIMISTIC_TIMES}.png"
)
process_logger = ProcessDataLogger(
    run_id=run_id,
    total_steps=STEPS,
    grid_size=GRID_SIZE,
)

agents, reward, metrics = batch_train(
    count=RUN_COUNT,
    agent_factory=create_ts_agent,
    env=ENV,
    steps=STEPS,
    seed=SEED,
    convergence_threshold=CONVERGENCE_THRESHOLD,
    convergence_min_steps=CONVERGENCE_MIN_STEPS,
    process_logger=process_logger,
)
print(metrics)
print(reward)

plot_metrics_history(agents, "TS 算法", file_name, x_log=False)
plot_metrics_history(agents, "TS 算法", file_name, x_log=True)
save_experiment_data(reward, metrics, file_name)
process_logger.save(file_name.with_stem(file_name.stem + "process"), total_steps=STEPS)
dump = process_logger.export(total_steps=STEPS)
keys = list(dump.points[0].data.keys())

print(ENV.best_reward(1000))
for m in ENV.machines:
    print(m.reward_probability)

del agents, reward, metrics, process_logger, dump
gc.collect()

达到收敛时的步数: 6200
达到收敛时的步数: 810
avg_regret=2758.2555826458674 avg_regret_rate=0.03404183984911322 avg_total_reward=78267.2 avg_optimal_rate=0.039290000000000005 avg_convergence_steps=1402.0 avg_convergence_rate=0.4
values=[17.8, 6909.6, 3092.8, 3869.6, 17491.6, 11548.0, 6013.0, 11290.8, 16827.6, 1206.4] counts=[33.0, 8496.6, 3929.0, 4908.2, 21346.2, 15154.6, 7595.6, 14809.8, 22178.0, 1549.0]
✅ 字体文件 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/assets/微软雅黑.ttf 已加载
✅ 图表已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/thompson_sampling_1758166906.808646_T=100000_K=10_Q_0=1.png
✅ 字体文件 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/assets/微软雅黑.ttf 已加载
✅ 图表已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/thompson_sampling_1758166906.808646_T=100000_K=10_Q_0=1_x_log.png
✅ 实验结果数据已保存至 /home/Jese__Ki/Projects/learn/Python/rl_atomic/bandit/experiment_data/thompson_sampling_1758166906.808646_T=100000_K=10_Q_0=1.json
✅ 过程数据已保存至 /home/J

22908