In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

np.random.seed(0)

%matplotlib inline

In [2]:
def get_tuple(iterable):
    return tuple(sorted(iterable))

In [3]:
def diff(s, e):
    return get_tuple(set(s).difference([e]))

In [4]:
def subsets(s):
    res = []
    for i in range(len(s)+1):
        tmp = []
        for e in combinations(s, i):
            tmp += [list(e)]
        res += sorted(tmp)
        
    return res

In [5]:
ϵ = 0.1
α = 1
γ = 1
θ = 0.02

In [6]:
N = 10
M = 10

In [7]:
cows = set([(0,0),(0,M-1)])

In [8]:
class State:
    def __init__(self, i, j, cows, n=N, m=M):
        self.i = i
        self.j = j
        self.cows = cows
        self.actions = []
        self.value = 0
        self.terminal = False
        
        if j < m-1:
            self.actions.append('r')
        if j > 0:
            self.actions.append('l')
        if i > 0:
            self.actions.append('u')
        if i < n-1:
            self.actions.append('d')
            
        self.policy = self.actions[0]
        
    def __repr__(self):
        return "({i}, {j}, {c}), v: {v}, p: {p}".format(i=self.i, j=self.j, v=self.value, a=self.actions, p=self.policy,
                                                          c=len(self.cows))
        
    def make_terminal(self):
        self.terminal = True
        self.actions = ['.']
        self.policy = '.'
        self.value = 0
        
    def next_state(self, action):        
        i, j = self.i, self.j
        
        if action == 'r': j += 1
        elif action == 'l': j -= 1
        elif action == 'u': i -= 1
        elif action == 'd': i += 1
        
        return i, j

In [9]:
def new_grid():
    grid = {}
    comb = subsets(cows)
    for i in range(N):
        for j in range(M):
            for c in comb:
                if (i, j) in c:
                    continue
                grid[i, j, get_tuple(c)] = State(i, j, get_tuple(c))
        
    grid[N-1, 0, ()].make_terminal()
    return grid

In [10]:
def test_policy(grid):
    s = grid[N-1, 0, get_tuple(cows)]
    R = 0
    while not s.terminal:
        R -= 2
        i, j = s.next_state(s.policy)
        if (i, j) in cows:
            s = grid[i, j, diff(s.cows, (i, j))]
        else:
            s = grid[i, j, s.cows]
            
    R += 102
    
    return R

In [11]:
grid = new_grid()

it = 0

while True:
    
    it += 1
    Δ = 0

    for k in grid:
        s = grid[k]
        val = s.value
        for a in s.actions:
            i, j = s.next_state(a)

            if (i, j) in cows:
                s_ = grid[i, j, diff(s.cows, (i, j))]
            else:
                s_ = grid[i, j, s.cows]


            R = -2 if not s_.terminal else 100 if s != s_ else 0
            
            tmp = R + γ * s_.value

            if tmp > s.value:
                s.value = tmp
                s.policy = a

            Δ = max(Δ, np.abs(val - s.value))

    if Δ < θ:
        break

In [12]:
it

15

In [13]:
test_policy(grid)

30