# Programming with Continuations

We have seen ideas such as continuation passing style and trampolines so far. We presented these chiefly as a means to remove non-tail recursion systematically by _packaging_ the work that remains as a continuation (and as a _thunk_ in case of trampolines). However, this is perhaps not the most interesting application of these ideas. 
The idea of continuations and trampolines are useful in implementing some interesting functions in an application that cannot be performed any other way.

In these notes, we will show two such  applications: 
  - systematically solving constraint problems using various forms of state-space search. 
  - coroutines that can perform "asynchronous operations"
 
 
We assume you are familiar with ideas such as 
  - Depth first search
  - Breadth first search
  - Iterative deepening search
  
You will  encounter these in a basic _Artificial Intelligence_ class such as CSCI 3202.

## Send More Money

This is a standard puzzle: 

 $$\begin{array}{ccccc r}
    & S & E & N & D  & + \\
    & M & O & R & E \\ 
    \hline
    M & O & N & E & Y \\ 
    \end{array}$$
    
Here, $S, E, N, D, M, O, R, Y$ each stand for a digit between $0, \ldots, 9$. Such that the sum of the top row and middle row equals the bottom row after the digits are substituted for each individual number. 

## Boolean Satisfiability Problem

We have a boolean formula such as 

$$ (p \& q \& !r) | ( p \& (!q | !r) \& !s) | (!p \& (!s | q) ) $$

Find truth values for the propositions $p, q, r, s$ that makes the formula above true or conclude that the formula is unsatisfiable. 

## Other Examples

Other examples include solving a sudoku or three coloring a graph. These are all problems where we need to use explicit search.

# Programming Search

There is no technical difficulty about programming search. Let us focus on the send-more-money problem. 
We can always program search explicitly ourselves using recursion.

In [8]:
/* Explicit Search 
    lst: list of letters that have not yet been assigned.
    dict: Dictionary holding the letters that have been assigned.
*/
def solve_send_more_money(lst:List[String], dict: Map[String, Int]): Boolean = {
    if (lst.length == 0){
        /*- all letters have been assigned: check if send + more == money -*/
        val send = 1000*dict("s") + 100*dict("e") + 10 * dict("n") + dict("d") // send
        val more = 1000*dict("m") + 100 * dict("o") + 10 * dict("r") + dict("e") //more 
        val money = 10000*dict("m") + 1000 * dict("o")+ 100*dict("n") + 10 * dict("e") + dict("y") // money
        if (send + more == money) {
            println("Found Solution")
            println("--------------")
            List("s","e","n","d","m","o","r","y"). foreach {  // Print the solution.
                case chr => println(s"${chr} --> ${dict(chr)}")
            }
            true // "return"/evaluate to true
        } else {
            false // "return"/evaluate to false
        }
    } else {
        // Run through the elements 0 to 9 inclusive
        (0 to 9).foldLeft[Boolean] (false)  // Accumulator starts with false
         { 
            (acc, i:Int) => { 
                if (!acc){ // If current accumulator is false
                    val head_chr:String = lst.head // take the first unassigned char
                    val new_dict:Map[String,Int] = dict + (head_chr -> i) // assign it to i
                    solve_send_more_money(lst.tail, new_dict) // recursively solve the rest of the problem
                } else {
                    true // this means, we already found a solution.. just skip to the end.
                }
            }
        }
    }
}

