In [1]:
from collections import deque
from dataclasses import dataclass
from graphviz import Digraph
from itertools import permutations
from typing import NamedTuple

MAX_VOLUMES: dict[str, int] = {'x': 12, 'y': 8, 'z': 3}


class State(NamedTuple):
    x: int = 0
    y: int = 0
    z: int = 0

    def is_final(self) -> bool:
        """Check if any jug contains exactly 1L."""
        return any(value == 1 for value in self)

    def fill(self, field: str) -> 'State':
        """Fill completely the specified jug."""
        return self._replace(**{field: MAX_VOLUMES[field]})

    def empty(self, field: str) -> 'State':
        """Empty completely the specified jug."""
        return self._replace(**{field: 0})

    def pour(self, source: str, target: str) -> 'State':
        """Pour from source jug to target jug until source is empty or target is full."""
        source_vol = getattr(self, source)
        target_vol = getattr(self, target)
        target_capacity = MAX_VOLUMES[target]

        # Calculate how much can be poured
        pour_amount = min(source_vol, target_capacity - target_vol)

        return self._replace(
            **{source: source_vol - pour_amount, target: target_vol + pour_amount}
        )

    def expand(self) -> list[tuple['State', str]]:
        """Generate all valid children states with the action leading to that state."""
        children = []

        # Fill actions (only if not already full)
        children += [
            ((self.fill(field)), f'fill({field})')
            for field in self._fields
            if getattr(self, field) < MAX_VOLUMES[field]
        ]

        # Empty actions (only if not already empty)
        children += [
            (self.empty(field), f'empty({field})')
            for field in self._fields
            if getattr(self, field) > 0
        ]

        # Pour actions (only if source has water and target has space)
        children += [
            (self.pour(source=source, target=target), f'pour({source}, {target})')
            for source, target in permutations(self._fields, 2)
            if getattr(self, source) > 0 and getattr(self, target) < MAX_VOLUMES[target]
        ]

        return children

    def __str__(self) -> str:
        return f'({self.x}, {self.y}, {self.z})'

In [2]:
@dataclass(frozen=True)
class Node:
    state: 'State'
    parent: 'Node | None' = None
    action: str | None = None
    path_cost: int = 0

    def expand_node(self) -> list['Node']:
        """Generate all children for this node."""
        return [
            Node(state=state, parent=self, action=action, path_cost=self.path_cost + 1)
            for state, action in self.state.expand()
        ]

    def get_path(self) -> list[tuple[str | None, str]]:
        """Reconstruct the path from root to this node."""
        path = []
        node = self
        while node is not None:
            path.append((node.action, f'{node.state}'))
            node = node.parent
        return list(reversed(path))

    def __repr__(self) -> str:
        parent_state = str(self.parent.state) if self.parent else 'None'
        return (
            f'Node(action={self.action}, state={self.state}, '
            f'parent={parent_state}, path_cost={self.path_cost})'
        )

In [3]:
def bfs_set(initial_state: State) -> Node | None:
    root_node = Node(state=initial_state)
    frontier = deque([root_node])
    reached = set([initial_state])
    node_count = 1

    if root_node.state.is_final():
        return root_node

    while len(frontier) > 0:
        node = frontier.popleft()

        for child in node.expand_node():
            node_count += 1
            if child.state.is_final():
                print(f'Node generated: {node_count}')
                return child
            elif child.state not in reached:
                frontier.append(child)
                reached.add(child.state)

    print(f'Node generated: {node_count}')
    return None

In [4]:
res = bfs_set(State())
res

Node generated: 34


Node(action=pour(x, z), state=(1, 8, 3), parent=(4, 8, 0), path_cost=3)

In [5]:
res.get_path()

[(None, '(0, 0, 0)'),
 ('fill(x)', '(12, 0, 0)'),
 ('pour(x, y)', '(4, 8, 0)'),
 ('pour(x, z)', '(1, 8, 3)')]

In [None]:
def bfs_visualize(initial_state: State) -> Digraph:
    dot = Digraph('graph_space')
    dot.graph_attr['rankdir'] = 'TB'

    root_node = Node(state=initial_state)
    frontier = deque([root_node])
    reached = set([initial_state])
    node_count = 1

    while len(frontier) > 0:
        node = frontier.popleft()
        node_id = str(hash(node.state))
        node_count += 1

        dot.node(node_id, label=str(node.state))

        for child in node.expand_node():
            if child.state not in reached:
                frontier.append(child)
                reached.add(child.state)

                # Most edges to already generated nodes are missing
                child_id = str(hash(child.state))
                dot.edge(node_id, child_id, label=f'{child.action}')

    print(f'Node generated: {node_count}')
    return dot

In [7]:
graph_space = bfs_visualize(State())

graph_space.render(directory='dot-output', format='pdf').replace('\\', '/')
# graph_space

Node generated: 315


'dot-output/Digraph.gv.pdf'