# State machine proof of concept

In [244]:
import copy
import logging
from asyncio import Protocol
from dataclasses import dataclass, field
from typing import Optional, Any, List
_logger = logging.getLogger(__name__)

The model we choose is one where the states are "static" – nothing changes when we are in a state. As such, the state has only one attribute: a name.

In [245]:
@dataclass(frozen=True)
class State:
    name: str

Before we can take a transition, we need to check whether it is allowed. For each state transition, there may be a condition. If the user doesn't want to specify a transition condition, then the transition is always allowed.

In [246]:
class TransitionCondition(Protocol):
    """ Checks whether a transition is allowed. """
    def __call__(self, current_state: State, data: Optional[Any]) -> bool: ...


def default_transition_condition(current_state, data):
    return True

The transitions are dynamic: during a transition, the state changes and the data may change too. We therefore allow a transition callback function which can operate on some data provided by the state machine. Again, if the user doesn't want to specify a transition callback, the data are returned unchanged.

In [247]:
class TransitionCallback(Protocol):
    """ Transforms the data when running a transition. """
    def __call__(self, data: Any) -> Any: ...


def default_transition_callback(data):
    return data

Now we can define the transition itself. It has a starting state, ending state, the `TransitionCondition` function to check whether it is allowed and the `TransitionCallback` function to update the data.

In [248]:
@dataclass(frozen=True)
class Transition:
    name: str
    state1: State
    state2: State
    allowed: TransitionCondition = default_transition_condition
    callback: TransitionCallback = default_transition_callback

To ensure that the state machine can be run in parallel without the risk of unexpected overwriting, we always ensure that the data, if they are provided, are frozen. The user can define their own data type to hold their application results, we provide only a minimum object as an example.

In [249]:
@dataclass(frozen=True)
class ImmutableData:
    pass

We define the state machine to have:
- states,
- transitions,
- a current state,
- and some arbitrary data.

The following functionalities are defined as methods:
- We can print a representation of the machine, its transitions and its current state.
- We can register and unregister states and transitions to change the way the machine works.
- We can set the current state, and
- We can set the data.
- We can take a "step" to find the first allowed transition from the current state and carry it out, executing the callback function.

In [251]:
@dataclass
class StateMachine(object):

    states: List[State] = field(default_factory=list)
    transitions: List[Transition] = field(default_factory=list)
    current_state: State = None
    data: ImmutableData = field(default_factory=ImmutableData)

    def register_state(self, state: State):
        self.states.append(state)

    def unregister_state(self, state: State):
        self.states.remove(state)

    def register_transition(self, transition: Transition):
        self.transitions.append(transition)

    def unregister_transition(self, transition: Transition):
        self.transitions.remove(transition)

    def set_current_state(self, state: State):
        self.current_state = state

    def set_data(self, data: ImmutableData):
        self.data = data

    def step(self):
        _logger.info(f"{self.current_state=}")

        allowed_transitions = (
            k for k in self.transitions
            if k.state1 == self.current_state
            and k.allowed(self.current_state, self.data)
        )
        try:
            transition = next(allowed_transitions)
        except StopIteration:
            raise StopIteration(f"There are no more allowed transitions from {self.current_state=} with {self.data=}")

        _logger.info(f"next {transition=}")
        assert isinstance(transition, Transition), f"{transition} is not a Transition – something's gone wrong."

        _logger.info(f"running {transition.callback=}")
        new_data = transition.callback(self.data)

        _logger.info(f"updating data with {new_data=}")
        self.data = new_data

        _logger.info(f"moving from {transition.state1=} to {transition.state2=}")
        self.current_state = transition.state2

## Example usage 0: a trivial two state stopping machine
We initialize an empty state machine, and add two states with a single transition.

In [252]:
trivial_machine = StateMachine()

start = State("start")
end = State("end")
trivial_machine.register_state(start)
trivial_machine.register_state(end)

transition_start_end = Transition("start-end", start, end)
trivial_machine.register_transition(transition_start_end)

trivial_machine.set_current_state(start)
trivial_machine

StateMachine(states=[State(name='start'), State(name='end')], transitions=[Transition(name='start-end', state1=State(name='start'), state2=State(name='end'), allowed=<function default_transition_condition at 0x1600295a0>, callback=<function default_transition_callback at 0x16002a050>)], current_state=State(name='start'), data=ImmutableData())

The machine is in the starting state:

In [253]:
print(trivial_machine.current_state)

State(name='start')


Now we can run one step:

In [254]:
trivial_machine.step()
print(trivial_machine.current_state)

State(name='end')


If we try to step once more, the machine raises an exception (as is the pythonic way).

In [255]:
try:
    trivial_machine.step()
except StopIteration as exception:
    print(f"{exception=}")

