<center>

<h1 style="text-align:center"> Monads </h1>
</center>


* Monads
  + Dealing with **effects** in a **pure** setting

## Whence Monads

* The term "monad" comes from **Category Theory**
  + Category theory is the study of mathematical abstractions.
  + Out of scope for this course.
  + We will focus on **programming with monads**.

## Monads for programming

* Monads were popularized by the Haskell programming language
  + Haskell is **purely functional** programming languages
  + Unlike OCaml, Haskell separates pure code from side-effecting code through the use of monads.
* Monads are a way to *simulate* and *encapsulate* effects in a pure setting
  + ... similar to how we simulated advanced language features in lambda calculus encodings.
* Monad is an _idiom_ / _a design pattern_
  + not a primitive language feature

## What is a Monad?

A monad is any implementation that satisfies the following signature:

In [None]:
module type Monad = sig
  type 'a t                                 (* computation *)
  val return : 'a -> 'a t                   (* lift a value to a computation *)
  val bind   : 'a t -> ('a -> 'b t) -> 'b t (* sequence two computations *)
end

and the **monad laws**.

## Is that it?

* All of this seems **very abstract** (as many FP concepts are).
* An example will help us see the pattern.
* Let's write an interpreter for arithmetic expressions

## Interpreting artihmetic expressions

In [1]:
type expr = 
| Val of int 
| Plus of expr * expr 
| Div of expr * expr

type expr = Val of int | Plus of expr * expr | Div of expr * expr


## Interpreting artihmetic expressions

* Our goal is to make the interpreter a **total function**.
  + Produces a **value** for every arithmetic expression.

