In [26]:
import networkx as nx
import matplotlib.pyplot as plt
from typing import Set, Dict, Tuple, Union, Optional

State = Union[str, Tuple]  # A state can be a string or a tuple (location, environment)
Action = str  # Actions are represented as strings
Transition = Tuple[State, Action, State]  # (source_state, action, target_state)
LabelingMap = Dict[State, Set[str]]  # Maps states to atomic propositions


class TransitionSystem:
    """
    A Transition System (TS) representation.

    Attributes:
        S (Set[State]): The set of all states (strings or tuples).
        Act (Set[Action]): The set of all possible actions.
        Transitions (Set[Transition]): The set of transitions, each represented as (state_origin, action, state_target).
        I (Set[State]): The set of initial states.
        AP (Set[str]): The set of atomic propositions.
        _L (LabelingMap): A dictionary mapping states to their respective atomic propositions.
    """

    def __init__(
        self,
        states: Optional[Set[State]] = None,
        actions: Optional[Set[Action]] = None,
        transitions: Optional[Set[Transition]] = None,
        initial_states: Optional[Set[State]] = None,
        atomic_props: Optional[Set[str]] = None,
        labeling_map: Optional[LabelingMap] = None,
    ) -> None:
        """
        Initializes the Transition System.

        :param states: A set of states (each a string or a tuple). Defaults to an empty set.
        :param actions: A set of actions. Defaults to an empty set.
        :param transitions: A set of transitions, each as (state_origin, action, state_target). Defaults to an empty set.
        :param initial_states: A set of initial states. Defaults to an empty set.
        :param atomic_props: A set of atomic propositions. Defaults to an empty set.
        :param labeling_map: A dictionary mapping states to sets of atomic propositions. Defaults to an empty dictionary.
        """
        self.S: Set[State] = set(states) if states is not None else set()
        self.Act: Set[Action] = set(actions) if actions is not None else set()
        self.Transitions: Set[Transition] = set(transitions) if transitions is not None else set()
        self.I: Set[State] = set(initial_states) if initial_states is not None else set()
        self.AP: Set[str] = set(atomic_props) if atomic_props is not None else set()
        self._L: LabelingMap = dict(labeling_map) if labeling_map is not None else {}

    def add_state(self, *states: State) -> "TransitionSystem":
        """
        Adds one or more states to the transition system.

        :param states: One or more states (strings or tuples) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        self.S.update(states)
        return self

    def add_action(self, *actions: Action) -> "TransitionSystem":
        """
        Adds one or more actions to the transition system.

        :param actions: One or more actions (strings) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        self.Act.update(actions)
        return self

    def add_transition(self, *transitions: Transition) -> "TransitionSystem":
        """
        Adds one or more transitions to the transition system.
        Ensures that all involved states and actions exist before adding the transitions.

        Each transition must be provided as a tuple of the form `(state_from, action, state_to)`, where:
        - `state_from` is the source state.
        - `action` is the action performed.
        - `state_to` is the resulting state.

        :param transitions: One or more transitions, each as a tuple `(state_from, action, state_to)`.
        :raises ValueError:
            - If a transition is not a tuple of length 3.
            - If `state_from` or `state_to` does not exist in `self.S`.
            - If `action` is not in `self.Act`.
        :return: The `TransitionSystem` instance (for method chaining).
        """
        for transition in transitions:
            if len(transition) != 3:
                raise ValueError("Each transition must be a tuple of length 3.")
            state_from, action, state_to = transition
            if state_from not in self.S:
                raise ValueError(f"State {state_from} is not in the transition system.")
            if state_to not in self.S:
                raise ValueError(f"State {state_to} is not in the transition system.")
            if action not in self.Act:
                raise ValueError(f"Action {action} is not in the transition system.")
            self.Transitions.add(transition)
        return self

    def add_initial_state(self, *states: State) -> "TransitionSystem":
        """
        Adds one or more states to the set of initial states.

        :param states: One or more states to be marked as initial.
        :raises ValueError: If any state does not exist in the system.
        :return: The TransitionSystem instance (for method chaining).
        """
        for state in states:
            if state not in self.S:
                raise ValueError(f"Initial state {state} must be in the transition system.")
        self.I.update(states)
        return self

    def add_atomic_proposition(self, *props: str) -> "TransitionSystem":
        """
        Adds one or more atomic propositions to the transition system.

        :param props: One or more atomic propositions (strings) to be added.
        :return: The TransitionSystem instance (for method chaining).
        """
        self.AP.update(props)
        return self

    def add_label(self, state: State, *labels: str) -> "TransitionSystem":
        """
        Adds one or more atomic propositions to a given state.

        :param state: The state to label.
        :param labels: One or more atomic propositions to be assigned to the state.
        :raises ValueError: If the state is not in the system or if any label is not a valid atomic proposition.
        :return: The TransitionSystem instance (for method chaining).
        """
        if state not in self.S:
            raise ValueError(f"Cannot set labels for {state}. State is not in the transition system.")
        if state not in self._L:
            self._L[state] = set()
        invalid = set(labels) - self.AP
        if invalid:
            raise ValueError(
                f"Cannot assign labels {invalid}. They are not in the set of atomic propositions (AP)."
            )
        self._L[state].update(labels)
        return self

    def L(self, state: State) -> Set[str]:
        """
        Retrieves the set of atomic propositions that hold in a given state.

        :param state: The state whose atomic propositions are being retrieved.
        :raises ValueError: If the state is not in the transition system.
        :return: A set of atomic propositions associated with the given state.
        """
        if state not in self.S:
            raise ValueError(f"State {state} is not in the transition system.")
        return self._L.get(state, set())

    def pre(self, S: Union[State, Set[State]], action: Optional[Action] = None) -> Set[State]:
        """
        Computes the set of predecessor states from which a given state or set of states can be reached.

        :param S: A single state (string/tuple) or a collection of states.
        :param action: (Optional) If provided, filters only the transitions that use this action.
        :return: A set of predecessor states.
        """
        if isinstance(S, str):
            S = {S}
        predecessors = set()
        for state_from, action_taken, state_to in self.Transitions:
            if action is None or action_taken == action:
                if state_to in S:
                    predecessors.add(state_from)
        return predecessors

    def post(self, S: Union[State, Set[State]], action: Optional[Action] = None) -> Set[State]:
        """
        Computes the set of successor states reachable from a given state or a collection of states.

        :param S: A single state or a collection of states.
        :param action: (Optional) Filters transitions by this action.
        :return: A set of successor states.
        """
        if not isinstance(S, set):
            S = {S}
        successors = set()
        for state_from, action_taken, state_to in self.Transitions:
            if action is None or action_taken == action:
                if state_from in S:
                    successors.add(state_to)
        return successors

    def reach(self) -> Set[State]:
        """
        Computes the set of all reachable states from the initial states.

        :return: A set of reachable states.
        """
        # like BFS 
        reachable = set(self.I)
        new_states = set(self.I)
        while new_states:
            new_states = self.post(new_states)
            reachable.update(new_states)
        return reachable

    def is_action_deterministic(self) -> bool:
        """
        Checks whether the transition system is action-deterministic.

        A transition system is action-deterministic if:
        - It has at most one initial state.
        - For each state and action, there is at most one successor state.

        :return: True if the transition system is action-deterministic, False otherwise.
        """
        if len(self.I) > 1:
            return False
        for state in self.S:
            for action in self.Act:
                if len(self.post(state, action)) > 1:
                    return False
        return True

    def is_label_deterministic(self) -> bool:
        """
        Checks whether the transition system is label-deterministic.

        A transition system is label-deterministic if:
        - It has at most one initial state.
        - For each state, the number of reachable successor states is equal to the number of unique label sets
          of these successor states.

        :return: True if the transition system is label-deterministic, False otherwise.
        """
        if len(self.I) > 1:
            return False
        for state in self.S:
            labels = {frozenset(self.L(s)) for s in self.post(state)}
            if len(labels) != len(self.post(state)):
                return False
        return True

    def __repr__(self) -> str:
        """
        Returns a string representation of the Transition System.

        :return: A formatted string representation of the TS.
        """
        return (
            f"TransitionSystem(\n"
            f"  States: {self.S}\n"
            f"  Actions: {self.Act}\n"
            f"  Transitions: {len(self.Transitions)}\n"
            f"  Initial States: {self.I}\n"
            f"  Atomic Propositions: {self.AP}\n"
            f"  Labels: {self._L}\n"
            f")"
        )


    def plot(self, title: str = "Transition System", figsize: Tuple[int, int] = (10, 6)) -> None:
        """
        Plots the Transition System as a directed graph.

        :param title: Title of the plot.
        :param figsize: Figure size for the plot.
        """
        G = nx.DiGraph()

        # Add nodes (states)
        for state in self.S:
            label = f"{state}\n{' '.join(self.L(state))}" if self.L(state) else str(state)
            print(label)
            G.add_node(state, label=label, color="blue" if state in self.I else "yellow")

        # Add edges (transitions)
        for state_from, action, state_to in self.Transitions:
            G.add_edge(state_from, state_to, label=action)

        plt.figure(figsize=figsize)
        pos = nx.spring_layout(G)  # Positioning algorithm for layout

        # Draw nodes
        node_colors = [G.nodes[n]["color"] for n in G.nodes]
        nx.draw(G, pos, with_labels=True, labels=nx.get_node_attributes(G, "label"), node_color=node_colors, edgecolors="black", node_size=2000, font_size=10)

        # Draw edge labels (actions)
        edge_labels = {(u, v): d["label"] for u, v, d in G.edges(data=True)}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=9)

        plt.title(title)
        # plt.show()

