In [None]:
from enum import Enum, auto
import random
from collections import deque
from dataclasses import dataclass, field
from typing import Deque, Dict, List, Optional, Set, Tuple

#Percept Definition

In [None]:
class Percept:
    # Attributes:
          # time_step (int): Current time step in the episode.
          # bump (bool): True if the agent bumped into a wall on this step.
          # breeze (bool): True if the agent is adjacent to a pit.
          # stench (bool): True if the agent is adjacent to a live Wumpus.
          # scream (bool): True if the agent hears a scream (Wumpus killed).
          # glitter (bool): True if there is gold in the agent's current square.
          # reward (int): Reward obtained for the last action.
          # done (bool): True if the episode has terminated.

    # Type annotations
    time_step: int
    bump: bool
    breeze: bool
    stench: bool
    scream: bool
    glitter: bool
    reward: int
    done: bool

    def __init__(self, time_step: int, bump: bool, breeze: bool, stench: bool,
                 scream: bool, glitter: bool, reward: int, done: bool):
        """Initialize all percept attributes."""
        self.time_step = time_step
        self.bump = bump
        self.breeze = breeze
        self.stench = stench
        self.scream = scream
        self.glitter = glitter
        self.reward = reward
        self.done = done

    def __str__(self) -> str:
        """Return a readable string summarizing this percept."""
        active_signals = []
        if self.bump:
            active_signals.append("Bump")
        if self.breeze:
            active_signals.append("Breeze")
        if self.stench:
            active_signals.append("Stench")
        if self.scream:
            active_signals.append("Scream")
        if self.glitter:
            active_signals.append("Glitter")

        signals_str = ", ".join(active_signals) if active_signals else "None"
        return (f"Percept(t={self.time_step}, "
                f"Signals=[{signals_str}], Reward={self.reward}, Done={self.done})")


In [None]:
# Verify if pecept class works
p = Percept(time_step=1, bump=False, breeze=True, stench=False,
            scream=False, glitter=True, reward=-1, done=False)
print(p)

Percept(t=1, Signals=[Breeze, Glitter], Reward=-1, Done=False)


# Action Defintion

In [None]:
class Action(Enum):
    # Turn left (rotate 90° counter-clockwise)
    LEFT = 0

    # Turn right (rotate 90° clockwise)
    RIGHT = 1

    # Move one cell forward in the current orientation
    FORWARD = 2

    # Pick up gold if present in the current square
    GRAB = 3

    # Fire the arrow (only once) in the current facing direction
    SHOOT = 4

    # Climb out of the cave (only valid from the start cell)
    CLIMB = 5

In [None]:
print(Action.GRAB.name)   # 'GRAB'
print(Action.GRAB.value)  # 3

GRAB
3


# Orientation Definition

In [None]:
class Orientation(Enum):
    E = 0
    S = 1
    W = 2
    N = 3

    def symbol(self) -> str:
        """Return the single-letter code representing this orientation."""
        return self.name  # e.g., Orientation.E -> "E"

    def symbol(self) -> str:
            """Return a visual arrow symbol representing this orientation."""
            symbols = {
                Orientation.E: "→",
                Orientation.S: "↓",
                Orientation.W: "←",
                Orientation.N: "↑"
            }
            return symbols[self]


    def turn_right(self) -> 'Orientation':
        """Return a new orientation turned 90° clockwise."""
        # Clockwise rotation: E → S → W → N → E
        return Orientation((self.value + 1) % 4)

    def turn_left(self) -> 'Orientation':
        """Return a new orientation turned 90° counter-clockwise."""
        # Counterclockwise rotation: E → N → W → S → E
        return Orientation((self.value - 1) % 4)

In [None]:
o = Orientation.E
print(o.symbol())          # E
print(o.turn_right())      # Orientation.S
print(o.turn_left())       # Orientation.N
print(o.turn_left().name)  # N


→
Orientation.S
Orientation.N
N


# Location Definition

