In [49]:
import copy
from dataclasses import dataclass
from collections import namedtuple, defaultdict
from msdm.core.mdp import TabularMarkovDecisionProcess
from msdm.core.pomdp import TabularPOMDP
from msdm.core.distributions import DictDistribution

State = namedtuple("State", "x y")
Action = namedtuple("Action", "dx dy")
Observation = namedtuple("Observation", "x y")

class KeysAndDoors(TabularPOMDP):
    def __init__(
        self,
        coherence=.95,
        discount_rate=.95,
        step_cost=-1,
        target_reward=50,
        grid=None
    ):
        """
        Heaven or Hell (a.k.a. information gathering) as first described by
        [Bonet and Geffner (1998)](https://bonetblai.github.io/reports/fall98-pomdp.pdf).

        A simple POMDP where the agent must gather information to figure out
        which goal is gives a reward or punishment.

        Parameters
        ---------
        :coherence:       The strength of the signal about which side is heaven/hell
        :discount_rate:
        :step_cost:       Step cost when not reading
        :reward:
        :grid:            A multiline string representing a heaven/hell configuration.
                          `s` is the initial state,
                          `#` are walls,
                          't' is the target
                          'd' are closed doors
                          'o' are open doors
                          'l' are locked doors
        """
        if grid is None:
            grid = \
            """
            t....
            ##.##
            .....
            ##s..
            """
        grid = [list(r.strip()) for r in grid.split('\n') if len(r.strip()) > 0]
        self.grid = grid
        self.loc_features = {}
        self.features_loc = defaultdict(list)
        for y, row in enumerate(grid):
            for x, f in enumerate(row):
                self.loc_features[(x, y)] = f
                self.features_loc[f].append((x, y))
        self.coherence = coherence
        self.discount_rate = discount_rate
        self.step_cost = step_cost
        self.target_reward = target_reward

    def initial_state_dist(self):
        x, y = self.features_loc['s'][0]
        return DictDistribution({
            State(x=x, y=y): 1.0,
        })

    def actions(self, s):
        return (
            Action(0, -1),
            Action(0, 1),
            Action(-1, 0),
            Action(1, 0),
            Action(0, 0),
        )

    def is_absorbing(self, s):
        loc = (s.x, s.y)
        return self.loc_features[loc] == 't'

    def next_state_dist(self, s, a):
        x, y = s.x, s.y
        nx, ny = (s.x + a.dx, s.y + a.dy)
        if self.loc_features.get((nx, ny), '#') == '#':
            nx, ny = (s.x, s.y)
        if self.loc_features.get((nx, ny), 'l') == 'l':
            nx, ny = (s.x, s.y)
        if self.loc_features.get((nx, ny), 'd') == 'd':
            nx, ny = (s.x, s.y)
        return DictDistribution({
            State(x=nx, y=ny): 1.0
        })

    def reward(self, s, a, ns):
        r = 0
        r += self.step_cost
        if self.loc_features[(ns.x, ns.y)] == 't':
            r += self.target_reward
        return r

    def observation_dist(self, a, ns):
        return DictDistribution({
                Observation(x=ns.x, y=ns.y): 1.0
        })

    def state_string(self, s):
        grid = copy.deepcopy(self.grid)
        for y, row in enumerate(grid):
            for x, f in enumerate(row):
                if (x, y) == (s.x, s.y):
                    grid[y][x] = '@'
        return '\n'.join([''.join(r) for r in grid])


In [52]:
from msdm.algorithms import  PointBasedValueIteration
hh = KeysAndDoors(
    coherence=.9,
    grid=
        """
        t....
        ##.##
        .....
        ##s..
        """,
    discount_rate=.9
)
pbvi_res = PointBasedValueIteration(
    min_belief_expansions=1,
    max_belief_expansions=20,
)

# Try to plan and print intermediate info
try:
    print("Starting planning process...")
    pbvi_res = PointBasedValueIteration(
        min_belief_expansions=1,
        max_belief_expansions=20
    ).plan_on(hh)
    print("Planning successful!")
except Exception as e:
    print(f"Error during planning: {type(e).__name__}: {str(e)}")
    

Starting planning process...
Planning successful!


In [53]:
# pbvi_res.policy
traj = pbvi_res.policy.run_on(hh)
tuple(traj[0])
for t, step in enumerate(traj):
    sstr = hh.state_string(step.state)
    print(f"state {t}: \n", sstr, sep="")
    print(step.action)
    print(step.observation)
    print()

state 0: 
t....
##.##
.....
##@..
Action(dx=0, dy=-1)
Observation(x=2, y=2)

state 1: 
t....
##.##
..@..
##s..
Action(dx=0, dy=-1)
Observation(x=2, y=1)

state 2: 
t....
##@##
.....
##s..
Action(dx=0, dy=-1)
Observation(x=2, y=0)

state 3: 
t.@..
##.##
.....
##s..
Action(dx=-1, dy=0)
Observation(x=1, y=0)

state 4: 
t@...
##.##
.....
##s..
Action(dx=-1, dy=0)
Observation(x=0, y=0)

state 5: 
@....
##.##
.....
##s..
None
None

