# 贪婪算法的不同和优化

In [23]:
from dataclasses import dataclass
from typing import Callable, List

class SlotMachine:
    """老虎机，每次拉动有一定概率获得奖励"""
    def __init__(self, reward_probility: float) -> None:
        self.reward_probility = reward_probility
        
    def pull(self) -> int:
        import random
        return 1 if random.random() < self.reward_probility else 0
    
class RLEnv:
    """环境中包含多个老虎机，默认为 10 个"""
    def __init__(self, machine_count: int = 10) -> None:
        self.machines: List[SlotMachine] = []
        
        for i in range(machine_count):
            self.machines.append(SlotMachine(reward_probility=(i + 1) / (machine_count + 1)))

    def pull(self, machine_id: int) -> int:
        assert 0 <= machine_id < len(self.machines)
        return self.machines[machine_id].pull()

@dataclass
class EpsilonDecreasingState:
    epsilon: float
    decay: float
    min_epsilon: float

class GreedyAgent:
    """我们的 Agent 默认使用贪婪算法，来找到最优的老虎机"""
    
    def __init__(self, name: str, env: RLEnv, greedy_algorithm: Callable[..., int]) -> None:
        self.name = name
        self.rewords = [0] * len(env.machines)
        self.greedy_algorithm = greedy_algorithm
        self.env = env
        self.episode_state = EpsilonDecreasingState(epsilon=1, decay=0.995, min_epsilon=0.01)
        
    def act(self, **kwargs) -> int:
        """选择拉动哪个老虎机，传入一个指定的贪婪算法，根据当前的奖励情况，选择一个老虎机"""
        return self.greedy_algorithm(self.rewords, **kwargs)
        
    def _pull_machine(self, machine_id: int) -> int:
        reward = self.env.pull(machine_id)
        return reward
    
def greedy_normal(rewords: List[int], **_) -> int:
    """普通贪婪算法，选择当前奖励最高的老虎机"""
    return rewords.index(max(rewords))

def epsilon_greedy(rewords: List[int], epsilon: float = 0.1, **_) -> int:
    """ε-贪婪算法，以 ε 的概率随机选择一个老虎机，以 1-ε 的概率选择当前奖励最高的老虎机"""
    import random
    if random.random() < epsilon:
        return random.randint(0, len(rewords) - 1)
    else:
        return rewords.index(max(rewords))
    
def epsilon_decreasing_greedy(rewords: List[int], epsilon_state: EpsilonDecreasingState, **_) -> int:
    """ε-递减贪婪算法，ε 随时间递减"""
    import random
    if random.random() < epsilon_state.epsilon:
        action = random.randint(0, len(rewords) - 1)
    else:
        action = rewords.index(max(rewords))
    
    
    epsilon_state.epsilon = max(epsilon_state.min_epsilon, epsilon_state.epsilon * epsilon_state.decay)
    return action
    
env = RLEnv(machine_count=10)
nomal_greedy_agent = GreedyAgent("normal_greedy", env, greedy_algorithm=greedy_normal)
epsilon_greedy_agent = GreedyAgent("epsilon_greedy", env, greedy_algorithm=epsilon_greedy)
epsilon_decreasing_greedy_agent = GreedyAgent("epsilon_decreasing_greedy", env, greedy_algorithm=epsilon_decreasing_greedy)

agents: List[GreedyAgent] = [
    nomal_greedy_agent,
    epsilon_greedy_agent,
    epsilon_decreasing_greedy_agent
]

def train(agent: GreedyAgent, episodes: int = 1000) -> GreedyAgent:
    _printed = False
    for i in range(episodes):
        action = agent.act(epsilon_state=agent.episode_state, epsilon=0.1)
        reward = agent._pull_machine(action)
        agent.rewords[action] += reward
    
        if agent.episode_state.epsilon <= 0.5 and not _printed:
            print(f"当前 epsilon 已经降到 0.5 了， 回合：{i}")
            _printed = True
    
    total_rewords = sum(agent.rewords)
        
    print(f"Name: {agent.name} \nTotal rewards: {total_rewords} \nRewards per machine: {agent.rewords}")
    if agent.name == "epsilon_decreasing_greedy":
        print(f"Final epsilon: {agent.episode_state.epsilon:.4f}")
    print("-" * 50)
    
    return agent

for agent in agents:
    train(agent, episodes=1000)

Name: normal_greedy 
Total rewards: 95 
Rewards per machine: [95, 0, 0, 0, 0, 0, 0, 0, 0, 0]
--------------------------------------------------
Name: epsilon_greedy 
Total rewards: 532 
Rewards per machine: [0, 3, 8, 6, 3, 482, 4, 4, 10, 12]
--------------------------------------------------
当前 epsilon 已经降到 0.5 了， 回合：138
Name: epsilon_decreasing_greedy 
Total rewards: 826 
Rewards per machine: [4, 1, 5, 12, 8, 11, 14, 15, 22, 734]
Final epsilon: 0.0100
--------------------------------------------------