# HW 2: Program Graph $\rightarrow$ Transition System

In this exercise, we will implement the formal concept of a Program Graph (PG) as we defined in class using Python classes. We will also implement the `to_transition_system()` method to convert a program graph into a transition system.


In [27]:
import sys
import otter
# try:
#   import otter
# except ImportError:
#     %pip install otter-grader
#     import otter

grader = otter.Notebook("HW2.ipynb")

## Challenge 1: Implementing the `ProgramGraph` class

### **Formal Definition**
A **Program Graph** (PG) is a tuple: $(Loc, Act, Effect, \rightarrow, Loc_0, g_0)$
where:
- $Loc$ is a finite set of program locations (control points).
- $Act$ is a set of actions (instructions that modify variables).
- $Effect$ is a function that applies an action to an environment and produces a new environment.
- $Transitions \subseteq Loc \times Cond \times Act \times Loc$ is a set of edges where:
    - $Cond$ is a condition that must hold for the transition.
    - $Act$ is the action that modifies the variables.
    - The edge represents moving from an **old location** to a **new location** under **condition** $Cond$ and **action** $Act$.
- $Loc_0 \subseteq Loc$ is a finite set of initial locations.
- $g_0$ is an initial condition on the program variables.
- $Eval$ is a function that evaluates conditions on variables.


