# Higher-Order Types
Types are a useful tool to keep programs organized and avoid certain kind of errors.

Higher-order types extend the idea of types in a fairly intuitive way given the first-class nature of functions in functional programming, but can be confusing to look at and takes some getting used to.

We will be looking at types in both Scala and Lettuce, starting simple and working our way up to complex, higher-order types.

First, let's just get a type checker for Lettuce going. Don't worry about how this code works for now.

In [5]:
sealed trait Type
case object NumType extends Type
case object BoolType extends Type
case class FunType(t1: Type, t2: Type) extends Type

sealed trait Program
sealed trait Expr

case class Const(f: Double) extends Expr
case class Ident(s: String) extends Expr
case class Plus(e1: Expr, e2: Expr) extends Expr
case class Minus(e1: Expr, e2: Expr) extends Expr
case class Geq(e1: Expr, e2: Expr) extends Expr
case class IfThenElse(e1: Expr, e2: Expr, e3: Expr) extends Expr
case class Let(x: String, xType: Type, e1: Expr, e2: Expr) extends Expr
case class FunDef(id: String, idType: Type, e: Expr) extends Expr
case class FunCall(calledFun: Expr, argExpr: Expr) extends Expr
case class TopLevel(e: Expr) extends Program

def typeEquals(t1: Type, t2: Type): Boolean = t1 == t2
case class TypeErrorException(s: String) extends Exception

def typeOf(e: Expr, alpha: Map[String, Type]): Type = {
    def checkType(opName: String, e1: Expr, t1: Type, e2: Expr, t2: Type, resType: Type): Type = {
        val t1hat = typeOf(e1, alpha)
        if (! typeEquals(t1hat, t1)){
            throw new TypeErrorException(s"Type mismatch in arithmetic/comparison/bool op $opName, Expected type $t1, obtained $t1hat")
        }
        
        val t2hat = typeOf(e2, alpha)
        if (! typeEquals(t2hat, t2)){
            throw new TypeErrorException(s"Type mismatch in arithmetic/comparison/bool op $opName, Expected type $t2, obtained $t2hat")
        }
        
        resType
    }
    
    e match {
        case Const(f) => NumType
        case Ident(s) => {if (alpha contains s)
                             alpha(s)
                          else 
                             throw TypeErrorException(s"Unknown identifier $s")}
        case Plus(e1, e2) =>  checkType("Plus", e1,  NumType, e2, NumType, NumType)
        case Minus(e1, e2) => checkType("Minus",e1,  NumType, e2, NumType, NumType)
        case Geq(e1, e2) => checkType("Geq", e1,  NumType, e2, NumType, BoolType)
        case IfThenElse(e, e1, e2) => {
            val t = typeOf(e, alpha)
            if (t == BoolType){
                val t1 = typeOf(e1, alpha)
                val t2 = typeOf(e2, alpha)
                if (typeEquals(t1, t2))
                    t1
                else 
                    throw TypeErrorException(s"If then else returns unequal types $t1 and $t2")
            } else {
                throw TypeErrorException(s"If then else condition expression not boolean $t")
            }
        }

        case Let(x, t, e1, e2) => {
            val t1 = typeOf(e1, alpha)
            if (typeEquals(t1, t)){
                val newAlpha = alpha + (x -> t)
                typeOf(e2, newAlpha)
            } else {
                throw TypeErrorException(s"Let binding has type $t whereas it is bound to expression of type $t1")
            }
        }

        case FunDef(x, t1, e) => {
            val newAlpha = alpha + (x -> t1)
            val t2 = typeOf(e, newAlpha)
            FunType(t1, t2)
        }

        case FunCall(e1, e2) => {
            val ftype = typeOf(e1, alpha)
            ftype match {
                case FunType(t1, t2) => {
                    val argType = typeOf(e2, alpha)
                    if (typeEquals(argType, t1)){
                        t2
                    } else {
                        throw TypeErrorException(s"Call to function with incompatible argument type. Expected $t1, obtained $argType")
                    }
                }
                case _ => { throw TypeErrorException(s"Call to function but with a non function type $ftype")}

            }
        }
    }
}

def typeOfProgram(p: Program) = p match {
    case TopLevel(e) => {
            val t = typeOf(e, Map())
            //println(s"Program type computed successfully as $t")
            t
    }
}