defined [32mfunction[39m [36msolve_send_more_money[39m

In [9]:
solve_send_more_money(List("s","e","n","d","m","o","r","y"), Map.empty)

Found Solution
--------------
s --> 0
e --> 0
n --> 0
d --> 0
m --> 0
o --> 0
r --> 0
y --> 0


[36mres8[39m: [32mBoolean[39m = true

That was blazingly fast. But wait a minute: we forgot a key constraint: the numbers cannot all be 0. In particular, `s` and `m` must range from 1 to 9. Let's modify the code.

In [10]:
/* Explicit Search 
    lst: list of letters that have not yet been assigned.
    dict: Dictionary holding the letters that have been assigned.
    But, s, m cannot be assigned 0.
    Identical to previous code except for the condition above that was added.
*/
def solve_smm_2(lst:List[String], dict: Map[String, Int]): Boolean = {
    if (lst.length == 0){
        /*- all letters have been assigned -*/
        val send = 1000*dict("s") + 100*dict("e") + 10 * dict("n") + dict("d")
        val more = 1000*dict("m") + 100 * dict("o") + 10 * dict("r") + dict("e")
        val money = 10000*dict("m") + 1000 * dict("o")+ 100*dict("n") + 10 * dict("e") + dict("y")
        if (dict("s") != 0 && dict("m") != 0 && send + more == money) {
            println("Found Solution")
            println("--------------")
            List("s","e","n","d","m","o","r","y"). foreach { 
                case chr => println(s"${chr} --> ${dict(chr)}")
            }
            true
        } else {
            false
        }
    } else {
        (0 to 9).foldLeft[Boolean] (false) {
            (acc, i:Int) => {
                if (!acc){
                    val head_chr:String = lst.head
                    val new_dict:Map[String,Int] = dict + (head_chr -> i)
                    solve_smm_2(lst.tail, new_dict)
                } else {
                    true
                }
            }
        }
    }
}

defined [32mfunction[39m [36msolve_smm_2[39m

In [11]:
solve_smm_2(List("s","e","n","d","m","o","r","y"), Map.empty)

Found Solution
--------------
s --> 9
e --> 0
n --> 0
d --> 0
m --> 1
o --> 0
r --> 0
y --> 0


[36mres10[39m: [32mBoolean[39m = true

Success!! However, it took a lot longer (more than 15 seconds) since the trivial solution is no longer available.  

However, the problem is that we have to write a lot of code for this and it is not at all easy to modify for a different purpose or to try a different search scheme. 

## Domain Specific Language (DSL) for programming search.

Ideally, as a user, we would like to write something like this.

~~~
   val s = choose(1 to 9);
   val e = choose(0 to 9);
   val n = choose(0 to 9);
   val d = choose(0 to 9)
   val m = choose(1 to 9)
   val o = choose(0 to 9)
   val r = choose(0 to 9)
   val n = choose(0 to 9)
   val y = choose(0 to 9)
   assert_sol( d + 10 * n + 100 * e + 1000 * s + 
               e + 10 * r + 100 * o + 1000 * m == 
               y + 10*e + 100 * n + 1000 * o + 10000*m
             ) 
~~~

Ideally, the DSL should find and set the appropriate values for the variables to make the assertion at the end go through. However, we are not interested in parsing our own language such as Lettuce. We wish to write code that is going to compile in Scala and just define the appropriate API functions `choose` and `assert_sol` in this case. 

If you think for a minute, the problem is impossible since 
  - Every time you call choose, it returns one value. How can we go back to an earlier choice and backtrack? We do not get to control how scala compiles the code or the JVM executes the compiled bytecode. 
  
The answer of course is __continuation passing style__. We will see that CPS allows our API the unique ability to 
"hijack" execution and thus do things like backtracking without having to rewrite the scala compiler ourselves :-)

Let us take a simpler problem of the same type:
  $$\begin{array}{cccc}
    & A & B & + \\
    &  B & A & \\ 
    \hline
   C & B & C \\ 
    \end{array}$$

We wish to express this code.
 ~~~
 val a = choose(1 to 9)
 val b = choose(1 to 9)
 val c = choose(1 to 9)
 assert_sol(a + 10 * b + 10 * a + b = 100 * c + 10 * b + c )
 ~~~
 
 Let's write it in a contination passing style.
 
 ~~~
 choose( 1 to 9, a => {
   choose( 1 to 9, b => {
     choose( 1 to 9, c => {
        assert_sol(a + 10 * b + 10 * a + b = 100* c +  10 * b + c )
      })
    })
 })
~~~

In other words, let's make function, `choose` take a continuation as an argument.

In [12]:
/* -- for convenience, we will write this in an imperative style -- */

/* function choose: make choices from the provided list one by one and 
   call the continuation for each choice */
def choose(choices: List[Int],  k: Int => Boolean): Boolean = {
    /* -- iterate through each choice --*/
    for (i <- choices){
        /* -- call continuation with current choice --*/
        if (k(i)){ return true } // If executing the rest of the continuation with choice i returns true
    }
    false // we tried all choices and the continuation returned false for all of them.
}

/* When the choices are done, check if we got a solution*/
def assert_sol(d: Map[String, Int], b: Boolean): Boolean = {
    if (b) {
        println("Solution Found")
        println("--------------")
        for ((s,i) <- d){
            println(s"${s} --> ${i}")
        }
        true
    } else {
        false
    }
}


defined [32mfunction[39m [36mchoose[39m
defined [32mfunction[39m [36massert_sol[39m

In [13]:
val lst = (1 to 9).toList
choose(lst, a => { // choose a 
    choose(lst, b => { // choose b
        choose(lst, c => { // choose c
            // assert the solution condition 
            // for convenience, we pack up all the choices into a dictionary for pretty printing.
            assert_sol(Map("a" -> a, "b" -> b, "c"-> c), a + 10 * b + 10 * a + b == 100 * c + 10 * b + c  )
        })
    })
})

Solution Found
--------------
a --> 9
b --> 2
c --> 1


[36mlst[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m1[39m, [32m2[39m, [32m3[39m, [32m4[39m, [32m5[39m, [32m6[39m, [32m7[39m, [32m8[39m, [32m9[39m)
[36mres12_1[39m: [32mBoolean[39m = true

This solution works but it is nevertheless based on depth first search. It also resorts to the use of non-tail recursion. For some search problems, DFS can lead to infinite loops and never find a solution. It would be good to try other search methods as well.

Let's also try and make it more user friendly as well so that the user can 
  - place labels on choices
  - the code does not employ non-tail recursion.
  - allow the search over different types of choices.

In [17]:
/* class dfs searcher is generic over type T, which is the type of choices we are going to make */
/* for the send+more = money problem, we will have T = Int */
class DFSSearcher[T] { 
    /* Let us actually store all the continuations. 
       Although for a DFS, we do not need to do this,
       it will help us for implementing other algorithms like BFS or randomized DFS */
    var continuations: List[Continuation] = Nil /* Store a list of unexecuted continuations */
    var cur_path: List[(String,T)] = Nil /* Store the choices we have made so far */
    
    /*- 
        Useful class for continuations.
    
        Store the label for the choice, the option we have chosen and 
        the continuation to be executed.
        --*/
    class Continuation(val label: String, val opt: T, val kont: T => Boolean){
        val my_path = cur_path // Store a copy of the search path
        /* Function doit --> this will update the cur_path in the outside class and call the continuation */
        def doit: Boolean = {
            cur_path = (label, opt)::my_path
            kont(opt)
        }
    }
 
    /* - Function choose -*/
    def choose(label: String, options: List[T], k: T => Boolean): Boolean = {
        assert (options.length >= 1)
        /* -- iterate through all options.
              create a continuation class for each object and 
              push them to the front of the stack of continuations we are 
              storing --*/
        for (t <- options) {
            continuations = (new Continuation(label, t, k)) :: continuations
        }
        
        /* -- pop the first continuation and execute it -- */
        val head_cont = continuations.head
        continuations = continuations.tail
        head_cont.doit // run the first continuation we popped off the stack.
    }
  
    /*- Function assert_sol -*/
    def assert_sol(b: Boolean): Boolean = {
        if (b) { // Solution found.
            println("Found Solution") // Print the solution
            println("---------------")
            cur_path.foreach {
                case (label, t) => println(s"\t %s --> %s".format(label, t.toString))
            }
            println("---------------"); 
            true // And bail out with true value
        } else {
            /*-- This path is not a solution */
            if (continuations.length >= 1){ /*- Do we still have unexecuted continuations left ? -*/
                val c: Continuation = continuations.head  /* If yes, then take the first one from the list */
                continuations = continuations.tail /* pop it off */
                c.doit /* and execute it */
            } else {
                println("All options exhausted. No solution found");
                false /* bail out with false */
            }
        }
    }

}

defined [32mclass[39m [36mDFSSearcher[39m

In [18]:
def solve_simple_puzzle() = {
    val dfs = new DFSSearcher[Int]()
    val lst = List(1, 2, 3, 4, 5, 6, 7, 8, 9)
    dfs.choose("a", lst, a => {
        dfs.choose("b", lst, b => {
            dfs.choose("c", lst, c => {
                dfs.assert_sol(a + 10 * b + 10 * a + b == 100 * c + 10 * b + c )
            })
        })
    })
}

defined [32mfunction[39m [36msolve_simple_puzzle[39m

In [19]:
solve_simple_puzzle()

Found Solution
---------------
	 c --> 1
	 b --> 2
	 a --> 9
---------------


[36mres18[39m: [32mBoolean[39m = true

In [20]:
/*-- Let's try the larger puzzle --*/

def solve_puzzle_1() = {
    val lst = List(0,1,2,3,4,5,6,7,8,9)
    val d = new DFSSearcher[Int]()
    d.choose("s", lst, s => {
            d.choose("e", lst, e => {
                d.choose("n", lst, n => {
                    d.choose("d", lst, d1 => {
                        d.choose("m", lst, m => {
                        d.choose("o", lst, o => {
                            d.choose("r", lst, r => {
                                 d.choose("y", lst, y => {
                                    d.assert_sol(m >= 1 && d1 + n*10 + e*100 + s*1000+ e + 10 * r + 100 * o + 1000 * m == y + 10*e + 100 * n + 1000 * o + 10000*m)
                                    })
                            })
                        })
                    })
                })
            })
        })
    })
}

solve_puzzle_1()

: 

Unfortunately, Scala's weaknesses with tail call optimization now bite us in the back. Instead of solving the problem, we have made it worse. No fear: let's trampoline our work. 

In [21]:
sealed trait Trampoline[A]
case class Done[A](value: A) extends Trampoline[A]
case class More[A](call: () => Trampoline[A]) extends Trampoline[A]

defined [32mtrait[39m [36mTrampoline[39m
defined [32mclass[39m [36mDone[39m
defined [32mclass[39m [36mMore[39m

In [22]:
import scala.annotation.tailrec 

@tailrec
def run[A](t: Trampoline[A]):A = t match {
    case More(f) => {
        val t1 = f()
        run(t1)
    }
    case Done(r) => r
}

[32mimport [39m[36mscala.annotation.tailrec 

[39m
defined [32mfunction[39m [36mrun[39m

In [23]:
/*-- This is very similar to the original DFSSearcher class but now we will 
     define it as a trampoline --*/

class DFSSearcherTramp[T] {
    var continuations: List[Continuation] = Nil
    var cur_path: List[(String,T)] = Nil

    class Continuation(val label: String, val opt: T, val kont: T => Trampoline[Boolean]){
        val my_path = cur_path
        def doit: Trampoline[Boolean] = {
            cur_path = (label, opt)::my_path
            More( () => kont(opt) )
        }
    }
 
    def choose(label: String, options: List[T], k: T => Trampoline[Boolean]): Trampoline[Boolean] = {
        assert (options.length >= 1)
        for (t <- options) {
            continuations = (new Continuation(label, t, k)) :: continuations
        }
        val fst = continuations.head
        continuations = continuations.tail
        fst.doit
    }
  
    def assert_sol(b: Boolean): Trampoline[Boolean] = {
        if (b) {
            println("Found Solution")
            println("---------------")
            cur_path.foreach {
                case (label, t) => println(s"\t %s --> %s".format(label, t.toString))
            }
            println("---------------");
            Done(true)
        } else {
            if (continuations.length >= 1){
                val c: Continuation = continuations.head 
                continuations = continuations.tail 
                c.doit
            } else {
                println("All options exhausted. No solution found");
                Done(false)
            }
        }
    }

}

defined [32mclass[39m [36mDFSSearcherTramp[39m

In [24]:
def solve_puzzle_2() = {
    val lst = List(0,1,2,3,4,5,6,7,8,9)
    val d = new DFSSearcherTramp[Int]()
    val prob = d.choose("s", lst.tail, s => {
                    d.choose("e", lst, e => {
                        d.choose("n", lst, n => {
                            d.choose("d", lst, d1 => {
                                d.choose("m", lst.tail, m => {
                                    d.choose("o", lst, o => {
                                        d.choose("r", lst, r => {     
                                            d.choose("y", lst, y => {
                                                d.assert_sol(s != 0 && m != 0 && d1 + n* 10 + e * 100 + s *1000 +  e + 10 * r + 100 * o + 1000 * m == y + 10*e + 100 * n + 1000 * o + 10000*m)
                                        })
                                    })
                                })
                            })
                        })
                })
        })
    })
    run(prob)
}

solve_puzzle_2()

Found Solution
---------------
	 y --> 9
	 r --> 0
	 o --> 0
	 m --> 1
	 d --> 0
	 n --> 9
	 e --> 9
	 s --> 9
---------------


defined [32mfunction[39m [36msolve_puzzle_2[39m
[36mres23_1[39m: [32mBoolean[39m = true

Success!! Plus the code executed blindingly fast. Notice how it took less than 400 milli-seconds. The recursive search we wrote at the beginning required nearly 16 seconds ~~ 40x slower.

In [25]:
abstract class BasicSearcher[T] {
    var continuations: List[Continuation] = Nil
    var cur_path: List[(String,T)] = Nil
    class Continuation(val label: String, val opt: T, val kont: T => Trampoline[Boolean]){
        val my_path: List[(String, T)] = cur_path // This is initialized at construction time. 
        def doit: Trampoline[Boolean] = {
            cur_path = (label, opt)::my_path
            More( () => kont(opt) )
        }
    }
 
    def choose(label: String, options: List[T], k: T => Trampoline[Boolean]): Trampoline[Boolean]
    def assert_sol(b: Boolean): Trampoline[Boolean]
}

defined [32mclass[39m [36mBasicSearcher[39m

In [26]:
class BFSSearcher[T] extends BasicSearcher[T] {
    def choose(label: String, options: List[T], k: T => Trampoline[Boolean]): Trampoline[Boolean] = {
        assert (options.length >= 1)
        val new_continuations = options.map(t => new Continuation(label, t, k))
        continuations = continuations ++ new_continuations // Append to the very end
        continuations.head.doit
    }
    
     def assert_sol(b: Boolean): Trampoline[Boolean] = {
        if (b) {
            println("Found Solution")
            println("---------------")
            cur_path.foreach {
                case (label, t) => println(s"\t %s --> %s".format(label, t.toString))
            }
            println("---------------");
            Done(true)
        } else {
            if (continuations.length >= 1){
                val c: Continuation = continuations.head 
                continuations = continuations.tail 
                c.doit
            } else {
                println("All options exhausted. No solution found");
                Done(false)
            }
        }
    }
    
}

defined [32mclass[39m [36mBFSSearcher[39m

Warning: BFS will be quite expensive in terms of memory and in terms of time as well for a problem like this. 
The code below will take a while to run. 

In [26]:
/* def solve_puzzle_3() = {
    val lst = List(0,1,2,3,4,5,6,7,8,9)
    val bfs = new BFSSearcher[Int]()
    val prob = bfs.choose("s", lst.tail, s => {
                    bfs.choose("e", lst, e => {
                        bfs.choose("n", lst, n => {
                            bfs.choose("d", lst, d1 => {
                                bfs.choose("m", lst.tail, m => {
                                    bfs.choose("o", lst, o => {
                                        bfs.choose("r", lst, r => {     
                                            bfs.choose("y", lst, y => {
                                                bfs.assert_sol(s != 0 && m != 0 && d1 + n* 10 + e * 100 + s *1000 +  e + 10 * r + 100 * o + 1000 * m == y + 10*e + 100 * n + 1000 * o + 10000*m)
                                        })
                                    })
                                })
                            })
                        })
                })
        })
    })
    run(prob)
}

solve_puzzle_3()
*/

We can implement a randomized DFS that traverses the options in a randomized manner by shuffling. 

In [27]:
import scala.util.Random.shuffle 

class RandomizedDFSSearcher[T] extends BasicSearcher[T] {
    
    def choose(label: String, options: List[T], k: T => Trampoline[Boolean]): Trampoline[Boolean] = {
        assert (options.length >= 1)
        val new_continuations = options.map(t => new Continuation(label, t, k))
        /*-- Shuffle the continuations in a random order --*/
        continuations =  scala.util.Random.shuffle(new_continuations) ++ continuations
        continuations.head.doit
    }
    
     def assert_sol(b: Boolean): Trampoline[Boolean] = {
        if (b) {
            println("Found Solution")
            println("---------------")
            cur_path.foreach {
                case (label, t) => println(s"\t %s --> %s".format(label, t.toString))
            }
            println("---------------");
            Done(true)
        } else {
            if (continuations.length >= 1){
                val c: Continuation = continuations.head 
                continuations = continuations.tail 
                c.doit
            } else {
                println("All options exhausted. No solution found");
                Done(false)
            }
        }
    }
    
}

[32mimport [39m[36mscala.util.Random.shuffle 

[39m
defined [32mclass[39m [36mRandomizedDFSSearcher[39m

In [28]:
def solve_puzzle_4() = {
    val lst = List(0,1,2,3,4,5,6,7,8,9)
    val d = new RandomizedDFSSearcher[Int]()
    val prob = d.choose("s", lst.tail, s => {
                    d.choose("e", lst, e => {
                        d.choose("n", lst, n => {
                            d.choose("d", lst, d1 => {
                                d.choose("m", lst.tail, m => {
                                    d.choose("o", lst, o => {
                                        d.choose("r", lst, r => {     
                                            d.choose("y", lst, y => {
                                                d.assert_sol(s != 0 && m != 0 && d1 + n* 10 + e * 100 + s *1000 +  e + 10 * r + 100 * o + 1000 * m == y + 10*e + 100 * n + 1000 * o + 10000*m)
                                        })
                                    })
                                })
                            })
                        })
                })
        })
    })
    run(prob)
}

solve_puzzle_4()

Found Solution
---------------
	 y --> 6
	 r --> 0
	 o --> 0
	 m --> 1
	 d --> 3
	 n --> 3
	 e --> 3
	 s --> 9
---------------


defined [32mfunction[39m [36msolve_puzzle_4[39m
[36mres27_1[39m: [32mBoolean[39m = true

We can program more sophisticated search using this approach. Let's build a Boolean SAT solver.

### Example

We have a boolean formula such as 

$$ (p \& q \& !r) | ( p \& (!q | !r) \& !s) | (!p \& (!s | q) ) $$

Find truth values for the propositions $p, q, r, s$ that makes the formula above true or conclude that the formula is unsatisfiable. 

Let's begin by defining Boolean formulas.

In [29]:
sealed trait Formula
case object True extends Formula
case object False extends Formula
case class Var(id: String) extends Formula 
case class And(lst: List[Formula]) extends Formula
case class Not(f: Formula) extends Formula
case class Or (lst: List[Formula]) extends Formula 


defined [32mtrait[39m [36mFormula[39m
defined [32mobject[39m [36mTrue[39m
defined [32mobject[39m [36mFalse[39m
defined [32mclass[39m [36mVar[39m
defined [32mclass[39m [36mAnd[39m
defined [32mclass[39m [36mNot[39m
defined [32mclass[39m [36mOr[39m

In [30]:
val (p, q, r, s) = (Var("p"), Var("q"), Var("r"), Var("s"))
val my_formula = Or(List(
    And(List(p, q, Not(r))),
    And(List(p, Or( List( Not(q), Not(r))), Not(s))),
    And(List(Not(p), Or(List(Not(s), q))))
))



[36mp[39m: [32mVar[39m = [33mVar[39m([32m"p"[39m)
[36mq[39m: [32mVar[39m = [33mVar[39m([32m"q"[39m)
[36mr[39m: [32mVar[39m = [33mVar[39m([32m"r"[39m)
[36ms[39m: [32mVar[39m = [33mVar[39m([32m"s"[39m)
[36mmy_formula[39m: [32mOr[39m = [33mOr[39m(
  [33mList[39m(
    [33mAnd[39m([33mList[39m([33mVar[39m([32m"p"[39m), [33mVar[39m([32m"q"[39m), [33mNot[39m([33mVar[39m([32m"r"[39m)))),
    [33mAnd[39m([33mList[39m([33mVar[39m([32m"p"[39m), [33mOr[39m([33mList[39m([33mNot[39m([33mVar[39m([32m"q"[39m)), [33mNot[39m([33mVar[39m([32m"r"[39m)))), [33mNot[39m([33mVar[39m([32m"s"[39m)))),
    [33mAnd[39m([33mList[39m([33mNot[39m([33mVar[39m([32m"p"[39m)), [33mOr[39m([33mList[39m([33mNot[39m([33mVar[39m([32m"s"[39m)), [33mVar[39m([32m"q"[39m)))))
  )
)

We will now implement a function that takes a variable `id` and substitutes it for a truth value `b`. 
This function will take an input formula and return a transformed formula. For instance, if we took the formula


$$ (p \& q \& !r) | ( p \& (!q | !r) \& !s) | (!p \& (!s | q) ) $$

and substituted $p: true$, we will obtain after substitution and simplification:

$$ (q \& !r) |  ( (!q | !r) \& !s) $$

In [31]:
def substituteAndSimplify(f: Formula, id: String, b: Boolean) : Formula = f match {
    case True => True
    case False => False 
    case Var(s) if s == id => { if (b) { True} else {False } }
    case Var(_) => f 
    case And(lst) => {
        val new_lst = lst.map(substituteAndSimplify(_, id, b))
                         .filter( _ != True)
        if (new_lst.length == 0){
            True
        } else {
            if (new_lst contains False){
                False
            } else {
                if (new_lst.length == 1){
                    new_lst.head
                } else {
                    And(new_lst)
                }
            }
        }
    }
    
    case Or(lst) => {
        val new_lst = lst.map(substituteAndSimplify(_, id, b))
                         . filter(_ != False)
        if (new_lst.length == 0){
            False
        } else {
            if (new_lst contains True ){
                True
            } else{
                if (new_lst.length == 1){
                    new_lst.head
                } else {
                    Or(new_lst)
                }
            }
        }
    }
    
    case Not(g) => {
        val f1 = substituteAndSimplify(g, id, b)
        f1 match {
            case True => False
            case False => True
            case _ => Not(f1)
        }
    }
}

defined [32mfunction[39m [36msubstituteAndSimplify[39m

In [32]:
println(substituteAndSimplify(my_formula, "p", false))
println(substituteAndSimplify(And(List(Var("p"), Var("q"), Not(Var("r")))), "p", false))

Or(List(Not(Var(s)), Var(q)))
False


In [33]:
/*-- Let's create a search for solving SAT
    props: remaining propositional variables,
    formula: Boolean formula.
    --*/
def solve_sat(d: DFSSearcherTramp[Boolean], props: List[String], formula: Formula): Trampoline[Boolean] = {
    if (props.length == 0){
        d.assert_sol(formula == True) // We are out of propositions. Formula better be true/false.
    } else {
        val x = props.head // Take the first proposaition from the list.
        val choices = List(true, false) // Available choices are true/false
        d.choose( x, choices, p => { // Continuation starts here
            val new_formula = substituteAndSimplify(formula, x, p) //substitute for the proposition
            if (new_formula == True) { // If formula simplifies to True
                d.assert_sol(true) // we are done
            } else if (new_formula == False){
                d.assert_sol(false) // If it simplifies to false, we can stop this branch as well
            } else { 
                solve_sat(d, props.tail, new_formula) // Otherwise, solve sat problem with remaining propositions and new formula
            }
        }) // Continuation ends here.
    }
}

defined [32mfunction[39m [36msolve_sat[39m

In [34]:
val props = List("p", "q", "r", "s")
val t = solve_sat(new DFSSearcherTramp[Boolean](), props, my_formula)
run(t)

Found Solution
---------------
	 s --> false
	 r --> false
	 q --> false
	 p --> false
---------------


[36mprops[39m: [32mList[39m[[32mString[39m] = [33mList[39m([32m"p"[39m, [32m"q"[39m, [32m"r"[39m, [32m"s"[39m)
[36mt[39m: [32mTrampoline[39m[[32mBoolean[39m] = [33mMore[39m(
  ammonite.$sess.cmd22$Helper$DFSSearcherTramp$Continuation$$Lambda$2253/0x000000080161d840@1c84781d
)
[36mres33_2[39m: [32mBoolean[39m = true

In [35]:
val my_formula2 = And(List(
    Or(List(p, q, Not(r))),
    Or(List(p, And( List( Not(q), Not(r))), Not(s))),
    Or(List(Not(p), And(List(Not(s), q))))
))

[36mmy_formula2[39m: [32mAnd[39m = [33mAnd[39m(
  [33mList[39m(
    [33mOr[39m([33mList[39m([33mVar[39m([32m"p"[39m), [33mVar[39m([32m"q"[39m), [33mNot[39m([33mVar[39m([32m"r"[39m)))),
    [33mOr[39m([33mList[39m([33mVar[39m([32m"p"[39m), [33mAnd[39m([33mList[39m([33mNot[39m([33mVar[39m([32m"q"[39m)), [33mNot[39m([33mVar[39m([32m"r"[39m)))), [33mNot[39m([33mVar[39m([32m"s"[39m)))),
    [33mOr[39m([33mList[39m([33mNot[39m([33mVar[39m([32m"p"[39m)), [33mAnd[39m([33mList[39m([33mNot[39m([33mVar[39m([32m"s"[39m)), [33mVar[39m([32m"q"[39m)))))
  )
)

In [36]:
val props = List("p", "q", "r", "s")
val t = solve_sat(new DFSSearcherTramp[Boolean](), props, my_formula2)
run(t)

Found Solution
---------------
	 r --> false
	 q --> false
	 p --> false
---------------


[36mprops[39m: [32mList[39m[[32mString[39m] = [33mList[39m([32m"p"[39m, [32m"q"[39m, [32m"r"[39m, [32m"s"[39m)
[36mt[39m: [32mTrampoline[39m[[32mBoolean[39m] = [33mMore[39m(
  ammonite.$sess.cmd22$Helper$DFSSearcherTramp$Continuation$$Lambda$2253/0x000000080161d840@c57bcc5
)
[36mres35_2[39m: [32mBoolean[39m = true

## Coroutines

Co-routines are a programming pattern that allow us to pause the execution of a program right in the middle of doing something, wait for something to happen and then resume.  They are supported (in slightly different manner) by numerous lanaguages

  - [Scala Coroutines](https://scala-coroutines.github.io/coroutines/docs/0.6/101/)
  - [Coroutines in C++](https://en.cppreference.com/w/cpp/language/coroutines)
  - Do a google search and find out about coroutines in Kotlin and other languages.
  
  
### What is a co-routine? 

Suppose we are writing a program with two parts: one part is performing data-acquisition (data-producer) while the other part is doing some processing (data-consumer). Let's say that data is available intermittently and the 
producer places data into a buffer as it arrives. One can use parallelism: the consumer and producer run in separate threads or processes and use shared memory to store and process the data. However, another pattern is to allow functions to suspend and resume execution.

~~~
def producer() = {
   var num_data_produced = 0
   while(true){
    val v = read_data() // Get some data
    println("I produced data")
    num_data_produced = num_data_produced + 1
    yield(v) // This sends the value v to a different coroutine.
   }
}

def consumer() = {
   var data_count = 0
   var data_sum = 0
   while(true){ 
    val v = receive(); // Blocks until producer yields data from some other co-routine that sent it a value.
    process_new_data(v)
    data_count = data_count + 1
    data_sum = data_sum + v 
   }
}
~~~

Note here that `producer` and `consumer` are not "normal" functions. When the producer _yields_ a value 
_v_ to consumer, we would like the consumer to be able to _receive_ and process this value _v_. 
They should run somehow in lock-step.

Also notice that we have local variables for `consumer` including `data_count` and `data_sum` that we would like
to update. As such we cannot do this unless we have access to thread-based parallelism.  However, co-routines are a concept designed for situations like this. Scala has a coroutines library but we will show how to support it using continuations. 

### CoRoutine

Coroutines are functions with a special name or handles such that the following actions are supported.
 - Coroutines can `yield` a value to another coroutine. For simplicity, we assume that all messages are _broadcast_ to every other coroutine. We can provide names/handles to coroutines if need be to control which coroutine receives a message. 
 - Coroutines can `receive` values that are sent by another coroutine through a yield. 
 
Coroutines can support other operations such as block/resume and so on. But we will focus on how to implement a simple coroutine DSL that has support for yielding and receiving values. For simplicity, we will assume that any value has to have a special type `Message` and for now there is just one type of message.

In [37]:
sealed trait Message
case class Num(d: Int) extends Message
case object Terminate extends Message

defined [32mtrait[39m [36mMessage[39m
defined [32mclass[39m [36mNum[39m
defined [32mobject[39m [36mTerminate[39m

We will use continuations to implement co-routines. The idea is that our coroutines library will have a dispatcher which keeps track of each and every coroutine.  We will proceed by rewriting the code we wrote for producer and consumer as follows.


~~~
def producer() = {
   var num_data_produced = 0
   def producer_loop() = {
        val v = read_data() 
        println("I produced data")
        num_data_produced = num_data_produced + 1
        dispatch.yield_val(Num(v), producer_loop)
    }
    producer_loop
}

def consumer() = {
   var data_count = 0
   var data_sum = 0
   def consumer_loop() = {
       dispatch.receive( {
           case Num(v) => {
             process_new_data(v)
             data_count = data_count + 1
             data_sum = data_sum + v   
             consumer_loop()
           }
       }
   }
   consumer_loop
}
~~~

In particular, we write the `yield` and `receive` parts using CPS style. We also remove the loop in favor of a recursive call.

In [38]:
class Dispatcher {
    
    // Continuations that  are waiting for a message?
    var waiting_consumers: List[Message => ()] = Nil
    
    // Yield a value and also take on a continuation
    def yield_val(msg: Message, kont: () => ()) = {
        val current_consumers = waiting_consumers // Take up all the waiting consumers
        waiting_consumers = Nil
        for (k <- current_consumers){
            k(msg) // Call the continuations for each waiting consumer with the current message
        }
        // Done sending message to all current consumers
        kont() // resume my own continuation.
    }
    
    def receive(kont: Message => () ) = {
        // add to the list of waiting consumers.
        waiting_consumers = kont :: waiting_consumers
    }
}

defined [32mclass[39m [36mDispatcher[39m

In [40]:
import scala.util.Random

val dispatch = new Dispatcher()
val rand = new Random()
def read_data(): Int = { // simulate reading data by producing a random number
    rand.nextInt
}

def producer() = { // This is the producer loop
   var num_data_produced = 0
   def producer_loop() {
       if (num_data_produced <= 100){
        val v = read_data() 
        println(s"Producer read data $v")
        num_data_produced = num_data_produced + 1
        dispatch.yield_val(Num(v), producer_loop)
       } else {
           dispatch.yield_val(Terminate, ()=> ()) // Avoid infinite loop by terminating afer 100 steps.
       }
    }
    producer_loop()
}

def process_new_data(v: Int) = {
    println(s"Consumer received data $v") // simulate the data processing -- here simply printing the value.
}

def consumer() = {
   var data_count = 0
   var data_sum = 0
   def consumer_loop() {
       dispatch.receive( {
           case Num(v) => {
             process_new_data(v)
             data_count = data_count + 1
             data_sum = data_sum + v   
             consumer_loop()
           }
           case Terminate => {
               println("Consumer received termination message -- stopping!") // Implement the Terminate message.
               () // Stop looping
           }
       })
   }
   consumer_loop()
}
                  
// Let's start things off
consumer()
producer()

Producer read data -1261617271
Consumer received data -1261617271
Producer read data 62695230
Consumer received data 62695230
Producer read data -1083586563
Consumer received data -1083586563
Producer read data 327135605
Consumer received data 327135605
Producer read data 921078928
Consumer received data 921078928
Producer read data -314295507
Consumer received data -314295507
Producer read data 643966052
Consumer received data 643966052
Producer read data 394086693
Consumer received data 394086693
Producer read data -545057123
Consumer received data -545057123
Producer read data -663574379
Consumer received data -663574379
Producer read data -2051213023
Consumer received data -2051213023
Producer read data -695426150
Consumer received data -695426150
Producer read data -108438896
Consumer received data -108438896
Producer read data -710328312
Consumer received data -710328312
Producer read data 1651490200
Consumer received data 1651490200
Producer read data -719123972
Consumer receive

[32mimport [39m[36mscala.util.Random

[39m
[36mdispatch[39m: [32mDispatcher[39m = ammonite.$sess.cmd37$Helper$Dispatcher@691d635e
[36mrand[39m: [32mRandom[39m = scala.util.Random@6ca34c08
defined [32mfunction[39m [36mread_data[39m
defined [32mfunction[39m [36mproducer[39m
defined [32mfunction[39m [36mprocess_new_data[39m
defined [32mfunction[39m [36mconsumer[39m

## Advanced Topic # 1: Control Operators 

-- moot since scala does not support them --



## Advanced Topic # 2: Monads

Starting out we wanted to have the following syntax: 
~~~
   val s = choose(1 to 9);
   val e = choose(0 to 9);
   val n = choose(0 to 9);
   val d = choose(0 to 9);
   val m = choose(1 to 9);
   val o = choose(0 to 9);
   val r = choose(0 to 9);
   val n = choose(0 to 9);
   val y = choose(0 to 9);
   assert_sol( d + 10 * n + 100 * e + 1000 * s + e + 10 * r + 100 * o + 1000 * m == y + 10*e + 100 * n + 1000 * o + 10000*m) 
~~~

but we had to settle for writing our code in CPS style. This is a bummmer. Can we get back to our initial idea?
Yes, to an extent we can using the concept of a monad and Scala's inbuilt support for Monads. We will briefly cover Monads in a future lecture. They are sometimes called "programmable semicolons ". What we will do is to rework things slightly.

   

In [41]:
class DFSSearcherMonad[T] {
    var cur_path: List[(String,T)] = Nil

    class Chooser(label: String, options: List[T]) {
        def flatMap (k: T => Boolean): Boolean = {
            assert (options.length >= 1)
            val path_when_called = cur_path
            for (t <- options) {
                cur_path = (label, t) ::path_when_called
                if (k(t)) {
                    return true
                }
            }
           false
        }
        def map(b: T => Boolean):Boolean = {
            val path_when_called = cur_path
            for (t <- options){
                val bval = b(t) 
                cur_path = (label, t)::path_when_called
                if (bval) {
                    println("Found Solution")
                    println("---------------")
                    cur_path.foreach {
                        case (label, t) => println(s"\t %s --> %s".format(label, t.toString))
                    }
                    println("---------------");
                    return true
                } 
            }
            false
        }
        
        def withFilter(p: T => Boolean): Chooser  ={
            new Chooser(label, options.filter(p))
        }
    }
    
    def choose(label: String, options: List[T]) = new Chooser(label, options)
    
  
   

}

defined [32mclass[39m [36mDFSSearcherMonad[39m

In [42]:
val d = new DFSSearcherMonad[Int]()
val lst = (1 to 9).toList

for {
    a: Int <- d.choose("a", lst)
    b: Int <- d.choose("b", lst)
    c: Int <- d.choose("c", lst)
} yield (a + 10 * b + b + 10 * a == 100 *c + 10 *b + c)


Found Solution
---------------
	 c --> 1
	 b --> 2
	 a --> 9
---------------


[36md[39m: [32mDFSSearcherMonad[39m[[32mInt[39m] = ammonite.$sess.cmd40$Helper$DFSSearcherMonad@13133bc6
[36mlst[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m1[39m, [32m2[39m, [32m3[39m, [32m4[39m, [32m5[39m, [32m6[39m, [32m7[39m, [32m8[39m, [32m9[39m)
[36mres41_2[39m: [32mBoolean[39m = true

In [43]:
val dm = new DFSSearcherMonad[Int]()
val lst = (0 to 9).toList

for {
    s : Int <- dm.choose("s", lst) if s != 0
    e : Int <- dm.choose("e", lst)
    n : Int <- dm.choose("n", lst)
    d : Int <- dm.choose("d", lst)
    m : Int <- dm.choose("m", lst) if m != 0
    o : Int <- dm.choose("o", lst)
    r : Int <- dm.choose("r", lst)
    n : Int <- dm.choose("n", lst)
    y : Int <- dm.choose("y", lst)
    
} yield d + 10 * n + 100 * e + 1000 * s + e + 10 * r + 100 * o + 1000 * m == y + 10*e + 100 * n + 1000 * o + 10000*m


Found Solution
---------------
	 y --> 0
	 n --> 0
	 r --> 0
	 o --> 0
	 m --> 1
	 d --> 0
	 n --> 0
	 e --> 0
	 s --> 9
---------------


[36mdm[39m: [32mDFSSearcherMonad[39m[[32mInt[39m] = ammonite.$sess.cmd40$Helper$DFSSearcherMonad@63202fd7
[36mlst[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m0[39m, [32m1[39m, [32m2[39m, [32m3[39m, [32m4[39m, [32m5[39m, [32m6[39m, [32m7[39m, [32m8[39m, [32m9[39m)
[36mres42_2[39m: [32mBoolean[39m = true

That's all folks!