### Example

Let’s define a **simple program graph** for a counter that increments from 0 to 2.

- $Loc = \{L_0, L_1, L_2\}$
- $Loc_0 = \{L_0\}$
- $Act = \{x += 1\}$
- $\rightarrow = \{(L_0, x < 2, x += 1, L_1), (L_1, x < 2, x += 1, L_2)\}$
- $g_0 = \{x = 0\}$

### Use `HashableDict`


In this exercise we will represent environments (variable assignments) as part of the states in a **Transition System**, which requires them to be **hashable** (so they can be used in sets and dictionaries). However, **Python dictionaries are mutable and not hashable by default**.

By defining `HashableDict`, which inherits from `dict` and implements `__hash__` using `frozenset(self.items())`, we create an **immutable, hashable representation** of environments while preserving dictionary-like behavior for easy lookups and modifications.


In [28]:
class HashableDict(dict):
    def __hash__(self):
        return hash(frozenset(self.items()))

Your task is to implement the `ProgramGraph` class in Python. The class should have the following methods:

In [29]:
from typing import Callable, Set, Tuple, Dict, List, Union, Optional
import matplotlib.pyplot as plt
import networkx as nx

# Add your imports here, don't forget to include your TransitionSystem class
# from transition_system import TransitionSystem
from itertools import product
from collections import deque
from copy import deepcopy

