# 🧠 HW 3: Interleaving

In [25]:
import sys
import otter
# try:
#   import otter
# except ImportError:
#     %pip install otter-grader
#     import otter

grader = otter.Notebook("HW3.ipynb")

## Question 1: Interleaving Program Graphs

### 🔍 Objective

In this assignment, you will extend the `ProgramGraph` class to support **interleaving** two concurrent program graphs. Interleaving allows you to model concurrent execution, where each program graph represents an independent thread or component.

---

### 🔧 Conditions in ProgramGraph

We have extended the `ProgramGraph` class constructor by adding a new parameter:

```python
Cond: Set[Condition]
```

This represents the set of all conditions that may appear in the graph’s transitions or guards.

---


### 🔧 Task: Implement `interleave(pg: ProgramGraph) -> ProgramGraph`

Add a method `interleave(pg: ProgramGraph)` that returns a **new program graph** representing the interleaving of `self` with another program graph `pg`.

---

### 🤖 Interleaving Semantics

- **Locations (`Loc`)**: Cartesian product of both graphs' locations
  `(loc1, loc2)`

- **Initial Locations (`Loc0`)**: Cartesian product of initial locations
  `(loc1, loc2) ∈ pg1.Loc0 × pg2.Loc0`

- **Actions (`Act`)**: Union of actions from both graphs

- **Conditions (`Cond`)**: Union of conditions from both graphs

- **Transitions**: For each transition in either graph:
  - If `(l1, cond, act, l1') ∈ pg1`, then
    `((l1, l2), cond, act, (l1', l2)) ∈ interleaved_pg`
  - If `(l2, cond, act, l2') ∈ pg2`, then
    `((l1, l2), cond, act, (l1, l2')) ∈ interleaved_pg`

  ✅ This means only one component moves at a time (classic interleaving)

- **Effect Function**: Use a shared environment and apply the effect of the corresponding program

- **Eval Function**: Use the original `eval_fn` from either program graph

---

### ⚠️ Note on Interleaving Multiple Systems

When interleaving more than two `TransitionSystem` or `ProgramGraph` instances, the resulting states or locations **must be represented as flat tuples**, not nested ones.

✅ For example, interleaving 3 systems should produce states like:

```python
("s0", "t1", "u2")
```

❌ And not nested tuples like:

```python
(("s0", "t1"), "u2")
```

This flattening ensures consistency in state representation, simplifies analysis and labeling, and avoids errors in reachability, labeling, or visualization.


### 🧪 Example

If `pg1` has:
```python
Loc = {'A', 'B'}
Transitions = {('A', 'x > 0', 'x -= 1', 'B')}
```

And `pg2` has:
```python
Loc = {'L0', 'L1'}
Transitions = {('L0', 'y == 0', 'y += 1', 'L1')}
```

Then the interleaved program graph has locations:
```
{('A', 'L0'), ('B', 'L0'), ('A', 'L1'), ('B', 'L1')}
```

And transitions like:
```
(('A', 'L0'), 'x > 0', 'x -= 1', ('B', 'L0'))
(('B', 'L0'), 'y == 0', 'y += 1', ('B', 'L1'))
(('A', 'L0'), 'y == 0', 'y += 1', ('A', 'L1'))
(('A', 'L1'), 'x > 0', 'x -= 1', ('A', 'L1'))
```

---

### My Transition System:

In [26]:
from typing import Callable, Set, Tuple, Dict, List, Union , Optional
import matplotlib.pyplot as plt
import networkx as nx

# Add your imports here, don't forget to include your TransitionSystem class
from collections import deque
import itertools


#------------------------start of transition system class#-----------------------
State = Union[str, Tuple]  # A state can be a string or a tuple (location, environment)
Action = str  # Actions are represented as strings
Transition = Tuple[State, Action, State]  # (source_state, action, target_state)
LabelingMap = Dict[State, Set[str]]  # Maps states to atomic propositions