In [2]:
let rec eval e = match e with
  | Val v -> v
  | Plus (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 + v2
  | Div (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 / v2

val eval : expr -> int = <fun>


## Interpreting arithmetic expressions : examples

In [3]:
eval (Plus (Div (Val 4, Val 2), Val 7)) (* 4 / 2 + 7 *)

- : int = 9


## Division by zero

This looks fine. But what happens if the denominator in the division is a 0.

In [4]:
eval (Div (Val 1, Val 0))

error: runtime_error

* Recall that our goal is to make the interpreter a **total function**
  + Due to exceptions, the function is not total.

How can we avoid this?

## Interpreting Arithmetic Expressions: Take 2

* Rewrite `eval` function to have the type `expr -> int option`
  + Return `None` for division by zero.

In [5]:
let rec eval e = match e with
  | Val v -> Some v
  | Plus (e1,e2) ->
      begin match eval e1 with 
      | None -> None
      | Some v1 -> 
          match eval e2 with
          | None -> None 
          | Some v2 -> Some (v1 + v2)
      end
  | Div (e1,e2) ->
      begin match eval e1 with 
      | None -> None
      | Some v1 -> 
          match eval e2 with
          | None -> None 
          | Some v2 -> if v2 = 0 then None else Some (v1 / v2)
      end

val eval : expr -> int option = <fun>


## Interpreting Arithmetic Expressions: Take 2

In [6]:
eval (Plus (Div (Val 4, Val 2), Val 7)) (* 4 / 2 + 7 *)

- : int option = Some 9


In [7]:
eval (Div (Val 1, Val 0)) (* 1 / 0 *)

- : int option = None


## Abstraction

* There is a lot of repeated code in the interpreter above.
  + Factor out common code.

In [8]:
let return v = Some v

val return : 'a -> 'a option = <fun>


In [9]:
let bind m f = match m with
  | None -> None 
  | Some v -> f v

val bind : 'a option -> ('a -> 'b option) -> 'b option = <fun>


**Convention:** Using the names `return` and `bind` to match with the `Monad` module. But you could have alternatively picked any name. 

## Abstraction 

Let's rewrite the interpreter using these functions.

```ocaml
let return v = Some v

let bind m f = match m with
  | None -> None 
  | Some v -> f v
```

In [10]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      bind (eval e1) (fun v1 -> 
      bind (eval e2) (fun v2 ->
      return (v1+v2)))
  | Div (e1,e2) ->
      bind (eval e1) (fun v1 -> 
      bind (eval e2) (fun v2 ->
      if v2 = 0 then None else return (v1 / v2)))

val eval : expr -> int option = <fun>


This is written in a suggestive way so as to lead onto nice syntax.

## Infix bind operation

Usually `bind` is defined as an infix function `>>=`.

In [11]:
let (>>=) = bind

val ( >>= ) : 'a option -> ('a -> 'b option) -> 'b option = <fun>


In [12]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      eval e1 >>= fun v1 -> 
      eval e2 >>= fun v2 ->
      return (v1+v2)
  | Div (e1,e2) ->
      eval e1 >>= fun v1 -> 
      eval e2 >>= fun v2 ->
      if v2 = 0 then None else return (v1 / v2)

val eval : expr -> int option = <fun>


## `let*` syntax extension 

Since OCaml 4.08 released in June 2019, there is new syntax for making it easier to write monadic programs.

In [13]:
let ( let* ) = bind

val ( let* ) : 'a option -> ('a -> 'b option) -> 'b option = <fun>


## `let*` syntax extension 


In [14]:
let rec eval e = match e with
  | Val v -> return v
  | Plus (e1,e2) ->
      let* v1 = eval e1 in
      let* v2 = eval e2 in
      return (v1+v2)
  | Div (e1,e2) ->
      let* v1 = eval e1 in 
      let* v2 = eval e2 in
      if v2 = 0 then None 
      else return (v1 / v2)

val eval : expr -> int option = <fun>


## Compare this to our initial take

```ocaml
let rec eval e = match e with
  | Val v -> v
  | Plus (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 + v2
  | Div (e1,e2) -> 
    let v1 = eval e1 in
    let v2 = eval e2 in
    v1 / v2
```

There are additional `return` and `let*`, but the overall structure remains the same. 

## Modularise

* The `return` and `let*` we have defined for the interpreter works for any computation on option type. 
  + Put them in a module, we get the **Option Monad**.
* Option monad _simulates_ **exceptions**.

In [15]:
module type MONAD = sig
  type 'a t                                  (* computation *)
  val return  : 'a -> 'a t                   (* lift a value to a computation *)
  val (let*)  : 'a t -> ('a -> 'b t) -> 'b t (* sequence two computations *)
end

module OptionMonad : (MONAD with type 'a t = 'a option) = struct
  type 'a t = 'a option
  let return v = Some v
  let (let*) m f = match m with
  | Some v -> f v
  | None -> None
end

module type MONAD =
  sig
    type 'a t
    val return : 'a -> 'a t
    val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t
  end


module OptionMonad :
  sig
    type 'a t = 'a option
    val return : 'a -> 'a t
    val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t
  end


## Monad Laws

Monad laws constrain what the `return` and `>>=` can do.

Any implementation of the monad signature must satisfy the following laws:


```ocaml
1. return v >>= f   ≡  f v   (* Left Identity *)
2. v >>= return     ≡  v     (* Right Identity *)
3. (m >>= f) >>= g  ≡  m >>= (fun x -> f x >>= g) 
                             (* Associativity *)
```

## Option monad satisifies monad laws

**Left Identity**: `return v >>= f  ≡  f v`

```ocaml
  return v >>= f
≡ (Some v) >>= f  (* by definition of return *)
≡ match Some v with 
  | None   -> None 
  | Some v -> f v 
                  (* by definition of >>= *)
≡ f v             (* by beta reduction *)
```

**Exercice:** Prove other laws.

## Simulating state

* Recall, monads simulate **effects** in a **pure** setting.
  + **option** monad simulates **exceptions**
* How can we simulate **mutability**?
  + For a start, a single, typed, mutable location in the whole program.
  + Operations to `get` the current state and `put` a new state.

**Idea:** _Thread_ the state through the program.

_Threading_ the state means passing the state as an additional argument to _every_ function and returning the new state along with the function result.

## Threading the state

What does threading the state look like? 

The usual Fibonacci function looks like:

In [16]:
let rec fib n = 
  if n < 2 then 1 
  else fib (n-1) + fib (n-2)

val fib : int -> int = <fun>


## Threading the state

Here is the Fibonacci function that threads the state through as 

* the last additional argument and 
* returns a pair of the new state and the result of the function

In [17]:
let rec fib n (s (* threaded state *)) = 
  if n < 2 then (s, 1) 
  else 
    let (s1, v1) = fib (n-1) s in
    let (s2, v2) = fib (n-2) s1 in
    (s2, v1 + v2)


val fib : int -> 'a -> 'a * int = <fun>


The above function neither reads the state nor writes to the state.

## Remove tedium

Quite tedious to write functions that explicitly thread the state through (and possibly not even touch it).

**Note:** Using the type variable `state` for the state type.

```ocaml
val fib : int -> state -> state * int

let rec fib n (s (* threaded state *)) = 
  if n < 2 then (s, 1) 
  else 
    let (s1, v1) = fib (n-1) s in
    let (s2, v2) = fib (n-2) s1 in
    (s2, v1 + v2)
```



## Remove tedium

Look at the types:
```ocaml
type state
val fib : int -> state -> (state, int)
```
Identify the monadic pattern:
```ocaml
type state
type 'a t (* computation type *) = state -> (state, 'a)
val fib : int -> int t
```
`'a` is the return type of the computation.

## `bind` computations

How to make this better?

```ocaml
...
    let (s1, v1) = fib (n-1) s in
    let (s2, v2) = fib (n-2) s1 in
    (s2, v1 + v2)
```    

Use `bind` to forward the state to the subsequent computation. 


```ocaml
type state
type 'a t = state -> (state, 'a) (* computation *)

let bind (m : 'a t) (f : 'a -> 'b t) : 'b t = 
  fun s (* current state *) ->
    let (s': state, v : 'a) = m s in
    let (s'': state, res: 'b) = f v s' in
    (s'' (* resultant state *), res)

let return (v : 'a) = fun s -> (s,v)
```

## `bind` computations

With `let (let*) = bind`, we get:

```ocaml
...
    let* v1 = fib (n-1) in
    let* v2 = fib (n-2) in
    return (v1 + v2)
```

👍

## Manipulating the state

In order to read and write the state, we implement the following functions.

```ocaml
type state
type 'a t = state -> (state, 'a) (* computation *)

let get = fun (s:state) -> (s,s)  

let put ns = fun (s:state) -> (ns,())

```

## State Monad

What we've defined is a state monad.

* A State Monad introduces a **single, typed mutable cell**.
* Offers
  + `get` and `put` functions for reading and writing the state, and 
  + Also includes a `run_state` function for performing computations with an initial state.

## State Monad

In [40]:
module type STATE = sig
  type state
  include MONAD
  val get : state t
  val put : state -> unit t
  val run_state : 'a t 
                -> state  (* initial state *)
                -> state (* final state *) * 'a
end

module type STATE =
  sig
    type state
    type 'a t
    val return : 'a -> 'a t
    val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t
    val get : state t
    val put : state -> unit t
    val run_state : 'a t -> state -> state * 'a
  end


## State Monad

Here's an implementation of the `STATE` monad, parameterised by the type of the state:

In [41]:
module State (S : sig type t end) 
  : STATE with type state = S.t = struct

  type state = S.t
  type 'a t = state -> state * 'a (* computation *)

  let return v = fun s -> (s, v)

  let (let*) m f = fun s -> 
    let (s', a) = m s in 
    f a s'

  let get = fun s -> (s, s)

  let put s' = fun _ -> (s', ())

  let run_state m init = m init
end

module State :
  functor (S : sig type t end) ->
    sig
      type state = S.t
      type 'a t
      val return : 'a -> 'a t
      val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t
      val get : state t
      val put : state -> unit t
      val run_state : 'a t -> state -> state * 'a
    end


## Using State Monad

In [42]:
module IntState = State (struct type t = int end)
open IntState 

(* [inc v] increments the state by [v] *)
let inc v = 
  let* s = get in 
  put (s+v)

(* [dec v] decrements the state by [v] *)
let dec v = 
  let* s = get in
  put (s-v)

(* [double] doubles the state *)
let double =
  let* s = get in
  put (s*2)

module IntState :
  sig
    type state = int
    type 'a t
    val return : 'a -> 'a t
    val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t
    val get : state t
    val put : state -> unit t
    val run_state : 'a t -> state -> state * 'a
  end


val inc : int -> unit IntState.t = <fun>


val dec : int -> unit IntState.t = <fun>


val double : unit IntState.t = <abstr>


## Using State Monad

In [43]:
let comp = 
  let* _ = inc 20 in
  let* _ = double in
  dec 10
in

IntState.run_state comp 10

- : IntState.state * unit = (50, ())


In [44]:
module FloatState = State (struct type t = float end)
open FloatState

let comp = 
  let* v = get in 
  let* _ = put (v +. 1.0) in
  return "Hello, world"
;;

run_state  comp 5.4

module FloatState :
  sig
    type state = float
    type 'a t
    val return : 'a -> 'a t
    val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t
    val get : state t
    val put : state -> unit t
    val run_state : 'a t -> state -> state * 'a
  end


val comp : string FloatState.t = <abstr>


- : FloatState.state * string = (6.4, "Hello, world")


## Fibonacci, again (in a monad)

In [45]:
open State (struct type t = int end)

let rec fib n = 
  if n < 2 then return 1
  else
    let* v1 = fib (n-1) in
    let* v2 = fib (n-2) in
    return (v1 + v2)

let fib_state = 
  let* n = get in
  let* r = fib n in
  put r
;;
  
run_state fib_state 10 

val fib : int -> int t = <fun>


val fib_state : unit t = <abstr>


- : state * unit = (89, ())


## State monad satisfies monad laws

**Right Associativity**: `v >>= return  ≡  v`

```ocaml
  v >>= return
≡ fun s -> 
    let (s', a) = v s in 
    return a s' (* by definition of >>= *)
≡ fun s -> 
    let (s', a) = v s in 
    (fun v s -> (s,v)) a s' (* by definition of return *)
≡ fun s -> let (s', a) = v s in (s',a) (* by beta reduction *)
≡ fun s -> v s (* by eta reduction *)
≡ v (* by eta reduction *)
```

**Exercise**: Prove other laws.

## Exercise: Writer Monad

Suppose we want to maintain a log of all function calls made during a computation.

In [None]:
let inc x = x + 1;;
let dec x = x - 1;;
let id x = 
  let v1 = inc x in
  let v2 = dec v1 in
  v2

We first introduce a `log` function:

In [None]:
let log name f x = ("Called " ^ name ^ " on " ^ string_of_int x ^ ";",f x)

In [None]:
log "inc" inc 5

Design a WRITER MONAD module such that the following computation:
```ocaml
let id x = 
  let* v1 = log "inc" inc x in
  let* v2 = log "dec" dec v1 in
  return v2;;
  
id 5  
```
generates the result

```ocaml
("Called inc on 5;Called dec on 6;",5)
```