In [20]:
# -*- coding: utf-8 -*-

# Author: Kyle Kastner
# License: BSD 3-Clause
# Implementing http://mnemstudio.org/path-finding-q-learning-tutorial.htm
# Q-learning formula from http://sarvagyavaish.github.io/FlappyBirdRL/
# Visualization based on code from Gael Varoquaux gael.varoquaux@normalesup.org
# http://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection


In [21]:
# defines the reward/connection graph
r = np.array([[-1, -1, -1, -1,  0,  -1],
              [-1, -1, -1,  0, -1, 100],
              [-1, -1, -1,  0, -1,  -1],
              [-1,  0,  0, -1,  0,  -1],
              [ 0, -1, -1,  0, -1, 100],
              [-1,  0, -1, -1,  0, 100]]).astype("float32")
q = np.zeros_like(r)
print r[0][0]

-1.0


In [22]:
#更新Q矩阵
#参数含义：当前状态state采用动作action达到状态next_state，通过这些参数计算当前状态的新Q值并赋值给Q表
def update_q(state, next_state, action, alpha, gamma):
    rsa = r[state, action]
    qsa = q[state, action]
    new_q = qsa + alpha * (rsa + gamma * max(q[next_state, :]) - qsa)
    q[state, action] = new_q
    # renormalize row to be between 0 and 1
    rn = q[state][q[state] > 0] / np.sum(q[state][q[state] > 0])
    q[state][q[state] > 0] = rn
    return r[state, action]

In [23]:
#给所有起始点规划路径
def show_traverse():
    # show all the greedy traversals
    for i in range(len(q)):
        current_state = i
        traverse = "%i -> " % current_state
        n_steps = 0
        while current_state != 5 and n_steps < 20:
            next_state = np.argmax(q[current_state])
            current_state = next_state
            traverse += "%i -> " % current_state
            n_steps = n_steps + 1
        # cut off final arrow
        traverse = traverse[:-4]
        print("Greedy traversal for starting state %i" % i)
        print(traverse)
        print("")


In [24]:
# Core algorithm
gamma = 0.8
alpha = 1.
n_episodes = 1E3
n_states = 6
n_actions = 6
epsilon = 0.05
random_state = np.random.RandomState(1999)


In [27]:
# Core algorithm
gamma = 0.8
alpha = 1.
n_episodes = 1E3
n_states = 6
n_actions = 6
epsilon = 0.05
random_state = np.random.RandomState(1999)
for e in range(int(n_episodes)):
    states = list(range(n_states))
    random_state.shuffle(states)
    current_state = states[0]
    goal = False
    
    #查看中间运行过程
    if e % int(n_episodes / 10.) == 0 and e > 0:
        pass
        # uncomment this to see plots each monitoring
        #show_traverse()
        #show_q()
        
    #在每一个episode中，执行以下操作，直到达到目标
    while not goal:
        # epsilon greedy
        valid_moves = r[current_state] >= 0
        
        #5%的概率执行以下操作
        #从R表中随机选取一个该状态的动作
        if random_state.rand() < epsilon:
            actions = np.array(list(range(n_actions)))
            #可行动序列
            actions = actions[valid_moves == True]
            if type(actions) is int:
                actions = [actions]
            #随机
            random_state.shuffle(actions)
            action = actions[0]
            next_state = action
        #95%的概率执行以下操作
        else:
            #如果Q表中该状态存在非零动作值，选取对应动作
            if np.sum(q[current_state]) > 0:
                action = np.argmax(q[current_state])
            #否则，借助R表随机选取一个动作
            else:
                # Don't allow invalid moves at the start
                # Just take a random move
                actions = np.array(list(range(n_actions)))
                actions = actions[valid_moves == True]
                random_state.shuffle(actions)
                action = actions[0]
            next_state = action
        #action和next_state是一致的 e.g.采用第四个动作，就会进入第四个状态
        reward = update_q(current_state, next_state, action,
                          alpha=alpha, gamma=gamma)
        
        # Goal state has reward 100
        if reward > 1:
            goal = True
        current_state = next_state


print(q)
print '-------'
show_traverse()

[[  0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
    1.00000000e+00   0.00000000e+00]
 [  0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
    0.00000000e+00   1.00000000e+00]
 [  0.00000000e+00   0.00000000e+00   0.00000000e+00   1.00000000e+00
    0.00000000e+00   0.00000000e+00]
 [  0.00000000e+00   7.99844801e-01   6.03799708e-03   0.00000000e+00
    1.94117308e-01   0.00000000e+00]
 [  0.00000000e+00   0.00000000e+00   0.00000000e+00   2.80259693e-45
    0.00000000e+00   1.00000000e+00]
 [  0.00000000e+00   3.90896834e-33   0.00000000e+00   0.00000000e+00
    0.00000000e+00   1.00000000e+00]]
-------
Greedy traversal for starting state 0
0 -> 4 -> 5

Greedy traversal for starting state 1
1 -> 5

Greedy traversal for starting state 2
2 -> 3 -> 1 -> 5

Greedy traversal for starting state 3
3 -> 1 -> 5

Greedy traversal for starting state 4
4 -> 5

Greedy traversal for starting state 5
5

