<a href="https://colab.research.google.com/github/YI-CHENG-SHIH645/ML/blob/master/RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from IPython.display import display, HTML
from scipy.stats import norm
from itertools import product
import numpy as np
import pandas as pd

def pretty_print(df):
    return display( HTML( df.to_html().replace("\\n","<br>") ) )

# Multi-armed Bandit 
\\
每台拉霸機的 payoff 都是高斯分佈

In [None]:
levers_mu = [1.2, 1.0, 0.8, 1.4]
payoffs = [norm(loc=mu, scale=1.0) for mu in levers_mu]

Q-value: 平均獎勵 \\
$ Q^{new}_k = Q^{old}_k + \frac{1}{n}(R_n - Q^{old}_k) $

In [None]:
snap_shot_at = [1, 2, 3, 4, 50, 100, 500, 1000, 5000]

cols = pd.MultiIndex.from_tuples(list(product([''], ['Trial', 'Decision', 'Lever\nChosen', 'Payoff'])))
multi = pd.MultiIndex.from_tuples(list(product([f'Lever {i}(stats)' for i in range(1, 5)], ['Q-val', 'Nobs'])))
cols = cols.append(multi)
cols = cols.append(pd.MultiIndex.from_tuples(list(product([''], ['Avg Gain\nper trial']))))
results = pd.DataFrame(columns=cols, index=range(len(snap_shot_at)))
pretty_print(results.head(3))

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Lever 1(stats),Lever 1(stats),Lever 2(stats),Lever 2(stats),Lever 3(stats),Lever 3(stats),Lever 4(stats),Lever 4(stats),Unnamed: 13_level_0
Unnamed: 0_level_1,Trial,Decision,Lever Chosen,Payoff,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Avg Gain per trial
0,,,,,,,,,,,,,
1,,,,,,,,,,,,,
2,,,,,,,,,,,,,


In [None]:
def simulate(eps_scheduler):
    Q = np.array([.0, .0, .0, .0])
    nobs = np.array([0, 0, 0, 0])

    for i in range(1, 5001):
        if np.random.random() < eps_scheduler(i):
            lever = np.random.randint(1, 5)
            decision = 'Explore'
        else:
            lever = np.argmax(Q) + 1
            decision = 'Exploit'
        payoff = payoffs[lever-1].rvs()
        nobs[lever-1] += 1
        Q[lever-1] = Q[lever-1] + 1/nobs[lever-1] * (payoff - Q[lever-1])
        if i in snap_shot_at:
            row = snap_shot_at.index(i)
            results.loc[row, ('', 'Trial')] = i
            results.loc[row, ('', 'Decision')] = decision
            results.loc[row, ('', 'Lever\nChosen')] = lever
            results.loc[row, ('', 'Payoff')] = payoff.round(3)
            for j in range(4):
                results.loc[row, (f'Lever {j+1}(stats)', 'Q-val')] = Q[j].round(3)
            for j in range(4):
                results.loc[row, (f'Lever {j+1}(stats)', 'Nobs')] = nobs[j]
            results.loc[row, ('', 'Avg Gain\nper trial')] = ((Q * nobs).sum()/nobs.sum()).round(3)
    return results

