In [237]:
import numpy as np
from collections import defaultdict

class game():
    def __init__(self, c_, r_, R_):
        self.s = (c_, r_)
        self.nrow = 3
        self.ncol = 4
        
        assert 0<=self.s[0]<self.ncol and 0<=self.s[1]<self.nrow, "invalid initial position"
        self.R = {}
        self.T = defaultdict(lambda: defaultdict(int))
        self.actions = {'up':(1,0),'down':(-1,0),'left':(0,-1),'right':(0,1)}
        
        self.states = set([(0,0),(0,1),(0,2),(0,3),(1,0),(1,2),(1,3),(2,0),(2,1),(2,2),(2,3),'end'])
        for s in self.states:
            self.R[s] = R_
        self.R[(1,3)] = -1
        self.R[(2,3)] = 1
        self.R['end'] = 0
        
        for s in self.states:
            if s == (1,3) or s == (2,3) or s=='end':
                for a in self.actions:
                    self.T[s,a]['end'] = 1
                continue
            x1, y1 = s
            for a, (dx,dy) in self.actions.items():
                x2, y2 = x1 + dx, y1 + dy
                if (x2,y2) in self.states:
                    self.T[(x1,y1),a][(x2,y2)] += 0.8
                else:
                    self.T[(x1,y1),a][(x1,y1)] += 0.8
                x2, y2 = x1 + dy, y1 + dx
                if (x2,y2) in self.states:
                    self.T[(x1,y1),a][(x2,y2)] += 0.1
                else:
                    self.T[(x1,y1),a][(x1,y1)] += 0.1
                x2, y2 = x1 - dy, y1 - dx
                if (x2,y2) in self.states:
                    self.T[(x1,y1),a][(x2,y2)] += 0.1
                else:
                    self.T[(x1,y1),a][(x1,y1)] += 0.1
                    
    def move(self, action):
        R = self.R[self.s]
        
        next_s = list(self.T[self.s, action].keys())
        p = list(self.T[self.s, action].values())
        #print(next_s, p)
        i = np.random.choice(len(next_s), p=p)
        self.s = next_s[i]
        
        return R
    
    def reset_pos(self):
        self.s = list(self.states)[np.random.randint(len(self.states))]

In [295]:
def value_iteration(game, max_iteration=1000, beta=1):
    # initialize
    V = {s:0 for s in game.states}
    
    for n in range(max_iteration):
        V_new = {}
        for s in game.states:
            V_new[s] = game.R[s] + beta * max(sum(T*V[s_next] for s_next, T in game.T[(s,a)].items()) for a in game.actions)
        #print(V_new)
        if all(abs(V_new[s]-V[s])<1e-8 for s in game.states):
            print(f"value iteration converged at {n}")
            break
        V = V_new
        
    p = {}
    for s in game.states:
        q = {game.R[s] + beta * sum(T*V[s_next] for s_next, T in game.T[(s,a)].items()): a for a in game.actions}
        #print(s, q)
        p[s] = q[max(q.keys())]
    return V, p

def policy_iteration(game, max_iteration=1000, beta=1):
    
    p = {s:np.random.choice(list(grid_game.actions.keys())) for s in game.states}
    V = {s:0 for s in game.states}
    
    for n in range(max_iteration):
        for m in range(max_iteration):
            V_new = {}
            for s in game.states:
                V_new[s] = game.R[s] + beta * sum(T*V[s_next] for s_next, T in game.T[s, p[s]].items())
            if all(abs(V_new[s]-V[s])<1e-8 for s in game.states):
                break
            V = V_new
        p_new = {}
        for s in game.states:
            q = {game.R[s] + beta * sum(T*V[s_next] for s_next, T in game.T[(s,a)].items()): a for a in game.actions}
            p_new[s] = q[max(q.keys())]
        if all(p_new[s]==p[s] for s in game.states):
            print(f"policy iteration converged at {n}")
            break
        p = p_new
        
    return V, p


