In [None]:
# Project 3 - MDP + Value Iteration for Parallel Parking

from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Tuple, List, Iterable, Optional

# ============================================================
# 1. TYPES AND GLOBAL CONFIG
# ============================================================

State = Tuple[int, int]      # (col, row), 1-based indexing
Action = str                 # "N", "NE", "E", "SE", "S", "SW", "W", "NW"

# 8-directional movement
ACTION_TO_DELTA: Dict[Action, Tuple[int, int]] = {
    "N":  (0, 1),
    "NE": (1, 1),
    "E":  (1, 0),
    "SE": (1, -1),
    "S":  (0, -1),
    "SW": (-1, -1),
    "W":  (-1, 0),
    "NW": (-1, 1),
}
ALL_ACTIONS: List[Action] = list(ACTION_TO_DELTA.keys())

# Hazards: row 1, col 1 and row 5, col 1 → (1,1), (1,5)
HAZARD_STATES: Tuple[State, ...] = (
    (1, 1),
    (1, 5),
)

# Start states S1..S7 (top row, columns 2–8)
START_STATES: List[State] = [
    (2, 5),  # S1
    (3, 5),  # S2
    (4, 5),  # S3
    (5, 5),  # S4
    (6, 5),  # S5
    (7, 5),  # S6
    (8, 5),  # S7
]

# Goal: row 3, col 1 → (1,3)
GOAL_STATE: State = (1, 3)

# R is 1000: -R for hazards and +R for goal
GOAL_REWARD: float = 1000.0
HAZARD_REWARD: float = -1000.0

# Noise and discount
DISCOUNT_FACTOR: float = 0.9
NOISE: float = 0.1


# ============================================================
# 2. UTILITY HELPERS
# ============================================================

def is_opposite(a1: Action, a2: Action) -> bool:
    """Return True if a2 is the 180° opposite direction of a1."""
    dc1, dr1 = ACTION_TO_DELTA[a1]
    dc2, dr2 = ACTION_TO_DELTA[a2]
    return (dc1 + dc2 == 0) and (dr1 + dr2 == 0)


# ============================================================
# 3. GRIDWORLD MDP
# ============================================================

@dataclass
class GridWorldMDP:
    width: int = 8
    height: int = 5
    hazard_states: Tuple[State, ...] = HAZARD_STATES
    hazard_reward: float = HAZARD_REWARD
    goal_state: State = GOAL_STATE
    goal_reward: float = GOAL_REWARD
    live_reward: float = -1.0          # r in [-20, 0], will be varied
    discount: float = DISCOUNT_FACTOR
    noise: float = NOISE
    obstacle_states: Tuple[State, ...] = ()   # for R2: add (4, 3)

    # -------- basic state helpers --------

    def get_states(self) -> List[State]:
        return [
            (c, r)
            for c in range(1, self.width + 1)
            for r in range(1, self.height + 1)
        ]

    def in_bounds(self, s: State) -> bool:
        c, r = s
        return 1 <= c <= self.width and 1 <= r <= self.height

    def is_obstacle(self, s: State) -> bool:
        return s in self.obstacle_states

    def is_terminal(self, s: State) -> bool:
        return s == self.goal_state or s in self.hazard_states

    def reward(self, s: State) -> float:
        if s == self.goal_state:
            return self.goal_reward
        if s in self.hazard_states:
            return self.hazard_reward
        # everything else uses live-in reward r
        return self.live_reward

    # -------- transition model --------

    def get_possible_actions(self, s: State) -> Iterable[Action]:
        if self.is_terminal(s) or self.is_obstacle(s):
            return []
        return ALL_ACTIONS

    def _is_blocked(self, dest: State) -> bool:
        """
        Blocked if:
        - dest is off-grid
        - dest is an obstacle
        (Opposite direction is handled by never including it as a candidate.)
        """
        if not self.in_bounds(dest):
            return True
        if self.is_obstacle(dest):
            return True
        return False

    def get_transition_probs(self, s: State, a: Action) -> List[Tuple[State, float]]:
        """
        Implements the TransProb(s, a) logic from Tutorial 5:
        - Intended state gets probability 0.9 (1 - noise) if not blocked
        - Blocked intended move: stay put (prob 1.0)
        - Other non-blocked destinations share noise / (n_dest(s) - 1)
        - Opposite direction never taken
        """
        if self.is_terminal(s):
            return [(s, 1.0)]

        # Intended destination
        dc, dr = ACTION_TO_DELTA[a]
        intended_dest = (s[0] + dc, s[1] + dr)

        # Collect all candidate destinations (non-blocked, excluding opposite)
        destinations: List[State] = []
        for other_action, (odc, odr) in ACTION_TO_DELTA.items():
            if is_opposite(a, other_action):
                continue  # never move opposite to intended direction
            dest = (s[0] + odc, s[1] + odr)
            if not self._is_blocked(dest):
                destinations.append(dest)

        if self._is_blocked(intended_dest):
            # If intended is blocked, agent remains in current state (p=1)
            return [(s, 1.0)]

        # n_dest(s): number of (non-blocked) destinations
        n_dest = len(destinations)
        if n_dest == 0:
            # no possible moves; stay put
            return [(s, 1.0)]

        transitions: Dict[State, float] = {}

        # Intended move gets 0.9
        intended_prob = 1.0 - self.noise
        transitions[intended_dest] = transitions.get(intended_dest, 0.0) + intended_prob

        # Other moves share noise
        other_dests = [d for d in destinations if d != intended_dest]
        if len(other_dests) > 0:
            per_unintended = self.noise / len(other_dests)
            for d in other_dests:
                transitions[d] = transitions.get(d, 0.0) + per_unintended
        else:
            # only intended destination exists
            transitions[intended_dest] += self.noise

        # Normalize for safety
        total = sum(transitions.values())
        for st in list(transitions.keys()):
            transitions[st] /= total

        return list(transitions.items())


