<center>

<h1 style="text-align:center"> Monads </h1>
<h2 style="text-align:center"> CS3100 Fall 2019 </h2>
</center>

## Review

### Previously

* Streams, laziness and memoization

### This lecture

* Monads
  + Dealing with **effects** in a **pure** setting

## Whence Monads

* The term "monad" come from **Category Theory**
  + Category theory is the study of mathematical abstractions
  + Out of scope for this course
  + We will focus on **programming with monads**.
* Monads were popularized by the Haskell programming language
  + Haskell is **purely functional** programming languages
  + Unlike OCaml, Haskell separates pure code from side-effecting code through the use of monads. 

## What is a Monad?

A monad is any implementation that satisfies the following signature:

In [2]:
module type Monad = sig
  type 'a t
  val return : 'a -> 'a t
  val bind   : 'a t -> ('a -> 'b t) -> 'b t
end

module type Monad =
  sig
    type 'a t
    val return : 'a -> 'a t
    val bind : 'a t -> ('a -> 'b t) -> 'b t
  end


and the **monad laws**.

## Example: Interpreter

* All of this seems very abstract (as many FP concepts are).
  + Monad is a design pattern rather than a language feature.
* An example will help us see the pattern.
  + Overtime, you'll spot monads everywhere.
* Let's write an interpreter for artihmetic expressions

## Interpreting artihmetic expressions

In [3]:
type expr = Val of int | Plus of expr * expr | Div of expr * expr

type expr = Val of int | Plus of expr * expr | Div of expr * expr


* Our goal is to make the interpreter a **total function**.
  + Produces a **value** for every arithmetic expression.

In [4]:
let rec eval e = match e with
  | Val v -> v
  | Plus (v1,v2) -> eval v1 + eval v2
  | Div (v1,v2) -> eval v1 / eval v2

val eval : expr -> int = <fun>


## Division by zero

This looks fine. But what happens if the denominator in the division is a 0.

In [5]:
eval (Div (Val 1, Val 0))

error: runtime_error

How can we avoid this?

## Interpreting Arithmetic Expressions: Take 2

* Rewrite `eval` function to have the type `expr -> int option`
  + Return `None` for division by zero.

In [6]:
let rec eval e = match e with
  | Val v -> Some v
  | Plus (e1,e2) ->
      begin match eval e1 with 
      | None -> None
      | Some v1 -> 
          match eval e2 with
          | None -> None 
          | Some v2 -> Some (v1 + v2)
      end
  | Div (e1,e2) ->
      match eval e1 with 
      | None -> None
      | Some v1 -> 
          match eval e2 with
          | None -> None 
          | Some v2 -> if v2 = 0 then None else Some (v1 / v2)

val eval : expr -> int option = <fun>


In [7]:
eval (Div (Val 1, Val 0))

- : int option = None


## Abstraction

* There is a lot of repeated code in the interpreter above.
  + Factor out common code.

In [8]:
let return v = Some v

val return : 'a -> 'a option = <fun>


In [9]:
let bind m f = match m with
  | None -> None 
  | Some v -> f v

val bind : 'a option -> ('a -> 'b option) -> 'b option = <fun>


## Abstraction 

Let's rewrite the interpreter using these functions.

In [10]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      bind (eval e1) (fun v1 -> 
      bind (eval e2) (fun v2 ->
      return (v1+v2)))
  | Div (e1,e2) ->
      bind (eval e1) (fun v1 -> 
      bind (eval e2) (fun v2 ->
      if v2 = 0 then None else return (v1 / v2)))

val eval : expr -> int option = <fun>


## Infix bind operation

Usually `bind` is defined as an infix function `>>=`.

In [11]:
let (>>=) = bind

val ( >>= ) : 'a option -> ('a -> 'b option) -> 'b option = <fun>


In [12]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      eval e1 >>= fun v1 -> 
      eval e2 >>= fun v2 ->
      return (v1+v2)
  | Div (e1,e2) ->
      eval e1 >>= fun v1 -> 
      eval e2 >>= fun v2 ->
      if v2 = 0 then None else return (v1 / v2)

val eval : expr -> int option = <fun>


## Modularise

* The `return` and `>>=` we have defined for the interpreter works for any computation on option type. 
  + Put them in a module, we get the Option Monad.

In [13]:
module type MONAD = sig
  type 'a t
  val return : 'a -> 'a t
  val (>>=)  : 'a t -> ('a -> 'b t) -> 'b t
end

module OptionMonad : (MONAD with type 'a t = 'a option) = struct
  type 'a t = 'a option
  let return v = Some v
  let (>>=) m f = match m with
  | Some v -> f v
  | None -> None
end

module type MONAD =
  sig
    type 'a t
    val return : 'a -> 'a t
    val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t
  end