class TransitionSystem:
    """
    A Transition System (TS) representation.

    Attributes:
        S (Set[State]): The set of all states (strings or tuples).
        Act (Set[Action]): The set of all possible actions.
        Transitions (Set[Transition]): The set of transitions, each represented as (state_origin, action, state_target).
        I (Set[State]): The set of initial states.
        AP (Set[str]): The set of atomic propositions.
        _L (LabelingMap): A dictionary mapping states to their respective atomic propositions.
    """

    def __init__(
        self,
        states: Optional[Set[State]] = None,
        actions: Optional[Set[Action]] = None,
        transitions: Optional[Set[Transition]] = None,
        initial_states: Optional[Set[State]] = None,
        atomic_props: Optional[Set[str]] = None,
        labeling_map: Optional[LabelingMap] = None,
    ) -> None:
        """
        Initializes the Transition System.

        :param states: A set of states (each a string or a tuple). Defaults to an empty set.
        :param actions: A set of actions. Defaults to an empty set.
        :param transitions: A set of transitions, each as (state_origin, action, state_target). Defaults to an empty set.
        :param initial_states: A set of initial states. Defaults to an empty set.
        :param atomic_props: A set of atomic propositions. Defaults to an empty set.
        :param labeling_map: A dictionary mapping states to sets of atomic propositions. Defaults to an empty dictionary.
        """
        self.S: Set[State] = set(states) if states is not None else set()
        self.Act: Set[Action] = set(actions) if actions is not None else set()
        self.Transitions: Set[Transition] = set(transitions) if transitions is not None else set()
        self.I: Set[State] = set(initial_states) if initial_states is not None else set()
        self.AP: Set[str] = set(atomic_props) if atomic_props is not None else set()
        self._L: LabelingMap = dict(labeling_map) if labeling_map is not None else {}

    def add_state(self, *states: State) -> "TransitionSystem":
        """
        Adds one or more states to the transition system.

        :param states: One or more states (strings or tuples) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """

        # Check if all states are either strings or tuples
        # I'm not sure if the check is needed here, but let's keep it for safety
        if not all(isinstance(state, (str, tuple)) for state in states):
            raise ValueError("States must be strings or tuples.")
        
        self.S.update(states)
        return self

    def add_action(self, *actions: Action) -> "TransitionSystem":
        """
        Adds one or more actions to the transition system.

        :param actions: One or more actions (strings) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        # Check if all actions are strings
        # I'm not sure if the check is needed here, but let's keep it for safety
        if not all(isinstance(action, str) for action in actions):
            raise ValueError("Actions must be strings.")

        self.Act.update(actions)
        return self

    def add_transition(self, *transitions: Transition) -> "TransitionSystem":
        """
        Adds one or more transitions to the transition system.
        Ensures that all involved states and actions exist before adding the transitions.

        Each transition must be provided as a tuple of the form `(state_from, action, state_to)`, where:
        - `state_from` is the source state.
        - `action` is the action performed.
        - `state_to` is the resulting state.

        :param transitions: One or more transitions, each as a tuple `(state_from, action, state_to)`.
        :raises ValueError:
            - If a transition is not a tuple of length 3.
            - If `state_from` or `state_to` does not exist in `self.S`.
            - If `action` is not in `self.Act`.
        :return: The `TransitionSystem` instance (for method chaining).
        """
        
        # Check if all states and actions in the transitions exist in the system
        for transition in transitions:
            if not isinstance(transition, tuple) or len(transition) != 3:
                raise ValueError(f"Invalid transition format: {transition}. Must be a tuple of (state_from, action, state_to).")
            
            # Unpack the transition tuple
            state_from, action, state_to = transition

            # Check if state_from, action, and state_to are valid
            if state_from not in self.S:
                raise ValueError(f"State {state_from} is not in the transition system.")
            if state_to not in self.S:
                raise ValueError(f"State {state_to} is not in the transition system.")
            if action not in self.Act:
                raise ValueError(f"Action {action} is not in the transition system.")

            self.Transitions.add(transition)

        return self

    def add_initial_state(self, *states: State) -> "TransitionSystem":
        """
        Adds one or more states to the set of initial states.

        :param states: One or more states to be marked as initial.
        :raises ValueError: If any state does not exist in the system.
        :return: The TransitionSystem instance (for method chaining).
        """
        # Check if all states are in the system
        for state in states:
            if state not in self.S:
                raise ValueError(f"Initial state {state} must be in the transition system.")
        
        self.I.update(states)
        return self

    def add_atomic_proposition(self, *props: str) -> "TransitionSystem":
        """
        Adds one or more atomic propositions to the transition system.

        :param props: One or more atomic propositions (strings) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        # Check if all props are strings
        # I'm not sure if the check is needed here, but let's keep it for safety
        if not all(isinstance(prop, str) for prop in props):
            raise ValueError("Atomic propositions must be strings.")
        
        self.AP.update(props)
        return self

    def add_label(self, state: State, *labels: str) -> "TransitionSystem":
        """
        Adds one or more atomic propositions to a given state.

        :param state: The state to label.
        :param labels: One or more atomic propositions to be assigned to the state.
        :raises ValueError: If the state is not in the system or if any label is not a valid atomic proposition.
        :return: The TransitionSystem instance (for method chaining).
        """
        # Check if the state exists in the system
        if state not in self.S:
            raise ValueError(f"Cannot set labels for {state}. State is not in the transition system.")
        
        # Check if all labels are valid atomic propositions
        invalid_labels = {label for label in labels if label not in self.AP}
        if invalid_labels:
            raise ValueError(f"Cannot assign labels {invalid_labels}. They are not in the set of atomic propositions (AP).")

        # Add labels to the state
        if state not in self._L:
            self._L[state] = set()
        self._L[state].update(labels)
        
        return self

    def L(self, state: State) -> Set[str]:
        """
        Retrieves the set of atomic propositions that hold in a given state.

        :param state: The state whose atomic propositions are being retrieved.
        :raises ValueError: If the state is not in the transition system.
        :return: A set of atomic propositions associated with the given state.
        """
        # Check if the state exists in the system
        if state not in self.S:
            raise ValueError(f"State {state} is not in the transition system.")
        
        # Return the labels for the state, or an empty set if none exist
        return self._L.get(state, set())

    def pre(self, S: Union[State, Set[State]], action: Optional[Action] = None) -> Set[State]:
        """
        Computes the set of predecessor states from which a given state or set of states can be reached.

        :param S: A single state (string/tuple) or a collection of states.
        :param action: (Optional) If provided, filters only the transitions that use this action.
        :return: A set of predecessor states.
        """

        # Check if S is a single state or a set of states
        if not isinstance(S, set):
            S = {S}

        predecessors = set()

        for state_from, act, state_to in self.Transitions:
            # Check if the transition leads to a state in S and if the action matches (if provided)
            # If action is None, it means we want all transitions leading to S
            # If action is provided, we only want transitions with that action
            # Check if the transition leads to a state in S and if the action matches (if provided)
            # If action is None, it means we want all transitions leading to S
            if state_to in S and (action is None or act == action):
                predecessors.add(state_from)
    
        return predecessors


    def post(self, S: Union[State, Set[State]], action: Optional[Action] = None) -> Set[State]:
        """
        Computes the set of successor states reachable from a given state or a collection of states.

        :param S: A single state or a collection of states.
        :param action: (Optional) Filters transitions by this action.
        :return: A set of successor states.
        """
        # Check if S is a single state or a set of states
        if not isinstance(S, set):
            S = {S}

        successors = set()

        for state_from, act, state_to in self.Transitions:
            if state_from in S and (action is None or act == action):
                successors.add(state_to)

        return successors

    def reach(self) -> Set[State]:
        """
        Computes the set of all reachable states from the initial states.

        :return: A set of reachable states.
        """

        # If there are no initial states, return an empty set
        if not self.I:
            return set()
    
        reachable = set(self.I)
        frontier = set(self.I)
        
        # Perform a breadth-first search (BFS) to find all reachable states
        while frontier:
            new_frontier = set()
            for state in frontier:
                successors = self.post(state)
                new_states = successors - reachable
                reachable.update(new_states)
                new_frontier.update(new_states)
            frontier = new_frontier
        
        return reachable

    def is_action_deterministic(self) -> bool:
        """
        Checks whether the transition system is action-deterministic.

        A transition system is action-deterministic if:
        - It has at most one initial state.
        - For each state and action, there is at most one successor state.

        :return: True if the transition system is action-deterministic, False otherwise.
        """
        # Check if there is at most one initial state
        if len(self.I) > 1:
            return False
    
    # For each state and action, check if there is at most one successor
        for state in self.S:
            for action in self.Act:
                if len(self.post(state, action)) > 1:
                    return False
                
        return True
        

    def is_label_deterministic(self) -> bool:
        """
        Checks whether the transition system is label-deterministic.

        A transition system is label-deterministic if:
        - It has at most one initial state.
        - For each state, the number of reachable successor states is equal to the number of unique label sets
          of these successor states.

        :return: True if the transition system is label-deterministic, False otherwise.
        """
    # Check if there is at most one initial state
        if len(self.I) > 1:
            return False
        
        # For each state, check if all successor states have distinct labels
        for state in self.S:
            successors = self.post(state)
            if not successors:
                continue
            
            # Get the labels of all successor states
            successor_labels = [frozenset(self.L(s)) for s in successors]
            
            # Check if the number of distinct labels is equal to the number of successors
            if len(set(successor_labels)) != len(successors):
                return False
        
        return True

    def __repr__(self) -> str:
        """
        Returns a string representation of the Transition System.

        :return: A formatted string representation of the TS.
        """
        return (
            f"TransitionSystem(\n"
            f"  States: {self.S}\n"
            f"  Actions: {self.Act}\n"
            f"  Transitions: {len(self.Transitions)}\n"
            f"  Initial States: {self.I}\n"
            f"  Atomic Propositions: {self.AP}\n"
            f"  Labels: {self._L}\n"
            f")"
        )


    def plot(self, title: str = "Transition System", figsize: Tuple[int, int] = (10, 6)) -> None:
        """
        Plots the Transition System as a directed graph.

        :param title: Title of the plot.
        :param figsize: Figure size for the plot.
        """
        G = nx.DiGraph()

        # Add nodes (states)
        for state in self.S:
            label = f"{state}\n{' '.join(self.L(state))}" if self.L(state) else str(state)
            print(label)
            G.add_node(state, label=label, color="blue" if state in self.I else "yellow")

        # Add edges (transitions)
        for state_from, action, state_to in self.Transitions:
            G.add_edge(state_from, state_to, label=action)

        plt.figure(figsize=figsize)
        pos = nx.spring_layout(G)  # Positioning algorithm for layout

        # Draw nodes
        node_colors = [G.nodes[n]["color"] for n in G.nodes]
        nx.draw(G, pos, with_labels=True, labels=nx.get_node_attributes(G, "label"), node_color=node_colors, edgecolors="black", node_size=2000, font_size=10)

        # Draw edge labels (actions)
        edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)

        plt.title(title)
        plt.show()



