In [68]:
import numpy as np
from mdptoolbox import mdp
from itertools import product, repeat
from functools import reduce
from operator import iconcat
from scipy.sparse import csr_matrix

from copy import deepcopy

In [69]:
# data
field_content = ['empty', 'white', 'blue', 'red']
task = ['store', 'restore']
item = ['white', 'blue', 'red']
task_and_item = [['store', 'white'], ['store', 'blue'], ['store', 'red'],
                 ['restore', 'white'], ['restore', 'blue'], ['restore', 'red']]
actions = [0, 1, 2, 3]

# probabilities
probs = {'white': 0.2, 'blue': 0.7, 'red': 0.1}

In [70]:
def all_repeat():
    """
    Computes all permutations of items with repetitions

    :return: permutations of items with repetitions
    """
    result = []
    for c in product(field_content, repeat=4):
        comb = [c[0], c[1], c[2], c[3]]
        result.append(comb)

    result = list(repeat(result, 6))
    result = list(map(list, zip(result, task_and_item)))

    for i in range(len(result)):
        result[i] = [elem + result[i][1] for elem in result[i][0]]

    result = reduce(iconcat, result, [])

    return result

In [71]:
states = all_repeat()
num_states = len(states)

# check if all elements are unique.
np.testing.assert_equal(len(list(set(map(tuple, states)))), num_states)

In [72]:
def field_content_equals(from_state: list, to_state: list) -> bool:
    return from_state[:4] == to_state[:4]

In [73]:
def transition_prob(action: int, from_state: list, to_state: list) -> float:
    copy_from_state = deepcopy(from_state)
    next_item = to_state[-1]

    if copy_from_state[4] == 'store':
        if copy_from_state[action] != 'empty':
            return 1 if from_state == to_state else 0

        copy_from_state[action] = copy_from_state[-1]
        return probs[next_item] / 2 if field_content_equals(copy_from_state, to_state) else 0
    elif copy_from_state[4] == 'restore':
        if copy_from_state[action] != copy_from_state[-1]:
            return 1 if from_state == to_state else 0

        copy_from_state[action] = 'empty'
        return probs[next_item] / 2 if field_content_equals(copy_from_state, to_state) else 0

    return 0

In [74]:
test_prob = transition_prob(0, ['empty', 'empty', 'empty', 'empty', 'restore', 'red'],
                            ['empty', 'empty', 'empty', 'empty', 'restore', 'red'])

np.testing.assert_equal(test_prob, 1)

In [75]:
def reward(action: int, last_prob: float) -> float:
    if last_prob == 1:
        return 0

    rewards_dict = {0: 4, 1: 2, 2: 2, 3: 0}
    return rewards_dict[action]

In [76]:
def transition_and_reward_matrix():
    transitions = []
    rewards = []

    for action in actions:
        row = []
        col = []
        data = []

        reward_vector = []

        for id_from, from_state in enumerate(states):
            for id_to, to_state in enumerate(states):
                p = transition_prob(action, from_state, to_state)

                if p > 0:
                    row.append(id_from)
                    col.append(id_to)
                    data.append(p)

            reward_vector.append(reward(action, data[-1]))

        transitions.append(csr_matrix((data, (row, col)), shape=(num_states, num_states)))
        rewards.append(reward_vector)

    return transitions, np.array(rewards).T

In [77]:
P, R = transition_and_reward_matrix()

In [79]:
ones = np.zeros((4, num_states, num_states))
for i, m in enumerate(P):
    ones[i] = np.sum(np.array(m.toarray()), axis=1)

np.testing.assert_array_equal(ones, np.ones_like(ones))

In [80]:
np.testing.assert_array_equal(R.shape, (num_states, 4))

In [81]:
pi = mdp.PolicyIteration(P, R, 0.9)
pi.run()
pi.policy

(0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 1,
 2,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 0,
 0,
 0,
 2,
 2,
 2,
 2,
 3,
 0,
 0,
 0,
 3,
 0,


In [82]:
policy = pi.policy
print(len(policy))



1536
