In [1]:
%load_ext autoreload
%autoreload 2

In [108]:
from birl.mdp import MDP, GridWorld

## Pure MDPs

![Figure 1 of paper BIRL](images/figure1_paper.png)


In [109]:
states = ["S_0", "S_1", "S_2", "S_3"]
actions = ["a_1", "a_2"]
mdp = MDP(states, actions) # create an MDP with states and actions (by default, transition probabilities are all 0 and gamma=0.9)

# create transition probabilities
t_prob = [("S_0", "a_1", "S_1", 0.4), ("S_0", "a_1", "S_2", 0.3), ("S_0", "a_1", "S_3", 0.3), ("S_0", "a_2", "S_0", 1),
          ("S_1", "a_1", "S_2", 1), ("S_1", "a_2", "S_0", 1),
          ("S_2", "a_1", "S_3", 1), ("S_2", "a_2", "S_0", 1),
          ("S_3", "a_1", "S_1", 1), ("S_0", "a_2", "S_0", 1)]
mdp.set_transition_probabilities(t_prob)

In [111]:
# examples of retrieving the probability
s, a, s_ = "S_0", "a_1", "S_1"
print(f"Taking action {a} in {s} and arriving to {s_} has a probability of:", mdp.get_transition_prob(s, a, s_))
s, a, s_ = "S_0", "a_1", "S_0"
print(f"Taking action {a} in {s} and arriving to {s_} has a probability of:", mdp.get_transition_prob(s, a, s_))
s, a, s_ = "S_0", "a_2", "S_0"
print(f"Taking action {a} in {s} and arriving to {s_} has a probability of:", mdp.get_transition_prob(s, a, s_))

Taking action a_1 in S_0 and arriving to S_1 has a probability of: 0.4
Taking action a_1 in S_0 and arriving to S_0 has a probability of: 0.0
Taking action a_2 in S_0 and arriving to S_0 has a probability of: 1.0


## Environments
An environment is esentially an MDP but with specific model of transition probabilities, states, actions, etc. These are used in the policy walk algorithm.
### GridWorld

In [117]:
gw = GridWorld((3,5)) # by default noise is 0.2
gw.set_traps([(1,1)])
gw.set_terminals([(0,4)])

In [118]:
gw.show()

     0    1    2    3    4 
  . ---- ---- ---- ---- ---.
0 |    |    |    |    |   T|
  |----|----|----|----|----|
1 |    |   *|    |    |    |
  |----|----|----|----|----|
2 |    |    |    |    |    |
  |----|----|----|----|----|


In [128]:
gw.get_transition_prob((0, 0), "right")

array([0.1, 0.8, 0. , 0. , 0. , 0.1, 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. ])