#------------------------end of transition system class#------------------------

### Helper Flatten Function:

In [27]:
def flatten(item):
    """Flattens a possibly nested tuple into a single flat tuple."""
    if not isinstance(item, tuple):
        return item # Not a tuple, return as is
    result = []
    for element in item:
        if isinstance(element, tuple):
            result.extend(flatten(element)) # Recursively flatten nested tuples
        else:
            result.append(element)
    # Return single element if result has one item and wasn't originally a tuple of size 1
    # This part might need adjustment based on desired handling of single-element tuples
    # For state representation, always returning a tuple might be safer.
    return tuple(result)

# Example usage in interleave methods:
# Replace direct use of locations/states with flattened versions:
# e.g., loc1_flat = flatten(l1)
# And when constructing new tuples:
# new_parts = flatten(l1_flat) + flatten(l2_flat) # Ensure both parts are flattened before combining
# new_locs.add(tuple(p for p in new_parts)) # Re-tuple after flattening concatenation
# This needs careful testing based on how states evolve. The provided code assumes flatten works correctly.
# A simpler flatten might just handle one level of nesting if that's guaranteed.

# Revised simpler flatten (assumes max one level of nesting like ((s1,s2), s3))
def flatten(item):
    if not isinstance(item, tuple):
        return (item,) # Wrap single items in a tuple for consistency
    flat_list = []
    for x in item:
        if isinstance(x, tuple):
            flat_list.extend(x)
        else:
            flat_list.append(x)
    return tuple(flat_list)

# When combining flattened tuples a and b:
# new_state = flatten(a + b) # Flatten the concatenation

In [28]:
from typing import Callable, Set, Tuple, Dict, List, Union, Optional, Any
from collections import deque
import copy

# Add your imports here
import itertools as it
import networkx as nx
import matplotlib.pyplot as plt



Location = Union[str, Tuple]  # A state can be a string or a tuple (location, environment)
Action = str  # Actions are represented as strings
Condition = str # Conditions are represented as strings
Transition = Tuple[Location, Condition, Action, Location]  # (source_state, action, target_state)
Environment = Dict[str, Union[str, bool, int, float]]      # Variable assignments

