In [45]:
def if_(test, result, alternative):
    if test:
        return result
    else:
        return alternative
import operator

def vector_add(a, b):
    return tuple(map(operator.add, a, b))

In [46]:
def update(x, **entries):
    if isinstance(x, dict):
        x.update(entries)
    else:
        x.__dict__.update(entries)
    return x


In [47]:
def turn_left(direction):
    return directions[(directions.index(direction) + 1) % len(directions)]

def turn_right(direction):
    return directions[directions.index(direction) - 1]

In [48]:
directions = [(1, 0), (0,1), (-1, 0), (0, -1)]

In [49]:
class MDP:
    def __init__(self, actlist, init, terminals, discount=0.9):
        self.init = init
        self.actlist = actlist
        self.terminals = terminals
        self.discount = discount
        self.states = set()
        self.reward = {}

    def R(self, state):
        return self.reward[state]

    def T(self, state):
        print("Here we initialize the probability")

    def action(self, state):
        if state in self.terminals:
            return [None]
        else:
            return self.actlist

class GridMDP(MDP):
    def __init__(self, grid, terminals, init=(0, 0), discount=0.9):
        directions = [(1, 0), (-1, 0), (0, 1), (0, -1)]  # Assuming four possible directions (up, down, right, left)
        grid.reverse()
        super().__init__(actlist=directions, init=init, terminals=terminals, discount=discount)
        self.grid = grid
        self.rows = len(grid)
        self.cols = len(grid[0])
        for x in range(self.cols):
            for y in range(self.rows):
                self.reward[x, y] = grid[y][x]
                if grid[y][x] is not None:
                    self.states.add((x, y))
    def display(self,mapping):
        return list(reversed([[mapping.get((x,y),None) for x in range(self.cols)] for y in range(self.rows)]))
    def go(self, state, function):
        state1 = vector_add(state, function)
        return if_(state1 in self.states, state1, state)

    def T(self, state, action):
        if action is None:
            return [(0.0, state)]
        else:
            return [(0.8, self.go(state, action)), (0.1, self.go(state, turn_left(action))), (0.1, self.go(state, turn_right(action)))]


In [50]:
living_reward = 0
mdp1 = GridMDP([[living_reward, living_reward, living_reward, +1],
                [living_reward, None, living_reward, -1],
                [living_reward, living_reward, living_reward, living_reward]],
               terminals=[(3, 2), (3, 1)])

In [61]:
def optimal_value_iteration(mdp, eplison_value=0.001):
   # u_over_time = []
    u1 = {s: 0 for s in mdp.states}
    R, T, discount = mdp.R, mdp.T, mdp.discount
    total_count=0
    while True:
    #for _ in range(iterations):
        delta_val=0
        total_count+=1
        u = u1.copy()
        for s in mdp.states:
            if s in mdp.terminals:
                u1[s] = R(s)
            else:
                u1[s] = max([sum([p * (R(s) + discount * u[s1]) for (p, s1) in T(s, a)]) for a in mdp.action(s)])
            delta_val=max(delta_val,abs(u1[s]-u[s]))
        
        if delta_val<eplison_value:
            print("optimized value",total_count)
            return u

In [62]:
optimze_result = optimal_value_iteration(mdp1)

optimized value 16


In [63]:
print(*mdp1.display(optimze_result),sep="\n")

[0.6449472665878168, 0.7443794501753827, 0.8477661188408148, 1]
[0.5662439726531641, None, 0.5718585829794186, -1]
[0.4902998037829586, 0.42987207149041445, 0.4752173643973898, 0.27686990880812257]


In [64]:
def argmax(seq,fn):
    best=seq[0]
    bestscore=fn(best)
    for x in seq:
        x_score=fn(x)
        if x_score>bestscore:
            best=x
            bestscore=x_score
    return best

In [65]:
argmax(["123","four","two"] ,lambda a: len(a))

'four'

In [66]:
def exp_minmax(s,a,mdp,u):
    R,T,discount=mdp.R , mdp.T ,mdp.discount
    li=[]
    for (p,s1) in T(s,a):
        temp=p*(R(s)+discount*u[s1])
        li.append(temp)
        return sum(li)

In [67]:
def policy_extraction(mdp,u):
     R,T,discount=mdp.R , mdp.T ,mdp.discount
     pi={s:'x' for s in mdp.states}
     for s in mdp.states:
        if s in mdp.terminals:
            pi[s]='x'
        else:
            temp=argmax(mdp.action(s),lambda a: exp_minmax(s,a,mdp,u))
            pi[s]=arrows_direction[temp]
     return pi

In [68]:
print(policy_extraction(mdp1,optimze_result))

{(0, 1): '>', (1, 2): '<', (2, 1): '>', (0, 0): '>', (3, 1): 'x', (2, 0): '>', (3, 0): '^', (0, 2): '<', (2, 2): '<', (1, 0): '^', (3, 2): 'x'}


In [71]:
p=policy_extraction(mdp1,optimze_result)

In [72]:
print(*mdp1.display(p),sep="\n")

['>', '>', '>', 'x']
['^', None, '^', 'x']
['^', '<', '^', '<']


In [70]:
arrows_direction={(1, 0):'>', (0,1):'^', (-1, 0):'<', (0, -1):'v'}