In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations

np.random.seed(0)

%matplotlib inline

In [2]:
def get_tuple(iterable):
    return tuple(sorted(iterable))

In [3]:
def diff(s, e):
    return get_tuple(set(s).difference([e]))

In [4]:
def subsets(s):
    res = []
    for i in range(len(s)+1):
        tmp = []
        for e in combinations(s, i):
            tmp += [list(e)]
        res += sorted(tmp)
        
    return res

In [5]:
ϵ = 0.1
α = 1
γ = 0.9
θ = 0.02

In [6]:
N = 3
M = 3

In [7]:
cows = set([(0,0),(0,M-1),(1,M-1)])

In [8]:
class State:
    def __init__(self, i, j, cows, n=N, m=M):
        self.i = i
        self.j = j
        self.cows = cows
        self.actions = []
        self.terminal = False
        
        if j < m-1:
            self.actions.append('r')
        if j > 0:
            self.actions.append('l')
        if i > 0:
            self.actions.append('u')
        if i < n-1:
            self.actions.append('d')
            
        self.actions_num = len(self.actions)
        self.value = np.zeros(self.actions_num)
        self.policy = np.ones(self.actions_num) / self.actions_num
        self.returns = [[] for _ in range(self.actions_num)]
        
    def __repr__(self):
        return "({i}, {j}, {c}), v: {v}, p: {p}, a: {a}".format(i=self.i, j=self.j, v=self.value, a=self.actions,
                                                                p=self.policy, c=len(self.cows))
        
    def make_terminal(self):
        self.terminal = True
        self.actions = ['.']
        self.value = np.array([0])
        self.policy = np.array([1])
        
    def get_action(self):
        return np.random.choice(np.arange(self.actions_num), p=self.policy)
        
    def next_state(self, action):        
        i, j = self.i, self.j
        
        if action == 'r': j += 1
        elif action == 'l': j -= 1
        elif action == 'u': i -= 1
        elif action == 'd': i += 1
        
        return i, j

In [9]:
def new_grid():
    grid = {}
    comb = subsets(cows)
    for i in range(N):
        for j in range(M):
            for c in comb:
                if (i, j) in c:
                    continue
                grid[i, j, get_tuple(c)] = State(i, j, get_tuple(c))
        
    grid[N-1, 0, ()].make_terminal()
    return grid

In [10]:
def test_policy(grid):
    s = grid[N-1, 0, get_tuple(cows)]
    R = 0
    while not s.terminal:
        R -= 2
        i, j = s.next_state(s.actions[np.argmax(s.policy)])
        if (i, j) in cows:
            s = grid[i, j, diff(s.cows, (i, j))]
        else:
            s = grid[i, j, s.cows]
            
    R += 102
    
    return R

In [17]:
max_iter = 50

In [18]:
grid = new_grid()

for it in range(max_iter):
    s = grid[N-1, 0, get_tuple(cows)]
    appearing_pairs = []
    
    while not s.terminal:
        a = s.get_action()
        appearing_pairs.append((s, a))
        i, j = s.next_state(s.actions[a])
        if (i, j) in cows:
            s = grid[i, j, diff(s.cows, (i, j))]
        else:
            s = grid[i, j, s.cows]
            
    for i, (s, a) in enumerate(appearing_pairs[:-1]):
        if (s, a) not in appearing_pairs[:i]:
            s.returns[a].append(100 - (len(appearing_pairs) - i + 1) * 2)
        if len(s.returns[a]) > 0:
            s.value[a] = np.array(s.returns[a]).mean()
        else:
            s.value[a] = 0
    
    s, a = appearing_pairs[-1]
    s.returns[a].append(100)
    s.value[a] = np.array(s.returns[a]).mean()

    for s, _ in appearing_pairs:
        a = np.argmax(s.value)
        s.policy = np.zeros_like(s.value) + ϵ / s.actions_num
        s.policy[a] = 1 - ϵ + ϵ / s.actions_num

In [19]:
test_policy(grid)

78