In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import combinations
from queue import PriorityQueue

np.random.seed(0)

%matplotlib inline

In [2]:
def get_tuple(iterable):
    return tuple(sorted(iterable))

def diff(s, e):
    return get_tuple(set(s).difference([e]))

def subsets(s):
    res = []
    for i in range(len(s)+1):
        tmp = []
        for e in combinations(s, i):
            tmp += [list(e)]
        res += sorted(tmp)
        
    return res

In [3]:
def dict_max(d):
    tmp = None
    for k in d:
        if tmp is None:
            tmp = d[k]
        else:
            tmp = max(tmp, d[k])
            
    return tmp
            
def dict_argmax(d):
    tmp = None
    for k in d:
        if tmp is None:
            tmp = d[k]
            arg = k
        elif tmp < d[k]:
            tmp = d[k]
            arg = k
            
    return arg

In [4]:
ϵ = 0.1
γ = 1
θ = 0.02
κ = 0.01

In [5]:
N = 10
M = 10

In [6]:
cows = set([(0,0),(0,M-1)])

In [7]:
class State:
    def __init__(self, i, j, cows, n=N, m=M):
        self.i = i
        self.j = j
        self.cows = cows
        self.actions = []
        self.model = {}
        self.q = {}
        self.terminal = False
        
        if j < m-1:
            self.actions.append('r')
            self.model['r'] = (-2, (i, j+1, cows))
        if j > 0:
            self.actions.append('l')
            self.model['l'] = (-2, (i, j-1, cows))
        if i > 0:
            self.actions.append('u')
            self.model['u'] = (-2, (i-1, j, cows))
        if i < n-1:
            self.actions.append('d')
            self.model['d'] = (-2, (i+1, j, cows))
            
        self.actions_num = len(self.actions)
        
        for a in self.actions:
            self.q[a] = 0
        
        self.taken_actions = []
        
    def __repr__(self):
        return "({i}, {j}, {c}), q: {q}, a: {a}".format(i=self.i, j=self.j, q=self.q, a=self.actions, c=len(self.cows))
        
    def make_terminal(self):
        self.terminal = True
        self.actions = ['.']
        self.q = {'.': 0}
        self.model = {'.': (0, (self.i, self.j, self.cows))}
        
    def get_action(self):
        a = dict_argmax(self.q)
        p = np.ones(self.actions_num) * ϵ / (self.actions_num - 1)
        items = list(self.q.items())
        for i, (k, v) in enumerate(items):
            if k == a:
                p[i] = 1 - ϵ
        return items[np.random.choice(np.arange(self.actions_num), p=p)][0]
        
    def next_state(self, action):        
        i, j = self.i, self.j
        
        if action == 'r': j += 1
        elif action == 'l': j -= 1
        elif action == 'u': i -= 1
        elif action == 'd': i += 1
        
        return i, j
    
    def __lt__(self, other):
        return False

In [8]:
def new_grid():
    grid = {}
    comb = subsets(cows)
    for i in range(N):
        for j in range(M):
            for c in comb:
                if (i, j) in c:
                    continue
                grid[i, j, get_tuple(c)] = State(i, j, get_tuple(c))
        
    grid[N-1, 0, ()].make_terminal()
    return grid

In [9]:
def test_policy(grid):
    s = grid[N-1, 0, get_tuple(cows)]
    R = 0
    while not s.terminal:
        R -= 2
        i, j = s.next_state(dict_argmax(s.q))
        if (i, j) in cows:
            s = grid[i, j, diff(s.cows, (i, j))]
        else:
            s = grid[i, j, s.cows]
            
    R += 102
    
    return R

In [10]:
grid = new_grid()
keys = list(grid.keys())

n = 20
it = 0

observed_states = []

while True:
    
    it += 1
    
    while True:
        ind = np.random.choice(np.arange(len(keys)))
        s = grid[keys[ind]]
        if not s.terminal:
            break
            
    if s not in observed_states:
        observed_states.append(s)
    
    a = s.get_action()
    i, j = s.next_state(a)
    
    if a not in s.taken_actions:
        s.taken_actions.append(a)

    if (i, j) in cows:
        s_ = grid[i, j, diff(s.cows, (i, j))]
    else:
        s_ = grid[i, j, s.cows]
        
    R = -2 if not s_.terminal else 100 if s != s_ else 0
    
    s.q[a] = R + γ * dict_max(s_.q)
    
    s.model[a] = (R, (s_.i, s_.j, s_.cows))
    
    for _ in range(n):
        ind = np.random.choice(np.arange(len(observed_states)))
        s = observed_states[ind]
        
        ind = np.random.choice(np.arange(len(s.taken_actions)))
        a = s.taken_actions[ind]
        
        R, (i, j, c) = s.model[a]
        s_ = grid[i, j, c]
        
        s.q[a] = R + γ * dict_max(s_.q)

    if it > 3000:
        break

In [11]:
test_policy(grid)

26

In [12]:
def get_predecessors(grid, s):
    res = []
    for k in grid:
        model = grid[k].model
        for a in model:
            R, (i, j, c) = model[a]
            if s.i == i and s.j == j and s.cows == c:
                res.append((grid[k], a, R))
                
    return res

In [13]:
grid = new_grid()
keys = list(grid.keys())
pqueue = PriorityQueue()

it = 0

while True:
    
    it += 1
    
    while True:
        ind = np.random.choice(np.arange(len(keys)))
        s = grid[keys[ind]]
        if not s.terminal:
            break
    
    a = dict_argmax(s.q)
    i, j = s.next_state(a)
    
    if (i, j) in cows:
        s_ = grid[i, j, diff(s.cows, (i, j))]
    else:
        s_ = grid[i, j, s.cows]
        
    R = -2 if not s_.terminal else 100 if s != s_ else 0
    
    s.model[a] = (R, (s_.i, s_.j, s_.cows))
    
    P = abs(R + γ * dict_max(s_.q) - s.q[a])
    
    if P > θ:
        pqueue.put((-P, (s, a)))
    
    for _ in range(n):
        if pqueue.empty():
            break
            
        _, (s, a) = pqueue.get()
            
        R, (i, j, c) = s.model[a]
        s_ = grid[i, j, c]
        
        s.q[a] = R + γ * dict_max(s_.q)
        
        for s_, a_, R_ in get_predecessors(grid, s):
            P = abs(R_ + γ * dict_max(s.q) - s_.q[a_])
            if P > θ:
                pqueue.put((-P, (s_, a_)))
        
    if it > 1500:
        break

In [14]:
test_policy(grid)

26