Location = Union[str, Tuple]  # A state can be a string or a tuple (location, environment)
Action = str  # Actions are represented as strings
Condition = str # Conditions are represented as strings
Transition = Tuple[Location, Condition, Action, Location]  # (source_state, action, target_state)
Environment = Dict[str, Union[str, bool, int, float]]      # Variable assignments


class ProgramGraph:
    def __init__(
        self,
        locations: Set[Location],
        initial_locations: Set[Location],
        actions: Set[Action],
        transitions: Set[Transition],
        eval_fn: Callable[[Condition, Environment], bool],
        effect_fn: Callable[[Action, Environment], Environment],
        g0: Condition
    ):
        """
        A representation of a Program Graph.

        :param locations: Set of all program locations (Loc).
        :param initial_locations: Set of initial locations (Loc0).
        :param actions: Set of possible actions (Act).
        :param transitions: Set of transitions in the form (loc_from, condition, action, loc_to).
        :param eval_fn: Function to evaluate a condition string in an environment.
        :param effect_fn: Function to compute the new environment after applying an action.
        :param g0: Initial condition string for filtering valid starting environments.
        """
        self.Loc = set(locations)
        self.Loc0 = set(initial_locations)
        self.Act = set(actions)
        self.Transitions = set(transitions)
        self.eval_fn = eval_fn
        self.effect_fn = effect_fn
        self.g0 = g0

    def add_location(self, *locations: Location) -> "ProgramGraph":
        """Add one or more locations to the program graph."""
        self.Loc.update(locations)
        return self

    def add_action(self, *actions: Action) -> "ProgramGraph":
        """Add one or more actions to the program graph."""
        self.Act.update(actions)
        return self

    def add_transition(self, *transitions: Transition) -> "ProgramGraph":
        """
        Add one or more transitions to the program graph.

        Each transition must be a tuple: (loc_from, condition, action, loc_to).
        """
        for transition in transitions:
            if not isinstance(transition, tuple) or len(transition) != 4:
                raise ValueError(f"Invalid transition format: {transition}. Expected (loc_from, cond, action, loc_to).")
            loc_from, cond, action, loc_to = transition
            if loc_from not in self.Loc:
                raise ValueError(f"Location {loc_from} is not in the program graph.")
            if loc_to not in self.Loc:
                raise ValueError(f"Location {loc_to} is not in the program graph.")
            if action not in self.Act:
                raise ValueError(f"Action {action} is not in the program graph.")

            self.Transitions.add(transition)
        return self

    def add_initial_location(self, *locations: Location) -> "ProgramGraph":
        """Add one or more initial locations to the program graph."""
        for loc in locations:
            if loc not in self.Loc:
                raise ValueError(f"Cannot set initial location {loc}. Location is not in the set of locations.")
            self.Loc0.add(loc)
        return self

    def set_eval_fn(self, eval_fn: Callable[[Condition, Environment], bool]) -> "ProgramGraph":
        """Set the function used to evaluate conditions."""
        self.eval_fn = eval_fn
        return self

    def set_effect_fn(self, effect_fn: Callable[[Action, Environment], Environment]) -> "ProgramGraph":
        """Set the function used to apply actions to environments."""
        self.effect_fn = effect_fn
        return self

    def eval(self, condition: Condition, env: Environment) -> bool:
        """Evaluate a condition string in the given environment."""
        return self.eval_fn(condition, env)

    def effect(self, action: Action, env: Environment) -> Environment:
        """Apply an action to the environment and return the new environment."""
        return self.effect_fn(action, env)


    def valid_transitions(self, loc: Location, env: Environment, action: Action) -> List[Tuple[Location, Action, Location]]:
        """
        Return a list of valid transitions from a given location using the provided environment and action.
        """
        valid = []
        for (src, cond, act, dst) in self.Transitions:
            if src == loc and act == action and self.eval(cond, env):
                valid.append((src, act, dst))
        return valid

    def to_transition_system(self, vars: Dict[str, Set[Union[str, bool, int, float]]], labels: Set[Condition]) -> TransitionSystem:
        """
        Construct and return a Transition System from the program graph.

        :param vars: A dictionary mapping variable names to their finite sets of possible values.
        :param labels: A set of atomic proposition strings to be used for labeling states.
        :return: A TransitionSystem instance corresponding to the program graph.
        """
        ts = TransitionSystem()

        # Generate all possible environments (Cartesian product of variable domains)
        var_names = list(vars.keys())
        all_envs = [
            dict(zip(var_names, values))
            for values in product(*(vars[v] for v in var_names))
        ]

        # Add actions
        ts.add_action(*self.Act)

        # Add atomic propositions (labels + location names)
        ts.add_atomic_proposition(*labels)
        ts.add_atomic_proposition(*self.Loc)

        # Add initial states satisfying g0
        # Only add states during BFS traversal (i.e., reachable states only)
        visited = set()
        queue = deque()
        for loc in self.Loc0:
            for env in all_envs:
                h_env = HashableDict(env)
                if self.eval(self.g0, h_env):
                    state = (loc, h_env)
                    ts.add_state(state)
                    ts.add_initial_state(state)
                    queue.append(state)

        # Add transitions and labels
        
        while queue:
            current = queue.popleft()
            if current in visited:
                continue
            visited.add(current)

            loc, env = current
            # Add location name as label
            ts.add_label(current, loc)
            # Add labels based on label conditions
            for label in labels:
                if self.eval(label, env):
                    ts.add_label(current, label)

            for action in self.Act:
                for (_, _, dst) in self.valid_transitions(loc, env, action):
                    new_env = HashableDict(self.effect(action, deepcopy(env)))
                    new_state = (dst, new_env)
                    ts.add_state(new_state)
                    ts.add_transition((current, action, new_state))
                    if new_state not in visited:
                        queue.append(new_state)
        
        return ts

    def __repr__(self):
                return (
            f"ProgramGraph(\n"
            f"  Loc: {self.Loc}\n"
            f"  Act: {self.Act}\n"
            f"  Transitions: {self.Transitions}\n"
            f"  Loc0: {self.Loc0}\n"
            f"  g0: {self.g0}\n"
            f")"
        )

    def plot(self):
        """
        Visualize the program graph as a directed graph using networkx and matplotlib.

        Nodes are locations. Edges are labeled with (condition, action).
        """
        G = nx.MultiDiGraph()

        # Add nodes
        for loc in self.Loc:
            G.add_node(loc, color='lightblue' if loc in self.Loc0 else 'white')

        # Add edges with (condition, action) labels
        for (loc_from, cond, action, loc_to) in self.Transitions:
            label = f"{cond} / {action}"
            G.add_edge(loc_from, loc_to, label=label)

        pos = nx.spring_layout(G, seed=42)  # consistent layout

        # Draw nodes with color
        node_colors = [G.nodes[n].get('color', 'white') for n in G.nodes]
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, edgecolors='black')
        nx.draw_networkx_labels(G, pos)

        # Draw edges
        nx.draw_networkx_edges(G, pos, arrows=True)

        # Draw edge labels
        edge_labels = {(u, v): d['label'] for u, v, d in G.edges(data=True)}
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)

        # Display
        plt.title("Program Graph")
        plt.axis('off')
        plt.tight_layout()
        plt.show()