exception=StopIteration("There are no more allowed transitions from self.current_state=State(name='end') with self.data=ImmutableData()")


## Example usage 1: a simple three state cycling machine
We initalize an empty state machine, to which states and transitions can be added.

In [256]:
machine = StateMachine()

In this machine, we have three states which we register to the machine.

In [257]:
state_new_theory = State("new_theory")
state_new_experiment = State("new_experiment")
state_new_data = State("new_data")

machine.register_state(state_new_theory)
machine.register_state(state_new_experiment)
machine.register_state(state_new_data)

There are three transitions, which we also register to the machine.

In [258]:
transition_theorist = Transition(
    state1=state_new_data,
    name="theorist",
    state2=state_new_theory
)
transition_experimentalist = Transition(
    state1=state_new_theory,
    name="experimentalist",
    state2=state_new_experiment,
)
transition_experiment_runner = Transition(
    state1=state_new_experiment,
    name="experiment_runner",
    state2=state_new_data,
)

machine.register_transition(transition_theorist)
machine.register_transition(transition_experimentalist)
machine.register_transition(transition_experiment_runner)

We set the current state of the machine once.

In [259]:
machine.set_current_state(state_new_experiment)

The machine looks as follows:

In [260]:
machine

StateMachine(states=[State(name='new_theory'), State(name='new_experiment'), State(name='new_data')], transitions=[Transition(name='theorist', state1=State(name='new_data'), state2=State(name='new_theory'), allowed=<function default_transition_condition at 0x1600295a0>, callback=<function default_transition_callback at 0x16002a050>), Transition(name='experimentalist', state1=State(name='new_theory'), state2=State(name='new_experiment'), allowed=<function default_transition_condition at 0x1600295a0>, callback=<function default_transition_callback at 0x16002a050>), Transition(name='experiment_runner', state1=State(name='new_experiment'), state2=State(name='new_data'), allowed=<function default_transition_condition at 0x1600295a0>, callback=<function default_transition_callback at 0x16002a050>)], current_state=State(name='new_experiment'), data=ImmutableData())

Now we can run a few steps, and see what happens to the machine.

In [261]:
for i in range(10):
    machine.step()
    print(machine.current_state)

State(name='new_data')
State(name='new_theory')
State(name='new_experiment')
State(name='new_data')
State(name='new_theory')
State(name='new_experiment')
State(name='new_data')
State(name='new_theory')
State(name='new_experiment')
State(name='new_data')


As we see, it cycles through the states as expected.

## Add an additional state and new transitions to the simple machine


In [262]:
machine_with_additional_step  = copy.copy(machine)

state_validated_experiment = State("validated_experiment")

transition_experimental_validation=Transition(
    state1=state_new_experiment,
    name="experimental_validation",
    state2=state_validated_experiment,
)
transition_run_validated_experiment=Transition(
    state1=state_validated_experiment,
    name="run_validated_experiment",
    state2=state_new_data,
)

machine_with_additional_step.register_state(state_validated_experiment)

machine_with_additional_step.unregister_transition(transition_experiment_runner)
machine_with_additional_step.register_transition(transition_experimental_validation)
machine_with_additional_step.register_transition(transition_run_validated_experiment)

In [263]:
for i in range(10):
    machine_with_additional_step.step()
    print(machine_with_additional_step.current_state)

State(name='new_theory')
State(name='new_experiment')
State(name='validated_experiment')
State(name='new_data')
State(name='new_theory')
State(name='new_experiment')
State(name='validated_experiment')
State(name='new_data')
State(name='new_theory')
State(name='new_experiment')


## Add additional data / metadata to a simple machine

In [264]:
trivial_machine = StateMachine()

state_A = State("A")
state_B = State("B")
trivial_machine.register_state(state_A)
trivial_machine.register_state(state_B)

@dataclass(frozen=True)
class PlusMinusMachineData:
    value: int

initial_data = PlusMinusMachineData(1)
trivial_machine.set_data(initial_data)
trivial_machine.set_current_state(state_A)

transition_2 = Transition("A -> B", state_A, state_B, callback=lambda x: PlusMinusMachineData(x.value + 2))
transition_1 = Transition("B -> A", state_B, state_A, callback=lambda x: PlusMinusMachineData(x.value - 1))
trivial_machine.register_transition(transition_1)
trivial_machine.register_transition(transition_2)

In the starting state, the value in the data is =1.

In [265]:
print(f"{trivial_machine.data=}")

trivial_machine.data=PlusMinusMachineData(value=1)


Each step the transition function is called.
The first, from state A to state B adds 2 to the value.

In [266]:
trivial_machine.step()
print(f"{trivial_machine.data=}")

trivial_machine.data=PlusMinusMachineData(value=3)


Then the second from state B to state A is "-1"...

In [267]:
trivial_machine.step()
print(f"{trivial_machine.data=}")

