# 蛇棋环境搭建

In [1]:
import numpy as np
import gym
from gym.spaces import Discrete

from contextlib import contextmanager
import time

@contextmanager
def timer(name):
    start = time.time()
    yield
    end = time.time()
    print('{} COST:{}'.format(name, end-start))

In [2]:
class SnakeEnv(gym.Env):
    SIZE=100
    
    def __init__(self, ladder_num, dices):
        self.ladder_num = ladder_num
        self.dices = dices
        self.observation_space = Discrete(self.SIZE+1)
        self.action_space = Discrete(len(dices))
        
        if ladder_num == 0:
            self.ladders = {0:0}
        else:
            # 处理梯子值，让梯子的数值无重复地反向赋值
            ladders = set(np.random.randint(1, self.SIZE, size=self.ladder_num*2))
            while len(ladders) < self.ladder_num*2:
                ladders.add(np.random.randint(1, self.SIZE))

            ladders = list(ladders)
            ladders = np.array(ladders)
            np.random.shuffle(ladders)
            ladders = ladders.reshape((self.ladder_num,2))

            re_ladders = list()
            for i in ladders:
                re_ladders.append([i[1],i[0]])

            re_ladders = np.array(re_ladders)
            # dict()可以把nx2维数组转化为字典形式
            self.ladders = dict(np.append(re_ladders, ladders, axis=0))
        print(f'ladders info:{self.ladders} dice ranges:{self.dices}')
        self.pos = 1
        
    def reset(self):
        self.pos = 1
        return self.pos
    
    def step(self, a):
        step = np.random.randint(1, self.dices[a]+1)
        self.pos += step
        if self.pos == 100:
            return 100, 100, 1, {}
        elif self.pos > 100:
            self.pos = 200 - self.pos
            
        if self.pos in self.ladders:
            self.pos = self.ladders[self.pos]
        return self.pos, -1, 0, {}
    
    def reward(self, s):
        if s == 100:
            return 100
        else:
            return -1
    
    # 无渲染
    def render(self):
        pass

# 智能体构建

In [3]:
class ModelFreeAgent(object):
    def __init__(self, env):
        self.s_len = env.observation_space.n
        self.a_len = env.action_space.n
        
        self.pi = np.zeros(self.s_len, dtype=int)
        self.value_q = np.zeros((self.s_len, self.a_len))
        self.value_n = np.zeros((self.s_len, self.a_len))
        self.gamma = 0.8
    
    
    def play(self, state, epsilon=0.0):
        # epsilon代表探索的概率，如果在epsilon覆盖范围内则会随机返回一个action（代表探索），否则返回目前已知
        # 的最好策略
        if np.random.rand() < epsilon:
            return np.random.randint(self.a_len)
        else:
            return self.pi[state]

# 策略评估（reward计算）

In [4]:
def eval_game(env, agent):
    state = env.reset()
    total_reward = 0
    state_action = []
    
    while True:
        act = agent.play(state)
        state_action.append((state,act))
        state, reward, done, _ = env.step(act)
        total_reward += reward
        if done:
            break
    
    return total_reward, state_action

# 算法

## 1. Monte Carlo

In [5]:
class MonteCarlo(object):
    def __init__(self, epsilon=0.0):
        self.epsilon = epsilon
            
    
    def monte_carlo_eval(self, agent, env):
        state = env.reset()
        episode = []
        while True:
            ac = agent.play(state, self.epsilon)
            next_state, reward, done, _ = env.step(ac)
            episode.append((state, ac, reward))
            state = next_state
            if done:
                break
        
        value = []
        return_val = 0
        for item in reversed(episode):
            # return_val 当前状态之后的所有回报乘以对应的打折率，求和
            return_val = return_val*agent.gamma + item[2]
            value.append((item[0], item[1], return_val))
        
        # 求迭代value_n次后的长期回报均值
        for item in reversed(value):
            agent.value_n[item[0]][item[1]] += 1
            agent.value_q[item[0]][item[1]] += (item[2]-agent.value_q[item[0]][item[1]])/agent.value_n[item[0]][item[1]]
    
    
    def policy_improve(self, agent):
        # 如果用np.zeros(agent.pi)会报错"ValueError: maximum supported dimension for an ndarray is 32, found 101"
        new_policy = np.zeros_like(agent.pi)
        for i in range(1, agent.s_len):
            new_policy[i] = np.argmax(agent.value_q[i,:])
        # 之前if缩进在后面，会导致pi数组基本全为0。因为在policy未完全更新前进行判断，导致提前退出函数
        if np.all(np.equal(new_policy, agent.pi)):
            return False
        else:
            agent.pi = new_policy
            return True
    
    
    def monte_carlo_opt(self, agent, env):
        iteration = 0
        while True:
            iteration += 1
            for i in range(100):
                self.monte_carlo_eval(agent, env)
            ret = self.policy_improve(agent)
            if not ret:
                break
        print('Monte Carlo: {} rounds'.format(iteration))