## Example: Program Graph to Transition System Conversion

In the cell bellow, we define helper functions `eval_fn` and `effect_fn` using the Python interpreter. These functions are used to evaluate conditions and apply actions within a given environment. We then create a simple program graph for a counter that increments from 0 to 2 and convert it to a transition system.

In [30]:
# Define helper functions for evaluation and effects
def eval_fn(cond, env):
    """Evaluates a condition string in a given environment dictionary."""
    try:
        return eval(cond, {}, env)
    except:
        return False

def effect_fn(action, env):
    """Applies an action (which modifies variables) and returns a new environment."""
    new_env = env.copy()
    try:
        exec(action, {}, new_env)
    except:
        pass
    return new_env

# Create a simple program graph for a counter that increments from 0 to 2
pg = ProgramGraph(
    locations={'L0', 'L1', 'L2'},
    initial_locations={'L0'},
    actions={'x += 1'},
    transitions={
        ('L0', 'x < 2', 'x += 1', 'L1'),
        ('L1', 'x < 2', 'x += 1', 'L2')
    },
    eval_fn=eval_fn,
    effect_fn=effect_fn,
    g0='x == 0'  # Initial condition: x starts at 0
)

# Convert to transition system
vars = {'x': {0, 1, 2}}  # Possible values for x
labels = {'x < 2', 'x == 2'}  # Properties we want to track
ts = pg.to_transition_system(vars, labels)