class ProgramGraph:
    def __init__(
        self,
        locations: Set[Location],
        initial_locations: Set[Location],
        actions: Set[Action],
        conditions: Set[Condition],
        transitions: Set[Transition],
        eval_fn: Callable[[Condition, Environment], bool],
        effect_fn: Callable[[Action, Environment], Environment],
        g0: Condition
    ):
        """
        A representation of a Program Graph.

        :param locations: Set of all program locations (Loc).
        :param initial_locations: Set of initial locations (Loc0).
        :param actions: Set of possible actions (Act).
        :param conditions: Set of possible conditions (Cond).
        :param transitions: Set of transitions in the form (loc_from, condition, action, loc_to).
        :param eval_fn: Function to evaluate a condition string in an environment.
        :param effect_fn: Function to compute the new environment after applying an action.
        :param g0: Initial condition string for filtering valid starting environments.
        """
        self.Loc = set(locations)
        self.Loc0 = set(initial_locations)
        self.Act = set(actions)
        self.Cond = set(conditions)
        self.Transitions = set(transitions)
        self.eval_fn = eval_fn
        self.effect_fn = effect_fn
        self.g0 = g0

    def add_location(self, *locations: Location) -> "ProgramGraph":
        """Add one or more locations to the program graph."""
        self.Loc.update(locations)
        return self

    def add_action(self, *actions: Action) -> "ProgramGraph":
        """Add one or more actions to the program graph."""
        self.Act.update(actions)
        return self

    def add_condition(self, *conditions: Condition) -> "ProgramGraph":
        """Add one or more conditions to the program graph."""
        self.Cond.update(conditions)
        return self

    def add_transition(self, *transitions: Transition) -> "ProgramGraph":
        """
        Add one or more transitions to the program graph.

        Each transition must be a tuple: (loc_from, condition, action, loc_to).
        """
        for transition in transitions:
            if not isinstance(transition, tuple) or len(transition) != 4:
                raise ValueError(f"Invalid transition format: {transition}. Expected (loc_from, cond, action, loc_to).")
            loc_from, cond, action, loc_to = transition
            if loc_from not in self.Loc:
                raise ValueError(f"Location {loc_from} is not in the program graph.")
            if loc_to not in self.Loc:
                raise ValueError(f"Location {loc_to} is not in the program graph.")
            if action not in self.Act:
                raise ValueError(f"Action {action} is not in the program graph.")
            if cond not in self.Cond:
                raise ValueError(f"Condition {cond} is not in the program graph.")
            self.Transitions.add(transition)
        return self

    def add_initial_location(self, *locations: Location) -> "ProgramGraph":
        """Add one or more initial locations to the program graph."""
        for loc in locations:
            if loc not in self.Loc:
                raise ValueError(f"Cannot set initial location {loc}. Location is not in the set of locations.")
            self.Loc0.add(loc)
        return self

    def set_eval_fn(self, eval_fn: Callable[[Condition, Environment], bool]) -> "ProgramGraph":
        """Set the function used to evaluate conditions."""
        self.eval_fn = eval_fn
        return self

    def set_effect_fn(self, effect_fn: Callable[[Action, Environment], Environment]) -> "ProgramGraph":
        """Set the function used to apply actions to environments."""
        self.effect_fn = effect_fn
        return self

    def eval(self, condition: Condition, env: Environment) -> bool:
        """Evaluate a condition string in the given environment."""
        return self.eval_fn(condition, env)

    def effect(self, action: Action, env: Environment) -> Environment:
        """Apply an action to the environment and return the new environment."""
        return self.effect_fn(action, env)

    def valid_transitions(self, loc: Location, env: Environment, action: Action) -> List[Tuple[Location, Action, Location]]:
        """
        Return a list of valid transitions from a given location using the provided environment and action.
        """
        # Paste your solution from previous exercise here
        valid = []
        for (l_from, cond, act, l_to) in self.Transitions:
            if l_from == loc and act == action and self.eval(cond, env):
                valid.append((l_from, act, l_to))
        return valid

    def to_transition_system(self, vars: Dict[str, Set[Union[str, bool, int, float]]], labels: Set[Condition]) -> TransitionSystem:
        """
        Construct and return a Transition System from the program graph.

        :param vars: A dictionary mapping variable names to their finite sets of possible values.
        :param labels: A set of atomic proposition strings to be used for labeling states.
        :return: A TransitionSystem instance corresponding to the program graph.
        """
        # Paste your solution from previous exercise here
        # Initialize the Transition System components
        ts_states: Set[Tuple[Location, Environment]] = set()
        ts_actions: Set[Action] = self.Act
        ts_transitions: Set[Tuple[Tuple[Location, Environment], Action, Tuple[Location, Environment]]] = set()
        ts_initial_states: Set[Tuple[Location, Environment]] = set()
        ts_atomic_props: Set[str] = set(labels)
        ts_labeling_map: Dict[Tuple[Location, Environment], Set[str]] = {}

        # Generate all possible initial environments based on vars and g0
        initial_envs: List[Environment] = []
        var_names = list(vars.keys())
        var_value_combinations = it.product(*(vars[name] for name in var_names))

        for values in var_value_combinations:
            env = dict(zip(var_names, values))
            if self.eval(self.g0, env):
                initial_envs.append(env)

        # Set initial states for the TS
        queue = deque()
        for loc0 in self.Loc0:
            for env0 in initial_envs:
                initial_state = (loc0, tuple(sorted(env0.items()))) # Use tuple(sorted(items)) for hashable env
                if initial_state not in ts_states:
                    ts_initial_states.add(initial_state)
                    ts_states.add(initial_state)
                    queue.append((loc0, env0)) # Use dict env in queue for modification

        # Explore reachable states using BFS
        visited = set(ts_initial_states)

        while queue:
            current_loc, current_env_dict = queue.popleft()
            current_state_tuple = (current_loc, tuple(sorted(current_env_dict.items())))

            # Add labels for the current state
            state_labels = {ap for ap in self.AP if self.eval(ap, current_env_dict)}
            if state_labels:
                 ts_labeling_map[current_state_tuple] = state_labels

            # Explore transitions from the current state
            for l_from, cond, act, l_to in self.Transitions:
                if l_from == current_loc and self.eval(cond, current_env_dict):
                    next_env_dict = self.effect(act, current_env_dict)
                    next_state_tuple = (l_to, tuple(sorted(next_env_dict.items())))

                    # Add the transition to the TS
                    ts_transitions.add((current_state_tuple, act, next_state_tuple))

                    # If the next state hasn't been visited, add it to the queue and visited set
                    if next_state_tuple not in visited:
                        visited.add(next_state_tuple)
                        ts_states.add(next_state_tuple)
                        queue.append((l_to, next_env_dict)) # Add new state to explore
                    elif next_state_tuple not in ts_states: # Ensure state exists even if visited before full exploration started
                        ts_states.add(next_state_tuple)


        # Ensure all states involved in transitions are in the state set
        for s_from, _, s_to in ts_transitions:
            ts_states.add(s_from)
            ts_states.add(s_to)

        # Ensure initial states are added
        ts_states.update(ts_initial_states)


        # Need to import TransitionSystem if not already done globally
        from __main__ import TransitionSystem # Or wherever TransitionSystem is defined

        return TransitionSystem(
            states=ts_states,
            actions=ts_actions,
            transitions=ts_transitions,
            initial_states=ts_initial_states,
            atomic_props=ts_atomic_props,
            labeling_map=ts_labeling_map
        )

    def interleave(self, pg: "ProgramGraph") -> "ProgramGraph":
        """
        Returns a new program graph representing the interleaving of self with another program graph pg.
        """
        new_locs = set()
        new_loc0 = set()
        new_trans = set()

        # Flatten locations if they are already tuples (from previous interleaving)
        loc1_list = [flatten(l) for l in self.Loc]
        loc2_list = [flatten(l) for l in pg.Loc]

        # Cartesian product for new locations
        for l1_flat in loc1_list:
             for l2_flat in loc2_list:
                # Ensure the result is always a tuple, even if one part was not originally a tuple
                new_loc_parts = (l1_flat if isinstance(l1_flat, tuple) else (l1_flat,)) + \
                                (l2_flat if isinstance(l2_flat, tuple) else (l2_flat,))
                new_locs.add(new_loc_parts)


        # Cartesian product for initial locations
        loc1_init_flat = [flatten(l) for l in self.Loc0]
        loc2_init_flat = [flatten(l) for l in pg.Loc0]
        for l1_init in loc1_init_flat:
             for l2_init in loc2_init_flat:
                 new_loc0_parts = (l1_init if isinstance(l1_init, tuple) else (l1_init,)) + \
                                 (l2_init if isinstance(l2_init, tuple) else (l2_init,))
                 new_loc0.add(new_loc0_parts)


        # Union of actions and conditions
        new_actions = self.Act.union(pg.Act)
        new_conditions = self.Cond.union(pg.Cond)


        # Combined eval function (assuming they use the same underlying logic)
        # If eval_fn or effect_fn could be different in incompatible ways, this needs refinement.
        combined_eval_fn = self.eval_fn
        combined_effect_fn = self.effect_fn # Assuming effect applies universally based on action name

        # Interleaved transitions
        for l1_from, cond, act, l1_to in self.Transitions:
            l1_from_flat = flatten(l1_from)
            l1_to_flat = flatten(l1_to)
            for l2 in loc2_list:
                 l2_flat = flatten(l2)
                 src_parts = (l1_from_flat if isinstance(l1_from_flat, tuple) else (l1_from_flat,)) + \
                             (l2_flat if isinstance(l2_flat, tuple) else (l2_flat,))
                 tgt_parts = (l1_to_flat if isinstance(l1_to_flat, tuple) else (l1_to_flat,)) + \
                             (l2_flat if isinstance(l2_flat, tuple) else (l2_flat,))
                 new_trans.add((src_parts, cond, act, tgt_parts))


        for l2_from, cond, act, l2_to in pg.Transitions:
             l2_from_flat = flatten(l2_from)
             l2_to_flat = flatten(l2_to)
             for l1 in loc1_list:
                 l1_flat = flatten(l1)
                 src_parts = (l1_flat if isinstance(l1_flat, tuple) else (l1_flat,)) + \
                             (l2_from_flat if isinstance(l2_from_flat, tuple) else (l2_from_flat,))
                 tgt_parts = (l1_flat if isinstance(l1_flat, tuple) else (l1_flat,)) + \
                             (l2_to_flat if isinstance(l2_to_flat, tuple) else (l2_to_flat,))
                 new_trans.add((src_parts, cond, act, tgt_parts))

        # Combine initial conditions (assuming logical AND, might need clarification)
        # This is a simplification; combining initial conditions rigorously might depend on context.
        # Let's assume for now the initial environments must satisfy *both* g0 conditions if they involve shared variables.
        # A safer bet might be to handle initial condition satisfaction during TS conversion.
        # For the PG itself, maybe just combine them textually if simple, or require manual definition.
        # Let's assume g0 is handled externally or during TS conversion for simplicity here.
        # We'll just take self.g0 as the primary, acknowledging this might be incomplete.
        new_g0 = self.g0 # Or potentially combine: f"({self.g0}) and ({pg.g0})" if they are boolean expressions

        return ProgramGraph(
            locations=new_locs,
            initial_locations=new_loc0,
            actions=new_actions,
            conditions=new_conditions,
            transitions=new_trans,
            eval_fn=combined_eval_fn,
            effect_fn=combined_effect_fn,
            g0=new_g0
        )

    def __repr__(self):
                return (
            f"ProgramGraph(\n"
            f"  Loc: {self.Loc}\n"
            f"  Act: {self.Act}\n"
            f"  Cond: {self.Cond}\n"
            f"  Transitions: {self.Transitions}\n"
            f"  Loc0: {self.Loc0}\n"
            f"  g0: {self.g0}\n"
            f")"
        )


    def plot(self):
        """
        Visualize the program graph as a directed graph using networkx and matplotlib.

        Nodes are locations. Edges are labeled with (condition, action).
        """
        G = nx.MultiDiGraph()

        # Add nodes
        for loc in self.Loc:
            G.add_node(loc, color='lightblue' if loc in self.Loc0 else 'white')

        # Add edges with (condition, action) labels
        for (loc_from, cond, action, loc_to) in self.Transitions:
            label = f"{cond} / {action}"
            G.add_edge(loc_from, loc_to, label=label)

        pos = nx.spring_layout(G, seed=42)  # consistent layout

        # Draw nodes with color
        node_colors = [G.nodes[n].get('color', 'white') for n in G.nodes]
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, edgecolors='black')
        nx.draw_networkx_labels(G, pos)

        # Draw edges
        nx.draw_networkx_edges(G, pos, arrows=True)

        # Draw edge labels
        edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

        # Display
        plt.title("Program Graph")
        plt.axis('off')
        plt.tight_layout()
        plt.show()