In [6]:
def monte_carlo_demo():
    env = SnakeEnv(10, [3,6])
    agent = ModelFreeAgent(env)
    mc = MonteCarlo(0.05)
    with timer('Timer Monte Carlo Iter'):
        mc.monte_carlo_opt(agent, env)
    print('return_pi={}'.format(eval_game(env, agent)))
    print('agent.pi={}'.format(agent.pi))

## 2. TD (Temporal Difference)

In [14]:
# 改一处公式就是Q-Learning
class SARSA(object):
    def __init__(self, epsilon=0.0):
        self.epsilon = epsilon
    
    
    def sarsa_eval(self, agent, env):
        state = env.reset()
        prev_act = -1
        prev_state = -1

        while True:
            act = agent.play(state, self.epsilon)
            next_state, reward, done, _ = env.step(act)
            if prev_act != -1:
                # Q Learning与SARSA的区别就是np.max这里
                # return_val = reward + agent.gamma * (0 if done else np.max(agent.value_q[state,:]))
                return_val = reward + agent.gamma * (0 if done else agent.value_q[state][act])
                agent.value_n[prev_state][prev_act] += 1
                agent.value_q[prev_state][prev_act] += (return_val - agent.value_q[prev_state][prev_act]) / \
                                                        agent.value_n[prev_state][prev_act]
            prev_act = act
            prev_state = state
            state = next_state
            if done:
                break
    
    
    def policy_improve(self, agent):
        new_policy = np.zeros_like(agent.pi)
        for i in range(1, agent.s_len):
            new_policy[i] = np.argmax(agent.value_q[i,:])
        if np.all(np.equal(new_policy, agent.pi)):
            return False
        else:
            agent.pi = new_policy
            return True
        
        
    def sarsa_opt(self, agent, env):
        iteration = 0
        while True:
            iteration += 1
            for i in range(100):
                self.sarsa_eval(agent, env)
            ret = self.policy_improve(agent)
            if not ret:
                break
        print('SARSA: {} rounds'.format(iteration))

In [17]:
def sarsa_demo():
    env = SnakeEnv(10, [3,6])
    agent = ModelFreeAgent(env)
    sarsa_algo = SARSA(0.05)
    with timer('Sarsa Iter'):
        sarsa_algo.sarsa_opt(agent, env)
    print('return_pi={}'.format(eval_game(env, agent)))
    print('agent.pi={}'.format(agent.pi))

In [20]:
if __name__ == '__main__':
    monte_carlo_demo()
    sarsa_demo()

ladders info:{88: 94, 13: 41, 12: 82, 6: 84, 5: 78, 71: 50, 65: 93, 64: 10, 76: 61, 96: 9, 94: 88, 41: 13, 82: 12, 84: 6, 78: 5, 50: 71, 93: 65, 10: 64, 61: 76, 9: 96} dice ranges:[3, 6]
Monte Carlo: 14 rounds
Timer Monte Carlo Iter COST:0.28543591499328613
return_pi=(92, [(1, 1), (7, 0), (96, 1), (98, 0), (99, 0), (99, 0), (99, 0), (99, 0), (98, 0)])
agent.pi=[0 1 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 0 1 1 1 0 0 1 0 0 1 0
 1 1 0 1 1 0 0 0 1 1 0 1 1 1 0 0 0 0 0 1 0 1 0 0 1 0 0 0 0 0 1 1 1 1 1 1 1
 0 1 1 1 1 1 1 0 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0]
ladders info:{36: 51, 31: 34, 69: 83, 44: 91, 9: 53, 55: 19, 49: 47, 98: 4, 54: 67, 70: 22, 51: 36, 34: 31, 83: 69, 91: 44, 53: 9, 19: 55, 47: 49, 4: 98, 67: 54, 22: 70} dice ranges:[3, 6]
SARSA: 45 rounds
Sarsa Iter COST:1.673116683959961
return_pi=(99, [(1, 0), (98, 0)])
agent.pi=[0 0 0 0 1 1 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1
 1 1 1 1 0 0 0 1 1 1 1 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 