In [1]:
from abc import abstractmethod, ABC
from typing import Callable, Any
from dataclasses import dataclass, replace
from donotation import do

# define some initial state
state = {'init': 3}

# State Monad Derivation

This document demonstrates how the concept of a state monad, originally defined in Haskell, can be progressively translated into Python.


## Translating from Haskell to Python

### Monad

We begin by converting the Monad type class definition from [Haskell's Monad documentation](https://wiki.haskell.org/Monad) into Python.
The initial Python code (shown below) includes placeholder functions `map` or `flat_map` that currently return `None` and do not perform any operations.
These placeholders are designed to demonstrate the conversion process from Haskell to Python.

In [2]:
# type class surrogate
class Monad[U]:
    pass

def map[U, V](m: Monad[U], fn: Callable[[U], V]) -> Monad[V]:
    ...

def flat_map[U, V](m: Monad[U], fn: Callable[[U], Monad[V]]) -> Monad[V]:
    ...

class return_[V]:
    def __init__(self, value: V) -> Monad[V]: ...

A more Pythonic, object oriented approach is to implement functions as methods within the abstract `Monad` class.

In [3]:
class Monad[U](ABC):
    @abstractmethod
    def flat_map[V](self, fn: Callable[[U], V]):
        ...

    @abstractmethod
    def map[V](self, fn: Callable[[U], Monad[V]]):
        ...
    
    @abstractmethod
    def return_(self, val: U):
        ...

### State Monad

In Haskell, a Monad type class can be implemented for various types `U`.
Specifically the [State Monad](https://wiki.haskell.org/State_Monad) is defined for the type `U := Callable[[State], tuple[State, V]]`, wehere `State` represents the state and `V` represents the value produced.

In Python, this concept is translated by defining a `StateMonad` class that inherits from `Monad` class containing a function `run_state` that computes the valued given some state.

<!-- Implement a concrete Monad: The StateMonad. [https://wiki.haskell.org/State_Monad](https://wiki.haskell.org/State_Monad). -->

In [4]:
@dataclass(frozen=True)
class StateMonad[State, U](Monad[Callable[[State], tuple[State, U]]]):

    # monad data
    # ##########
    
    run_state: Callable[[State], tuple[State, U]]

    # monad methods
    # #############
    
    def flat_map(self, fn):
        def run_state(state: State):
            n_state, val = self.run_state(state)
            return fn(val).run_state(n_state)

        return replace(self, run_state=run_state)
    
    def map(self, fn):
        def run_state(state: State):
            n_state, val = self.run_state(state)
            return n_state, fn(val)

        return replace(self, run_state=run_state)

    @staticmethod
    def return_(val):
        return StateMonad(run_state=lambda s: (s, val))
    
    # state monad methods
    # ###################

    @staticmethod
    def get():
        def run_state(state):
            return state, state

        return StateMonad(run_state=run_state)
    
    def put(self, state):
        def run_state(state_):
            _, val = self.run_state(state_)
            return state, val

        return replace(self, run_state=run_state)

expr = StateMonad.return_(5).flat_map(
    lambda v: StateMonad.get().map(
         lambda state: state['init'] + v    
    )
)

result = expr.run_state(state)

print(f'{expr=}')
print(f'{result=}')

expr=StateMonad(run_state=<function StateMonad.flat_map.<locals>.run_state at 0x000002A463DEE8E0>)
result=({'init': 3}, 8)


The expression `expr` created with the state monad operators `return_`, `flat_map`, and `get`, can be rewritten using the `donotation` Python library as follows.

In [5]:
@do()
def create_expression():
    v = yield from StateMonad.return_(5)
    state = yield from StateMonad.get()
    return StateMonad.return_(state['init'] + v)

create_expression().run_state(state)

({'init': 3}, 8)

### State Monad Operations as classes

To organize the `StateMonad` implementation, we decompose the state monad operations into multiple dataclasses.

In [6]:
class StateMonadNode(ABC):
    @abstractmethod
    def run_state(self, state):
        ...

@dataclass(frozen=True)
class FromImpl(StateMonadNode):
    value: int

    def run_state(self, state):
        return state, self.value

@dataclass(frozen=True)
class MapImpl(StateMonadNode):
    child: StateMonadNode
    map_func: Callable[[int], int]

    def run_state(self, state):
        state, val = self.child.run_state(state)
        return state, self.map_func(val)

@dataclass(frozen=True)
class FlatMapImpl(StateMonadNode):
    child: StateMonadNode
    fmap_func: Callable[[int], StateMonadNode]

    def run_state(self, state):
        state, val = self.child.run_state(state)
        return self.fmap_func(val).run_state(state)

@dataclass(frozen=True)
class GetImpl(StateMonadNode):
    def run_state(self, state):
        return state, state

@dataclass(frozen=True)
class PutImpl(StateMonadNode):
    child: StateMonadNode
    state: Any

    def run_state(self, state):
        _, val = self.child.run_state(state)
        return self.state, val

Each dataclass represents a node in a tree-like structure, collectively forming the state monad expression.

In [7]:
expr = FlatMapImpl(
    child=FromImpl(value=5),
    fmap_func=lambda v: MapImpl(
        child=GetImpl(),
        map_func=lambda state: state['init'] + v
    )
)

result = expr.run_state(state)

print(f'{expr=}')
print(f'{result=}')

expr=FlatMapImpl(child=FromImpl(value=5), fmap_func=<function <lambda> at 0x000002A463DEFC40>)
result=({'init': 3}, 8)


Defining the expression `expr` directly can be complex. Instead, we use method chaining to construct the expression more elegantly.

In [8]:
@dataclass(frozen=True)
class StateMonad(StateMonadNode):
    data: StateMonadNode

    def run_state(self, state):
        return self.data.run_state(state)
    
    def map(self, map_func):
        data = MapImpl(
            child=self.data,
            map_func=map_func,
        )
        
        return replace(
            self,
            data=data,
        )

    def flat_map(self, fmap_func):
        data = FlatMapImpl(
            child=self.data,
            fmap_func=fmap_func
        )
        
        return replace(
            self,
            data=data,
        )

    def get(self):
        return replace(
            self,
            data=GetImpl(),
        )
    
    def put(self, state):
        return replace(
            self,
            data=PutImpl(child=self.data, state=state,),
        )

# could be implemented as a classmethod as well
def from_(val: int):
    return StateMonad(
        data=FromImpl(
            value=val
        )
    )
    
expr = from_(5).map(lambda v: 2*v)


@do()
def create_expression():
    v = yield from from_(5)
    state = yield from from_(None).get()
    return from_(state['init'] + v)

expr = create_expression()

result = expr.run_state(state)

print(f'{expr=}')
print(f'{result=}')

expr=StateMonad(data=FlatMapImpl(child=FromImpl(value=5), fmap_func=<function create_expression.<locals>._donotation_flatmap_func_0 at 0x000002A463E30860>))
result=({'init': 3}, 8)


## Zip Operator

Implement the `zip` operator for the `StateMonad` class. 

In [9]:
@dataclass(frozen=True)
class ZipImpl(StateMonadNode):
    left: StateMonadNode
    right: StateMonadNode

    def run_state(self, state):
        state, left = self.left.run_state(state)
        state, right = self.right.run_state(state)
        
        return state, (left, right)

class ZipStateMonad(StateMonad):
    def zip(self, other: StateMonad):
        data = ZipImpl(
            left=self.data,
            right=other.data,
        )
        
        return replace(
            self,
            data=data,
        )

def zip_(monads):
    monads_iterator = iter(monads)
    try:
        current = next(monads_iterator).map(lambda v: (v,))
    except StopIteration:
        return from_[State](tuple[U]())
    else:
        for other in monads_iterator:
            current = current.zip(other).map(lambda v: v[0] + (v[1],))
        return current

def zip_from(val: int):
    return ZipStateMonad(
        data=FromImpl(
            value=val
        )
    )

@do()
def create_expression():
    values = yield from zip_(zip_from(i) for i in range(5))
    state = yield from from_(None).get()
    v_add_init = sum(state['init'] + v for v in values)
    return from_(v_add_init)

expr = create_expression()

result = expr.run_state(state)

print(f'{expr=}')
print(f'{result=}')

expr=ZipStateMonad(data=FlatMapImpl(child=MapImpl(child=ZipImpl(left=MapImpl(child=ZipImpl(left=MapImpl(child=ZipImpl(left=MapImpl(child=ZipImpl(left=MapImpl(child=FromImpl(value=0), map_func=<function zip_.<locals>.<lambda> at 0x000002A463E31620>), right=FromImpl(value=1)), map_func=<function zip_.<locals>.<lambda> at 0x000002A463E316C0>), right=FromImpl(value=2)), map_func=<function zip_.<locals>.<lambda> at 0x000002A463E31760>), right=FromImpl(value=3)), map_func=<function zip_.<locals>.<lambda> at 0x000002A463E31800>), right=FromImpl(value=4)), map_func=<function zip_.<locals>.<lambda> at 0x000002A463E31940>), fmap_func=<function create_expression.<locals>._donotation_flatmap_func_0 at 0x000002A463E313A0>))
result=({'init': 3}, 25)