In [29]:
grader.check("q1")

## 🔄 Question 2: Interleaving Transition Systems with Handshaking

### 🔍 Objective

You will now extend the `TransitionSystem` class to support **interleaving** two transition systems, optionally synchronizing on a set of **handshaking actions**.

This is useful for modeling **parallel systems** that may:
- Evolve independently (interleaving)
- Or synchronize on specific shared actions (handshake)

---

### 🧩 Task

Add a method to the `TransitionSystem` class:

```python
def interleave(self, ts: TransitionSystem, h: Set[Action]) -> TransitionSystem:
    ...
```

Where:
- `self` is the first transition system (`TS1`)
- `ts` is the second transition system (`TS2`)
- `h` is a set of **handshaking actions** ⊆ `Act1 ∩ Act2`

---

### 🤖 Interleaving Semantics

Let `s1`, `s1'` be states in `TS1`, and `s2`, `s2'` be states in `TS2`.

You should add the following transitions to the interleaved system:

#### 🔹 Independent Transitions

If `a ∈ Act1 \ h`:
```
((s1, s2), a, (s1', s2)) if (s1, a, s1') ∈ TS1
```

If `a ∈ Act2 \ h`:
```
((s1, s2), a, (s1, s2')) if (s2, a, s2') ∈ TS2
```

#### 🔸 Handshaking Transitions

If `a ∈ h`:
```
((s1, s2), a, (s1', s2')) if (s1, a, s1') ∈ TS1 and (s2, a, s2') ∈ TS2
```

---

### 🧱 Transition System Components

Your interleaved TS should include:

- `S = S1 × S2`
- `Act = Act1 ∪ Act2`
- `Transitions =` as defined above
- `I = I1 × I2`
- `AP = AP1 ∪ AP2`
- `L((s1, s2)) = L1(s1) ∪ L2(s2)`

---

### 🧪 Example

Suppose:

TS1:
```
S = {s0, s1}, Act = {a}, Transitions = {(s0, a, s1)}
```

TS2:
```
S = {t0, t1}, Act = {a}, Transitions = {(t0, a, t1)}
```

With `h = {'a'}`, the interleaved TS will have:
```
((s0, t0), a, (s1, t1))
```

If `h = ∅`, you'll get:
```
((s0, t0), a, (s1, t0))
((s0, t0), a, (s0, t1))
((s1, t0), a, (s1, t1))
((s0, t1), a, (s1, t1))
```

---

In [30]:
import itertools as it
from typing import Set, Dict, Tuple, Union, Optional

import networkx as nx
import matplotlib.pyplot as plt

State = Union[str, Tuple]  # A state can be a string or a tuple (location, environment)
Action = str  # Actions are represented as strings
Transition = Tuple[State, Action, State]  # (source_state, action, target_state)
LabelingMap = Dict[State, Set[str]]  # Maps states to atomic propositions