In [None]:
for eps_scheduler in [lambda n: 0.1,
                      lambda n: 0.01,
                      lambda n: 0.5,
                      lambda n: 0.995**(n-1)]:
    pretty_print(simulate(eps_scheduler))
    print('\n\n')

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Lever 1(stats),Lever 1(stats),Lever 2(stats),Lever 2(stats),Lever 3(stats),Lever 3(stats),Lever 4(stats),Lever 4(stats),Unnamed: 13_level_0
Unnamed: 0_level_1,Trial,Decision,Lever Chosen,Payoff,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Avg Gain per trial
0,1,Exploit,1,1.344,1.344,1,0.0,0,0.0,0,0.0,0,1.344
1,2,Exploit,1,0.552,0.948,2,0.0,0,0.0,0,0.0,0,0.948
2,3,Exploit,1,0.743,0.88,3,0.0,0,0.0,0,0.0,0,0.88
3,4,Exploit,1,1.209,0.962,4,0.0,0,0.0,0,0.0,0,0.962
4,50,Exploit,1,-0.624,1.267,47,0.301,2,-0.485,1,0.0,0,1.193
5,100,Explore,2,0.926,1.351,93,0.509,3,0.719,2,1.066,2,1.308
6,500,Exploit,4,0.69,1.266,394,0.95,11,0.75,13,1.678,82,1.313
7,1000,Exploit,4,0.984,1.281,410,1.135,21,0.787,24,1.389,545,1.325
8,5000,Exploit,4,2.809,1.294,526,0.993,117,0.706,126,1.39,4231,1.353







Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Lever 1(stats),Lever 1(stats),Lever 2(stats),Lever 2(stats),Lever 3(stats),Lever 3(stats),Lever 4(stats),Lever 4(stats),Unnamed: 13_level_0
Unnamed: 0_level_1,Trial,Decision,Lever Chosen,Payoff,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Avg Gain per trial
0,1,Exploit,1,2.825,2.825,1,0.0,0,0.0,0,0.0,0,2.825
1,2,Exploit,1,1.856,2.341,2,0.0,0,0.0,0,0.0,0,2.341
2,3,Exploit,1,0.163,1.615,3,0.0,0,0.0,0,0.0,0,1.615
3,4,Exploit,1,1.397,1.56,4,0.0,0,0.0,0,0.0,0,1.56
4,50,Exploit,1,-0.081,1.193,49,1.103,1,0.0,0,0.0,0,1.191
5,100,Exploit,1,0.741,1.307,88,1.103,1,0.0,0,1.23,11,1.296
6,500,Exploit,4,0.319,1.271,299,0.868,2,0.0,0,1.387,199,1.315
7,1000,Exploit,4,2.787,1.266,304,0.868,2,-0.155,1,1.336,693,1.312
8,5000,Exploit,4,3.285,1.26,320,0.845,12,0.905,13,1.383,4655,1.372







Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Lever 1(stats),Lever 1(stats),Lever 2(stats),Lever 2(stats),Lever 3(stats),Lever 3(stats),Lever 4(stats),Lever 4(stats),Unnamed: 13_level_0
Unnamed: 0_level_1,Trial,Decision,Lever Chosen,Payoff,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Avg Gain per trial
0,1,Exploit,1,-0.212,-0.212,1,0.0,0,0.0,0,0.0,0,-0.212
1,2,Explore,3,1.688,-0.212,1,0.0,0,1.688,1,0.0,0,0.738
2,3,Explore,4,2.571,-0.212,1,0.0,0,1.688,1,2.571,1,1.349
3,4,Explore,1,0.724,0.256,2,0.0,0,1.688,1,2.571,1,1.193
4,50,Explore,1,2.501,1.474,19,0.908,6,0.71,8,0.991,17,1.12
5,100,Explore,1,4.731,1.312,33,0.595,11,0.892,14,1.483,42,1.246
6,500,Explore,4,2.039,1.134,92,0.945,58,0.947,67,1.492,283,1.289
7,1000,Exploit,4,2.952,1.19,154,1.075,121,0.872,129,1.456,596,1.294
8,5000,Explore,1,0.773,1.202,659,0.991,623,0.812,621,1.399,3097,1.249







Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Lever 1(stats),Lever 1(stats),Lever 2(stats),Lever 2(stats),Lever 3(stats),Lever 3(stats),Lever 4(stats),Lever 4(stats),Unnamed: 13_level_0
Unnamed: 0_level_1,Trial,Decision,Lever Chosen,Payoff,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Q-val,Nobs,Avg Gain per trial
0,1,Explore,4,3.218,0.0,0,0.0,0,0.0,0,3.218,1,3.218
1,2,Explore,3,-0.248,0.0,0,0.0,0,-0.248,1,3.218,1,1.485
2,3,Explore,2,0.705,0.0,0,0.705,1,-0.248,1,3.218,1,1.225
3,4,Explore,3,0.255,0.0,0,0.705,1,0.003,2,3.218,1,0.983
4,50,Exploit,4,0.637,1.051,9,1.166,10,0.957,14,1.63,17,1.245
5,100,Explore,3,2.399,1.341,16,1.314,24,0.941,25,1.402,35,1.256
6,500,Exploit,4,0.136,1.076,40,1.168,79,0.774,50,1.484,331,1.33
7,1000,Exploit,4,2.547,1.075,43,1.15,83,0.829,53,1.39,821,1.327
8,5000,Exploit,4,1.725,1.067,44,1.15,83,0.811,54,1.406,4819,1.392







上例中 \\
環境不會有變化，我們也就不需要狀態 \\

若今天環境會隨時間、行為而有所變化，就需要考慮當前狀態 \\
Q 值本來是所有動作都會紀錄一個，變成所有列得出的(狀態, 動作)都會紀錄一個 \\
Q(a) -> Q(s, a)

環境因為動作而有所變化，所處的環境又影響能夠達成的目標 \\
=> 規劃目標 \\
=> 最後獎勵總和最多 \\
=> $ G = \sum_{k=t+1}^TR_k$ \\
=> 不會只著眼於當前 \\

$ Q^{new}(s, a) = Q^{old}(s, a) + \frac{1}{n}(G - Q^{old}(s, a)) $ \\
=> trial 越多，Q 收斂到 "future expected reward"

