# Type Inference by hand

Using the Lettuce grammar and functions from the Type Inference Using Unification lecture, we gain an intuition for the mechanics of type inference by manually solving simple examples.


## Lettuce Grammar

We will use the following grammar that does not include recursion. Note that we do not allow type annotations in our version.

$$\begin{array}{rcll}
\mathbf{Program} & \rightarrow & TopLevel(\mathbf{Expr}) \\[5pt]
\mathbf{Expr} & \rightarrow & Const(\mathbf{Number}) \\
 & | & Ident(\mathbf{Identifier}) \\
 & | & Plus(\mathbf{Expr}, \mathbf{Expr}) \\
 & | & Mult(\mathbf{Expr}, \mathbf{Expr}) \\
 & | & Eq(\mathbf{Expr}, \mathbf{Expr}) \\
 & | & Geq (\mathbf{Expr}, \mathbf{Expr}) \\
 & | & IfThenElse(\mathbf{Expr}, \mathbf{Expr}, \mathbf{Expr}) \\
 & | & Let( \mathbf{Identifier}, \mathbf{Expr}, \mathbf{Expr}) \\
 & | & FunDef( \mathbf{Identifier}, \mathbf{Expr}) \\ 
 & | & FunCall(\mathbf{Expr}, \mathbf{Expr})\\
\end{array}$$


In [33]:
sealed trait Expr

case class Const(f: Double) extends Expr
case class Ident(x: String) extends Expr
case class Plus(e1: Expr, e2: Expr) extends Expr
case class Geq (e1: Expr, e2: Expr) extends Expr
case class And(e1: Expr, e2: Expr) extends Expr
case class Not(e: Expr) extends Expr
case class IfThenElse(e: Expr, e1: Expr, e2: Expr) extends Expr
case class Let(x: String, e1: Expr, e2: Expr) extends Expr
case class FunDef(param: String, body: Expr) extends Expr
case class FunCall(e1: Expr, e2: Expr) extends Expr