class TransitionSystem:
    """
    A Transition System (TS) representation.

    Attributes:
        S (Set[State]): The set of all states (strings or tuples).
        Act (Set[Action]): The set of all possible actions.
        Transitions (Set[Transition]): The set of transitions, each represented as (state_origin, action, state_target).
        I (Set[State]): The set of initial states.
        AP (Set[str]): The set of atomic propositions.
        _L (LabelingMap): A dictionary mapping states to their respective atomic propositions.
    """

    def __init__(
        self,
        states: Optional[Set[State]] = None,
        actions: Optional[Set[Action]] = None,
        transitions: Optional[Set[Transition]] = None,
        initial_states: Optional[Set[State]] = None,
        atomic_props: Optional[Set[str]] = None,
        labeling_map: Optional[LabelingMap] = None,
    ) -> None:
        """
        Initializes the Transition System.

        :param states: A set of states (each a string or a tuple). Defaults to an empty set.
        :param actions: A set of actions. Defaults to an empty set.
        :param transitions: A set of transitions, each as (state_origin, action, state_target). Defaults to an empty set.
        :param initial_states: A set of initial states. Defaults to an empty set.
        :param atomic_props: A set of atomic propositions. Defaults to an empty set.
        :param labeling_map: A dictionary mapping states to sets of atomic propositions. Defaults to an empty dictionary.
        """
        self.S: Set[State] = set(states) if states is not None else set()
        self.Act: Set[Action] = set(actions) if actions is not None else set()
        self.Transitions: Set[Transition] = set(transitions) if transitions is not None else set()
        self.I: Set[State] = set(initial_states) if initial_states is not None else set()
        self.AP: Set[str] = set(atomic_props) if atomic_props is not None else set()
        self._L: LabelingMap = dict(labeling_map) if labeling_map is not None else {}

    def add_state(self, *states: State) -> "TransitionSystem":
        """
        Adds one or more states to the transition system.

        :param states: One or more states (strings or tuples) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        self.S.update(states)
        return self

    def add_action(self, *actions: Action) -> "TransitionSystem":
        """
        Adds one or more actions to the transition system.

        :param actions: One or more actions (strings) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        self.Act.update(actions)
        return self

    def add_transition(self, *transitions: Transition) -> "TransitionSystem":
        """
        Adds one or more transitions to the transition system.
        Ensures that all involved states and actions exist before adding the transitions.

        Each transition must be provided as a tuple of the form `(state_from, action, state_to)`, where:
        - `state_from` is the source state.
        - `action` is the action performed.
        - `state_to` is the resulting state.

        :param transitions: One or more transitions, each as a tuple `(state_from, action, state_to)`.
        :raises ValueError:
            - If a transition is not a tuple of length 3.
            - If `state_from` or `state_to` does not exist in `self.S`.
            - If `action` is not in `self.Act`.
        :return: The `TransitionSystem` instance (for method chaining).
        """
        for transition in transitions:
            if not isinstance(transition, tuple) or len(transition) != 3:
                raise ValueError(f"Invalid transition format: {transition}. Expected (state_from, action, state_to).")

            state_from, action, state_to = transition

            if state_from not in self.S:
                raise ValueError(f"State {state_from} is not in the transition system.")
            if state_to not in self.S:
                raise ValueError(f"State {state_to} is not in the transition system.")
            if action not in self.Act:
                raise ValueError(f"Action {action} is not in the transition system.")

            self.Transitions.add(transition)
        return self

    def add_initial_state(self, *states: State) -> "TransitionSystem":
        """
        Adds one or more states to the set of initial states.

        :param states: One or more states to be marked as initial.
        :raises ValueError: If any state does not exist in the system.
        :return: The TransitionSystem instance (for method chaining).
        """
        for state in states:
            if state not in self.S:
                raise ValueError(f"Initial state {state} must be in the transition system.")
            self.I.add(state)
        return self

    def add_atomic_proposition(self, *props: str) -> "TransitionSystem":
        """
        Adds one or more atomic propositions to the transition system.

        :param props: One or more atomic propositions (strings) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        self.AP.update(props)
        return self

    def add_label(self, state: State, *labels: str) -> "TransitionSystem":
        """
        Adds one or more atomic propositions to a given state.

        :param state: The state to label.
        :param labels: One or more atomic propositions to be assigned to the state.
        :raises ValueError: If the state is not in the system or if any label is not a valid atomic proposition.
        :return: The TransitionSystem instance (for method chaining).
        """
        if state not in self.S:
            raise ValueError(f"Cannot set labels for {state}. State is not in the transition system.")

        invalid_labels = {label for label in labels if label not in self.AP}
        if invalid_labels:
            raise ValueError(f"Cannot assign labels {invalid_labels}. They are not in the set of atomic propositions (AP).")

        self._L.setdefault(state, set()).update(labels)
        return self

    def L(self, state: State) -> Set[str]:
        """
        Retrieves the set of atomic propositions that hold in a given state.

        :param state: The state whose atomic propositions are being retrieved.
        :raises ValueError: If the state is not in the transition system.
        :return: A set of atomic propositions associated with the given state.
        """
        if state not in self.S:
            raise ValueError(f"State {state} is not in the transition system.")
        return self._L.get(state, set())

    def pre(self, S: Union[State, Set[State]], action: Optional[Action] = None) -> Set[State]:
        """
        Computes the set of predecessor states from which a given state or set of states can be reached.

        :param S: A single state (string/tuple) or a collection of states.
        :param action: (Optional) If provided, filters only the transitions that use this action.
        :return: A set of predecessor states.
        """
        # Paste your solution from the previous exercise here
        target_states = {S} if not isinstance(S, set) else S
        predecessors = set()

        for state_from, trans_action, state_to in self.Transitions:
            if state_to in target_states:
                if action is None or trans_action == action:
                    predecessors.add(state_from)
        return predecessors

    def post(self, S: Union[State, Set[State]], action: Optional[Action] = None) -> Set[State]:
        """
        Computes the set of successor states reachable from a given state or a collection of states.

        :param S: A single state or a collection of states.
        :param action: (Optional) Filters transitions by this action.
        :return: A set of successor states.
        """
        # Paste your solution from the previous exercise here
        source_states = {S} if not isinstance(S, set) else S
        successors = set()

        for state_from, trans_action, state_to in self.Transitions:
            if state_from in source_states:
                if action is None or trans_action == action:
                    successors.add(state_to)
        return successors

    def reach(self) -> Set[State]:
        """
        Computes the set of all reachable states from the initial states.

        :return: A set of reachable states.
        """
        # Paste your solution from the previous exercise here
        reachable_states = set(self.I)
        queue = deque(self.I)

        while queue:
            current_state = queue.popleft()
            successors = self.post(current_state)
            for succ_state in successors:
                if succ_state not in reachable_states:
                    reachable_states.add(succ_state)
                    queue.append(succ_state)
        return reachable_states

    def is_action_deterministic(self) -> bool:
        """
        Checks whether the transition system is action-deterministic.

        A transition system is action-deterministic if:
        - It has at most one initial state.
        - For each state and action, there is at most one successor state.

        :return: True if the transition system is action-deterministic, False otherwise.
        """
        # Paste your solution from the previous exercise here
        if len(self.I) > 1:
            return False

        successors_map: Dict[Tuple[State, Action], State] = {}
        for state_from, action, state_to in self.Transitions:
            key = (state_from, action)
            if key in successors_map and successors_map[key] != state_to:
                return False  # Found a different successor for the same state-action pair
            successors_map[key] = state_to

        return True

    def is_label_deterministic(self) -> bool:
        """
        Checks whether the transition system is label-deterministic.

        A transition system is label-deterministic if:
        - It has at most one initial state.
        - For each state, the number of reachable successor states is equal to the number of unique label sets
          of these successor states.

        :return: True if the transition system is label-deterministic, False otherwise.
        """
        # Paste your solution from the previous exercise here
        if len(self.I) > 1:
            return False

        for state in self.S:
            successors = self.post(state)
            if not successors:
                continue # No successors, condition trivially holds for this state

            num_successors = len(successors)
            label_sets_of_successors = set()
            for succ in successors:
                # Create a frozenset of labels to make the set hashable
                label_set = frozenset(self.L(succ))
                label_sets_of_successors.add(label_set)

            num_unique_label_sets = len(label_sets_of_successors)

            if num_successors != num_unique_label_sets:
                return False

        return True

    def interleave(self, ts, h=frozenset()):
        S1, Act1, Trans1, I1, AP1, L1 = self.S, self.Act, self.Transitions, self.I, self.AP, self._L
        S2, Act2, Trans2, I2, AP2, L2 = ts.S, ts.Act, ts.Transitions, ts.I, ts.AP, ts._L

        # New components
        new_S = set()
        new_I = set()
        new_Trans = set()
        new_AP = AP1.union(AP2)
        new_L = {}

        # Flatten states if they are already tuples
        s1_list = [flatten(s) for s in S1]
        s2_list = [flatten(s) for s in S2]

        # 1. Calculate States (S = S1 x S2)
        for s1_flat in s1_list:
            for s2_flat in s2_list:
                new_state_parts = (s1_flat if isinstance(s1_flat, tuple) else (s1_flat,)) + \
                                  (s2_flat if isinstance(s2_flat, tuple) else (s2_flat,))
                new_S.add(new_state_parts)

        # 2. Calculate Initial States (I = I1 x I2)
        i1_flat_list = [flatten(i) for i in I1]
        i2_flat_list = [flatten(i) for i in I2]
        for i1_flat in i1_flat_list:
            for i2_flat in i2_flat_list:
                 new_initial_state_parts = (i1_flat if isinstance(i1_flat, tuple) else (i1_flat,)) + \
                                           (i2_flat if isinstance(i2_flat, tuple) else (i2_flat,))
                 new_I.add(new_initial_state_parts)


        # 3. Calculate Actions (Act = Act1 U Act2)
        new_Act = Act1.union(Act2)

        # 4. Calculate Labeling Function (L((s1, s2)) = L1(s1) U L2(s2))
        # Need original states to lookup labels before combining
        state_map_to_original = {}
        for s1 in S1:
            for s2 in S2:
                s1_flat = flatten(s1)
                s2_flat = flatten(s2)
                new_state_parts = (s1_flat if isinstance(s1_flat, tuple) else (s1_flat,)) + \
                                  (s2_flat if isinstance(s2_flat, tuple) else (s2_flat,))
                state_map_to_original[new_state_parts] = (s1, s2)


        for new_state, (s1_orig, s2_orig) in state_map_to_original.items():
             labels1 = self.L(s1_orig) # Use original L method
             labels2 = ts.L(s2_orig)   # Use original L method
             combined_labels = labels1.union(labels2)
             if combined_labels: # Only add if there are labels
                 new_L[new_state] = combined_labels


        # 5. Calculate Transitions
        # Store original transitions mapped by source state and action for efficient lookup
        trans1_map = {} # (s1, a) -> set of s1'
        for s1, a, s1_prime in Trans1:
             trans1_map.setdefault((s1, a), set()).add(s1_prime)

        trans2_map = {} # (s2, a) -> set of s2'
        for s2, a, s2_prime in Trans2:
             trans2_map.setdefault((s2, a), set()).add(s2_prime)


        for new_state_tuple, (s1_orig, s2_orig) in state_map_to_original.items():
            s1_flat, s2_flat = flatten(s1_orig), flatten(s2_orig) # We need flattened versions for state construction

            # Independent moves from TS1 (a in Act1 \ h)
            for a in Act1 - h:
                 if (s1_orig, a) in trans1_map:
                     for s1_prime_orig in trans1_map[(s1_orig, a)]:
                        s1_prime_flat = flatten(s1_prime_orig)
                        # Construct target state tuple
                        tgt_state_parts = (s1_prime_flat if isinstance(s1_prime_flat, tuple) else (s1_prime_flat,)) + \
                                          (s2_flat if isinstance(s2_flat, tuple) else (s2_flat,))
                        new_Trans.add((new_state_tuple, a, tgt_state_parts))


            # Independent moves from TS2 (a in Act2 \ h)
            for a in Act2 - h:
                 if (s2_orig, a) in trans2_map:
                     for s2_prime_orig in trans2_map[(s2_orig, a)]:
                        s2_prime_flat = flatten(s2_prime_orig)
                        # Construct target state tuple
                        tgt_state_parts = (s1_flat if isinstance(s1_flat, tuple) else (s1_flat,)) + \
                                          (s2_prime_flat if isinstance(s2_prime_flat, tuple) else (s2_prime_flat,))
                        new_Trans.add((new_state_tuple, a, tgt_state_parts))


            # Handshaking moves (a in h)
            for a in h:
                 if (s1_orig, a) in trans1_map and (s2_orig, a) in trans2_map:
                     for s1_prime_orig in trans1_map[(s1_orig, a)]:
                         for s2_prime_orig in trans2_map[(s2_orig, a)]:
                             s1_prime_flat = flatten(s1_prime_orig)
                             s2_prime_flat = flatten(s2_prime_orig)
                             # Construct target state tuple
                             tgt_state_parts = (s1_prime_flat if isinstance(s1_prime_flat, tuple) else (s1_prime_flat,)) + \
                                               (s2_prime_flat if isinstance(s2_prime_flat, tuple) else (s2_prime_flat,))
                             new_Trans.add((new_state_tuple, a, tgt_state_parts))

        return TransitionSystem(new_S, new_Act, new_Trans, new_I, new_AP, new_L)

    def __repr__(self) -> str:
        """
        Returns a string representation of the Transition System.

        :return: A formatted string representation of the TS.
        """
        return (
            f"TransitionSystem(\n"
            f"  States: {self.S}\n"
            f"  Actions: {self.Act}\n"
            f"  Transitions: {self.Transitions}\n"
            f"  Initial States: {self.I}\n"
            f"  Atomic Propositions: {self.AP}\n"
            f"  Labels: {self._L}\n"
            f")"
        )


    def plot(self, title: str = "Transition System", figsize: Tuple[int, int] = (10, 6)) -> None:
        """
        Plots the Transition System as a directed graph.

        :param title: Title of the plot.
        :param figsize: Figure size for the plot.
        """
        G = nx.DiGraph()

        # Add nodes (states)
        for state in self.S:
            label = f"{state}\n{' '.join(self.L(state))}" if self.L(state) else str(state)
            print(label)
            G.add_node(state, label=label, color="blue" if state in self.I else "yellow")

        # Add edges (transitions)
        for state_from, action, state_to in self.Transitions:
            G.add_edge(state_from, state_to, label=action)

        plt.figure(figsize=figsize)
        pos = nx.spring_layout(G)  # Positioning algorithm for layout

        # Draw nodes
        node_colors = [G.nodes[n]["color"] for n in G.nodes]
        nx.draw(G, pos, with_labels=True, labels=nx.get_node_attributes(G, "label"), node_color=node_colors, edgecolors="black", node_size=2000, font_size=10)

        # Draw edge labels (actions)
        edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)

        plt.title(title)
        plt.show()


In [31]:
grader.check("q2")

## ♟️ Question 3: Solving N-Queens with Transition Systems


### 🧩 Task 1: Build the System

Your goal in this task is to model the **N-Queens problem** using `TransitionSystem` objects and interleaving.

Each queen must be placed in a unique row and must not threaten any other queen. You will model each queen as a separate transition system component and interleave them to explore all possible board configurations.

---

#### 🔧 Function 1: `queen_ts(row: int, n: int) -> TransitionSystem`

This function builds a transition system representing a **single queen in row `row`**. The queen can choose to be placed in any of the `n` columns.

- **States**: `"start"`, and one state per column: `"q{row}{c}"`
- **Actions**: One per column: `"row->column"` (e.g., `"2->3"`)
- **Transitions**: From `"start"` to `"q{row}{c}"` via `"row->c"`
- **Atomic Propositions**: `"q{row}{c}"` for each column
- **Labeling**: `"q{row}{c}"` holds in state `"q{row}{c}"`

---

#### 🔧 Function 2: `interleave_n_queens(n: int) -> TransitionSystem`

This function should:

- Construct `n` independent TSs (one per queen/row) using `queen_ts(...)`
- Interleave them into a single transition system using `.interleave(...)`
- No handshaking is required

The resulting transition system will explore all ways to place `n` queens on an `n×n` board, one per row.

---

#### 🔧 Function 3: `extract_valid_n_queen_states(ts: TransitionSystem) -> List[Tuple[str]]`

After building the interleaved TS:

- Use `.reach()` to find all reachable states
- Filter only those states where all queens are placed (i.e., no `"start"` state appears)
- For each state, extract column indices and check for diagonal or column conflicts using `is_safe(...)`

Return all valid board configurations.

---

#### 🧠 Helper: `is_safe(positions: List[int]) -> bool`

This utility function receives a list of queen column positions and checks:

- No queens are in the same column
- No queens are on the same diagonal

Use this to validate candidate states.

---


### 💡 Example

For `n = 4`, your solution should yield exactly 2 valid solutions:

```python
[
    ('q01', 'q13', 'q20', 'q32'),
    ('q02', 'q10', 'q23', 'q31')
]
```

In [32]:
def queen_ts(row: int, n: int) -> TransitionSystem:
    """
    Build a TS for a queen in row `row`, choosing a column from 0 to n-1.
    """
    states = {"start"}
    actions = set()
    transitions = set()
    initial_states = {"start"}
    atomic_props = set()
    labeling_map = {"start": set()} # Start state has no labels initially

    for c in range(n):
        state_name = f"q{row}{c}"
        action_name = f"{row}->{c}"
        ap_name = state_name # Atomic proposition is the state name

        states.add(state_name)
        actions.add(action_name)
        atomic_props.add(ap_name)

        # Transition from start to the chosen column state
        transitions.add(("start", action_name, state_name))

        # Label the state with its corresponding atomic proposition
        labeling_map[state_name] = {ap_name}

    return TransitionSystem(
        states=states,
        actions=actions,
        transitions=transitions,
        initial_states=initial_states,
        atomic_props=atomic_props,
        labeling_map=labeling_map
    )



def interleave_n_queens(n: int) -> TransitionSystem:
    """
    Interleave N TSs for N queens placed on an NxN board.
    Each TS represents a queen choosing one column.
    """
    if n <= 0:
        # Return an empty TS or raise error for invalid n
        return TransitionSystem()
    if n == 1:
        # Base case: If n=1, just return the TS for the single queen
        return queen_ts(0, 1)

    # Create the first queen's TS
    interleaved_ts = queen_ts(0, n)

    # Interleave the remaining queens one by one
    for r in range(1, n):
        current_queen_ts = queen_ts(r, n)
        # Interleave with the result so far, no handshaking (h=empty set)
        interleaved_ts = interleaved_ts.interleave(current_queen_ts, h=frozenset())

    return interleaved_ts


def is_safe(positions: List[int]) -> bool:
    """
    Check if a list of queen column positions is a valid N-Queens solution.
    """
    n = len(positions)
    if len(set(positions)) != n:
        return False  # Check for column conflicts (duplicates)

    for r1 in range(n):
        for r2 in range(r1 + 1, n):
            c1 = positions[r1]
            c2 = positions[r2]
            # Check for diagonal conflicts
            if abs(r1 - r2) == abs(c1 - c2):
                return False
    return True



def extract_valid_n_queen_states(ts: TransitionSystem) -> List[Tuple[str]]:
    """
    Given the interleaved TS, return all reachable states that represent valid N-Queens solutions.
    """
    reachable = ts.reach()
    valid_solutions = []
    n = -1 # Determine n from the state format

    # Filter states that represent a complete placement (no "start" components)
    complete_states = []
    for state_tuple in reachable:
        # Ensure state is a tuple and doesn't contain 'start'
        if isinstance(state_tuple, tuple) and "start" not in state_tuple:
             complete_states.append(state_tuple)
             if n == -1:
                n = len(state_tuple) # Infer N from the length of the state tuple


    if n == -1 and len(ts.I) == 1 and list(ts.I)[0] == 'start': # Handle n=1 case where state might not be tuple
         # Special handling if n=1 - check if 'q00' is reachable
         if 'q00' in reachable:
             return [('q00',)] # The only valid state for n=1
         else:
             return [] # No valid state reached
    elif n == -1: # If N could not be determined, something is wrong
        # print("Could not determine N from states:", reachable) # Debugging
        return [] # Or raise an error

    # Check each complete state for N-Queens validity
    for state_tuple in complete_states:
         # Extract column positions assuming state format 'q{row}{col}'
         positions = [-1] * n
         valid_format = True
         for queen_state in state_tuple:
             if isinstance(queen_state, str) and queen_state.startswith('q') and len(queen_state) >= 3:
                 try:
                    # Extract row and column - assuming single digit row/col for simplicity first
                    # More robust parsing might be needed for n >= 10
                    row_str = ""
                    col_str = ""
                    in_col = False
                    for char in queen_state[1:]: # Skip 'q'
                        if char.isdigit():
                            if not in_col:
                                row_str += char
                            else:
                                col_str += char
                        else: # If we encounter non-digit, assume switch from row to col (might fail for complex names)
                              # A better approach relies on fixed format like knowing N or delimiters
                              # Let's assume simple 'q<row><col>' format
                              in_col = True # Simplistic assumption
                              col_str += char # This might be wrong if format is complex

                    # Assuming q<row><col> format after previous logic refinement attempt
                    # Let's extract based on expected length, assuming row index matches tuple index
                    # This part is tricky without knowing the exact state format guaranteed by interleave
                    # Safest: parse based on 'q' + row_index + col_index
                    row_index_in_tuple = -1
                    # Find which original queen this state component belongs to
                    for r_idx in range(n):
                        if queen_state.startswith(f'q{r_idx}'):
                            row_index_in_tuple = r_idx
                            break

                    if row_index_in_tuple != -1:
                        col = int(queen_state[len(f'q{row_index_in_tuple}'):])
                        if positions[row_index_in_tuple] == -1: # Ensure we haven't assigned this row yet
                           positions[row_index_in_tuple] = col
                        else:
                            valid_format = False; break # Duplicate row index found, state invalid
                    else:
                        valid_format = False; break # Could not parse row index

                 except (ValueError, IndexError):
                     valid_format = False
                     break
             else:
                 valid_format = False
                 break

         # If the format was parsed correctly and all rows assigned, check safety
         if valid_format and -1 not in positions:
             if is_safe(positions):
                 # Sort the tuple based on row index implicitly derived?
                 # The interleaving should naturally produce tuples ordered by row (0..n-1)
                 # based on the order of interleaving in interleave_n_queens
                 valid_solutions.append(state_tuple)


    # Sort the final list of solutions for consistent output
    # Sorting tuples of strings:
    valid_solutions.sort()

    return valid_solutions

In [33]:
grader.check("q3")

## Submission

Make sure you have run all cells in your notebook in order before running the cell below, so that all images/graphs appear in the output. The cell below will generate a zip file for you to submit. **Please save before exporting!**

In [34]:
# Save your notebook first, then run this cell to export your submission.
grader.export(pdf=False)