defined [32mtrait[39m [36mType[39m
defined [32mobject[39m [36mNumType[39m
defined [32mobject[39m [36mBoolType[39m
defined [32mclass[39m [36mFunType[39m
defined [32mtrait[39m [36mProgram[39m
defined [32mtrait[39m [36mExpr[39m
defined [32mclass[39m [36mConst[39m
defined [32mclass[39m [36mIdent[39m
defined [32mclass[39m [36mPlus[39m
defined [32mclass[39m [36mMinus[39m
defined [32mclass[39m [36mGeq[39m
defined [32mclass[39m [36mIfThenElse[39m
defined [32mclass[39m [36mLet[39m
defined [32mclass[39m [36mFunDef[39m
defined [32mclass[39m [36mFunCall[39m
defined [32mclass[39m [36mTopLevel[39m
defined [32mfunction[39m [36mtypeEquals[39m
defined [32mclass[39m [36mTypeErrorException[39m
defined [32mfunction[39m [36mtypeOf[39m
defined [32mfunction[39m [36mtypeOfProgram[39m

### Simple Types
Make some Scala and Lettuce values that have the following types.

In [6]:
import scala.reflect.runtime.universe.TypeTag
def getType[A](a: A)(implicit evA: TypeTag[A]) = evA.toString

//Make a Double in Scala
val sv1 = 2.0

assert(getType(sv1) == "TypeTag[Double]")

//Make a Bool in Scala
val sv2 = true

assert(getType(sv2) == "TypeTag[Boolean]")

[32mimport [39m[36mscala.reflect.runtime.universe.TypeTag
[39m
defined [32mfunction[39m [36mgetType[39m
[36msv1[39m: [32mDouble[39m = [32m2.0[39m
[36msv2[39m: [32mBoolean[39m = true

In [7]:
//Make a NumType in Lettuce
val lv1 = Const(2)

assert(typeOfProgram(TopLevel(lv1)) == NumType)

//Make a BoolType in Lettuce
val lv2 = Geq(Const(1), Const(0))

assert(typeOfProgram(TopLevel(lv2)) == BoolType)

[36mlv1[39m: [32mConst[39m = [33mConst[39m([32m2.0[39m)
[36mlv2[39m: [32mGeq[39m = [33mGeq[39m([33mConst[39m([32m1.0[39m), [33mConst[39m([32m0.0[39m))

### Functions
Of course, we have functions in both our languages, and functions have types too.

Say we have a function that takes as input a double and yields that double plus one. We say that function has type $\mathbf{Double} \Rightarrow \mathbf{Double}$, or "double to double".

In [8]:
//Make a Double => Double in Scala
def sf1(x:Double) =
    x+1

assert(getType(sf1(_)) == "TypeTag[Double => Double]")

//Make a Double => Double => Boolean in Scala
def sf2(x:Double) =
    (y:Double) => x == y

assert(getType(sf2(_)) == "TypeTag[Double => (Double => Boolean)]")

defined [32mfunction[39m [36msf1[39m
defined [32mfunction[39m [36msf2[39m

In [10]:
//Make a NumType => NumType in Lettuce
val lf1 =
    FunDef("x", NumType, Plus(Ident("x"), Const(1)))

assert(typeOfProgram(TopLevel(lf1)) == FunType(NumType, NumType))

//Make a NumType => NumType => BoolType in Lettuce
val lf2 =
    FunDef("x", NumType,
               FunDef("y", NumType, Geq(Ident("x"), Ident("y")))
          )

assert(typeOfProgram(TopLevel(lf2)) == FunType(NumType, FunType(NumType, BoolType)))

[36mlf1[39m: [32mFunDef[39m = [33mFunDef[39m([32m"x"[39m, NumType, [33mPlus[39m([33mIdent[39m([32m"x"[39m), [33mConst[39m([32m1.0[39m)))
[36mlf2[39m: [32mFunDef[39m = [33mFunDef[39m(
  [32m"x"[39m,
  NumType,
  [33mFunDef[39m([32m"y"[39m, NumType, [33mGeq[39m([33mIdent[39m([32m"x"[39m), [33mIdent[39m([32m"y"[39m)))
)

Lets see how these functions reduce as we apply arguements to them.

In [11]:
val sf1_ = sf1(1)
println(getType(sf1(_)))
println(getType(sf1_))
println()

TypeTag[Double => Double]
TypeTag[Double]



[36msf1_[39m: [32mDouble[39m = [32m2.0[39m

In [12]:
val sf2_ = sf2(1)
val sf2__ = sf2_(2)
println(getType(sf2 (_)))
println(getType(sf2_))
println(getType(sf2__))
println()

TypeTag[Double => (Double => Boolean)]
TypeTag[Double => Boolean]
TypeTag[Boolean]



[36msf2_[39m: [32mDouble[39m => [32mBoolean[39m = ammonite.$sess.cmd7$Helper$$Lambda$2936/0x0000000801e53040@57a94cae
[36msf2__[39m: [32mBoolean[39m = false

In [13]:
val lf1_ = FunCall(lf1,Const(3))
println(typeOfProgram(TopLevel(lf1)))
println(typeOfProgram(TopLevel(lf1_)))
println()

FunType(NumType,NumType)
NumType



[36mlf1_[39m: [32mFunCall[39m = [33mFunCall[39m(
  [33mFunDef[39m([32m"x"[39m, NumType, [33mPlus[39m([33mIdent[39m([32m"x"[39m), [33mConst[39m([32m1.0[39m))),
  [33mConst[39m([32m3.0[39m)
)

In [14]:
val lf2_ = FunCall(lf2,Const(1))
val lf2__ = FunCall(lf2_,Const(1))
println(typeOfProgram(TopLevel(lf2)))
println(typeOfProgram(TopLevel(lf2_)))
println(typeOfProgram(TopLevel(lf2__)))

FunType(NumType,FunType(NumType,BoolType))
FunType(NumType,BoolType)
BoolType


[36mlf2_[39m: [32mFunCall[39m = [33mFunCall[39m(
  [33mFunDef[39m([32m"x"[39m, NumType, [33mFunDef[39m([32m"y"[39m, NumType, [33mGeq[39m([33mIdent[39m([32m"x"[39m), [33mIdent[39m([32m"y"[39m)))),
  [33mConst[39m([32m1.0[39m)
)
[36mlf2__[39m: [32mFunCall[39m = [33mFunCall[39m(
  [33mFunCall[39m(
    [33mFunDef[39m([32m"x"[39m, NumType, [33mFunDef[39m([32m"y"[39m, NumType, [33mGeq[39m([33mIdent[39m([32m"x"[39m), [33mIdent[39m([32m"y"[39m)))),
    [33mConst[39m([32m1.0[39m)
  ),
  [33mConst[39m([32m1.0[39m)
)

### High-Order Functions
We've seen functions that can take functions as arguments. Let's take a closer look at how they are typed.

When a function $f1$ is given as an argument to another function $f2$, we will see the type of $f1$ expressed in $f2$'s type signature, contained in parens.

Say $f$ takes a function from double to double as one argument and a double as a second argument. It then applys its first argument to the second argument to produce a new double.

ie: f ( (x:Double => x+1) , 2 ) == 2+1 == 3

The type of $f$ should look something like this: $(\mathbf{Double} \Rightarrow \mathbf{Double}) \Rightarrow \mathbf{Double} \Rightarrow \mathbf{Double}$

The first component, $(\mathbf{Double} \Rightarrow \mathbf{Double})$, repersents the function given as an argument. The next component, $\mathbf{Double}$, is the second argument, just a simple double. The last component, $\mathbf{Double}$, is what gets produced by $f$ once it receives all its arguments.

In [15]:
//Make a (Double => Boolean) => Double => Boolean in Scala
def sho1(f:(Double => Boolean)) =
    (y:Double) => f(y)

assert(getType(sho1(_)) == "TypeTag[(Double => Boolean) => (Double => Boolean)]")

//Make a (Double => Double => Boolean) => (Double => Double) => Bool in Scala
def sho2(f1:(Double => Double => Boolean)) =
    (f2:(Double => Double)) => true

assert(getType(sho2(_)) == "TypeTag[(Double => (Double => Boolean)) => ((Double => Double) => Boolean)]")

defined [32mfunction[39m [36msho1[39m
defined [32mfunction[39m [36msho2[39m

In [18]:
//Make a (NumType => BoolType) => NumType => BoolType in Lettuce
val lho1 =
    FunDef("f", FunType(NumType, BoolType),
           FunDef("x", NumType, FunCall(Ident("f"),Ident("x")))
          )

assert(typeOfProgram(TopLevel(lho1)) == FunType(FunType(NumType, BoolType), FunType(NumType, BoolType)))

//Make a (NumType => NumType => BoolType) => (NumType => NumType) => BoolType in Lettuce
val lho2 =
    FunDef("f1", typeOfProgram(TopLevel(lf2)),
           FunDef("f2", typeOfProgram(TopLevel(lf1)), Geq(Const(1), Const(0)))
          )

assert(typeOfProgram(TopLevel(lho2)) == FunType(FunType(NumType, FunType(NumType, BoolType)),FunType(FunType(NumType, NumType), BoolType)))

[36mlho1[39m: [32mFunDef[39m = [33mFunDef[39m(
  [32m"f"[39m,
  [33mFunType[39m(NumType, BoolType),
  [33mFunDef[39m([32m"x"[39m, NumType, [33mFunCall[39m([33mIdent[39m([32m"f"[39m), [33mIdent[39m([32m"x"[39m)))
)
[36mlho2[39m: [32mFunDef[39m = [33mFunDef[39m(
  [32m"f1"[39m,
  [33mFunType[39m(NumType, [33mFunType[39m(NumType, BoolType)),
  [33mFunDef[39m([32m"f2"[39m, [33mFunType[39m(NumType, NumType), [33mGeq[39m([33mConst[39m([32m1.0[39m), [33mConst[39m([32m0.0[39m)))
)

Again, lets see how these functions reduce as we apply arguements to them.

In [16]:
val sho1_ = sho1((x:Double) => true)
val sho1__ = sho1_(2)
println(getType(sho1(_)))
println(getType(sho1_))
println(getType(sho1__))
println()

TypeTag[(Double => Boolean) => (Double => Boolean)]
TypeTag[Double => Boolean]
TypeTag[Boolean]



[36msho1_[39m: [32mDouble[39m => [32mBoolean[39m = ammonite.$sess.cmd14$Helper$$Lambda$3162/0x0000000801f03840@a8dc7d0
[36msho1__[39m: [32mBoolean[39m = true

In [17]:
val sho2_ = sho2((x:Double) => ((y:Double) => true))
val sho2__ = sho2_((x:Double) => 2.0)
println(getType(sho2(_)))
println(getType(sho2_))
println(getType(sho2__))
println()

TypeTag[(Double => (Double => Boolean)) => ((Double => Double) => Boolean)]
TypeTag[(Double => Double) => Boolean]
TypeTag[Boolean]



[36msho2_[39m: [32mDouble[39m => [32mDouble[39m => [32mBoolean[39m = ammonite.$sess.cmd14$Helper$$Lambda$3182/0x0000000801f0f840@647f2b5b
[36msho2__[39m: [32mBoolean[39m = true

In [19]:
val lho1_ = FunCall(lho1, FunDef("x", NumType, Geq(Const(2),Const(2))))
val lho1__ = FunCall(lho1_, Const(2))
println(typeOfProgram(TopLevel(lho1)))
println(typeOfProgram(TopLevel(lho1_)))
println(typeOfProgram(TopLevel(lho1__)))
println()

FunType(FunType(NumType,BoolType),FunType(NumType,BoolType))
FunType(NumType,BoolType)
BoolType



[36mlho1_[39m: [32mFunCall[39m = [33mFunCall[39m(
  [33mFunDef[39m(
    [32m"f"[39m,
    [33mFunType[39m(NumType, BoolType),
    [33mFunDef[39m([32m"x"[39m, NumType, [33mFunCall[39m([33mIdent[39m([32m"f"[39m), [33mIdent[39m([32m"x"[39m)))
  ),
  [33mFunDef[39m([32m"x"[39m, NumType, [33mGeq[39m([33mConst[39m([32m2.0[39m), [33mConst[39m([32m2.0[39m)))
)
[36mlho1__[39m: [32mFunCall[39m = [33mFunCall[39m(
  [33mFunCall[39m(
    [33mFunDef[39m(
      [32m"f"[39m,
      [33mFunType[39m(NumType, BoolType),
      [33mFunDef[39m([32m"x"[39m, NumType, [33mFunCall[39m([33mIdent[39m([32m"f"[39m), [33mIdent[39m([32m"x"[39m)))
    ),
    [33mFunDef[39m([32m"x"[39m, NumType, [33mGeq[39m([33mConst[39m([32m2.0[39m), [33mConst[39m([32m2.0[39m)))
  ),
  [33mConst[39m([32m2.0[39m)
)

In [20]:
val lho2_ = FunCall(lho2, lf2)
val lho2__ = FunCall(lho2_, lf1)
println(typeOfProgram(TopLevel(lho2)))
println(typeOfProgram(TopLevel(lho2_)))
println(typeOfProgram(TopLevel(lho2__)))

FunType(FunType(NumType,FunType(NumType,BoolType)),FunType(FunType(NumType,NumType),BoolType))
FunType(FunType(NumType,NumType),BoolType)
BoolType


[36mlho2_[39m: [32mFunCall[39m = [33mFunCall[39m(
  [33mFunDef[39m(
    [32m"f1"[39m,
    [33mFunType[39m(NumType, [33mFunType[39m(NumType, BoolType)),
    [33mFunDef[39m([32m"f2"[39m, [33mFunType[39m(NumType, NumType), [33mGeq[39m([33mConst[39m([32m1.0[39m), [33mConst[39m([32m0.0[39m)))
  ),
  [33mFunDef[39m([32m"x"[39m, NumType, [33mFunDef[39m([32m"y"[39m, NumType, [33mGeq[39m([33mIdent[39m([32m"x"[39m), [33mIdent[39m([32m"y"[39m))))
)
[36mlho2__[39m: [32mFunCall[39m = [33mFunCall[39m(
  [33mFunCall[39m(
    [33mFunDef[39m(
      [32m"f1"[39m,
      [33mFunType[39m(NumType, [33mFunType[39m(NumType, BoolType)),
      [33mFunDef[39m([32m"f2"[39m, [33mFunType[39m(NumType, NumType), [33mGeq[39m([33mConst[39m([32m1.0[39m), [33mConst[39m([32m0.0[39m)))
    ),
    [33mFunDef[39m([32m"x"[39m, NumType, [33mFunDef[39m([32m"y"[39m, NumType, [33mGeq[39m([33mIdent[39m([32m"x"[39m), [33mIdent[39m([32m"

### Map, Filter and Fold
Finally, let's look at the types for our favorite higher-order functions in order to understand them better.


In [23]:
def myMap[A,B](l:List[A], f:(A => B)): List[B] =
    l match {
        case Nil => Nil
        case x::xs => f(x) :: myMap(xs, f)
    }

myMap(List(1,2,3), (x:Int)=>x+1)
myMap(List(1,2,3), (x:Int)=>x>1)

def myFilter[A](l:List[A], f:(A => Boolean)): List[A] =
    l match {
        case Nil => Nil
        case x::xs => f(x) match {
            case true => x :: myFilter(xs, f)
            case false => myFilter(xs, f)
        }
    }

myFilter(List(1,2,3), (x:Int)=>x>=2)

def myFold[A,B](l:List[A], acc:B, f:((B,A) => B)): B = 
    l match{
        case Nil => acc
        case x::xs => myFold(xs, f(acc, x), f)
    }

myFold(List(1,2,3), 0, (x:Int,y:Int)=>x+y)

def mapWithFold[A,B](l:List[A], f:(A => B)): List[B] = 
    myFold(l, Nil, (x:List[B], y:A) => (f(y) :: x)).reverse

mapWithFold(List(1,2,3), (x:Int)=>x+1)
mapWithFold(List(1,2,3), (x:Int)=>x>1)

def filterWithFold[A](l:List[A], f:(A=>Boolean)): List[A] =
    myFold(l, Nil, (x:List[A], y:A) => if(f(y)) (y :: x) else x).reverse

filterWithFold(List(1,2,3), (x:Int)=>x>=2)

defined [32mfunction[39m [36mmyMap[39m
[36mres22_1[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m2[39m, [32m3[39m, [32m4[39m)
[36mres22_2[39m: [32mList[39m[[32mBoolean[39m] = [33mList[39m(false, true, true)
defined [32mfunction[39m [36mmyFilter[39m
[36mres22_4[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m2[39m, [32m3[39m)
defined [32mfunction[39m [36mmyFold[39m
[36mres22_6[39m: [32mInt[39m = [32m6[39m
defined [32mfunction[39m [36mmapWithFold[39m
[36mres22_8[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m2[39m, [32m3[39m, [32m4[39m)
[36mres22_9[39m: [32mList[39m[[32mBoolean[39m] = [33mList[39m(false, true, true)
defined [32mfunction[39m [36mfilterWithFold[39m
[36mres22_11[39m: [32mList[39m[[32mInt[39m] = [33mList[39m([32m2[39m, [32m3[39m)