## Temporal Difference: Q-learning

In [1]:
# 获取一个格子的状态
def get_state(row, col):
  if row != 3:
    return 'ground'
  if row == 3 and col == 0:
    return 'ground'
  if row == 3 and col == 11:
    return 'terminal'
  
  return 'trap'

get_state(0, 0)

'ground'

In [2]:
# 在一个格子里做一个动作
def move(row, col, action):
  # 如果当前已经在陷阱或者重点，则不能执行任何动作
  if get_state(row, col) in ['trap', 'terminal']:
    return row, col, 0
  # ↑
  if action == 0:
    row -= 1
  # ↓
  if action == 1:
    row += 1
  # ←
  if action == 2:
    col -= 1
  # →
  if action == 3:
    col += 1
    
  # 不允许走到地图外面
  row = max(0, row)
  row = min(3, row)
  col = max(0, col)
  col = min(11, col)
  
  # 陷阱-100，否则都是-1
  reward = -1
  if get_state(row, col) == 'trap':
    reward = -100
    
  return row, col, reward
  
move(0, 0, 3)



(0, 1, -1)

In [3]:
import numpy as np

# 初始化每个格子里采取每个动作的分数，
# 均为0， 因为没有先验知识
Q = np.zeros([4, 12, 4]) # 全局变量

## new code begin here
## new code end here

Q.shape

(4, 12, 4)

In [4]:
import random

# 根据状态选择动作，e-greedy
def get_action(row, col):
  # 小概率随机选择
  if random.random() < 0.1: # epsilon
    return random.choice(range(4))
  #
  # 否则选择分数最高的动作
  return Q[row, col].argmax()

get_action(0, 0)

0

In [5]:
# new code begin here
def get_update(row, col, action, reward, next_row, next_col):
  # target 为下一个格子最高分数，不在关注动作
  target = 0.9 * Q[next_row, next_col].max()
  # 加上本步分数
  target += reward

  # value 为当前state和action的分数
  value = Q[row, col, action]
    

  #根据时序差分算法,当前state,action的分数 = 下一个state,action的分数*gamma + reward
  #此处是求两者的差,越接近0越好
  update = target - value
    
  # 学习率 0.1
  update *= 0.1
  
  return update

get_update(0, 0, 3, -1, 0, 1)

-0.1

In [6]:
# 训练 training
def train():
  for epoch in range(1500):
    # init curr loc
    row = random.choice(range(4))
    col = 0
    
    # init first action 
    action = get_action(row, col)
    
    # 计算反馈和，应当逐渐减小
    reward_sum = 0
    
    # new code begin here
    # new code end here 
    
    # 循环直到终止
    while get_state(row, col) not in ['terminal', 'trap']:
      
      # 执行动作
      next_row, next_col, reward = move(row, col, action)
      reward_sum += reward
      
      # 求新位置的动作
      next_action = get_action(next_row, next_col)
      
      # new code begin here ------------
      # 计算分数
      update = get_update(row, col, action, reward, next_row, next_col)

      # 更新分数
      Q[row, col, action] += update
      
      # 更新当前位置
      row = next_row
      col = next_col
      action = next_action
      
    if epoch % 100 == 0:
      print(epoch, reward_sum)
      # new code end here   ------------
    
train()

0 -117
100 -123
200 -21
300 -23
400 -12
500 -110
600 -15
700 -17
800 -15
900 -13
1000 -12
1100 -15
1200 -105
1300 -13
1400 -15


In [7]:
#打印游戏，方便测试
def show(row, col, action):
  graph = [
      '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
      '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□', '□',
      '□', '□', '□', '□', '□', '□', '□', '□', '□', '○', '○', '○', '○', '○',
      '○', '○', '○', '○', '○', '❤'
  ]

  action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]

  graph[row * 12 + col] = action

  graph = ''.join(graph)

  for i in range(0, 4 * 12, 12):
    print(graph[i:i + 12])


show(1, 1, 0)

□□□□□□□□□□□□
□↑□□□□□□□□□□
□□□□□□□□□□□□
□○○○○○○○○○○❤


In [10]:
from IPython import display
import time

def test():
  # 起点
  row = random.choice(range(4))
  col = 0
  
  # 最多玩N = 200步
  for _ in range(200):
    # 获取当前状态，终止态则终止
    if get_state(row, col) in ['trap', 'terminal']:
      break
    
    # 选择最优动作
    action = Q[row, col].argmax()
    
    # 打印此动作
    display.clear_output(wait=True)
    time.sleep(0.1)
    show(row, col, action)
    
    # 执行动作
    row, col, reward = move(row, col, action)

test()
    

□□□□□□□□□□□□
□□□□□□□□□□□□
□□□□□□□□□□□↓
□○○○○○○○○○○❤


In [11]:
# 打印所有格子上的动作倾向
for row in range(4):
  line = ''
  for col in range(12):
    action = Q[row, col].argmax()
    action = {0: '↑', 1: '↓', 2: '←', 3: '→'}[action]
    line += action
    
  print(line)

→→→→→→↓→→→↓↓
→→→↓→↓→↓→→→↓
→→→→→→→→→→→↓
↑↑↑↑↑↑↑↑↑↑↑↑
