In [65]:
import numpy as np
import queue

In [110]:
N_BLOCKS = 3
N_ITERS = 10

In [117]:
test_prob = np.array(
    [
        [0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1],
        [1, 0, 0, 1, 0]
    ]
)
# Block 0 on Block 1
test_prob2 = np.array(
    [
        [0, 0, 0, .9, .8],
        [0, 0, 0, .78, .98],
        [0, 0, 0, .8, .9]
    ]
)

In [111]:
def atom2idx(atom):
    atom_type = atom[0]
    if atom_type.lower() == 'clear':
        block_idx = atom[1] - 1
        return (block_idx, -2)
    elif atom_type.lower() == 'ontable':
        block_idx = atom[1] - 1
        return (block_idx, -1)
    elif atom_type.lower() == 'on':
        block_a, block_b = atom[1] - 1, atom[2] - 1
        return (block_a, block_b)
    else:
        raise NotImplementedError()

In [112]:
print(atom2idx(['OnTable', 1]))
print(test_prob[atom2idx(['OnTable', 1])])

(0, -1)
0


In [113]:
def get_atom_prob(state, atom):
    return state[atom2idx(atom)]

In [114]:
get_atom_prob(test_prob2, ['onTable', 0])

1

In [115]:
def get_applicability(state, action):
    block_a, block_b = action[1], action[2]
    return get_atom_prob(state, ['Clear', block_a]) * get_atom_prob(state, ['Clear', block_b])

In [116]:
print(get_applicability(test_prob2, ['put', 0, 1]))
print(get_applicability(test_prob2, ['put', 1, 0]))

1
1


In [74]:
def get_effective_set(action):
    block_a, block_b = action[1], action[2]
    return (
        [
            ['on', block_a, block_b]
        ],
        [
            ['clear', block_b],
        ]
    )

In [75]:
print(get_effective_set(['put', 1, 2]))

([['on', 1, 2]], [['clear', 2]])


In [99]:
def apply_action(state, action):
    new_state = state.copy()
    positive_set, negative_set = get_effective_set(action)
    applicability = get_applicability(state, action)
    for positive_atom in positive_set:
        atom_idx = atom2idx(positive_atom)
        new_state[atom_idx] = np.clip(applicability + (1 - applicability) * state[atom_idx], 0, 1)
    for negative_atom in negative_set:
        atom_idx = atom2idx(negative_atom)
        new_state[atom_idx] = np.clip(state[atom_idx] - applicability, 0, 1)
    return new_state

In [77]:
print(test_prob2)
print(apply_action(test_prob2, ['put', 2, 1]))

[[0.  0.1 0.9 0.8]
 [0.1 0.  0.8 0.9]]
0.7200000000000001
[[0.    0.1   0.18  0.8  ]
 [0.748 0.    0.8   0.9  ]]


In [118]:
def list_action(n_blocks):
    res = []
    for i in range(1, n_blocks + 1):
        for j in range(1, n_blocks + 1):
            if i != j:
                res.append(['put', i, j])
    return res

In [119]:
ACTIONS = list_action(N_BLOCKS)
print(ACTIONS)

[['put', 1, 2], ['put', 1, 3], ['put', 2, 1], ['put', 2, 3], ['put', 3, 1], ['put', 3, 2]]


In [131]:
def evaluate_state(curr, goal):
    return np.linalg.norm(curr[..., :-1] - goal[..., :-1])

In [132]:
print(evaluate_state(test_prob, test_prob2))

1.8596773913773323


In [136]:
def naive_bfs(state, goal, max_len, eps=1e-5):
    valid_action_list = ACTIONS
    state_queue = queue.Queue()
    state_queue.put((state, [], evaluate_state(state, goal)))
    while state_queue.qsize() < max_len:
        curr_state, plan, _ = state_queue.get()
        min_error = 1e10
        for action in valid_action_list:
            next_state = apply_action(curr_state, action)
            error = evaluate_state(next_state, goal)
            state_queue.put((next_state, plan + [action], error))
            min_error = min(min_error, error)
        if min_error < eps:
            break
    selected = sorted(list(state_queue.queue), key=lambda a: a[-1])[0]
    return selected[1][0]

In [137]:
print(naive_bfs(test_prob2, test_prob, 50))

(array([[0.       , 0.7229196, 0.       , 0.18     , 0.8      ],
       [0.       , 0.       , 0.       , 0.0078   , 0.98     ],
       [0.72     , 0.       , 0.       , 0.8      , 0.9      ]]), [['put', 1, 2], ['put', 1, 2], ['put', 3, 1]], 0.4771104568799136)
['put', 1, 2]