def Q_learning(game, max_iteration=10000, alpha=0.8, random_explore=0.05, beta=1):
    
    Q = {s:{a:0 for a in game.actions} for s in game.states}
    
    for i in range(max_iteration):
        game.reset_pos()
        tmpQ = {s:{a:Q[s][a] for a in game.actions} for s in game.states}
        while game.s != 'end':
            curr_s = game.s
            q = {q: a for a,q in Q[curr_s].items()}
            next_a = q[max(q.keys())]
            if np.random.random()<random_explore:
                next_a = np.random.choice(list(grid_game.actions.keys()))
            R = game.move(next_a)
            Q[curr_s][next_a] = (1-alpha) * Q[curr_s][next_a] + alpha * (R + beta * max(Q[game.s].values()))
#         if i%1000==0:
#             print(i, sum(abs(Q[s][a]-tmpQ[s][a]) for s in Q for a in Q[s]))
            
    p = {}
    V = {}
    for s in game.states:
        q = {Q[s][a]: a for a in game.actions}
        p[s] = q[max(q.keys())]
        V[s] = max(Q[s].values())
    return V, p

In [296]:
grid_game = game(0,0,-0.04)
Q_learning(grid_game)

0 0.8320000000000001
1000 1.1471591471570615
2000 0.0
3000 0.13473277836906494
4000 0.05457824747646567
5000 0.0
6000 0.0
7000 0.0
8000 0.0
9000 0.1292776620297249


({(0, 1): 0.7559494871520556,
  (1, 2): 0.8448496478933267,
  (0, 0): 0.7886263247458721,
  (1, 3): -1.0,
  'end': 0,
  (2, 1): 0.9187199536992825,
  (2, 0): 0.8748266443254658,
  (2, 3): 1.0,
  (2, 2): 0.9462951710209209,
  (1, 0): 0.8345397682031598,
  (0, 2): 0.68102900008528,
  (0, 3): 0.28483994826770426},
 {(0, 1): 'left',
  (1, 2): 'left',
  (0, 0): 'up',
  (1, 3): 'right',
  'end': 'right',
  (2, 1): 'right',
  (2, 0): 'right',
  (2, 3): 'right',
  (2, 2): 'right',
  (1, 0): 'up',
  (0, 2): 'left',
  (0, 3): 'left'})

In [263]:
value_iteration(grid_game)

value iteration converged at 35


({(0, 1): 0.6553082174525494,
  (1, 2): 0.6602739726027294,
  (0, 0): 0.7053082186269344,
  (1, 3): -1,
  'end': 0,
  (2, 1): 0.867808219178066,
  (2, 0): 0.8115582191696421,
  (2, 3): 1,
  (2, 2): 0.9178082191780785,
  (1, 0): 0.7615582191497039,
  (0, 2): 0.6114155212421983,
  (0, 3): 0.38792490255093054},
 {(0, 1): 'left',
  (1, 2): 'up',
  (0, 0): 'up',
  (1, 3): 'right',
  'end': 'right',
  (2, 1): 'right',
  (2, 0): 'right',
  (2, 3): 'right',
  (2, 2): 'right',
  (1, 0): 'up',
  (0, 2): 'left',
  (0, 3): 'left'})

In [272]:
policy_iteration(grid_game)

policy iteration converged at 3


({(0, 1): 0.6553082191745441,
  (1, 2): 0.6602739726027398,
  (0, 0): 0.7053082191769686,
  (1, 3): -1,
  'end': 0,
  (2, 1): 0.8678082191780823,
  (2, 0): 0.8115582191780686,
  (2, 3): 1,
  (2, 2): 0.9178082191780822,
  (1, 0): 0.7615582191780359,
  (0, 2): 0.6114155250915418,
  (0, 3): 0.3879249099008332},
 {(0, 1): 'left',
  (1, 2): 'up',
  (0, 0): 'up',
  (1, 3): 'right',
  'end': 'right',
  (2, 1): 'right',
  (2, 0): 'right',
  (2, 3): 'right',
  (2, 2): 'right',
  (1, 0): 'up',
  (0, 2): 'left',
  (0, 3): 'left'})