In [None]:
class Location:
    """
        Represents a single cell in the Wumpus World grid.

        The coordinate system is 1-based:
            - (1,1) is the bottom-left corner (the starting cell for the agent).
            - x increases to the East (right).
            - y increases to the North (up).
    """
    x: int
    y: int

    def __init__(self, x: int, y: int):
        """Initialize a location with x and y coordinates."""
        self.x = x
        self.y = y

    def __str__(self):
        """Return a human-readable string representation, e.g., '(2, 3)'."""
        return f'({self.x}, {self.y})'

    # ---------------------------------------------------------------------
    # RELATIVE POSITION CHECKS
    # ---------------------------------------------------------------------
    def is_left_of(self, location: 'Location') -> bool:
        """
        Return True if this location is immediately to the LEFT of another location.
        That means:
            - Both are on the same row (same y-coordinate)
            - This cell's x-coordinate is exactly one less than the other.
        """
        return self.y == location.y and self.x == location.x - 1

    def is_right_of(self, location: 'Location') -> bool:
        """
        Return True if this location is immediately to the RIGHT of another location.
        Conditions:
            - Both share the same row
            - This cell's x is exactly one greater.
        """
        return self.y == location.y and self.x == location.x + 1

    def is_above(self, location: 'Location') -> bool:
        """
        Return True if this location is immediately ABOVE another location.
        Conditions:
            - Both share the same column
            - This cell’s y is exactly one greater.
        """
        return self.x == location.x and self.y == location.y + 1

    def is_below(self, location: 'Location') -> bool:
        """
        Return True if this location is immediately BELOW another location.
        Conditions:
            - Both share the same column
            - This cell’s y is exactly one less.
        """
        return self.x == location.x and self.y == location.y - 1

    def neighbours(self, width: int = 4, height: int = 4) -> List['Location']:
        """
        Return a list of the four adjacent (cardinal) neighbors of this cell.
        Directions considered: East, West, North, South.
        The function automatically removes cells that would be outside a width×height grid.
        """
        # Generate all four cardinal neighbors
        candidates = [
            Location(self.x + 1, self.y),  # East neighbor
            Location(self.x - 1, self.y),  # West neighbor
            Location(self.x, self.y + 1),  # North neighbor
            Location(self.x, self.y - 1),  # South neighbor
        ]

        # Keep only those inside the grid (1..width, 1..height)
        return [c for c in candidates if 1 <= c.x <= width and 1 <= c.y <= height]

    # ---------------------------------------------------------------------
    # LOCATION COMPARISON
    # ---------------------------------------------------------------------
    def is_location(self, location: 'Location') -> bool:
        """
        Return True if this location has exactly the same coordinates as another.

        This is a convenience wrapper around coordinate equality (x and y both match).
        Example:
            Location(2, 3).is_location(Location(2, 3)) → True
            Location(2, 3).is_location(Location(3, 3)) → False
        """
        return self.x == location.x and self.y == location.y

    # ---------------------------------------------------------------------
    # EDGE DETECTION (used for wall/boundary logic)
    # ---------------------------------------------------------------------
    def at_left_edge(self) -> bool:
        """
        Return True if this cell is at the LEFT boundary of the grid.
        Left edge means x == 1 (no valid cells further west).
        """
        return self.x == 1

    def at_right_edge(self, width: int = 4) -> bool:
        """
        Return True if this cell is at the RIGHT boundary of the grid (x == width).
        """
        return self.x == width

    def at_top_edge(self, height: int = 4) -> bool:
        """
        Return True if this cell is at the TOP boundary of the grid (y == height).
        """
        return self.y == height

    def at_bottom_edge(self) -> bool:
        """
        Return True if this cell is at the BOTTOM boundary of the grid.
        The bottom edge corresponds to y == 1.
        """
        return self.y == 1

    # ---------------------------------------------------------------------
    # MOVE FORWARD OPERATION
    # ---------------------------------------------------------------------
    def forward(self, orientation: 'Orientation',
                width: int = 4, height: int = 4) -> bool:
        """
        Attempt to move one cell forward in the given orientation on a width×height grid.
        This method updates the agent’s current coordinates based on which
        direction they are facing (E, W, N, S).

        If the move would take the agent *outside the grid boundaries*, then:
            • The position is left unchanged.
            • The function returns True to indicate a "bump" percept.

        If the move is valid and inside the grid:
            • The coordinates are updated to the new location.
            • The function returns False (no bump occurred).
        """
        # Store tentative next position
        nx, ny = self.x, self.y

        # Determine next cell based on orientation
        if orientation.name == "E":
            nx += 1      # move right
        elif orientation.name == "W":
            nx -= 1      # move left
        elif orientation.name == "N":
            ny += 1      # move upward
        else:  # orientation.name == "S"
            ny -= 1      # move downward

        # Check if new position is within 1..width and 1..height
        if not (1 <= nx <= width and 1 <= ny <= height):
            # Out of bounds → agent bumps into wall, position unchanged
            return True

        # Valid move → update coordinates
        self.x, self.y = nx, ny
        return False

    # ---------------------------------------------------------------------
    # SETTER / COPY UTILITIES
    # ---------------------------------------------------------------------
    def set_to(self, location: 'Location'):
        """
        Set this location's coordinates to match another location (in-place).

        This is a convenience method to update the current object without creating
        a new Location instance.
        """
        self.x, self.y = location.x, location.y

    # ---------------------------------------------------------------------
    # LINEAR INDEX CONVERSIONS (row-major from bottom row)
    # ---------------------------------------------------------------------
    @staticmethod
    def from_linear(n: int, width: int = 4, height: int = 4) -> 'Location':
        """
        Convert a 0-based linear index (0..width*height-1) into 1-based grid coordinates (x, y).

        Mapping uses row-major order with the bottom row first:
            0 → (1,1), 1 → (2,1), ..., (width-1) → (width,1),
            width → (1,2), ..., (width*height-1) → (width,height)
        """
        if not (0 <= n < width * height):
            raise ValueError(f"Linear index out of bounds (0..{width*height-1}).")

        # Compute 1-based coordinates for a width×height grid
        x = (n % width) + 1
        y = (n // width) + 1
        return Location(x, y)

    def to_linear(self, width: int = 4) -> int:
        """
        Convert this (x, y) location into a 0-based linear index for a width×height grid.

        This is the inverse of from_linear(), using the same row-major mapping:
            (1,1) → 0, (2,1) → 1, ..., (width,1) → (width-1),
            (1,2) → width, ..., (width,height) → (width*height-1)
        """
        # Shift both coordinates to 0-based, then compute row-major index.
        return (self.y - 1) * width + (self.x - 1)

    # ---------------------------------------------------------------------
    # RANDOM SAMPLING ON THE GRID
    # ---------------------------------------------------------------------
    @staticmethod
    def random(width: int = 4, height: int = 4) -> 'Location':
        """
        Sample a uniformly random cell on a width×height grid (1-based coordinates).
        """
        import random as _rnd
        return Location(_rnd.randint(1, width), _rnd.randint(1, height))


# Environment Definition

In [None]:
class Environment:
    wumpus_location: Location
    wumpus_alive: bool
    agent_location: Location
    agent_orientation: Orientation
    agent_has_arrow: bool
    agent_has_gold: bool
    game_over: bool
    gold_location: Location
    pit_locations: List[Location]
    time_step: int
    WIDTH: int
    HEIGHT: int
    allow_climb_without_gold: bool
    pit_prob: float

    # ---------------------------------------------------------------------
    # EPISODE INITIALIZATION
    # ---------------------------------------------------------------------
    def init(self, width: int = 4, height: int = 4,
             pit_prob: float = 0.2, allow_climb_without_gold: bool = True):
        """
        Reset the world and start a new episode.

        World layout (width×height, 1-based):
          - Agent starts at (1,1), facing East, with one arrow and no gold.
          - Place exactly one Wumpus (not at start), alive.
          - Place exactly one Gold (not at start).
          - For each non-start cell, place a Pit with probability `pit_prob`.
            (Overlaps with Wumpus/Gold are allowed; start is always safe.)
        """
        # Store config/state flags
        self.WIDTH = width
        self.HEIGHT = height
        self.pit_prob = pit_prob
        self.allow_climb_without_gold = allow_climb_without_gold

        # Agent state
        self.agent_location = Location(1, 1)
        self.agent_orientation = Orientation.E
        self.agent_has_arrow = True
        self.agent_has_gold = False

        # Episode flags
        self.game_over = False
        self.time_step = 0

        # World objects
        self.make_wumpus()
        self.make_gold()
        self.make_pits(self.pit_prob)

        # Initial percept (no action taken yet → reward=0; bump/scream=False)
        return Percept(
            time_step=self.time_step,
            bump=False,
            breeze=self.is_breeze(),
            stench=self.is_stench(),
            scream=False,
            glitter=self.is_glitter(),
            reward=0,
            done=self.game_over
        )

    # ---------------------------------------------------------------------
    # RANDOM PLACEMENT HELPERS
    # ---------------------------------------------------------------------
    def make_wumpus(self):
        """
        Choose a random location for the Wumpus (not the start) and set alive=True.
        Overlap with pits/gold is allowed.
        """
        while True:
            loc = Location.random(self.WIDTH, self.HEIGHT)
            if not loc.is_location(Location(1, 1)):
                self.wumpus_location = loc
                self.wumpus_alive = True
                return

    def make_gold(self):
        """
        Choose a random location for the Gold (not the start).
        Overlap with pits/Wumpus is allowed.
        """
        while True:
            loc = Location.random(self.WIDTH, self.HEIGHT)
            if not loc.is_location(Location(1, 1)):
                self.gold_location = loc
                return

    def make_pits(self, pit_prob: float):
        """
        For every non-start cell, independently place a Pit with probability `pit_prob`.
        """
        pits: List[Location] = []
        for n in range(self.WIDTH * self.HEIGHT):
            cell = Location.from_linear(n, self.WIDTH, self.HEIGHT)
            if cell.is_location(Location(1, 1)):
                continue  # start is always safe
            if random.random() < pit_prob:
                pits.append(cell)
        self.pit_locations = pits

    # ---------------------------------------------------------------------
    # LOCATION QUERIES (safe, explicit comparisons w/o relying on __eq__/__hash__)
    # ---------------------------------------------------------------------
    def is_pit_at(self, location: Location) -> bool:
        """Return True if there is a Pit at `location`."""
        return any(p.is_location(location) for p in self.pit_locations)

    def is_pit_adjacent_to_agent(self) -> bool:
        """
        Return True if a Pit is in any cardinally adjacent cell to the agent
        (or same cell—though if same cell, agent is dying/just died).
        """
        here = self.agent_location
        if self.is_pit_at(here):
            return True
        for n in here.neighbours(self.WIDTH, self.HEIGHT):
            if self.is_pit_at(n):
                return True
        return False

    def is_wumpus_adjacent_to_agent(self) -> bool:
        """
        Return True if the (alive) Wumpus is in a cardinal neighbor (or same cell).
        """
        if not self.wumpus_alive:
            return False
        here = self.agent_location
        if self.is_wumpus_at(here):
            return True
        for n in here.neighbours(self.WIDTH, self.HEIGHT):
            if self.is_wumpus_at(n):
                return True
        return False

    def is_agent_at_hazard(self) -> bool:
        """
        Return True if the agent is on a Pit or on the (alive) Wumpus.
        Used immediately after a successful Forward to check death.
        """
        return self.is_pit_at(self.agent_location) or (
            self.is_wumpus_at(self.agent_location) and self.wumpus_alive
        )

    def is_wumpus_at(self, location: Location) -> bool:
        """Return True if the Wumpus is at `location` (alive or dead)."""
        return self.wumpus_location is not None and self.wumpus_location.is_location(location)

    def is_agent_at(self, location: Location) -> bool:
        """Return True if the agent is at `location`."""
        return self.agent_location.is_location(location)

    def is_gold_at(self, location: Location) -> bool:
        """Return True if the Gold is at `location` (i.e., not yet grabbed)."""
        return self.gold_location is not None and self.gold_location.is_location(location)

    # ---------------------------------------------------------------------
    # PERCEPT QUERIES (Breeze / Stench / Glitter)
    # ---------------------------------------------------------------------
    def is_glitter(self) -> bool:
        """Return True if the agent is in the same cell as the Gold."""
        return self.is_gold_at(self.agent_location)

    def is_breeze(self) -> bool:
        """
        Return True if a Pit is adjacent (or same cell).
        Note: if agent is in a Pit, they will die that step; including 'same cell'
        here makes the percept logic monotone and easy to read.
        """
        return self.is_pit_adjacent_to_agent()

    def is_stench(self) -> bool:
        """
        Return True if the (alive) Wumpus is adjacent (or same cell).
        If the Wumpus is dead, there is no Stench.
        """
        return self.is_wumpus_adjacent_to_agent()

    # ---------------------------------------------------------------------
    # FIRING LINE / COMBAT HELPERS
    # ---------------------------------------------------------------------
    def wumpus_in_line_of_fire(self) -> bool:
        """
        Return True if, from the agent’s current cell and orientation,
        the Wumpus lies strictly ahead in the same row or column.
        """
        if not (self.wumpus_alive and self.wumpus_location):
            return False

        ax, ay = self.agent_location.x, self.agent_location.y
        wx, wy = self.wumpus_location.x, self.wumpus_location.y

        if self.agent_orientation.name == "E":
            return wy == ay and wx > ax
        if self.agent_orientation.name == "W":
            return wy == ay and wx < ax
        if self.agent_orientation.name == "N":
            return wx == ax and wy > ay
        # SOUTH
        return wx == ax and wy < ay

    def kill_attempt(self) -> bool:
        """
        If the Wumpus is alive and in the line of fire, kill it and return True.
        Otherwise, return False.
        """
        if self.wumpus_alive and self.wumpus_in_line_of_fire():
            self.wumpus_alive = False
            return True
        return False

    # ---------------------------------------------------------------------
    # MAIN TRANSITION FUNCTION
    # ---------------------------------------------------------------------
    def step(self, action: Action) -> Percept:
        """
        Apply an action, update state, and return the resulting Percept.

        Reward components:
            - Base per-step cost:             -1  (always)
            - First SHOOT (arrow available): -10  (consumes arrow)
            - Death (Pit or live Wumpus):  -1000  (terminal)
            - CLIMB at (1,1) with Gold:   +1000  (terminal)
            - CLIMB at (1,1) without Gold:
                * if allow_climb_without_gold=True → terminal with only step cost
                * else ignored (no termination, still pays step cost)
        """
        assert not self.game_over, "Episode already finished. Call init() for a new one."
        self.time_step += 1

        # Transient signals for this step
        bump = False
        scream = False

        # Base step cost
        reward = -1

        # ---------------------------
        # Dispatch on action
        # ---------------------------
        if action == Action.LEFT:
            self.agent_orientation = self.agent_orientation.turn_left()

        elif action == Action.RIGHT:
            self.agent_orientation = self.agent_orientation.turn_right()

        elif action == Action.FORWARD:
            # Try to move forward; Location.forward returns True if bumped (no move)
            bumped = self.agent_location.forward(self.agent_orientation, self.WIDTH, self.HEIGHT)
            bump = bumped
            if not bumped:
                # After a successful move, check for fatal hazards
                if self.is_agent_at_hazard():
                    reward += -1000
                    self.game_over = True

        elif action == Action.GRAB:
            # Pick up gold if present
            if self.is_glitter():
                self.agent_has_gold = True
                self.gold_location = None

        elif action == Action.SHOOT:
            # Only the first time with an arrow should cost -10 and attempt a kill
            if self.agent_has_arrow:
                self.agent_has_arrow = False
                reward += -10
                if self.kill_attempt():  # sets wumpus_alive=False if hit
                    scream = True
            # If no arrow, no extra penalty/effect

        elif action == Action.CLIMB:
            # Only meaningful at the start cell (1,1)
            if self.agent_location.is_location(Location(1, 1)):
                if self.agent_has_gold:
                    reward += 1000
                    self.game_over = True
                else:
                    if self.allow_climb_without_gold:
                        # End episode with just the step cost already applied
                        self.game_over = True
                    # else: climbing without gold is ignored

        # ---------------------------
        # Build Percept for this step
        # ---------------------------
        percept = Percept(
            time_step=self.time_step,
            bump=bump,
            breeze=self.is_breeze(),
            stench=self.is_stench(),
            scream=scream,
            glitter=self.is_glitter(),
            reward=reward,
            done=self.game_over
        )
        return percept

    # ---------------------------------------------------------------------
    # VISUALIZATION OF THE GAME STATE
    # ---------------------------------------------------------------------
    def visualize(self):
        """
        Print a simple text grid showing the current world state.

        Legend:
            A→ A← A↑ A↓ : Agent and its facing direction
            P           : Pit
            W / w       : Wumpus (alive/dead)
            G           : Gold

        Coordinate system:
            (1,1) is bottom-left; printed from the top row down to the bottom.
        """
        for y in range(self.HEIGHT, 0, -1):  # print rows top→bottom
            line = '|'
            for x in range(1, self.WIDTH + 1):  # columns left→right
                loc = Location(x, y)
                cell_symbols = []  # dynamic list for whatever is in this cell

                # Agent (shows letter A plus its facing arrow)
                if self.is_agent_at(loc):
                    cell_symbols.append('A' + self.agent_orientation.symbol())

                # Pit
                if self.is_pit_at(loc):
                    cell_symbols.append('P')

                # Wumpus (alive/dead)
                if self.is_wumpus_at(loc):
                    cell_symbols.append('W' if self.wumpus_alive else 'w')

                # Gold
                if self.is_gold_at(loc):
                    cell_symbols.append('G')

                # If cell empty, leave a few spaces for alignment
                cell_str = ''.join(cell_symbols) if cell_symbols else '   '

                line += f'{cell_str:4}|'  # pad each cell to uniform width
            print(line)


# Probabilistic Model Definition (Pomegranate)

## Helper Funcitons

In [None]:
!pip install "pomegranate==1.0.4"



In [None]:
from typing import Dict, Tuple, List

import numpy as np
import torch
from torch import masked

from pomegranate.distributions import Categorical, ConditionalCategorical
from pomegranate.bayesian_network import BayesianNetwork


In [None]:
class Predicate:
    """
    Small helper exactly like in the course notebook.

    Represents a boolean predicate which is True with probability p
    and False with probability 1-p. Encodes as [P(False), P(True)].
    """

    def __init__(self, prob: float):
        self.p = prob

    def to_list(self) -> List[float]:
        return [1.0 - self.p, self.p]

    def to_categorical(self) -> Categorical:
        return Categorical([self.to_list()])


def all_cells(width: int, height: int) -> List[Tuple[int, int]]:
    """Return all (x, y) cells in a 1-based width×height grid."""
    return [(x, y) for x in range(1, width + 1) for y in range(1, height + 1)]


def adjacent_cells(x: int, y: int, width: int, height: int) -> List[Tuple[int, int]]:
    """Return all 4-connected neighbours of (x, y) within the grid."""
    neighbours: List[Tuple[int, int]] = []
    for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
        nx, ny = x + dx, y + dy
        if 1 <= nx <= width and 1 <= ny <= height:
            neighbours.append((nx, ny))
    return neighbours


def pit_var(x: int, y: int) -> str:
    return f"Pit_{x}_{y}"


def breeze_var(x: int, y: int) -> str:
    return f"Breeze_{x}_{y}"


def stench_var(x: int, y: int) -> str:
    return f"Stench_{x}_{y}"


def wumpus_loc_name() -> str:
    return "WumpusLocation"


def build_or_cases(num_parents: int) -> List:
    """
    Build the nested probability tensor for a breeze-like predicate which is
    True iff any of its binary pit parents is True.

    Returns a nested list with shape (2, 2, ..., 2, 2):
      - one 2 for each parent (False/True)
      - last 2 for the child (False/True)
    which will then be wrapped as [cases] to give the (1, ...) shape that
    ConditionalCategorical expects.
    """

    def rec(level: int, assignment: List[int]) -> List:
        if level == num_parents:
            any_pit = any(v == 1 for v in assignment)
            p_true = 1.0 if any_pit else 0.0
            return [1.0 - p_true, p_true]  # [P(False), P(True)]
        # next level: 0 (False), 1 (True)
        return [rec(level + 1, assignment + [0]), rec(level + 1, assignment + [1])]

    return rec(0, [])

## Pit / Breeze Model

In [None]:
def build_pit_breeze_model(
    width: int = 4, height: int = 4, pit_prob: float = 0.2
) -> BayesianNetwork:
    """
    Build a Bayesian network for pits and breezes, following the style of
    'Working with Pomegranate v1.04.ipynb'.

    Design choices:
      - 15 pit variables: all grid cells EXCEPT (1,1). Start is always safe.
      - 16 breeze variables: one per cell.
      - Pits are independent with prior:
          Pit(x,y) ~ Bernoulli(pit_prob) for (x,y) != (1,1)
      - Each Breeze(x,y) is True iff any adjacent pit is True.
        Encoded as a ConditionalCategorical with OR semantics over parent pits.
    """

    # 1) Pit distributions (no pit at (1,1))
    pit_dists: Dict[Tuple[int, int], Categorical] = {}
    for x, y in all_cells(width, height):
        if (x, y) == (1, 1):
            continue  # start cell is always safe; no pit variable
        pit_dists[(x, y)] = Predicate(pit_prob).to_categorical()

    # 2) Breeze distributions
    breeze_dists: Dict[Tuple[int, int], Categorical | ConditionalCategorical] = {}

    for x, y in all_cells(width, height):
        # Parent pits for this breeze: all adjacent cells that have a pit variable
        parent_cells = [
            (nx, ny)
            for (nx, ny) in adjacent_cells(x, y, width, height)
            if (nx, ny) in pit_dists
        ]
        num_parents = len(parent_cells)

        if num_parents == 0:
            # Degenerate case (e.g., 1x1 world) → no breeze ever.
            breeze_dists[(x, y)] = Predicate(0.0).to_categorical()
        else:
            cases = build_or_cases(num_parents)
            # Wrap as [cases] to give shape (1, 2, 2, ..., 2)
            breeze_dists[(x, y)] = ConditionalCategorical([cases])

    # 3) Build variables list and edges
    variables: List[Categorical | ConditionalCategorical] = []
    edges: List[Tuple[Categorical | ConditionalCategorical, ConditionalCategorical]] = []
    index_by_name: Dict[str, int] = {}

    # Add all pits first, in a consistent order
    for x, y in all_cells(width, height):
        if (x, y) == (1, 1):
            continue
        name = pit_var(x, y)
        dist = pit_dists[(x, y)]
        index_by_name[name] = len(variables)
        variables.append(dist)

    # Then all breezes
    for x, y in all_cells(width, height):
        name = breeze_var(x, y)
        dist = breeze_dists[(x, y)]
        index_by_name[name] = len(variables)
        variables.append(dist)

    # Now define edges: Pit -> Breeze whenever pit cell is adjacent
    for x, y in all_cells(width, height):
        breeze_dist = breeze_dists[(x, y)]
        for nx, ny in adjacent_cells(x, y, width, height):
            if (nx, ny) in pit_dists:
                edges.append((pit_dists[(nx, ny)], breeze_dist))

    pits_model = BayesianNetwork(variables, edges)

    # Attach helper metadata for later use in inference
    pits_model.variable_indices = index_by_name
    pits_model.num_variables = len(variables)

    return pits_model

## Wumpus / Stench Model

In [None]:
def build_wumpus_stench_model(width: int = 4, height: int = 4) -> BayesianNetwork:
    """
    Build a Bayesian network for the Wumpus and stenches.

    Design:
      - One Categorical for WumpusLocation with 15 categories, one for each
        non-start cell (all cells except (1,1)), uniform prior.
      - 16 ConditionalCategorical nodes for stenches, one per cell.
        Each Stench(x,y) is True iff Wumpus is in a 4-neighbour cell of (x,y).
    """

    # Enumerate possible Wumpus locations (exclude start cell)
    wumpus_positions: List[Tuple[int, int]] = [
        (x, y) for (x, y) in all_cells(width, height) if (x, y) != (1, 1)
    ]
    num_locations = len(wumpus_positions)

    # 1) Wumpus prior: uniform over 15 locations
    wumpus_probs = [1.0 / num_locations] * num_locations
    wumpus_dist = Categorical([wumpus_probs])

    # 2) Stench conditional distributions
    stench_dists: Dict[Tuple[int, int], ConditionalCategorical] = {}

    for x, y in all_cells(width, height):
        # For this cell, compute P(Stench | WumpusLocation = each loc)
        rows: List[List[float]] = []
        adj_cells = set(adjacent_cells(x, y, width, height))
        for (wx, wy) in wumpus_positions:
            is_adjacent = (wx, wy) in adj_cells
            p_true = 1.0 if is_adjacent else 0.0
            rows.append([1.0 - p_true, p_true])  # [P(False), P(True)]

        # probs shape: (1, num_locations, 2)
        probs = [rows]
        stench_dists[(x, y)] = ConditionalCategorical(probs)

    # 3) Build variables and edges
    variables: List[Categorical | ConditionalCategorical] = []
    edges: List[Tuple[Categorical, ConditionalCategorical]] = []
    index_by_name: Dict[str, int] = {}

    # WumpusLocation first
    index_by_name[wumpus_loc_name()] = 0
    variables.append(wumpus_dist)

    # Then all stenches
    for x, y in all_cells(width, height):
        name = stench_var(x, y)
        dist = stench_dists[(x, y)]
        index_by_name[name] = len(variables)
        variables.append(dist)
        edges.append((wumpus_dist, dist))

    wumpus_model = BayesianNetwork(variables, edges)

    # Metadata for inference
    wumpus_model.variable_indices = index_by_name
    wumpus_model.num_variables = len(variables)
    wumpus_model.wumpus_positions = wumpus_positions  # list of (x,y) in index order

    return wumpus_model

## Inference helpers for Pit/Breeze and Wumpus/Stench Models

In [None]:
def infer_pit_posteriors(
    pit_model: BayesianNetwork,
    width: int,
    height: int,
    breeze_history: Dict[Tuple[int, int], bool],
) -> Dict[Tuple[int, int], float]:
    """
    Given:
      - a baked pit/breeze BayesianNetwork (pits_model)
      - grid size
      - observed breezes in some cells (True/False)

    Return:
      - pit_probs[(x,y)] = P(Pit_x_y = True | all observed breezes)

    Follows the pattern from the course notebook:
      - Build X with -1 for unknown, 0/1 for observed variables
      - Use a MaskedTensor for predict_proba
    """
    n_vars = pit_model.num_variables
    X = torch.full((1, n_vars), -1, dtype=torch.long)
    mask = torch.zeros((1, n_vars), dtype=torch.bool)

    # Only breezes are used as evidence here (you could also encode known-safe pits as 0)
    for (x, y), has_breeze in breeze_history.items():
        name = breeze_var(x, y)
        idx = pit_model.variable_indices[name]
        X[0, idx] = 1 if has_breeze else 0
        mask[0, idx] = True

    X_masked = masked.MaskedTensor(X, mask=mask)
    posteriors = pit_model.predict_proba(X_masked)

    pit_probs: Dict[Tuple[int, int], float] = {}

    for x, y in all_cells(width, height):
        if (x, y) == (1, 1):
            # Start cell cannot be a pit
            pit_probs[(x, y)] = 0.0
            continue
        name = pit_var(x, y)
        idx = pit_model.variable_indices[name]
        # posteriors[idx] is a tensor of shape (1, 2): [P(False), P(True)]
        probs_tensor = posteriors[idx]
        p_true = float(probs_tensor[0, 1].item())
        pit_probs[(x, y)] = p_true

    return pit_probs


def infer_wumpus_posteriors(
    wumpus_model: BayesianNetwork,
    width: int,
    height: int,
    stench_history: Dict[Tuple[int, int], bool],
    wumpus_dead: bool,
) -> Dict[Tuple[int, int], float]:
    """
    Given:
      - a baked Wumpus/stench BayesianNetwork (wumpus_model)
      - grid size
      - observed stenches in some cells (True/False)
      - flag wumpus_dead (if we have heard a scream and know it's dead)

    Return:
      - wumpus_probs[(x,y)] = P(Wumpus is at (x,y) | evidence)
        or all zeros if wumpus_dead is True.
    """
    wumpus_probs: Dict[Tuple[int, int], float] = {}

    if wumpus_dead:
        # Once the Wumpus is dead, its hazard contribution is zero.
        for x, y in all_cells(width, height):
            wumpus_probs[(x, y)] = 0.0
        return wumpus_probs

    n_vars = wumpus_model.num_variables
    X = torch.full((1, n_vars), -1, dtype=torch.long)
    mask = torch.zeros((1, n_vars), dtype=torch.bool)

    # Evidence: stenches
    for (x, y), has_stench in stench_history.items():
        name = stench_var(x, y)
        idx = wumpus_model.variable_indices[name]
        X[0, idx] = 1 if has_stench else 0
        mask[0, idx] = True

    X_masked = masked.MaskedTensor(X, mask=mask)
    posteriors = wumpus_model.predict_proba(X_masked)

    # Posterior over WumpusLocation
    w_name = wumpus_loc_name()
    w_idx = wumpus_model.variable_indices[w_name]
    w_tensor = posteriors[w_idx]  # shape (1, num_locations)
    w_vec = w_tensor[0]

    # Map location indices back to (x,y)
    positions: List[Tuple[int, int]] = wumpus_model.wumpus_positions
    for loc_idx, (x, y) in enumerate(positions):
        wumpus_probs[(x, y)] = float(w_vec[loc_idx].item())

    # Start cell cannot contain the Wumpus
    if (1, 1) not in wumpus_probs:
        wumpus_probs[(1, 1)] = 0.0

    return wumpus_probs


def compute_safety_from_hazards(
    pit_probs: Dict[Tuple[int, int], float],
    wumpus_probs: Dict[Tuple[int, int], float],
) -> Dict[Tuple[int, int], float]:
    """
    Combine pit and Wumpus probabilities into a safety score:

      P(safe) = (1 - P(pit)) * (1 - P(wumpus))

    for each cell (x, y).
    """
    safety: Dict[Tuple[int, int], float] = {}

    all_cells_set = set(pit_probs.keys()) | set(wumpus_probs.keys())
    for cell in all_cells_set:
        p_pit = pit_probs.get(cell, 0.0)
        p_w = wumpus_probs.get(cell, 0.0)
        safety[cell] = (1.0 - p_pit) * (1.0 - p_w)

    return safety

## PGM Sanity Testing


### Sanity-check the Pit/Breeze model (tiny world)

In [None]:
# Tiny sanity-check for pits_model, like the notebook example

# Build 4x4 model but we’ll only care about (1,2), (2,1), (1,1)
pits_model = build_pit_breeze_model(width=4, height=4, pit_prob=0.2)

print("Variables in pits_model:")
print(pits_model.variable_indices)


Variables in pits_model:
{'Pit_1_2': 0, 'Pit_1_3': 1, 'Pit_1_4': 2, 'Pit_2_1': 3, 'Pit_2_2': 4, 'Pit_2_3': 5, 'Pit_2_4': 6, 'Pit_3_1': 7, 'Pit_3_2': 8, 'Pit_3_3': 9, 'Pit_3_4': 10, 'Pit_4_1': 11, 'Pit_4_2': 12, 'Pit_4_3': 13, 'Pit_4_4': 14, 'Breeze_1_1': 15, 'Breeze_1_2': 16, 'Breeze_1_3': 17, 'Breeze_1_4': 18, 'Breeze_2_1': 19, 'Breeze_2_2': 20, 'Breeze_2_3': 21, 'Breeze_2_4': 22, 'Breeze_3_1': 23, 'Breeze_3_2': 24, 'Breeze_3_3': 25, 'Breeze_3_4': 26, 'Breeze_4_1': 27, 'Breeze_4_2': 28, 'Breeze_4_3': 29, 'Breeze_4_4': 30}


In [None]:
# Helper: just test breezes at (1,1)
breeze_history_cases = {
    "Q1_no_breeze": {(1, 1): False},
    "Q2_yes_breeze": {(1, 1): True},
    "Q3_unknown_breeze": {},  # no evidence
}

for label, bh in breeze_history_cases.items():
    pit_probs = infer_pit_posteriors(pits_model, width=4, height=4, breeze_history=bh)
    print(f"\n{label}:")
    print("P(Pit_1_2 = True):", pit_probs[(1, 2)])
    print("P(Pit_2_1 = True):", pit_probs[(2, 1)])



Q1_no_breeze:
P(Pit_1_2 = True): 0.0
P(Pit_2_1 = True): 0.0


  X_masked = masked.MaskedTensor(X, mask=mask)



Q2_yes_breeze:
P(Pit_1_2 = True): 0.5555555820465088
P(Pit_2_1 = True): 0.5555555820465088

Q3_unknown_breeze:
P(Pit_1_2 = True): 0.20000000298023224
P(Pit_2_1 = True): 0.20000000298023224


### Sanity-check the Wumpus/Stench model

In [None]:
wumpus_model = build_wumpus_stench_model(width=4, height=4)

print("Variables in wumpus_model:")
print(wumpus_model.variable_indices)
print("Wumpus positions:", wumpus_model.wumpus_positions)

Variables in wumpus_model:
{'WumpusLocation': 0, 'Stench_1_1': 1, 'Stench_1_2': 2, 'Stench_1_3': 3, 'Stench_1_4': 4, 'Stench_2_1': 5, 'Stench_2_2': 6, 'Stench_2_3': 7, 'Stench_2_4': 8, 'Stench_3_1': 9, 'Stench_3_2': 10, 'Stench_3_3': 11, 'Stench_3_4': 12, 'Stench_4_1': 13, 'Stench_4_2': 14, 'Stench_4_3': 15, 'Stench_4_4': 16}
Wumpus positions: [(1, 2), (1, 3), (1, 4), (2, 1), (2, 2), (2, 3), (2, 4), (3, 1), (3, 2), (3, 3), (3, 4), (4, 1), (4, 2), (4, 3), (4, 4)]


In [None]:
# No stenches → uniform posterior
stench_history = {}
w_probs = infer_wumpus_posteriors(
    wumpus_model,
    width=4,
    height=4,
    stench_history=stench_history,
    wumpus_dead=False,
)

# Excluding (1,1), all others should be ~1/15
print("Sum of P(Wumpus):", sum(w_probs.values()))
print("Example cells:", {pos: w_probs[pos] for pos in list(w_probs.keys())[:5]})


Sum of P(Wumpus): 0.9999999403953552
Example cells: {(1, 2): 0.06666666269302368, (1, 3): 0.06666666269302368, (1, 4): 0.06666666269302368, (2, 1): 0.06666666269302368, (2, 2): 0.06666666269302368}


  X_masked = masked.MaskedTensor(X, mask=mask)


In [None]:
# One stench → only neighbours have non-zero probability
stench_history = {(2, 1): True}
w_probs = infer_wumpus_posteriors(
    wumpus_model,
    width=4,
    height=4,
    stench_history=stench_history,
    wumpus_dead=False,
)

print("Cells with non-negligible probability:")
for (x, y), p in w_probs.items():
    if p > 1e-6:
        print((x, y), ":", p)


Cells with non-negligible probability:
(2, 2) : 0.5
(3, 1) : 0.5


  X_masked = masked.MaskedTensor(X, mask=mask)


# Agent definition

## Planner Definition

In [None]:
State = Tuple[int, int, int]  # (x, y, d) with x,y 1-based; d in {0..3}

# Define DX and DY for directional movement
DX = [1, 0, -1, 0]  # E, S, W, N
DY = [0, -1, 0, 1]  # E, S, W, N

# Helper functions for turning
def dir_left(d: int) -> int:
    """Turn counter-clockwise (E → N → W → S → E)."""
    return (d - 1) % 4

def dir_right(d: int) -> int:
    """Turn clockwise (E → S → W → N → E)."""
    return (d + 1) % 4

def bfs_shortest_actions(
    start: State,
    goal_cell: Tuple[int, int],
    safe_cells: Set[Tuple[int, int]],
    width: int,
    height: int,
) -> Optional[List[str]]:
    """
    Shortest path (unit cost) in orientation-augmented space.
    Actions are: "TurnLeft", "TurnRight", "Forward".
    Forward permitted only if the destination cell is in safe_cells and in-bounds.
    """
    sx, sy, sd = start
    if (sx, sy) == goal_cell:
        return []

    def in_bounds(x: int, y: int) -> bool:
        return 1 <= x <= width and 1 <= y <= height

    parent: Dict[State, Tuple[State, str]] = {}
    seen: Set[State] = {(sx, sy, sd)}
    q: Deque[State] = deque([(sx, sy, sd)])

    while q:
        x, y, d = q.popleft()

        # TurnLeft
        nl = (x, y, dir_left(d))
        if nl not in seen:
            seen.add(nl); parent[nl] = ((x, y, d), "TurnLeft")
            if (x, y) == goal_cell:
                return _reconstruct_actions(parent, nl)
            q.append(nl)

        # TurnRight
        nr = (x, y, dir_right(d))
        if nr not in seen:
            seen.add(nr); parent[nr] = ((x, y, d), "TurnRight")
            if (x, y) == goal_cell:
                return _reconstruct_actions(parent, nr)
            q.append(nr)

        # Forward (only into known-safe)
        fx, fy = x + DX[d], y + DY[d]
        if in_bounds(fx, fy) and (fx, fy) in safe_cells:
            nf = (fx, fy, d)
            if nf not in seen:
                seen.add(nf); parent[nf] = ((x, y, d), "Forward")
                if (fx, fy) == goal_cell:
                    return _reconstruct_actions(parent, nf)
                q.append(nf)

    return None  # no route through known-safe cells

def _reconstruct_actions(
    parent: Dict[State, Tuple[State, str]],
    goal: State
) -> List[str]:
    actions: List[str] = []
    cur: Optional[State] = goal
    while cur in parent:
        prev, a = parent[cur]
        actions.append(a)
        cur = prev
    actions.reverse()
    return actions

## NaiveAgent Definition

In [None]:
class NaiveAgent:
    """
    A naive agent that selects random actions and interacts with the Environment.
    It uses the updated Environment interface supporting dynamic grid sizes.
    """

    def __init__(self, width: int = 4, height: int = 4,
                 pit_prob: float = 0.2, allow_climb_without_gold: bool = True,
                 seed: int = None):
        """
        Initialize the NaiveAgent with optional environment parameters.

        Parameters
        ----------
        width : int
            Width of the grid (default: 4)
        height : int
            Height of the grid (default: 4)
        pit_prob : float
            Probability that a non-start cell contains a pit (default: 0.2)
        allow_climb_without_gold : bool
            Whether climbing without gold ends the episode (default: True)
        seed : int, optional
            Random seed for reproducibility (default: None)
        """
        self.width = width
        self.height = height
        self.pit_prob = pit_prob
        self.allow_climb_without_gold = allow_climb_without_gold

        if seed is not None:
            random.seed(seed)

    def choose_action(self):
        """Return a randomly chosen action from the Action enum."""
        return random.choice(list(Action))

    def run(self):
        """Run a full episode of random actions until the game ends."""
        env = Environment()
        cumulative_reward = 0

        # Initialize the environment using the new parameterized interface
        percept = env.init(
            width=self.width,
            height=self.height,
            pit_prob=self.pit_prob,
            allow_climb_without_gold=self.allow_climb_without_gold
        )

        # Main loop: random actions until terminal state
        while not percept.done:
            env.visualize()
            print('Percept:', percept)
            action = self.choose_action()
            print('\nAction:', action, '\n')
            percept = env.step(action)
            cumulative_reward += percept.reward

        # Final visualization and summary
        env.visualize()
        print('Percept:', percept)
        print('Cumulative reward:', cumulative_reward)


In [None]:
agent = NaiveAgent()
for _ in range(6):
    print(agent.choose_action())


Action.RIGHT
Action.LEFT
Action.FORWARD
Action.CLIMB
Action.FORWARD
Action.SHOOT


## MovePlanningAgent Definition

In [None]:
@dataclass
class MovePlanningAgent:
    width: int = 4
    height: int = 4
    allow_climb_without_gold: bool = True
    pit_prob: float = 0.2

    # runtime state
    x: int = 1
    y: int = 1
    d: int = 0  # 0:E, 1:S, 2:W, 3:N
    has_gold: bool = False
    visited_safe: Set[Tuple[int, int]] = field(default_factory=lambda: {(1, 1)})
    plan: Deque[str] = field(default_factory=deque)
    rng: random.Random = field(default_factory=random.Random)
    cumulative_reward: int = 0

    # environment is injected at run()
    env: object = None

    def run(self, Environment, Action):
        """Assumes your Environment has .init(pit_prob, allowClimbWithoutGold), .step(action), .visualize()."""
        # initialize episode
        self.env = Environment()

        percept = self.env.init(
              width=self.width,
              height=self.height,
              pit_prob=self.pit_prob,
              allow_climb_without_gold=self.allow_climb_without_gold,
        )


        self.x, self.y, self.d = 1, 1, 0
        self.has_gold = False
        self.visited_safe = {(1, 1)}
        self.plan.clear()
        self.cumulative_reward = 0


        while not percept.done:
              # 1) Show board and current percept (same as NaiveAgent)
              self.env.visualize()
              print('Percept:', percept)

              # 2) Deterministic reaction to glitter
              if percept.glitter and not self.has_gold:
                  print('\nAction:', Action.GRAB, '\n')
                  percept = self.env.step(Action.GRAB)
                  self.cumulative_reward += percept.reward
                  self.has_gold = True
                  if not percept.done:
                      self.visited_safe.add((self.x, self.y))
                  # plan shortest safe path to start
                  self.plan = deque(bfs_shortest_actions(
                      (self.x, self.y, self.d), (1, 1), self.visited_safe, self.width, self.height
                  ) or [])
                  continue

              # 3) If executing a plan, take the next planned action
              if self.plan:
                  action = self._action_from_label(self.plan.popleft(), Action)
                  print('\nAction:', action, '\n')
                  # _act_and_update() will call env.step() and update pose/safe set
                  percept = self._act_and_update(action)

                  # If plan finished at start with gold, climb out
                  if not self.plan and self.has_gold and (self.x, self.y) == (1, 1) and not percept.done:
                      print('\nAction:', Action.CLIMB, '\n')
                      percept = self._act_and_update(Action.CLIMB)
                  continue

              # 4) Otherwise: explore (no random Grab/Climb)
              action = self.rng.choice([Action.FORWARD, Action.LEFT, Action.RIGHT, Action.SHOOT])
              print('\nAction:', action, '\n')
              percept = self._act_and_update(action)


        # final board
        try: self.env.visualize()
        except Exception: pass

        print("Percept:", percept)
        print("Cumulative reward:", self.cumulative_reward)
        return self.cumulative_reward

    # ---- helpers ----
    def _act_and_update(self, action):
        """Dispatch action to env, update pose and safe set based on percept."""
        p = self.env.step(action)
        self.cumulative_reward += p.reward

        # Update heading/position consistent with Assignment 1 semantics
        name = getattr(action, "name", str(action))
        if name == "LEFT":
            self.d = dir_left(self.d)
        elif name == "RIGHT":
            self.d = dir_right(self.d)
        elif name == "FORWARD":
            # Only advance on no-bump
            if not p.bump:
                self.x += DX[self.d]
                self.y += DY[self.d]

        if not p.done:
            self.visited_safe.add((self.x, self.y))
        return p

    @staticmethod
    def _action_from_label(label: str, Action):
        return {
            "Forward": Action.FORWARD,
            "TurnLeft": Action.LEFT,
            "TurnRight": Action.RIGHT,
        }[label]

## ProbAgent Definition

In [None]:
# Orientation indices: 0 = E, 1 = S, 2 = W, 3 = N
# We map them to (dx, dy) steps in the grid.
# Assumption: x increases to the right (east), y increases upward (north).
ORIENTATION_DELTAS = {
    Orientation.E.value: (1, 0),   # east
    Orientation.S.value: (0, -1),  # south
    Orientation.W.value: (-1, 0),  # west
    Orientation.N.value: (0, 1),   # north
}

In [None]:
class ProbAgent(MovePlanningAgent):
    """
    Probabilistic Wumpus World agent.

    Extends MovePlanningAgent by:
      - Maintaining a belief state over pits and the Wumpus using
        Bayesian networks (pomegranate v1.0.4).
      - Using percept history (breezes, stenches, scream) to update
        hazard probabilities.
      - Choosing actions based on safest cells instead of random exploration.
    """

    def __init__(
        self,
        width: int = 4,
        height: int = 4,
        pit_prob: float = 0.2,
        allow_climb_without_gold: bool = True,
    ):
        # Initialize MovePlanningAgent base (dataclass-generated __init__)
        super().__init__(
            width=width,
            height=height,
            allow_climb_without_gold=allow_climb_without_gold,
            pit_prob=pit_prob,
        )

        # Store grid + prior parameters for convenience
        self.width = width
        self.height = height
        self.pit_prob = pit_prob

        # Build Bayesian models once per agent/episode
        self.pit_model = build_pit_breeze_model(width, height, pit_prob)
        self.wumpus_model = build_wumpus_stench_model(width, height)

        # Percept history (evidence)
        self.breeze_history: Dict[Tuple[int, int], bool] = {}
        self.stench_history: Dict[Tuple[int, int], bool] = {}

        # Wumpus status
        self.heard_scream: bool = False
        self.wumpus_dead: bool = False

        # Agent's personal inventory
        self.has_arrow: bool = True  # Initialize has_arrow here

        # Tracking: was the arrow used during this episode? (for evaluation stats)
        self.arrow_used_this_episode: bool = False

        # Current posterior beliefs and safety (filled by update_beliefs)
        self.pit_probs: Dict[Tuple[int, int], float] = {}
        self.wumpus_probs: Dict[Tuple[int, int], float] = {}
        self.safety: Dict[Tuple[int, int], float] = {}

        # Initialize to prior beliefs (no evidence yet)
        self._initialize_prior_beliefs()

    # ------------------------------------------------------------------
    # Belief initialization and update
    # ------------------------------------------------------------------

    def _initialize_prior_beliefs(self) -> None:
        """
        Initialize pit_probs, wumpus_probs, and safety from the priors,
        before any evidence is observed.
        """
        # Pit priors: 0 at (1,1), pit_prob elsewhere
        for x, y in all_cells(self.width, self.height):
            if (x, y) == (1, 1):
                self.pit_probs[(x, y)] = 0.0
            else:
                self.pit_probs[(x, y)] = self.pit_prob

        # Wumpus priors: 0 at (1,1), uniform over the remaining cells
        num_cells = self.width * self.height
        num_valid = num_cells - 1  # exclude (1,1)
        for x, y in all_cells(self.width, self.height):
            if (x, y) == (1, 1):
                self.wumpus_probs[(x, y)] = 0.0
            else:
                self.wumpus_probs[(x, y)] = 1.0 / num_valid

        # Combine into initial safety scores
        self.safety = compute_safety_from_hazards(self.pit_probs, self.wumpus_probs)

    def update_beliefs(self, percept: Percept) -> None:
        """
        Update Bayesian beliefs given the latest percept at the agent's
        current location.

        This should be called after each env.step(...) in the main loop.
        """
        # Current agent location as tracked by MovePlanningAgent
        # (Assumes MovePlanningAgent maintains self.x, self.y)
        loc = (self.x, self.y)

        # Record local percept evidence
        self.breeze_history[loc] = bool(percept.breeze)
        self.stench_history[loc] = bool(percept.stench)
        if percept.scream:
            self.heard_scream = True
            self.wumpus_dead = True

        # Recompute posteriors from all accumulated evidence
        self.pit_probs = infer_pit_posteriors(
            self.pit_model,
            self.width,
            self.height,
            self.breeze_history,
        )

        self.wumpus_probs = infer_wumpus_posteriors(
            self.wumpus_model,
            self.width,
            self.height,
            self.stench_history,
            self.wumpus_dead,
        )

        # Combine hazards into cell safety scores
        self.safety = compute_safety_from_hazards(self.pit_probs, self.wumpus_probs)

    # ------------------------------------------------------------------
    # Frontier selection and basic risk-based target choice
    # ------------------------------------------------------------------

    def _frontier_cells(self) -> list[Tuple[int, int]]:
        """
        Frontier = all cells in the grid that:
          - are within bounds, and
          - have NOT yet been visited (i.e., not in self.visited_safe).

        These are candidates for future exploration.
        """
        frontier: list[Tuple[int, int]] = []
        for x, y in all_cells(self.width, self.height):
            if (x, y) not in self.visited_safe:
                frontier.append((x, y))
        return frontier

    @staticmethod
    def _manhattan_distance(a: Tuple[int, int], b: Tuple[int, int]) -> int:
        """L1 distance between two grid cells."""
        (x1, y1), (x2, y2) = a, b
        return abs(x1 - x2) + abs(y1 - y2)

    def _select_best_frontier_cell(
        self,
        abort_risk_threshold: float = 0.5,
    ) -> Tuple[Tuple[int, int] | None, float | None]:
        """
        Select the "best" frontier cell to explore next based on current
        safety beliefs.

        For each frontier cell c:
          - risk(c) = 1 - safety[c]
        We choose:
          - the cell with MINIMUM risk; ties broken by Manhattan distance
            from the agent's current location (self.x, self.y).

        If ALL frontier cells have risk > abort_risk_threshold, we return
        (None, None) to signal that the agent should give up and go home.
        """
        frontier = self._frontier_cells()
        if not frontier:
            return None, None  # nothing left to explore

        # Compute risk for each frontier cell; default safety=0.0 if unknown.
        risks: list[Tuple[Tuple[int, int], float]] = []
        for cell in frontier:
            s = self.safety.get(cell, 0.0)
            risk = 1.0 - s
            risks.append((cell, risk))

        # Find minimum risk
        min_risk = min(r for (_, r) in risks)

        # If even the safest option is too risky, we should bail.
        if min_risk > abort_risk_threshold:
            return None, min_risk

        # Among cells with risk == min_risk (within small epsilon),
        # pick the one with the shortest Manhattan distance from current location.
        eps = 1e-6
        current = (self.x, self.y)
        best_cell: Tuple[int, int] | None = None
        best_dist: int | None = None

        for cell, risk in risks:
            if abs(risk - min_risk) > eps:
                continue
            dist = self._manhattan_distance(current, cell)
            if best_cell is None or dist < best_dist:
                best_cell = cell
                best_dist = dist

        return best_cell, min_risk

    # ------------------------------------------------------------------
    # Main control loop
    # ------------------------------------------------------------------

    def run(self, Environment, Action):
        """
        Same overall structure as MovePlanningAgent.run, but:
          - Maintains and updates Bayesian beliefs after every action.
          - Uses risk-based planning instead of random exploration.
          - May decide to bail out and climb even without gold if all
            frontier cells are too risky.

        If self.quiet is True, board visualization and print statements
        are suppressed (useful for running many episodes).

        At the end of the episode, self.outcome is one of:
          - "success" : grabbed gold and climbed out at (1,1)
          - "death"   : died in a pit or by the Wumpus
          - "bail"    : climbed out at (1,1) without gold
        """
        # 1) Initialize episode and environment
        self.env = Environment()

        percept = self.env.init(
            width=self.width,
            height=self.height,
            pit_prob=self.pit_prob,
            allow_climb_without_gold=self.allow_climb_without_gold,
        )

        # Reset agent state (same as MovePlanningAgent)
        self.x, self.y, self.d = 1, 1, 0
        self.has_gold = False
        self.has_arrow = True
        self.visited_safe = {(1, 1)}
        self.plan.clear()
        self.cumulative_reward = 0
        self.outcome = "unknown"

        # Reset per-episode arrow tracking
        self.arrow_used_this_episode = False

        # Reset belief-related state
        self.breeze_history.clear()
        self.stench_history.clear()
        self.heard_scream = False
        self.wumpus_dead = False
        self._initialize_prior_beliefs()

        # Initial beliefs from first percept
        self.update_beliefs(percept)

        # 2) Main perception–action loop
        while not percept.done:
            # Show board and current percept (unless quiet)
            if not getattr(self, "quiet", False):
                self.env.visualize()
                print("Percept:", percept)

            # --- Step 1: deterministic reaction to glitter ---
            if percept.glitter and not self.has_gold:
                if not getattr(self, "quiet", False):
                    print("\nAction:", Action.GRAB, "\n")

                percept = self.env.step(Action.GRAB)
                self.cumulative_reward += percept.reward
                self.has_gold = True

                if not percept.done:
                    self.visited_safe.add((self.x, self.y))

                # Update beliefs at current location
                self.update_beliefs(percept)

                # Plan shortest safe path to start using visited_safe
                self.plan = deque(
                    bfs_shortest_actions(
                        (self.x, self.y, self.d),
                        (1, 1),
                        self.visited_safe,
                        self.width,
                        self.height,
                    )
                    or []
                )
                continue

            # --- Step 2: follow an existing plan, if any ---
            if self.plan:
                action = self._action_from_label(self.plan.popleft(), Action)
                if not getattr(self, "quiet", False):
                    print("\nAction:", action, "\n")

                # _act_and_update is inherited; it updates position,
                # cumulative_reward, and visited_safe as in MovePlanningAgent
                percept = self._act_and_update(action)
                self.update_beliefs(percept)

                # If plan finished at start, climb out (with or without gold).
                if (
                    not self.plan
                    and (self.x, self.y) == (1, 1)
                    and not percept.done
                ):
                    if not getattr(self, "quiet", False):
                        print("\nAction:", Action.CLIMB, "\n")

                    percept = self._act_and_update(Action.CLIMB)
                    self.update_beliefs(percept)
                continue

            # --- Step 3A: Shooting heuristic (before deciding next move) ---
            # If the Wumpus is not dead and we still have the arrow,
            # check whether the cell directly ahead has high probability of containing the Wumpus.

            if not self.wumpus_dead and self.has_arrow:

                # Compute the coordinates of the cell in front
                dx, dy = ORIENTATION_DELTAS[self.d]
                fx, fy = self.x + dx, self.y + dy

                # Check if forward cell is inside the grid
                if 1 <= fx <= self.width and 1 <= fy <= self.height:

                    # Probability Wumpus is there
                    p_wumpus_forward = self.wumpus_probs.get((fx, fy), 0.0)

                    # Threshold: shoot only if forward cell has high Wumpus probability
                    if p_wumpus_forward >= 0.45:  # you can tune this
                        if not getattr(self, "quiet", False):
                            print(
                                "\nAction:",
                                Action.SHOOT,
                                f"(p_wumpus={p_wumpus_forward:.2f})\n",
                            )

                        percept = self._act_and_update(Action.SHOOT)
                        self.has_arrow = False
                        self.arrow_used_this_episode = True
                        self.update_beliefs(percept)

                        # After shooting, skip movement decision and continue main loop
                        continue

            # --- Step 3: no plan, no glitter -> decide what to do next ---

            # Choose safest frontier cell, or decide to bail out
            target_cell, min_risk = self._select_best_frontier_cell(
                abort_risk_threshold=0.5  # can tune as hyperparameter
            )

            if target_cell is None:
                # Either no frontier or all options too risky => bail out
                if (self.x, self.y) == (1, 1):
                    # Already at start: climb and end episode
                    if not getattr(self, "quiet", False):
                        print("\nAction:", Action.CLIMB, "\n")

                    percept = self._act_and_update(Action.CLIMB)
                    self.update_beliefs(percept)
                    continue

                # Otherwise, plan a safe path back to (1,1) using visited_safe
                self.plan = deque(
                    bfs_shortest_actions(
                        (self.x, self.y, self.d),
                        (1, 1),
                        self.visited_safe,
                        self.width,
                        self.height,
                    )
                    or []
                )
                # Next loop iteration will execute the plan
                continue

            # We have a target frontier cell (x, y)
            tx, ty = target_cell
            safe_for_planning = set(self.visited_safe)
            safe_for_planning.add((tx, ty))

            plan_labels = bfs_shortest_actions(
                (self.x, self.y, self.d),
                (tx, ty),
                safe_for_planning,
                self.width,
                self.height,
            )

            if not plan_labels:
                # BFS failed (e.g., target not reachable via our current safe set).
                # As a conservative fallback, take a random local action like the
                # MovePlanningAgent.
                action = self.rng.choice(
                    [Action.FORWARD, Action.LEFT, Action.RIGHT, Action.SHOOT]
                )
                if not getattr(self, "quiet", False):
                    print("\nAction:", action, "\n")

                # If we randomly choose to shoot and still have the arrow,
                # mark it as used and clear has_arrow.
                if action == Action.SHOOT and self.has_arrow:
                    self.has_arrow = False
                    self.arrow_used_this_episode = True

                percept = self._act_and_update(action)
                self.update_beliefs(percept)
                continue

            # Store the planned route; next loop iteration will follow it.
            self.plan = deque(plan_labels)
            # loop continues; plan branch will run next

        # 3) Classify episode outcome for statistics
        if (self.x, self.y) == (1, 1):
            if self.has_gold:
                self.outcome = "success"   # grabbed gold and escaped
            else:
                self.outcome = "bail"      # climbed out without gold
        else:
            self.outcome = "death"         # died in pit or by Wumpus

        # Final board & summary
        if not getattr(self, "quiet", False):
            try:
                self.env.visualize()
            except Exception:
                pass

            print("Percept:", percept)
            print("Cumulative reward:", self.cumulative_reward)

        return self.cumulative_reward


# Visualization Of The Game State

## Sanity tests

In [None]:
env = Environment()
# p0 = env.init(width=4, height=4, pit_prob=0.0, allow_climb_without_gold=True) # Safe world (no pits)
p1 = env.init(width=4, height=4, pit_prob=0.2, allow_climb_without_gold=True)
env.visualize()


|    |    |    |P   |
|W   |    |    |    |
|    |G   |    |    |
|A→  |    |    |P   |


In [None]:
# # Path with pits disabled (should sometimes succeed quickly):

# move_planning_agent_happy = MovePlanningAgent(pit_prob=0.0)
# move_planning_agent_happy.run(Environment, Action)

In [None]:
# Planner sanity (inline harness):
safe_line = {(1,1),(2,1),(3,1),(4,1)}
plan = bfs_shortest_actions(start=(1,1,0), goal_cell=(4,1),
                            safe_cells=safe_line, width=4, height=4)
print(plan)  # expect ['Forward','Forward','Forward']

['Forward', 'Forward', 'Forward']


In [None]:
# Turn cost check
plan = bfs_shortest_actions(start=(1,1,1), goal_cell=(2,1),
                            safe_cells={(1,1),(2,1)}, width=4, height=4)
print(plan)  # one of ['TurnRight','TurnRight','Forward'] or ['TurnLeft','Forward']


['TurnLeft', 'Forward']


## Playing Game

In [None]:
# NAIVE AGENT

# naive_agent = NaiveAgent(width=4, height=4, pit_prob=0.2, allow_climb_without_gold=False)
# naive_agent.run()

#----------------------------------------------------------------------------------------------------

# MOVE PLANNING AGENT

# move_planning_agent = MovePlanningAgent(width=4, height=4, pit_prob=0.2, allow_climb_without_gold=False)
# move_planning_agent.run(Environment, Action)

# --------------------------------------------------------------------------------------------------

#PROBABILISTIC AGENT

agent = ProbAgent(width=4, height=4, pit_prob=0.2, allow_climb_without_gold=True)
agent.run(Environment, Action)



  X_masked = masked.MaskedTensor(X, mask=mask)
  X_masked = masked.MaskedTensor(X, mask=mask)


|    |    |WG  |    |
|    |    |    |    |
|P   |    |    |    |
|A→  |    |P   |    |
Percept: Percept(t=0, Signals=[Breeze], Reward=0, Done=False)

Action: Action.SHOOT 

|    |    |WG  |    |
|    |    |    |    |
|P   |    |    |    |
|A→  |    |P   |    |
Percept: Percept(t=1, Signals=[Breeze], Reward=-11, Done=False)

Action: Action.RIGHT 

|    |    |WG  |    |
|    |    |    |    |
|P   |    |    |    |
|A↓  |    |P   |    |
Percept: Percept(t=2, Signals=[Breeze], Reward=-1, Done=False)

Action: Action.RIGHT 

|    |    |WG  |    |
|    |    |    |    |
|P   |    |    |    |
|A←  |    |P   |    |
Percept: Percept(t=3, Signals=[Breeze], Reward=-1, Done=False)

Action: Action.SHOOT 

|    |    |WG  |    |
|    |    |    |    |
|P   |    |    |    |
|A←  |    |P   |    |
Percept: Percept(t=4, Signals=[Breeze], Reward=-1, Done=False)

Action: Action.SHOOT 

|    |    |WG  |    |
|    |    |    |    |
|P   |    |    |    |
|A←  |    |P   |    |
Percept: Percept(t=5, Signals=[Breeze

-1017

## Running game for 1000 episodes

In [None]:
def evaluate_agent(num_episodes=1000, verbose_every=100):
    """
    Run ProbAgent for num_episodes games and report:
      - average reward
      - counts and rates of success / death / bailout
      - how many episodes the arrow was used at least once
    """
    print(f"Starting evaluation for {num_episodes} episodes...")

    total_reward = 0.0
    rewards = []

    success_count = 0
    death_count = 0
    bailout_count = 0

    # counts how many *episodes* used the arrow at least once
    shots_fired = 0

    for ep in range(1, num_episodes + 1):
        agent = ProbAgent(
            width=4,
            height=4,
            pit_prob=0.2,
            allow_climb_without_gold=True,
        )
        agent.quiet = True

        reward = agent.run(Environment, Action)
        rewards.append(reward)
        total_reward += reward

        # explicit per-episode arrow flag
        if getattr(agent, "arrow_used_this_episode", False):
            shots_fired += 1

        # classify outcome from reward
        if reward >= 500:
            success_count += 1
        elif reward <= -1000:
            death_count += 1
        else:
            bailout_count += 1

        if verbose_every and ep % verbose_every == 0:
            print(f"Episode {ep}: running average reward = {total_reward / ep:.2f}")

    avg_reward = total_reward / num_episodes

    print("\n=== Evaluation summary ===")
    print(f"Episodes:               {num_episodes}")
    print(f"Average reward:         {avg_reward:.2f}")
    print(f"Min reward:             {min(rewards):.2f}")
    print(f"Max reward:             {max(rewards):.2f}")
    print()
    print(f"Successes (gold+escape): {success_count} ({success_count/num_episodes:.1%})")
    print(f"Deaths:                  {death_count} ({death_count/num_episodes:.1%})")
    print(f"Bailouts w/o gold:       {bailout_count} ({bailout_count/num_episodes:.1%})")
    print(f"Episodes where arrow was fired: {shots_fired} ({shots_fired/num_episodes:.1%})")

    return rewards, avg_reward, {
        "success": success_count,
        "death": death_count,
        "bail": bailout_count,
        "shots_fired": shots_fired,
    }


In [None]:
# Run ProbAgent for 1000 episodes and report average score
rewards, avg_reward, outcome_stats = evaluate_agent(num_episodes=1000, verbose_every=500)
# print("Done. Final average reward:", avg_reward)
print("Outcome stats:", outcome_stats)


Starting evaluation for 1000 episodes...


  X_masked = masked.MaskedTensor(X, mask=mask)
  X_masked = masked.MaskedTensor(X, mask=mask)


Episode 500: running average reward = -160.63
Episode 1000: running average reward = -128.41

=== Evaluation summary ===
Episodes:               1000
Average reward:         -128.41
Min reward:             -1184.00
Max reward:             993.00

Successes (gold+escape): 434 (43.4%)
Deaths:                  534 (53.4%)
Bailouts w/o gold:       32 (3.2%)
Episodes where arrow was fired: 666 (66.6%)
Outcome stats: {'success': 434, 'death': 534, 'bail': 32, 'shots_fired': 666}