trivial_machine.data=PlusMinusMachineData(value=2)


 ... and then the iteration starts from state A to state B again.

In [268]:
for i in range(9):
    trivial_machine.step()
    print(f"{trivial_machine.data=}")

trivial_machine.data=PlusMinusMachineData(value=4)
trivial_machine.data=PlusMinusMachineData(value=3)
trivial_machine.data=PlusMinusMachineData(value=5)
trivial_machine.data=PlusMinusMachineData(value=4)
trivial_machine.data=PlusMinusMachineData(value=6)
trivial_machine.data=PlusMinusMachineData(value=5)
trivial_machine.data=PlusMinusMachineData(value=7)
trivial_machine.data=PlusMinusMachineData(value=6)
trivial_machine.data=PlusMinusMachineData(value=8)


The initial value, and the results at each stage, are treated as immutable, so they can be passed around and used without the risk that the cycler will inadvertently change them.

In [269]:
print(f"{initial_data=}")

initial_data=PlusMinusMachineData(value=1)


In [270]:
the_data_after_all_the_iterations_so_far = trivial_machine.data
the_data_after_all_the_iterations_so_far

PlusMinusMachineData(value=8)

In [271]:
n = 1000
for i in range(n):
    trivial_machine.step()
print(f"the data we just saved:\n{the_data_after_all_the_iterations_so_far=}\nthe current value after {n=} more iterations:\n{trivial_machine.data=}")

the data we just saved:
the_data_after_all_the_iterations_so_far=PlusMinusMachineData(value=8)
the current value after n=1000 more iterations:
trivial_machine.data=PlusMinusMachineData(value=508)


## Machine with a stopping condition
Sometimes, the machine will need to stop conditionally based on some computed value.
Here we start from the value -1234 and stop when the value gets to be over 0 by means of a condition which disallows all other transitions.

In [272]:
stopping_machine = StateMachine()

state_C = State("C")
state_D = State("D")
stopping_machine.register_state(state_C)
stopping_machine.register_state(state_D)

initial_data = PlusMinusMachineData(-1234)
stopping_machine.set_data(initial_data)
stopping_machine.set_current_state(state_C)

transition_CD = Transition(
    "C -> D",
    state_C,
    state_D,
    callback=lambda x: PlusMinusMachineData(x.value + 2),
    allowed=lambda state, data: data.value < 0
)
transition_DC = Transition("D -> C", state_D, state_C, callback=lambda x: PlusMinusMachineData(x.value - 1))
stopping_machine.register_transition(transition_CD)
stopping_machine.register_transition(transition_DC)


In [273]:
print(f"started with: {stopping_machine.data=}")
i=0
while True:
    try:
        stopping_machine.step()
        i+=1
    except StopIteration:
        print(f"machine stopped after {i=} iterations")
        break
print(f"finished with: {stopping_machine.data=}")

started with: stopping_machine.data=PlusMinusMachineData(value=-1234)
machine stopped after i=2468 iterations
finished with: stopping_machine.data=PlusMinusMachineData(value=0)


## Machine with a stopping state
Alternatively, we can stop by having a stopping state from which there is no transition. This would allow doing "special" steps at the end of the cycle.
The transfer to the stopping state is the first (highest priority) transition, but is only allowed when a condition is met.

In [274]:
machine_with_stopped_state = StateMachine()

stopped_state = State("stopped")
machine_with_stopped_state.register_state(state_C)
machine_with_stopped_state.register_state(state_D)
machine_with_stopped_state.register_state(stopped_state)

initial_data = PlusMinusMachineData(-1234)
machine_with_stopped_state.set_data(initial_data)
machine_with_stopped_state.set_current_state(state_C)

transition_Cstopped = Transition(
    "C -> stopped",
    state_C,
    stopped_state,
    allowed=lambda state, data: data.value >= 0
)
transition_CD = Transition(
    "C -> D",
    state_C,
    state_D,
    callback=lambda x: PlusMinusMachineData(x.value + 2),
)
transition_DC = Transition("D -> C", state_D, state_C, callback=lambda x: PlusMinusMachineData(x.value - 1))
machine_with_stopped_state.register_transition(transition_Cstopped)
machine_with_stopped_state.register_transition(transition_CD)
machine_with_stopped_state.register_transition(transition_DC)

In [275]:
print(f"started with: {machine_with_stopped_state.data=}")
i=0
while True:
    try:
        machine_with_stopped_state.step()
        i+=1
    except StopIteration:
        print(f"machine stopped after {i=} iterations")
        break
print(f"finished with: {machine_with_stopped_state.data=}")

started with: machine_with_stopped_state.data=PlusMinusMachineData(value=-1234)
machine stopped after i=2469 iterations
finished with: machine_with_stopped_state.data=PlusMinusMachineData(value=0)