defined [32mtrait[39m [36mExpr[39m
defined [32mclass[39m [36mConst[39m
defined [32mclass[39m [36mIdent[39m
defined [32mclass[39m [36mPlus[39m
defined [32mclass[39m [36mGeq[39m
defined [32mclass[39m [36mAnd[39m
defined [32mclass[39m [36mNot[39m
defined [32mclass[39m [36mIfThenElse[39m
defined [32mclass[39m [36mLet[39m
defined [32mclass[39m [36mFunDef[39m
defined [32mclass[39m [36mFunCall[39m

## Abstract Syntax Tree for Types

In lecture, we learned to infer types by performing two steps:

1. Generate a system of equations: these are our type constraints
2. Solve the system of equations using the unification approach.

Let us define an AST for our types along with some utilities.

In [34]:
sealed trait Type
case object NumType extends Type
case object BoolType extends Type
case class FunType(type1: Type, type2: Type) extends Type
case class TypeVar(name:String) extends Type

type TypeEnvironment = Map[String, Type]
type ListOfEquations = List[(Type, Type)]

object TypeVarGenerator {
    var counter = 0
    
    def getFreshTypeVariable(id: String): TypeVar = {
        val t = TypeVar("type_" + id + "_" + counter.toString)
        counter = counter + 1
        t
    }
    
    def resetCounter = {
        counter = 0
    }
}

defined [32mtrait[39m [36mType[39m
defined [32mobject[39m [36mNumType[39m
defined [32mobject[39m [36mBoolType[39m
defined [32mclass[39m [36mFunType[39m
defined [32mclass[39m [36mTypeVar[39m
defined [32mtype[39m [36mTypeEnvironment[39m
defined [32mtype[39m [36mListOfEquations[39m
defined [32mobject[39m [36mTypeVarGenerator[39m

## Generating Type Variables and Constraints

The following rules allow us to generate constraints:
    - `Const(_)`: 
        - return NumType, Nil
    - `Ident(s)`: 
        - look up s in our type environment and throw an error if it is not found
        - return type(s), Nil
    - `Plus(e1, e2)`:   
        - generate constraints for `e1` and add a constraint: `type(e1) == NumType`
        - generate constraints for `e2` and add a constraint: `type(e2) == NumType`
        - return NumType, constraints
    - `Geq(e1, e2)`:  
        - generate constraints for `e1` and add a constraint: `type(e1) == NumType`
        - generate constraints for `e2` and add a constraint: `type(e2) == NumType`
        - return BoolType, constraints
    - `And(e1, e2)`:  
        - generate constraints for `e1` and add a constraint: `type(e1) == BoolType`
        - generate constraints for `e2` and add a constraint: `type(e2) == BoolType`
        - return BoolType, constraints
    - `Not(e1)`:  
        - generate constraints for `e1` and add a constraint: `type(e1) == BoolType`
        - return BoolType, constraints
    - `IfThenElse(e, e1, e2)`:
        - generate constraints for `e` and add a constraint: `type(e) == BoolType`
        - generate constraints for `e1` and `e2` and add a constraint: `type(e1) == type(e2)`
        - return `type(e1)`, constraints
    - `Let(x, e1, e2)`:
        - create a new `TypeVar` for the argument x and call it `type_x_0`
        - generate constraints for `e1` and create a new type environment augmented with `x -> type_x_0`
        - generate constraints for `e2` using the new type environment
        - add the constraint `type_x_0 == type(e1)` 
        - return `type(e2)`, constraints
    - `FunDef(x, e1)`:
        - create a new `TypeVar` for the argument x and call it `type_x_0`
        - create a new type environment augmented with `x -> type_x_0`
        - generate constraints for `e1` using the new type environment 
        - return `FunType(type_x_0, type(e1))`, constraints
    - `FunCall(e1, e2)`:
        - create a new `TypeVar` for the function call and call it `type_fcall_0`
        - generate constraints for `e1` and `e2`
        - add the constraint `type(e1) == FunType(type(e2), type_fcall_0)`
        - return `type_fcall_0`, constraints

From the readings, we have code that will do this automatically. We will go through some examples and build up the constraints manually, then compare the results to the output of `generateEquations`.


In [35]:
case class ErrorException(s: String) extends Exception 

def generateEquations(e: Expr, alpha: TypeEnvironment): (Type, ListOfEquations) = 
    e match {
        case Const(f) => (NumType, Nil) // If expr is a constant, return num and Nil
        case Ident(id) => {
              if (alpha contains id){
                  (alpha(id), Nil)
              } else {
                  throw new ErrorException(s"Used undeclared identifier $id -- type error")
              }
        }

        case Plus(e1, e2) => {
             val (t1, lst1) = generateEquations(e1, alpha) // Gen. eqs for e1
             val (t2, lst2) = generateEquations(e2, alpha) // Gen. eqs for e2
             val combinedList = lst1 ++ lst2 ++ List( (t1, NumType), (t2, NumType) )
             (NumType, combinedList) // The overall type of Plus is a num
        }

        case Geq(e1, e2) => {
             val (t1, lst1) = generateEquations(e1, alpha) // Gen. eqs for e1
             val (t2, lst2) = generateEquations(e2, alpha) // Gen. eqs for e2
             val combinedList = lst1 ++ lst2 ++ List( (t1, NumType), (t2, NumType) )
             (BoolType, combinedList) // Overall type of Geq is a boolean
        }

        case And(e1, e2) => {
            val (t1, lst1) = generateEquations(e1, alpha) // Gen. eqs for e1
            val (t2, lst2) = generateEquations(e2, alpha) // Gen. eqs for e2
            val combinedList = lst1 ++ lst2 ++ List( (t1, BoolType), (t2, BoolType) )
            (BoolType, combinedList) // Overall type of And is a boolean
        }

        case Not(e1) => {
            val (t1, lst1) = generateEquations(e1, alpha) // Gen. eqs for e1
            val combinedList = lst1 ++ List( (t1, BoolType) )
            (BoolType, combinedList) // Overall type of And is a boolean
        }

        case IfThenElse(e, e1, e2) => {
             val (t0, lst0) = generateEquations(e, alpha) // Gen. eqs for e
             val (t1, lst1) = generateEquations(e1, alpha) // Gen. eqs for e1
             val (t2, lst2) = generateEquations(e2, alpha) // Gen. eqs for e2
             val combinedList = lst1 ++ lst2 ++ List( (t0, BoolType), (t1, t2) )
             (t1, combinedList) // Overall type of IfThenElse is t1
        }

        case Let(x, e1, e2) => {
           val tx = TypeVarGenerator.getFreshTypeVariable(x)
           val (te1, listE1) = generateEquations(e1, alpha)
           val newAlpha = alpha + (x -> tx)
           val (te2, listE2) = generateEquations(e2, newAlpha)
           val combinedList = listE1 ++ listE2 ++ List( (tx, te1) )
           (te2, combinedList)
        }  

        case FunDef(param, body) => {
            val tparam = TypeVarGenerator.getFreshTypeVariable(param) // Gen. fresh type variable 
            val newEnv = alpha + (param -> tparam)
            val (tbody, listBody) = generateEquations(body, newEnv)
            val fnType = FunType(tparam, tbody)
            (fnType, listBody)
        }

        case FunCall(e1, e2) => {
            val (te1, listE1) = generateEquations(e1, alpha)
            val (te2, listE2) = generateEquations(e2, alpha)
            val te = TypeVarGenerator.getFreshTypeVariable("fcall")
            val newTypeConstraint = (te1, FunType(te2, te) ) // te1 == te2 => te
            val combinedList = listE1 ++ listE2 ++ List(newTypeConstraint)
            (te, combinedList)
        }  
    }

defined [32mclass[39m [36mErrorException[39m
defined [32mfunction[39m [36mgenerateEquations[39m

In [36]:
def typeToString(t: Type) : String = t match {
    case NumType => "NumType"
    case BoolType => "BoolType"
    case FunType(t1, t2) => "FunType("+(typeToString(t1)) + " => " + (typeToString(t2)) +")"
    case TypeVar(str) => str
}

def prettyPrintTypeEqs (lst: List[(Type, Type)]): Unit = {
    lst.foreach {
        case (t1, t2) => {
            println(typeToString(t1) + " == " + typeToString(t2))
        }
    }
}

defined [32mfunction[39m [36mtypeToString[39m
defined [32mfunction[39m [36mprettyPrintTypeEqs[39m

## Examples


In [37]:
// let x = 2 in x + 1
val expr = Let("x", Const(2), Plus(Ident("x"), Const(1)))

// 1. The top level expression is a Let
//     - `Let(x, e1, e2)`:
//         - create a new `TypeVar` for the argument x and call it `type_x_0`
//         - generate constraints for `e1` and create a new type environment augmented with `x -> type_x_0`
//         - generate constraints for `e2` using the new type environment
//         - add the constraint `type_x_0 == type(e1)` 
//         - return `type(e2)`, constraints

// 2. e1 = Const(2), so type(e1) = NumType and we evaluate e2 using alpha ++ (x -> NumType)

// 3. e2 = Plus(Ident("x"), Const(1))
//     - `Plus(e1_hat, e2_hat)`:   
//         - generate constraints for `e1_hat` and add a constraint: `type(e1_hat) == NumType`
//         - generate constraints for `e2_hat` and add a constraint: `type(e2_hat) == NumType`
//         - return NumType, constraints

// 4. e1_hat = Ident("x") which evaluates to type_x_0, we add our first constraint type_x_0 == NumType.

// 5. e2_hat = Const(1) which evaluates to NumType, we add our second constraint NumType == NumType

// 6. type(e1) = NumType, so add the constraint `type_x_0 == NumType`

// 7. type(e2) = NumType, so we return NumType, constraints

val (typ, lstOfEqs) = generateEquations(
    expr,
    Map[String, Type]()
)
TypeVarGenerator.resetCounter

println("-- Generated Eqs --")
prettyPrintTypeEqs(lstOfEqs)
println("The overall program has type: " + typeToString(typ))



-- Generated Eqs --
type_x_0 == NumType
NumType == NumType
type_x_0 == NumType
The overall program has type: NumType


[36mexpr[39m: [32mLet[39m = [33mLet[39m([32m"x"[39m, [33mConst[39m([32m2.0[39m), [33mPlus[39m([33mIdent[39m([32m"x"[39m), [33mConst[39m([32m1.0[39m)))
[36mtyp[39m: [32mType[39m = NumType
[36mlstOfEqs[39m: [32mListOfEquations[39m = [33mList[39m(
  ([33mTypeVar[39m([32m"type_x_0"[39m), NumType),
  (NumType, NumType),
  ([33mTypeVar[39m([32m"type_x_0"[39m), NumType)
)

In [38]:
// if 5 > 2 then 3 else function (x) x + 1
val expr = IfThenElse(
                Geq(Const(5), Const(2)),
                Const(3),
                FunDef("x", Plus(Ident("x"), Const(1)))
           )

// 1. The top level expression is an IfThenElse
//     - `IfThenElse(e, e1, e2)`:
//         - generate constraints for `e` and add a constraint: `type(e) == BoolType`
//         - generate constraints for `e1` and `e2` and add a constraint: `type(e1) == type(e2)`
//         - return `type(e1)`, constraints

// 2. e = Geq(Const(5), Const(2))
//     - `Geq(e1_hat, e2_hat)`:  
//         - generate constraints for `e1_hat` and add a constraint: `type(e1_hat) == NumType`
//         - generate constraints for `e2_hat` and add a constraint: `type(e2_hat) == NumType`
//         - return BoolType, constraints

// 3. e1_hat = Const(5), so type(e1_hat) = NumType, add constraint NumType == NumType
//    e2_hat = Const(2), so type(e1_hat) = NumType, add constraint NumType == NumType

// 4. type(e) = BoolType, so add constraint BoolType == BoolType

// 5. e1 = Const(3), so type(e1) = NumType

// 6. e2 = FunDef("x", Plus(Ident("x"), Const(1)))
//     - `FunDef(x, e_prime)`:
//         - create a new `TypeVar` for the argument x and call it `type_x_0`
//         - create a new type environment augmented with `x -> type_x_0`
//         - generate constraints for `e_prime` using the new type environment 
//         - return `FunType(type_x_0, type(e_prime))`, constraints

// 7. e_prime = Plus(Ident("x"), Const(1))
//     - `Plus(e1_hat_hat, e2_hat_hat)`:   
//         - generate constraints for `e1_hat_hat` and add a constraint: `type(e1_hat_hat) == NumType`
//         - generate constraints for `e2_hat_hat` and add a constraint: `type(e2_hat_hat) == NumType`
//         - return NumType, constraints

// 8. e1_hat_hat = Ident("x") and so type(e1_hat_hat) = type_x_0, add constraint type_x_0 == NumType

// 9. e2_hat_hat = Const(1), add constraint NumType == NumType

// 10. type(e_prime) = NumType, so type(e2) = FunType(type_x_0, NumType), 
//     add constraint NumType == FunType(type_x_0, NumType) from step 1

// 11. type(e1) = NumType so we return NumType, constraints

val (typ, lstOfEqs) = generateEquations(
    expr,
    Map[String, Type]()
)
TypeVarGenerator.resetCounter

println("-- Generated Eqs --")
prettyPrintTypeEqs(lstOfEqs)
println("The overall program has type: " + typeToString(typ))



-- Generated Eqs --
type_x_0 == NumType
NumType == NumType
BoolType == BoolType
NumType == FunType(type_x_0 => NumType)
The overall program has type: NumType


[36mexpr[39m: [32mIfThenElse[39m = [33mIfThenElse[39m(
  [33mGeq[39m([33mConst[39m([32m5.0[39m), [33mConst[39m([32m2.0[39m)),
  [33mConst[39m([32m3.0[39m),
  [33mFunDef[39m([32m"x"[39m, [33mPlus[39m([33mIdent[39m([32m"x"[39m), [33mConst[39m([32m1.0[39m)))
)
[36mtyp[39m: [32mType[39m = NumType
[36mlstOfEqs[39m: [32mListOfEquations[39m = [33mList[39m(
  ([33mTypeVar[39m([32m"type_x_0"[39m), NumType),
  (NumType, NumType),
  (BoolType, BoolType),
  (NumType, [33mFunType[39m([33mTypeVar[39m([32m"type_x_0"[39m), NumType))
)

In [39]:
// let f = function (x)  x >= 5 in 
//   f(2)

val expr = Let("f", 
               FunDef("x", Geq(Ident("x"), Const(5))), 
               FunCall(Ident("f"), Const(0))
           )

// 1. The top level expression is a Let
//     - `Let(x, e1, e2)`:
//         - create a new `TypeVar` for the argument f and call it `type_f_0`
//         - generate constraints for `e1` and create a new type environment augmented with `x -> type_f_0`
//         - generate constraints for `e2` using the new type environment
//         - add the constraint `type_f_0 == type(e1)` 
//         - return `type(e2)`, constraints

// 2. e1 = FunDef("x", Geq(Ident("x"), Const(5)))
//     - `FunDef(x, e_prime)`:
//         - create a new `TypeVar` for the argument x and call it `type_x_1`
//         - create a new type environment augmented with `x -> type_x_1`
//         - generate constraints for `e_prime` using the new type environment 
//         - return `FunType(type_x_1, type(e_prime))` 

// 3. e_prime = Geq(Ident("x"), Const(5))
//     - `Geq(e1_hat, e2_hat)`:  
//         - generate constraints for `e1_hat` and add a constraint: `type(e1_hat) == NumType`
//         - generate constraints for `e2_hat` and add a constraint: `type(e2_hat) == NumType`
//         - return BoolType

// 4. e1_hat = Ident("x"), so type(e1_hat) = type_x_1, add constraint type_x_1 == NumType

// 5. e2_hat = Const(5), so type(e1_hat) = NumType, add constraint NumType == NumType

// 6. type(e_prime) = BoolType, so type(e1) = FunType(type_x_1, BoolType)

// 7. e2 = FunCall(Ident("f"), Const(0))
//     - `FunCall(e1_hat_hat, e2_hat_hat)`:
//         - create a new `TypeVar` for the function call and call it `type_fcall_2`
//         - generate constraints for `e1_hat_hat` and `e2_hat_hat`
//         - add the constraint `type(e1_hat_hat) == FunType(type(e2_hat_hat), type_fcall_2)`
//         - return `type_fcall_2`, constraints

// 8. e1_hat_hat = Ident("f"), so type(e1_hat_hat) = type_f_0 == FunType(type_x_1, BoolType)

// 9. e2_hat_hat = Const(0), so type(e2_hat_hat) = Numtype
//    add constraint type_f_0 == FunType(Numtype, type_fcall_2)

// 10. return type_fcall_2, constraints

val (typ, lstOfEqs) = generateEquations(
    expr,
    Map[String, Type]()
)
TypeVarGenerator.resetCounter

println("-- Generated Eqs --")
prettyPrintTypeEqs(lstOfEqs)
println("The overall program has type: " + typeToString(typ))



-- Generated Eqs --
type_x_1 == NumType
NumType == NumType
type_f_0 == FunType(NumType => type_fcall_2)
type_f_0 == FunType(type_x_1 => BoolType)
The overall program has type: type_fcall_2


[36mexpr[39m: [32mLet[39m = [33mLet[39m(
  [32m"f"[39m,
  [33mFunDef[39m([32m"x"[39m, [33mGeq[39m([33mIdent[39m([32m"x"[39m), [33mConst[39m([32m5.0[39m))),
  [33mFunCall[39m([33mIdent[39m([32m"f"[39m), [33mConst[39m([32m0.0[39m))
)
[36mtyp[39m: [32mType[39m = [33mTypeVar[39m([32m"type_fcall_2"[39m)
[36mlstOfEqs[39m: [32mListOfEquations[39m = [33mList[39m(
  ([33mTypeVar[39m([32m"type_x_1"[39m), NumType),
  (NumType, NumType),
  ([33mTypeVar[39m([32m"type_f_0"[39m), [33mFunType[39m(NumType, [33mTypeVar[39m([32m"type_fcall_2"[39m))),
  ([33mTypeVar[39m([32m"type_f_0"[39m), [33mFunType[39m([33mTypeVar[39m([32m"type_x_1"[39m), BoolType))
)

## Do all of these examples typecheck? If no, which ones do not and why?
