In [163]:
class MDP:
    def __init__(self, actlist, init, terminals, discount=0.9):
        self.init = init
        self.actlist = actlist
        self.terminals = terminals
        self.discount = discount
        self.states = set()
        self.reward = {}

    def R(self, state):
        return self.reward[state]

    def T(self, state):
        print("Here we initialize the probability")

    def action(self, state):
        if state in self.terminals:
            return [None]
        else:
            return self.actlist

class GridMDP(MDP):
    def __init__(self, grid, terminals, init=(0, 0), discount=0.9):
        directions = [(1, 0), (-1, 0), (0, 1), (0, -1)]  # Assuming four possible directions (up, down, right, left)
        grid.reverse()
        super().__init__(actlist=directions, init=init, terminals=terminals, discount=discount)
        self.grid = grid
        self.rows = len(grid)
        self.cols = len(grid[0])
        for x in range(self.cols):
            for y in range(self.rows):
                self.reward[x, y] = grid[y][x]
                if grid[y][x] is not None:
                    self.states.add((x, y))
    def display(self,mapping):
        return list(reversed([[mapping.get((x,y),None) for x in range(self.cols)] for y in range(self.rows)]))
    def go(self, state, function):
        state1 = vector_add(state, function)
        return if_(state1 in self.states, state1, state)

    def T(self, state, action):
        if action is None:
            return [(0.0, state)]
        else:
            return [(0.8, self.go(state, action)), (0.1, self.go(state, turn_left(action))), (0.1, self.go(state, turn_right(action)))]


In [164]:
def update(x, **entries):
    if isinstance(x, dict):
        x.update(entries)
    else:
        x.__dict__.update(entries)
    return x


In [165]:
def turn_left(direction):
    return directions[(directions.index(direction) + 1) % len(directions)]

def turn_right(direction):
    return directions[directions.index(direction) - 1]

In [166]:
directions.index((1,0))+1

1

In [167]:
grid = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]

In [168]:
turn_right((1,0))

(0, -1)

In [169]:
    grid.reverse()
    print(*grid,sep="\n")

[9, 10, 11, 12]
[5, 6, 7, 8]
[1, 2, 3, 4]


In [170]:
    grid[1][3]

8

In [171]:
directions = [(1, 0), (0,1), (-1, 0), (0, -1)]

In [172]:
import operator

def vector_add(a, b):
    return tuple(map(operator.add, a, b))

In [173]:
vector_add((2,2),(1,0))

(3, 2)

In [174]:
grid[2][1]

2

In [175]:
grid[1][2]

7

In [176]:
def if_(test, result, alternative):
    if test:
        return result
    else:
        return alternative

In [177]:
living_reward = 0
mdp1 = GridMDP([[living_reward, living_reward, living_reward, +1],
                [living_reward, None, living_reward, -1],
                [living_reward, living_reward, living_reward, living_reward]],
               terminals=[(3, 2), (3, 1)])

In [178]:
print(mdp1.states)

{(0, 1), (1, 2), (2, 1), (0, 0), (3, 1), (2, 0), (3, 0), (0, 2), (2, 2), (1, 0), (3, 2)}


In [179]:
for a in mdp1.action((2,2)):
    print(a)

(1, 0)
(-1, 0)
(0, 1)
(0, -1)


In [180]:
u1 = {s: 0 for s in mdp1.states}
u = u1.copy()
print(u1)
print(u)

for a in mdp1.action((2, 2)):
    for (p, s1) in mdp1.T((2, 2), a):
        print(u1[s1])

{(0, 1): 0, (1, 2): 0, (2, 1): 0, (0, 0): 0, (3, 1): 0, (2, 0): 0, (3, 0): 0, (0, 2): 0, (2, 2): 0, (1, 0): 0, (3, 2): 0}
{(0, 1): 0, (1, 2): 0, (2, 1): 0, (0, 0): 0, (3, 1): 0, (2, 0): 0, (3, 0): 0, (0, 2): 0, (2, 2): 0, (1, 0): 0, (3, 2): 0}
0
0
0
0
0
0
0
0
0
0
0
0


In [192]:
def value_iteration(mdp, iterations=9):
    u_over_time = []
    u1 = {s: 0 for s in mdp.states}
    R, T, discount = mdp.R, mdp.T, mdp.discount

    for _ in range(iterations):
        u = u1.copy()
        for s in mdp.states:
            if s in mdp.terminals:
                u1[s] = R(s)
            else:
                u1[s] = max([sum([p * (R(s) + discount * u[s1]) for (p, s1) in T(s, a)]) for a in mdp.action(s)])
        u_over_time.append(u)
    return u_over_time

In [193]:
u_over_time = value_iteration(mdp1)

In [194]:
print(mdp1.T((2,2),(1,0)))

[(0.8, (3, 2)), (0.1, (2, 2)), (0.1, (2, 1))]


In [195]:
print(u_over_time[2])

{(0, 1): 0.0, (1, 2): 0.0, (2, 1): 0.0, (0, 0): 0.0, (3, 1): -1, (2, 0): 0.0, (3, 0): 0.0, (0, 2): 0.0, (2, 2): 0.7200000000000001, (1, 0): 0.0, (3, 2): 1}


In [196]:
print(u_over_time[1])

{(0, 1): 0.0, (1, 2): 0.0, (2, 1): 0.0, (0, 0): 0.0, (3, 1): -1, (2, 0): 0.0, (3, 0): 0.0, (0, 2): 0.0, (2, 2): 0.0, (1, 0): 0.0, (3, 2): 1}


In [199]:
print(*mdp1.display(u_over_time[4]),sep="\n")

[0.3732480000000001, 0.6583680000000002, 0.8291880000000001, 1]
[0.0, None, 0.5136120000000002, -1]
[0.0, 0.0, 0.30844800000000006, 0.0]


In [187]:
for y in range(mdp1.cols):
    for x in range(mdp1.rows):
        print(u_over_time[1].get((x,y),None))

0.0
0.0
0.0
0.0
None
0.0
0.0
0.0
0.0
None
None
None