# ============================================================
# 4. VALUE ITERATION
# ============================================================

def value_iteration(
    mdp: GridWorldMDP,
    theta: float = 1e-3,
    max_iterations: int = 500,
) -> Tuple[Dict[State, float], Dict[State, Optional[Action]]]:
    """
    Standard Bellman optimality value iteration.
    U(s) = max_a Σ_s' P(s'|s,a)[R(s') + γ U(s')]
    """
    states = mdp.get_states()
    U: Dict[State, float] = {s: 0.0 for s in states}
    policy: Dict[State, Optional[Action]] = {s: None for s in states}

    for _ in range(max_iterations):
        delta = 0.0
        U_new = U.copy()

        for s in states:
            if mdp.is_terminal(s) or mdp.is_obstacle(s):
                U_new[s] = mdp.reward(s)
                policy[s] = None
                continue

            actions = list(mdp.get_possible_actions(s))
            if not actions:
                U_new[s] = mdp.reward(s)
                policy[s] = None
                continue

            best_value = float("-inf")
            best_action: Optional[Action] = None

            for a in actions:
                expected_util = 0.0
                for s_prime, prob in mdp.get_transition_probs(s, a):
                    r_sp = mdp.reward(s_prime)
                    expected_util += prob * (r_sp + mdp.discount * U[s_prime])

                if expected_util > best_value:
                    best_value = expected_util
                    best_action = a

            U_new[s] = best_value
            policy[s] = best_action
            delta = max(delta, abs(U_new[s] - U[s]))

        U = U_new
        if delta < theta:
            break

    return U, policy


# ============================================================
# 5. POLICY SCORING FOR P1 (R1)
# ============================================================

def score_policy_P1(
    mdp: GridWorldMDP,
    U: Dict[State, float],
    policy: Dict[State, Optional[Action]],
    max_steps: int = 50,
) -> float:
    """
    Approximate the score of P1 as described in the spec:
    For each start state in S1..S7, follow the policy greedily and
    sum the utilities along the path until reaching a terminal state
    or max_steps.
    """
    total_score = 0.0

    for s0 in START_STATES:
        s = s0
        path_value = 0.0

        for _ in range(max_steps):
            path_value += U[s]
            if mdp.is_terminal(s):
                break

            a = policy.get(s)
            if a is None:
                break

            transitions = mdp.get_transition_probs(s, a)
            # follow the most probable next state
            s = max(transitions, key=lambda t: t[1])[0]

        total_score += path_value

    return total_score


# ============================================================
# 6. VISUALIZATION HELPERS
# ============================================================

def print_utility_grid(mdp: GridWorldMDP, U: Dict[State, float]) -> None:
    """
    Print utilities in grid form (row 5 down to 1, col 1 to 8).
    """
    for r in range(mdp.height, 0, -1):
        row_vals = []
        for c in range(1, mdp.width + 1):
            row_vals.append(f"{U[(c, r)]:7.1f}")
        print(" ".join(row_vals))
    print()


def print_policy_grid(mdp: GridWorldMDP, policy: Dict[State, Optional[Action]]) -> None:
    """
    Print policy arrows / labels for each cell.
    """
    for r in range(mdp.height, 0, -1):
        row_vals = []
        for c in range(1, mdp.width + 1):
            s = (c, r)
            if mdp.is_terminal(s):
                if s == mdp.goal_state:
                    row_vals.append(" G ")
                else:
                    row_vals.append(" H ")
            elif mdp.is_obstacle(s):
                row_vals.append(" X ")
            else:
                a = policy.get(s)
                row_vals.append(f"{a:>3}" if a is not None else " . ")
        print(" ".join(row_vals))
    print()


# ============================================================
# 7. R1: SEARCH OVER r IN [-20, 0]
# ============================================================