# Print the states and transitions
print("States:", ts.S)
print("\nTransitions:")
for transition in ts.Transitions:
    print(transition)

States: {('L2', {'x': 2}), ('L1', {'x': 1}), ('L0', {'x': 0})}

Transitions:
(('L0', {'x': 0}), 'x += 1', ('L1', {'x': 1}))
(('L1', {'x': 1}), 'x += 1', ('L2', {'x': 2}))


In [31]:
grader.check("q1")

## 🧠 Question 2: Modeling a Multi-Threaded Counter with a Program Graph

### 🔍 Objective
In this task, you will implement a function that constructs a **Program Graph** representing a simplified **multi-threaded counter protocol**.

Each of the `n_threads` executes a fixed four-step sequence to increment a shared variable `x`. The increment operation is broken into atomic actions using a temporary variable `tmp_i` and a counter variable `count_i` for each thread:

1. `count_i += 1` – mark the start of a new iteration
2. `tmp_i = x` – read
3. `tmp_i += 1` – increment
4. `x = tmp_i` – write

The goal is to repeat this protocol **up to `n_repeats` times per thread**. Once a thread completes all 4 steps, it may loop back to the beginning **only if `count_i < n_repeats`**.


### 🔍 Objective
In this task, you will implement a function that constructs a **Program Graph** representing a simplified **multi-threaded counter protocol**.

Each of the `n_threads` executes a fixed four-step sequence to increment a shared variable `x`. The increment operation is broken into atomic actions using a temporary variable `tmp_i` and a counter variable `count_i` for each thread:

1. `count_i += 1` – mark the start of a new iteration
2. `tmp_i = x` – read
3. `tmp_i += 1` – increment
4. `x = tmp_i` – write

The goal is to repeat this protocol **up to `n_repeats` times per thread**. Once a thread completes all 4 steps, it may loop back to the beginning **only if `count_i < n_repeats`**.


### **🧩 Task 1: Program Graph Definition**

You are required to implement the function:

```python
def counter_program_graph(n_threads: int, n_repeats: int) -> ProgramGraph:
    ...
```

The function should return a valid `ProgramGraph` object with the following properties:

#### 🔹 Locations (`Loc`)
- Each location is a string of length `n_threads`, where each character is in `{1, 2, 3, 4}`.
- The `i`-th character indicates the program counter of thread `i`, representing its progress in the 4-step protocol.
- All combinations of `{1, 2, 3, 4}^n_threads` should be included.

#### 🔹 Initial Location (`Loc0`)
- All threads start at step 1: `'1' * n_threads`

#### 🔹 Actions (`Act`)
Each thread has exactly 4 actions:
1. `count_i += 1` (track iteration count)
2. `tmp_i = x` (read)
3. `tmp_i += 1` (increment)
4. `x = tmp_i` (write)

Total actions: `4 * n_threads`

#### 🔹 Transitions (`Transitions`)
- If a thread is at step 1:
  - It performs `count_i += 1` and may only proceed **if `count_i < n_repeats`**.
- If a thread is at step 2 or 3:
  - It progresses unconditionally using its corresponding action.
- If a thread is at step 4:
  - It always performs `x = tmp_i` and then loops back to step 1.

Each transition updates the corresponding thread's program counter by one (or loops to 1 after step 4).

#### 🔹 Initial Condition (`g0`)
All variables are initialized to 0:
```
x == 0 and tmp_1 == 0 and count_1 == 0 and tmp_2 == 0 and count_2 == 0 ... and tmp_n == 0 and count_n == 0
```

---



### 🔧 Example

For `n_threads = 2` and `n_repeats = 3`, the constructed program graph will include:

#### Locations
All combinations of two digits from `{1, 2, 3, 4}`:
```
{'11', '12', '13', '14', '21', '22', '23', '24', '31', '32', '33', '34', '41', '42', '43', '44'}
```

Each digit represents the program counter (PC) of a thread.
For example, `'24'` means:
- Thread 1 is at PC = 2
- Thread 2 is at PC = 4

#### Actions
Each thread has 4 actions. For 2 threads:
```
[
  'count_1 += 1', 'tmp_1 = x', 'tmp_1 += 1', 'x = tmp_1',
  'count_2 += 1', 'tmp_2 = x', 'tmp_2 += 1', 'x = tmp_2'
]
```

#### Initial Location
```python
'11'
```
Both threads start at PC = 1.

#### Transitions
Each thread moves forward in its 4-step sequence or loops back to 1:

- `'11'` → `'21'` via action `'count_1 += 1'` **if `count_1 < 3`**
- `'21'` → `'31'` via action `'tmp_1 = x'`
- `'31'` → `'41'` via action `'tmp_1 += 1'`
- `'41'` → `'11'` via action `'x = tmp_1'`

These transitions apply similarly for thread 2. All transitions are independent and interleaved.

#### Initial Condition
All variables initialized to 0:
```python
x == 0 and tmp_1 == 0 and count_1 == 0 and tmp_2 == 0 and count_2 == 0
```


### 🔧 Example

For `n_threads = 2` and `n_repeats = 3`, the constructed program graph will include:

#### Locations
All combinations of two digits from `{1, 2, 3, 4}`:
```
{'11', '12', '13', '14', '21', '22', '23', '24', '31', '32', '33', '34', '41', '42', '43', '44'}
```

Each digit represents the program counter (PC) of a thread.
For example, `'24'` means:
- Thread 1 is at PC = 2
- Thread 2 is at PC = 4

#### Actions
Each thread has 4 actions. For 2 threads:
```
[
  'count_1 += 1', 'tmp_1 = x', 'tmp_1 += 1', 'x = tmp_1',
  'count_2 += 1', 'tmp_2 = x', 'tmp_2 += 1', 'x = tmp_2'
]
```

#### Initial Location
```python
'11'
```
Both threads start at PC = 1.

#### Transitions
Each thread moves forward in its 4-step sequence or loops back to 1:

- `'11'` → `'21'` via action `'count_1 += 1'` **if `count_1 < 3`**
- `'21'` → `'31'` via action `'tmp_1 = x'`
- `'31'` → `'41'` via action `'tmp_1 += 1'`
- `'41'` → `'11'` via action `'x = tmp_1'`

These transitions apply similarly for thread 2. All transitions are independent and interleaved.

#### Initial Condition
All variables initialized to 0:
```python
x == 0 and tmp_1 == 0 and count_1 == 0 and tmp_2 == 0 and count_2 == 0
```


In [32]:
from itertools import product

def eval_fn(cond, env):
    """Evaluates a condition string in a given environment dictionary."""
    try:
        return eval(cond, {}, env)
    except:
        return False

def effect_fn(action, env):
    """Applies an action (which modifies variables) and returns a new environment."""
    new_env = env.copy()
    try:
        exec(action, {}, new_env)
    except:
        pass
    return new_env