$ Q^{new}(s, a) = Q^{old}(s, a) + \alpha(G - Q^{old}(s, a)), \alpha < 1 $

# The Game of Nim

In [None]:
class Nim:
    def __init__(self, matches: int):
        self.init_matches = matches
        self.current_matches = self.init_matches
    
    def reset(self):
        self.current_matches = self.init_matches
        return self.current_matches

    def max_action(self):
        return np.clip(self.current_matches, 1, 3)    

    def step(self, action: int):
        assert (action in [1, 2, 3]) and (action <= self.current_matches)
        self.current_matches -= action

        done = self.current_matches == 0
        reward = -1 if done else 0

        if not done:
            a = np.random.randint(1, self.max_action() +1)
            self.current_matches -= a
            done = self.current_matches == 0
            reward = 1 if done else 0

        return self.current_matches, action, reward, done

class NimSimulator:
    def __init__(self, nim_instance, epsilon_scheduler):
        self.nim_game = nim_instance
        self.epsilon_scheduler = epsilon_scheduler
        self.state_col = 'State (= number of matches left)'
        self.Q_table = self.create_Q_table(self.nim_game.init_matches)
        self.sa, self.r_history = [], []

    def create_Q_table(self, matches):
        cols = pd.MultiIndex.from_tuples(list(product([''], ['Matches\npicked up'])))
        multi = pd.MultiIndex.from_tuples(list(product([self.state_col], range(1, matches+1))))
        cols = cols.append(multi)
        Q_table = pd.DataFrame(columns=cols, index=range(1, 4))
        Q_table.loc[:, ('', 'Matches\npicked up')] = [1, 2, 3]

        return Q_table

    def simulate(self, n_run=5000, alpha=0.05, method='MC'):
        self.Q_table.loc[:, self.state_col] = 0

        for i in range(1, n_run+1):
            # 現有多少火柴
            s = self.nim_game.reset()

            # 遊戲是否結束
            done = False

            while not done:
                k = self.nim_game.max_action()
                # exploration
                if np.random.random() < self.epsilon_scheduler(i):
                    a = np.random.randint(1, k+1)
                # exploitation
                else:
                    a = np.argmax(self.Q_table.loc[:k, (self.state_col, s)]) + 1
                s_next, a, r, done = self.nim_game.step(a)

                if method == 'MC':
                    self._mc('record', s, a, r)
                else:
                    self._td(s, a, r, s_next, alpha=alpha)
                s = s_next

            if method == 'MC':
                self._mc('update', alpha=alpha)

    def _mc(self, state, s=None, a=None, r=None, alpha=0.05):
        if state == 'record':
            # 每次 run episode 的過程中要紀錄經過了哪些 state，在 episode 結束後才知道更新哪些
            self.sa.append((s, a))
            self.r_history.append(r)
        else:
            G = np.array(self.r_history)[::-1].cumsum()[::-1]
            for idx, (s, a) in enumerate(self.sa):
                Gt = G[idx]
                Q_old = self.Q_table.loc[a, (self.state_col, s)]
                self.Q_table.loc[a, (self.state_col, s)] = Q_old + alpha * (Gt - Q_old)
            self.sa.clear()
            self.r_history.clear()
    
    def _td(self, s, a, r, s_next, alpha=0.05):
        Q_old = self.Q_table.loc[a, (self.state_col, s)]
        k = self.nim_game.max_action()
        Q_next = np.max(self.Q_table.loc[:k, (self.state_col, s_next)]) if s_next > 0 else 0
        self.Q_table.loc[a, (self.state_col, s)] = Q_old + alpha * (r + Q_next - Q_old)

In [None]:
simulator = NimSimulator(Nim(8), lambda n: 0.9995**(n-1))
simulator.simulate(method='MC')
pretty_print(simulator.Q_table.round(3))

Unnamed: 0_level_0,Unnamed: 1_level_0,State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left)
Unnamed: 0_level_1,Matches picked up,1,2,3,4,5,6,7,8
1,1,-1.0,1.0,-0.166,0.343,-0.103,0.733,0,0.746
2,2,0.0,-1.0,1.0,-0.027,0.526,0.085,0,0.713
3,3,0.0,0.0,-1.0,1.0,0.074,0.342,0,0.819


In [None]:
simulator.simulate(method='TD')
pretty_print(simulator.Q_table.round(3))

Unnamed: 0_level_0,Unnamed: 1_level_0,State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left),State (= number of matches left)
Unnamed: 0_level_1,Matches picked up,1,2,3,4,5,6,7,8
1,1,-1.0,1.0,-0.024,0.731,0.195,0.999,0,0.794
2,2,0.0,-1.0,1.0,0.247,0.381,0.466,0,0.837
3,3,0.0,0.0,-1.0,1.0,0.039,0.483,0,1.0
