In [14]:
import numpy as np
import pandas as pd
import time

class simple_qlearning(object):
    
    @staticmethod
    def initialize_q(state_num, action_num):
        
        q_table = np.zeros(shape=(state_num, action_num))
        # q_table = np.random.rand(state_num, action_num)
        return q_table    

    def __init__(self, state_num, action_num, seed, epoch_num=10, lr=0.1, dr=0.8, greedy=0.9):
        
        np.random.seed(seed)
        self.state_num = state_num
        self.action_num = action_num
        self.epoch_num = epoch_num
        self.greedy = greedy
        self.q_table = self.initialize_q(state_num, action_num)
        self.lr = lr
        self.dr = dr
        

    def take_action(self, state_idx):
        
        random_num = np.random.rand()
        probs = self.q_table[state_idx]
        if random_num > self.greedy or probs.all() == 0:
            action = np.random.choice(range(self.action_num))
        else:
            action = np.argmax(probs)
            
        # print(random_num>self.greedy, probs, action)
        return action
    
    def init_env(self):
        self.env = ["O"] + ["_"] * (self.state_num-2) + ["T"]
    
    def update_env(self, current_pos, action):
        
        if action == 1:
            next_pos = current_pos + 1
            self.env[current_pos] = "_"
            if self.env[next_pos] != "T":
                self.env[next_pos] = "O"
        else:
            
            if current_pos == 0:
                next_pos = current_pos + 1
            else:
                next_pos = current_pos - 1
                
            self.env[current_pos] = "_"
            self.env[next_pos] = "O"
            
        print("\r"+"".join(self.env), end="")
                
    def update_q(self, current_pos, next_pos, action, r):
        
        update_val = max(self.q_table[next_pos]) * self.dr - self.q_table[current_pos][action] + r
        self.q_table[current_pos][action] += self.lr * update_val
        # print(update_val, self.q_table)
    
    def learn(self):
        
        for epoch in range(self.epoch_num):
            step = 0
            current_pos = 0
            self.init_env()
            
            while True:
                
                if current_pos == 0:
                    next_pos, action = 1, 1 
                else:
                    action = self.take_action(current_pos)
                    if action:
                        next_pos = current_pos + action
                    else:
                        next_pos = current_pos - 1


                if self.env[next_pos] == "T":
                    r = 10
                    self.update_q(current_pos, next_pos, action, r)
                    self.update_env(current_pos, action)
                    print(f"Epoch: {epoch}, use step {step}....")
                    break
                else:
                    r = 0
                    self.update_q(current_pos, next_pos, action, r)
                    self.update_env(current_pos, action)      
                    current_pos = next_pos
                    step += 1

In [12]:
state_num, action_num = 6, 2

ql = simple_qlearning(state_num, action_num, seed=2020)
ql.learn()

0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
_O___T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
__O__T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
___O_T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
____OT0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
___O_T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
__O__T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
___O_T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
__O__T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
___O_T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
__O__T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
_O___T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
__O__T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
___O_T0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
____OT0.0 [[0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]
 [0. 0.]]
___O_T0.0 [[0. 0.]
 [0. 0.]
 [0

In [5]:
probs = np.array([0,0,0])

probs.all() == 0

True