def counter_program_graph(n_threads: int, n_repeats: int) -> ProgramGraph:
    """
    Builds a program graph with locations as strings of digits ('1'-'4'), one per thread.
    Each thread follows a 4-step sequence that repeats `n_repeats` times.

    :param n_threads: Number of threads
    :param n_repeats: Number of iterations per thread
    :return: ProgramGraph instance
    """

    # Locations: strings of length n_threads, each digit in '1' to '4'
    step_digits = ['1', '2', '3', '4']
    locations = set(''.join(p) for p in product(step_digits, repeat=n_threads))
    initial_location = '1' * n_threads
    initial_locations = {initial_location}

    actions = set()
    transitions = set()

    for loc in locations:
        for t in range(n_threads):
            pc = int(loc[t])  # Current step for thread t
            next_loc = list(loc)

            # Determine next program counter (cycle: 1 → 2 → 3 → 4 → 1)
            if pc == 1:
                cond = f"count_{t+1} < {n_repeats}"
                action = f"count_{t+1}+=1"
                next_pc = 2
            elif pc == 2:
                cond = "True"
                action = f"tmp_{t+1}=x"
                next_pc = 3
            elif pc == 3:
                cond = "True"
                action = f"tmp_{t+1}+=1"
                next_pc = 4
            elif pc == 4:
                cond = "True"
                action = f"x=tmp_{t+1}"
                next_pc = 1
            else:
                continue

            # Construct next location string
            next_loc[t] = str(next_pc)
            loc_to = ''.join(next_loc)

            transitions.add((loc, cond, action, loc_to))
            actions.add(action)

    # Initial condition: all tmp_i and count_i == 0, and x == 0
    g0_parts = ["x == 0"] + [f"tmp_{i+1} == 0 and count_{i+1} == 0" for i in range(n_threads)]
    g0 = " and ".join(g0_parts)

    return ProgramGraph(
        locations=locations,
        initial_locations=initial_locations,
        actions=actions,
        transitions=transitions,
        eval_fn=eval_fn,
        effect_fn=effect_fn,
        g0=g0
    )


## 🧠 Task 2: Determining Final Values of `x`

### 🔍 Objective

Your task is to analyze the transition system generated from the program graph (created in Task 1) and determine **all possible values** the variable `x` may have when the program **terminates**.

A program execution is considered **terminated** when **all threads have returned to the first step** of their 4-step protocol and **no further transitions are possible** (i.e., the guard `count_i < n_repeats` no longer holds for any thread `i`).

---

### 🎯 Goal

Implement the following function:

```python
def final_x_values(ts: TransitionSystem) -> List[int]:
    """
    Given a transition system created from the counter program graph,
    return a sorted list of all possible final values of x at program termination.
    """
```

The function should return a **sorted list** (from lowest to highest) of all values `x` can take in **terminal states**.

---

### 📌 Terminal State Criteria

A state `(location, environment)` is terminal if:

- The location is `'1' * n_threads` (i.e., all threads are at step 1), **and**
- For every thread `i`, the condition `count_i < n_repeats` evaluates to **False** in the given environment.

You should extract the value of `x` from each such terminal environment.

---

### 🧪 Example

If `n_threads = 2` and `n_repeats = 2`, the program may terminate in multiple ways, resulting in different possible values of `x`, e.g.:

```python
[1, 2, 3, 4]
```

This means the program could terminate with `x in [1, 2, 3, 4]` when all threads complete exactly `2` iterations.
(In other configurations, there may be non-determinism allowing early termination with smaller values.)


In [33]:
def final_x_values(ts: TransitionSystem) -> List[int]:
    """
    Return sorted list of all x values in terminal states:
    - location is '1' * n_threads
    - no outgoing transitions (ts.Post(state) == ∅)
    """
    final_values = set()

    for state in ts.reach():
        loc, env = state
        if not isinstance(loc, str):
            continue

        n_threads = len(loc)
        if loc == '1' * n_threads and len(ts.post(state)) == 0:
            final_values.add(env["x"])

    return sorted(final_values)

In [34]:
def list_x_values(n_threads, n_repeats):
    """
    Create the counter program graph and return all final values of x
    after termination, using the final_x_values helper.
    """
    pg = counter_program_graph(n_threads=n_threads, n_repeats=n_repeats)

    max_val = n_threads * n_repeats + 1
    vars = {
        "x": set(range(max_val))
    }

    for i in range(1, n_threads + 1):
        vars[f"tmp_{i}"] = set(range(max_val))
        vars[f"count_{i}"] = set(range(n_repeats + 1))

    labels = set()  # No labels needed
    ts = pg.to_transition_system(vars=vars, labels=labels)

    return final_x_values(ts)

In [35]:
grader.check("q2")

## Submission

Make sure you have run all cells in your notebook in order before running the cell below, so that all images/graphs appear in the output. The cell below will generate a zip file for you to submit. **Please save before exporting!**

In [36]:
# Save your notebook first, then run this cell to export your submission.
grader.export(pdf=False)