def find_best_r_for_original_env(
    r_min: int = -20,
    r_max: int = 0,
) -> Tuple[int, Dict[State, float], Dict[State, Optional[Action]]]:
    """
    R1: Search r in [r_min, r_max] and find the FIRST r that yields
    the best P1 according to the policy score.
    """
    best_score = float("-inf")
    best_r: Optional[int] = None
    best_U: Optional[Dict[State, float]] = None
    best_policy: Optional[Dict[State, Optional[Action]]] = None

    for r in range(r_min, r_max + 1):
        mdp = GridWorldMDP(live_reward=float(r))
        U, policy = value_iteration(mdp)
        score = score_policy_P1(mdp, U, policy)

        print(f"r = {r:3d}, policy score = {score:.2f}")

        if score > best_score:
            best_score = score
            best_r = r
            best_U = U
            best_policy = policy

    print("\n=== R1 RESULT ===")
    print(f"Best score across r in [{r_min}, {r_max}]: {best_score:.2f}")
    print(f"First r achieving this best score: r = {best_r}")

    return best_r, best_U, best_policy


# ============================================================
# 8. R2: COMPARE ORIGINAL ENVIRONMENT WITH OBSTACLE AT (4, 3)
# ============================================================

def compare_original_and_obstacle_env(best_r: int) -> None:
    """
    R2: Add an obstacle at row 3, col 4 → (4,3), recompute optimal P2,
    and compare it with P1.
    """
    mdp_original = GridWorldMDP(live_reward=float(best_r))
    U1, policy1 = value_iteration(mdp_original)

    mdp_obstacle = GridWorldMDP(
        live_reward=float(best_r),
        obstacle_states=((4, 3),)   # row 3, column 4
    )
    U2, policy2 = value_iteration(mdp_obstacle)

    same = True
    for s in mdp_original.get_states():
        if policy1.get(s) != policy2.get(s):
            same = False
            print(f"Policy differs at state {s}: P1={policy1.get(s)}, P2={policy2.get(s)}")

    if same:
        print("\nP2 is identical to P1.")
    else:
        print("\nP2 differs from P1 in the states listed above.")

    print("\n--- Original environment policy (P1) ---")
    print_policy_grid(mdp_original, policy1)

    print("\n--- New environment policy (P2) with obstacle at (4, 3) ---")
    print_policy_grid(mdp_obstacle, policy2)

In [None]:
# ============================================================
# 9. RUN R1
# ============================================================

best_r, best_U, best_policy = find_best_r_for_original_env()
mdp_best = GridWorldMDP(live_reward=float(best_r))
print("\nUtility grid for best r:")
print_utility_grid(mdp_best, best_U)
print("Policy grid for best r:")
print_policy_grid(mdp_best, best_policy)

r = -20, policy score = 49562.42
r = -19, policy score = 49620.64
r = -18, policy score = 49678.85
r = -17, policy score = 49737.07
r = -16, policy score = 49795.28
r = -15, policy score = 49853.50
r = -14, policy score = 49911.71
r = -13, policy score = 49969.92
r = -12, policy score = 50028.14
r = -11, policy score = 50086.35
r = -10, policy score = 50144.57
r =  -9, policy score = 50202.78
r =  -8, policy score = 50261.00
r =  -7, policy score = 50319.21
r =  -6, policy score = 50377.42
r =  -5, policy score = 50435.64
r =  -4, policy score = 50493.85
r =  -3, policy score = 50552.07
r =  -2, policy score = 50610.28
r =  -1, policy score = 50668.50
r =   0, policy score = 50726.71

=== R1 RESULT ===
Best score across r in [-20, 0]: 50726.71
First r achieving this best score: r = 0

Utility grid for best r:
-1000.0  1579.4  1603.7  1472.4  1314.3  1171.0  1043.2   936.1
 1867.7  1807.9  1659.0  1479.9  1318.5  1174.2  1045.8   936.3
 1000.0  1870.0  1663.1  1480.9  1318.7  1174.3  10

In [None]:
# ============================================================
# 9. RUN R2
# ============================================================
compare_original_and_obstacle_env(best_r)

Policy differs at state (4, 3): P1=W, P2=None
Policy differs at state (5, 2): P1=NW, P2=W
Policy differs at state (5, 3): P1=W, P2=SW
Policy differs at state (5, 4): P1=SW, P2=W
Policy differs at state (6, 1): P1=NW, P2=W
Policy differs at state (6, 5): P1=SW, P2=W

P2 differs from P1 in the states listed above.

--- Original environment policy (P1) ---
 H   SW  SW  SW  SW  SW   W  SW
  S  SW  SW  SW  SW  SW  SW  SW
 G    W   W   W   W   W   W   W
  N  NW  NW  NW  NW  NW  NW  NW
 H   NW  NW  NW  NW  NW   W  NW


--- New environment policy (P2) with obstacle at (4, 3) ---
 H   SW  SW  SW  SW   W   W  SW
  S  SW  SW  SW   W  SW  SW  SW
 G    W   W  X   SW   W   W   W
  N  NW  NW  NW   W  NW  NW  NW
 H   NW  NW  NW  NW   W   W  NW