module OptionMonad :
  sig
    type 'a t = 'a option
    val return : 'a -> 'a t
    val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t
  end


## Monad Laws

Any implementation of the monad signature must satisfy the following laws:


```ocaml
1. return v >>= k  ≡  k v (* Left Identity *)
2. v >>= return  ≡  v (* Right Identity *)
3. (m >>= f) >>= g  ≡  m >>= (fun x -> f x >>= g) (* Associativity *)
```

## Option monad satisifies monad laws

**Left Identity**: `return v >>= k  ≡  k v`

```ocaml
  return v >>= k
≡ (Some v) >>= k (* by definition of return *)
≡ match Some v with None -> None | Some v -> k v (* by definition of >>= *)
≡ k v (* by beta reduction *)
```

**Exercice:** Prove other laws.

## State Monad

* Each monad implementation typically extends the signature with additional operations.
* A State Monad introduces a **single, typed mutable cell**.
* Here's a signature for dealing with mutable state, which adds
  + `get` and `put` functions for reading and writing the state, and 
  + a `runState` function for actually running computations.

In [14]:
module type STATE = sig
  type state
  include MONAD
  val get : state t
  val put : state -> unit t
  val run_state : 'a t -> init:state -> state * 'a
end

module type STATE =
  sig
    type state
    type 'a t
    val return : 'a -> 'a t
    val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t
    val get : state t
    val put : state -> unit t
    val run_state : 'a t -> init:state -> state * 'a
  end


## State Monad

The idea of a state monad is to simulate a single, typed mutable location in the program. Values can be `put` into this location and read from this location using `get`. How might we implement such a feature without using references?

We can *thread* the state through every computation in the monad. Suppose you were interested in implementing an successor function in the state monad. This function does not read or write to the state, but simply passes the state through. 

The usual successor function is:

```ocaml
let succ x = x + 1
```

The successor function that would pass the state through would be:

```ocaml
let succ_st x s = (s, x+1)
```

The extra argument `s` is the state being passed to this function. Unlike the usual successor function, the `succ_st` function returns a pair with the new state (which happens to be the same state as that was passed in) and the result which is `x+1`.

Now, a function `get` which only reads the current state can be written as:

```ocaml
let get s = (s,s)
```

`get` does not modify the current state, and hence, returns the pair `(s,s)` where the first argument is the new state (which is the same as that was passed in) and the second argument is the result of the `get` function.

A function `put s'` which updates the state can be written as:

```ocaml
let put s' s = (s',())
```

`put` updates the state to `s'` and the result of put is `()`.

Observe that the last argument of each of the functions `put`, `get` and `succ` is the previous state and they all return a pair of new state and the result of the computation. If the type of state is `state`, we can assign the following type to the functions:

```ocaml
val put     : state -> state -> (state, unit)
val get     : state -> (state, state)
val succ_st : int -> state -> (state, int)
```

We can make it better by:

```ocaml
type 'a t = state -> state * 'a
val put     : state -> unit t
val get     : state t
val succ_st : int -> int t
```

How do we build up larger programs with these individual functions? We can write a function that 

1. puts 10 and
2. gets the current state and
3. returns 5 + the current state as

```ocaml
let p s0 (* initial state *) =
  let (s1,())     = put 10 s0 in
  let (s2,s)      = get s1 in
  let (s3,result) = succ_st s s2 in
  (s3, result)
```

Rather than explicitly threading the state through, which is tedious, we can use a monad to hide the tedious bits. The subsequent computation may also use results from the previous computation (as in the case of `get` and `succ_st` in the example above). So we define a function:

```ocaml
let (>>=) (m : state -> state * 'a) (f : 'a -> state -> state * 'b) : state -> state * 'b = 
  fun s (* previous state *) ->
    let (s': state, v : 'a) = m s in
    let (s'': state, res: 'b) = f a s in
    (s'', res)
```

Recall that `(>>=)` is an infix function. Using this function, we can write the program `p` as:

```ocaml
let p s0 =
  let computation = 
    put 10 >>= (fun () ->
    get >>= (fun s ->
    succ_st s))
  in
  computation s0
```
We can drop the extra parenthesis to get

```ocaml
let p s0 =
  let computation = 
    put 10 >>= fun () ->
    get >>= fun s ->
    succ_st s
  in
  computation s0
```

We can also rewrite `(>>=)` as:

```ocaml
let (>>=) (m : 'a t) (f : 'a -> 'b t) : 'b t = 
  fun s (* previous state *) ->
    let (s': state, v : 'a) = m s in
    f a s
```

## State Monad

Here's an implementation of `State`, parameterised by the type of the state:

In [15]:
module State (S : sig type t end)
  : STATE with type state = S.t = struct
  type state = S.t
  type 'a t = state -> state * 'a
  let return v = fun s -> (s, v)
  let (>>=) m f = fun s -> 
    let (s', a) = m s in 
    f a s'
  let get s = (s, s)
  let put s' _ = (s', ())
  let run_state m ~init = m init
end

module State :
  functor (S : sig type t end) ->
    sig
      type state = S.t
      type 'a t
      val return : 'a -> 'a t
      val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t
      val get : state t
      val put : state -> unit t
      val run_state : 'a t -> init:state -> state * 'a
    end


## Using State Monad

In [15]:
module IntState = State (struct type t = int end)
open IntState 

let inc v = 
  get >>= fun s ->
  put (s+v)

let dec v = 
  get >>= fun s -> 
  put (s-v)
  
let double =
  get >>= fun s ->
  put (s*2)

module IntState :
  sig
    type state = int
    type 'a t
    val return : 'a -> 'a t
    val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t
    val get : state t
    val put : state -> unit t
    val run_state : 'a t -> init:state -> state * 'a
  end


val inc : int -> unit IntState.t = <fun>


val dec : int -> unit IntState.t = <fun>


val double : unit IntState.t = <abstr>


## Using State Monad

In [16]:
IntState.run_state ~init:10 (
  inc 5 >>= fun () -> 
  dec 10 >>= fun () ->
  double)

- : IntState.state * unit = (10, ())


In [17]:
let module FloatState = State (struct type t = float end) in 
let open FloatState in
FloatState.run_state ~init:5.4 (
  get >>= fun v ->
  put (v +. 1.0))

- : float * unit = (6.4, ())


## State monad satisfies monad laws

**Right Associativity**: `v >>= return  ≡  v`

```ocaml
  v >>= return
≡ fun s -> let (s', a) = v s in return a s' (* by definition of >>= *)
≡ fun s -> let (s', a) = v s in (fun v s -> (s,v)) a s' (* by definition of return *)
≡ fun s -> let (s', a) = v s in (s',a) (* by beta reduction *)
≡ fun s -> (fun (s', a) -> (s', a)) (v s) (* rewrite `let` to `fun` *)
≡ fun s -> v s (* by eta reduction *)
≡ v
```

**Exercise**: Prove other laws.

## Type of State

* State in the state monad is of a single type
  + In our example, the state was of `int` type
* *Can we change type of state as the computation evolves?*

## Parameterised monads

* Parameterised monads add two additional type parameters to `t` representing the start and end states of a computation.
* A computation of type `('p, 'q, 'a) t` has 
  + *precondition* (or starting state) `'p`
  + *postcondition* (or ending state) `'q`
  + *produces a result* of type `'a`.

i.e. `('p, 'q, 'a) t` is a kind of Hoare triple `{P} M {Q}`.



## Parameterised monads

Here's the parameterised monad signature:

In [18]:
module type PARAMETERISED_MONAD =
sig
  type ('s,'t,'a) t
  val return : 'a -> ('s,'s,'a) t
  val (>>=) : ('r,'s,'a) t ->
       ('a -> ('s,'t,'b) t) ->
              ('r,'t,'b) t
end

module type PARAMETERISED_MONAD =
  sig
    type ('s, 't, 'a) t
    val return : 'a -> ('s, 's, 'a) t
    val ( >>= ) : ('r, 's, 'a) t -> ('a -> ('s, 't, 'b) t) -> ('r, 't, 'b) t
  end


## Parameterised state monad

Here's a parameterised monad version of the `STATE` signature, using the extra parameters to represent the type of the reference cell.

In [19]:
module type PSTATE =
sig
 include PARAMETERISED_MONAD
 val get : ('s,'s,'s) t
 val put : 's -> (_,'s,unit) t
 val runState : ('s,'t,'a) t -> init:'s -> 't * 'a
end

module type PSTATE =
  sig
    type ('s, 't, 'a) t
    val return : 'a -> ('s, 's, 'a) t
    val ( >>= ) : ('r, 's, 'a) t -> ('a -> ('s, 't, 'b) t) -> ('r, 't, 'b) t
    val get : ('s, 's, 's) t
    val put : 's -> ('a, 's, unit) t
    val runState : ('s, 't, 'a) t -> init:'s -> 't * 'a
  end


## Parameterised state monad


Here's an implementation of `PSTATE`.

In [20]:
module PState : PSTATE =
struct
  type ('s, 't, 'a) t = 's -> 't * 'a
  let return v s = (s, v)
  let (>>=) m k s = let t, a = m s in k a t
  let put s _ = (s, ())
  let get s = (s, s)
  let runState m ~init = m init
end

module PState : PSTATE


## Computation with changing state

In [21]:
open PState

let inc v = get >>= fun s -> put (s+v)
let dec v = get >>= fun s -> put (s-v)
let double = get >>= fun s -> put (s*2)
  
let to_string = get >>= fun i -> put (string_of_int i)
let of_string = get >>= fun s -> put (int_of_string s)

val inc : int -> (int, int, unit) PState.t = <fun>


val dec : int -> (int, int, unit) PState.t = <fun>


val double : (int, int, unit) PState.t = <abstr>


val to_string : (int, string, unit) PState.t = <abstr>


val of_string : (string, int, unit) PState.t = <abstr>


## Computation with changing state

In [22]:
let foo = inc 5 >>= fun () -> to_string
let bar = get >>= fun s -> put (s ^ "00")
  
let baz = foo >>= fun () -> bar
let quz = bar >>= fun () -> foo

val foo : (int, string, unit) PState.t = <abstr>


val bar : (string, string, unit) PState.t = <abstr>


val baz : (int, string, unit) PState.t = <abstr>


error: compile_error

## A well-typed stack machine

* Let's build a tiny stack machine with 3 instructions
  + `push` pushes a constant on to the stack. Constant could be of any type. 
  + `add` adds the top two integers on the stack and pushes the result
  + `_if_` expects a `[b;v1;v2] @ rest_of_stack` on top of the stack.
      * if `b` is true then result stack will be `v1::rest_of_stack`
      * otherwise, `v2::rest_of_stack`.
* Our stack machine will not get stuck! 
  + recall the definition from lambda calculus lectures
* This is how WASM operational semantics is defined!

## Stack operations

* Because our stack will have values of different types, encode then using pairs.
  + `[]` will be `()`
  + `[1;2;3]` will be `(1, (2, (3, ())))`
  + `[1;true;3]` (which is not a well-typed OCaml expression) will be `(1, (true, (3, ()))))`

## Stack Operations

In [23]:
module type STACK_OPS =
sig
  type ('s,'t,'a) t
  val add : unit -> (int * (int * 's), 
                     int * 's, 
                     unit) t
  val _if_ : unit -> (bool * ('a * ('a * 's)), 
                      'a * 's, 
                      unit) t
  val push_const : 'a -> ('s, 
                          'a * 's, 
                          unit) t
end

module type STACK_OPS =
  sig
    type ('s, 't, 'a) t
    val add : unit -> (int * (int * 's), int * 's, unit) t
    val _if_ : unit -> (bool * ('a * ('a * 's)), 'a * 's, unit) t
    val push_const : 'a -> ('s, 'a * 's, unit) t
  end


## Stack Machine

We can combine the stack operations with the parameterised monad signature to
build a signature for a stack machine:

In [24]:
module type STACKM = sig
 include PARAMETERISED_MONAD
 include STACK_OPS
   with type ('s,'t,'a) t := ('s,'t,'a) t
 val execute : ('s,'t,'a) t -> 's -> 't * 'a
end

module type STACKM =
  sig
    type ('s, 't, 'a) t
    val return : 'a -> ('s, 's, 'a) t
    val ( >>= ) : ('r, 's, 'a) t -> ('a -> ('s, 't, 'b) t) -> ('r, 't, 'b) t
    val add : unit -> (int * (int * 's), int * 's, unit) t
    val _if_ : unit -> (bool * ('a * ('a * 's)), 'a * 's, unit) t
    val push_const : 'a -> ('s, 'a * 's, unit) t
    val execute : ('s, 't, 'a) t -> 's -> 't * 'a
  end


## Stack Machine

Here is the implementation of the stack machine

In [25]:
module StackM : STACKM =
struct
  include PState
 
  let add ()=
    get >>= fun (x,(y,s)) ->
    put (x+y,s)
 
  let _if_ () =
    get >>= fun (c,(t,(e,s))) ->
    put ((if c then t else e),s)

  let push_const k =
    get >>= fun s ->
    put (k, s)

  let execute c s = runState ~init:s c
end

module StackM : STACKM


## Using the stack machine

In [26]:
let program = let open StackM in
  push_const 4 >>= fun () ->
  push_const 5 >>= fun () ->
  push_const true >>= fun () ->
  _if_ () >>= fun () ->
  add ()

val program : (int * '_weak1, int * '_weak1, unit) StackM.t = <abstr>


In [27]:
StackM.execute program (20,(10,()))

- : (int * (int * unit)) * unit = ((25, (10, ())), ())


## Using the stack machine

In [28]:
StackM.execute (StackM._if_ ()) (false,(10,()))

error: compile_error

In [29]:
StackM.execute (StackM.add ()) ()

error: compile_error

<center>

<h1 style="text-align:center"> Fin. </h1>
</center>