In [24]:
# Define basic syntax terms.

# Expressions, made up of atoms, variables, applications, and abstractions.
class Atom():
    def __init__(self, str):
        self.str = str

    def __str__(self):
        return "A(" + str(self.str) + ")"
    __repr__ = __str__

    def __eq__(self, other):
        return type(self) == type(other) and self.str == other.str

    def __lt__(self, other):
        return type(self) == type(other) and self.str < other.str
    
    def __hash__(self):
        return hash((type(self), self.str))

class Var():
    def __init__(self, str):
        self.str = str

    def __str__(self):
        return "V(" + str(self.str) + ")"
    __repr__ = __str__

    def __eq__(self, other):
        return type(self) == type(other) and self.str == other.str

    def __lt__(self, other):
        return type(self) == type(other) and self.str < other.str
    
    def __hash__(self):
        return hash((type(self), self.str))
    

class App():
    def __init__(self, expr1, expr2):
        self.expr1 = expr1
        self.expr2 = expr2

    def __str__(self):
        return "App(" + str(self.expr1) + ", " + str(self.expr2) + ")"
    __repr__ = __str__

    def __eq__(self, other):
        return type(self) == type(other) and self.expr1 == other.expr1 and self.expr2 == other.expr2
    
    def __hash__(self):
        return hash((type(self), self.expr1, self.expr2))

class Abs():
    def __init__(self, atom, expr):
        self.atom = atom
        self.expr = expr

    def __str__(self):
        return "Abs(" + str(self.atom) + ", " + str(self.expr) + ")"
    __repr__ = __str__

    def __eq__(self, other):
        return type(self) == type(other) and self.expr == other.expr and self.expr == other.expr
    
    def __hash__(self):
        return hash((type(self), self.expr1, self.expr2))

def expElim(expr, atomElim, varElim, appElim, absElim):
    if type(expr) == Atom:
        return atomElim(expr)
    elif type(expr) == Var:
        return varElim(exp)
    elif type(expr) == App:
        return appElim(expr.expr1, expr.expr1)
    elif type(expr) == Abs:
        return absElim(expr.atom, expr.expr)
    else:
        raise Exception(str(expr) + 'is not an expression')

        
# Equations
class Eq():
    def __init__(self, expr1, expr2):
        self.expr1 = expr1
        self.expr2 = expr2

    def __str__(self):
        return "(" + str(self.expr1) + " = " + str(self.expr2) + ")"
    __repr__ = __str__

# Bound and free terms
class Bound():
    def __init__(self, atom, index):
        self.atom = atom
        self.index = index

class Free():
    def __init__(self, atom):
        self.atom = atom

def BoundnessElim(bness, boundElim, freeElim):
    if type(bness) == Bound:
        return boundElim(expr.atom, expr.index)
    elif type(bness) == Free:
        return freeElim(expr.atom)
    else:
        raise Exception(str(bness) + ' is not a boundness')

# Maps of binders.
class BinderMap():
    def __init__(self, a2i, i2a):
        self.a2i = a2i # Dict Atom Int
        self.i2a = i2a # Dict Int Atom

    def __str__(self):
        return "BM(" + str(self.a2i) + ", " + str(self.i2a) + ")"
    __repr__ = __str__
    
def extend(atom, binderMap):
    """ Add atom to map of binders.
    """
    binderMap.a2i[atom] = len(binderMap.a2i)
    binderMap.i2a[len(binderMap.a2i)] = atom
    
    return binderMap

def lookupAtom(atom, binderMap):
    index = binderMap.a2i.get(atom, -1)
    
    if index == -1:
        return Free(atom)
    
    return Bound(atom, index)

# def lookupIdx :: Int -> BinderMap -> Maybe Atom
def lookupIdx(index, binderMap):
    # j = len(binderMap.a2i) - index
    a = binderMap.i2a.get(index, -1)
    indexp = binderMap.a2i.get(a, -2)
    if index == indexp:
        return a
    
    raise NoMatchingBinderError(str(index) + "\n" + str(binderMap))

def emptyBinderMap():
    return BinderMap(dict([]), dict([]))

In [25]:
extend(Atom('cool'), emptyBinderMap())

BM({A(cool): 0}, {1: A(cool)})

In [26]:
# Constraints
class Closure():
    def __init__(self, av, binderMap):
        self.av = av
        self.binderMap = binderMap
    
    def __str__(self):
        return "Clo(" + str(self.av) + ", " + str(self.binderMap) + ")"
    __repr__ = __str__

def sameClo(clo1, clo2):
    l1 = lookupAtom(clo1.av, clo1.binderMap)
    l2 = lookupAtom(clo2.av, clo2.binderMap)
    
    if type(l1) == Free and type(l2) == Free:
        return clo1.av == clo2.av
    elif type(l1) == Bound and type(l2) == Bound:
        return l1.index == l2.index
    else:
        return False

class NuEquation():
    def __init__(self, clo1, clo2):
        self.clo1 = clo1 # Clo Atom
        self.clo2 = clo2
        self.var = type(clo2.av) == Var
        # If self.var is true, then self.clo2 will be a closure over
        # a Var, otherwise it's a closure over an Atom.
    
    def __str__(self):
        if self.var:
            return "AV(" + str(self.clo1) + ", " + str(self.clo2) + ")"
        else:
            return "AA(" + str(self.clo1) + ", " + str(self.clo2) + ")"
    __repr__ = __str__

# A NuProblem is a list of NuEquations

class DeltaEquation():
    def __init__(self, clo1, clo2, var):
        self.clo1 = clo1 # Clo Var
        self.clo2 = clo2 # Clo Var
    
    def __str__(self):
        return "VV(" + str(self.clo1) + ", " + str(self.clo2) + ")"
    __repr__ = __str__

# A DeltaProblem is a list of DeltaEquations

class MultiEquation():
    def __init__(self, clo1, clo2, var):
        self.clo1 = clo1 # Clo Expr
        self.clo2 = clo2 # Clo Expr
    
    def __str__(self):
        return "EE(" + str(self.clo1) + ", " + str(self.clo2) + ")"
    __repr__ = __str__

# A RhoProblem is a list of MultiEquations

class Substitution():
    def __init__(self, v2e):
        self.v2e = v2e
    
    def __str__(self):
        return "Subst(" + str(self.v2e) + ")"
    __repr__ = __str__

def idSubst():
    return Substitution(dict([]))

def dom(sub):
    """ Get the domain of a substitution.
    """
    return sub.v2e.keys

def subst(expr, sub):
    if type(expr) == Var:
        m = sub.v2e
        e = m.get(expr, -1)
        if e == -1:
            return expr
        return subst(e, sub)
    elif type(expr) == Abs:
        expr.expr = subst(expr.expr, sub)
        return expr
    return expr

def extendSubst(var, expr, sub):
    sub.v2e[var] = expr
    return sub

# Unification related exceptions.
class AAMismatchError(Exception): # (Clo Atom) (Clo Atom)
   """Raised when ..."""
   pass

class NameCaptureError(Exception): # Atom BinderMap
   """Raised when ..."""
   pass

class NoMatchingBinderError(Exception): # Int BinderMap
   """Raised when ..."""
   pass

class EEMismatchError(Exception): # (Clo Expr) (Clo Expr)
   """Raised when ..."""
   pass


In [29]:
from pymonad import *
from functools import reduce

In [31]:
# foldr : (a -> b -> b) -> b -> [a] -> b
def foldr(f, e, l):
    if len(l) == 0:
        return e
    return f(l[0], foldr(f, e, l[1:]))

In [32]:
foldr(lambda x, y: x+y, 0, [1,2,3])

6

In [40]:
# modify : (s -> s) -> State s ()
def modify(f):
    return State(lambda x: ((), f(x)))

# get : State s s
def get():
    return State(lambda x: (x, x))

# put : a -> State a () 
def put(a):
    return State(lambda x: (a, ()))

# mapMU : (Foldable t, Monad m) => (a -> m b) -> t a -> m ()
def mapMU(f, l):
    return foldr(lambda x, y: f(x).bind(lambda _: y), State.unit(()), l)

# foldM : (Foldable t, Monad m) => (b -> a -> m b) -> b -> t a -> m b
def foldM(f, z0, xs):
    def fp(x, k):
        return lambda z: f(z, x).bind(k)
    
    return foldr(fp, State.unit, xs)(z0)

In [49]:
l = [3,4]
def f(num):
    return State(lambda s: ((), s ** num))
# Should be 4096
mapMU(f,l).getState(2)

4096

In [47]:
def testf(b, a):
    return State(lambda s: (b ** a, s))
# Should be 4096
foldM(testf, 2, [3,4]).getResult(0)

4096

In [103]:
# Nu Machines : State Substitution a
def runNuMachine(sub, nuM):
    return nuM.getState(sub)

# bind : Var -> Exp -> NuMachine ()
def bind(var, expr):
    return modify(lambda sub: extendSubst(var, expr, sub))

# step : NuEquation -> NuMachine ()
def step(nuEq):
    clo1 = nuEq.clo1
    clo2 = nuEq.clo2
    if nuEq.var:
        res1 = lookupAtom(clo1.av, clo1.binderMap)
        if type(res1) == Free:
            res2 = lookupAtom(clo1.av, clo2.binderMap)
            if type(res2) == Free:
                 return bind(clo2.av, res2)
            else:
                raise NameCaptureError(str(clo1.av) + "\n" + str(clo1.binderMap))
        elif type(res1) == Bound:
            res2 = lookupIdx(clo1.index, clo2.binderMap)
            return bind(clo2.av, res2)
    else:
        if sameClo(clo1, clo2):
            State.unit(())
        else:
            raise AAMismatchError(str(clo1) + "\n" + str(clo2))

# eval : NuProblem -> NuMachine ()
def evalNu(nu):
    return mapMU(step, nu)


In [116]:
# Delta Machines : a (it's simply a wrapper)

# occursInEq : Var -> DeltaEquation -> Bool
def occursInEq(var, deq):
    return (var == deq.clo1.av) or (var == deq.clo2.av)

# partition : (a -> Bool) -> [a] -> ([a], [a])
def partition(f, x):
    return (filter(f, x), filter(lambda z: not f(z), x))

# evalDelta : Substitution -> DeltaProblem -> [Var] -> DeltaMachine (Substitution, DeltaProblem)
def evalDelta(s, p, xs):
    if len(xs) == 0:
        return (s, p)
    elif len(p) == 0:
        return (s, p)
    else:
        w, wo = partition(lambda z: occursInEq(xs[0], z), p)
        sp, xsp = pull(s, xs[1:], w)
        return evalDelta(sp, wo, xsp)

# pull : Substitution -> [Var] -> DeltaProblem -> DeltaMachine (Substitution, [Var])
def pull(s, xs, p):
    if len(p) == 0:
        return (s, xs)
    else:
        x1 = p[0].clo1.av
        bm1 = p[0].clo1.binderMap
        x2 = p[0].clo2.av
        bm2 = p[0].clo2.binderMap
        pp = p[1:]
        
        x1p = subst(x1, s)
        x2p = subst(x2, s)
        
        if type(x1p) == Atom:
            if type(x2p) == Atom:
                    if not sameClo(p[0].clo1, p[0].clo2):
                        raise AAMismatchError(str(p[0].clo1) + "\n" + str(p[0].clo2))

                    pull(s, xs, pp)
            elif type(x2p) == Var:
                sp = findSubstClo(s, x2p, bm2, Closure(a1, bm1))
                pull(sp, xs.copy().append(x2p), pp)
        elif type(x1p) == Var:
            sp = findSubstClo(s, x1p, bm1, Closure(a2, bm2))
            pull(sp, xs.copy().append(x1p), pp)

# findSubstClo : Substitution -> Var -> BinderMap -> Clo Atom -> DeltaMachine Substitution
def findSubstClo(s, x, bmx, clo):
    a = clo.av
    bma = clo.binderMap
    
    res1 = lookupAtom(a, bma)
    if type(res1) == Free:
        res2 = lookupAtom(a, bmx)
        if type(res2) == Free:
            return extendSubst(x, a, s)
        else:
            raise NameCapture(str(a) + "\n" + str(bmx))
    elif type(res1) == Bound:
        i = res1.index
        res2 = lookupIdx(i, bmx)
        return extendSubst(x, res2, s)

In [80]:
# Rho Machines : State Int a

# runRhoMachine : RhoMachine a -> Except UnificationError a
def runRhoMachine(m):
    return m.getState(0)

# freshVar : RhoMachine Var
def freshVar():
    return State(lambda n: (Var("$X" + str(n)), n + 1))

# freshAtom : RhoMachine Atom
def freshAtom():
    return State(lambda n: (Atom("$a" + str(n)), n + 1))

# evalRho : NuProblem -> DeltaProblem -> Substitution -> RhoProblem
# -> RhoMachine (NuProblem, DeltaProblem, Substitution)
def evalRho(np,dp,s,rp):
    return foldM(rhoStep, (np, dp, s), rp)

# rhoStep : (NuProblem, DeltaProblem, Substitution) -> MultiEquation
# -> RhoMachine (NuProblem, DeltaProblem, Substitution)
def rhoStep(p, m):
    np, dp, s = p
    cl = m.clo1
    cr = m.clo2
    el = cl.av
    bml = cl.binderMap
    er = cr.av
    bmr = cr.binderMap
    
    res = expElim(el
    ,   lambda al: # atomElim
        expElim(er
        , lambda ar: # atomElim
                State.unit(np.copy().append(NuEquation(Closure(al, bml), Closure(ar, bmr))), dp, s)
        , lambda vr: # varElim
                State.unit(np.copy().append(NuEquation(Closure(al, bml), Closure(vr, bmr))), dp, s)
        , lambda _, __: False
                # raise EEMismatch(str(cl) + "\n" + str(cr))
        , lambda _, __: False
                # raise EEMismatch(str(cl) + "\n" + str(cr))
        )
    ,   lambda vl: # varElim
        expElim(er
        , lambda _: # atomElim
                rhoStep((np, dp, s), MultiEquation(cr, cl))
        , lambda vr: # varElim
                State.unit(np, dp.copy().append(DeltaEquation(Closure(vl, bml), Closure(vr, bmr))), s)
        , lambda r1, r2: # appElim
                freshVar().bind(
                    lambda v1: freshVar().bind(
                    lambda v2:
                        (sp := extendSubst(vl, App(Var(v1), Var(v2)), s),
                         rhoStep((np, dp, sp), MultiEquation(Closure(Var(v1), bml), Closure(r1, bmr))).bind(
                            lambda p: (
                                (npp := p[0],
                                 dpp := p[1],
                                 spp := p[2],
                                 step((npp, dpp, spp), MultiEquation(Closure(Var(v2), bml), Closure(r2, bmr)))
                                )[-1]
                            )
                         )
                        )[-1]
                ))
                  
        , lambda ar, br: # absElim
                freshAtom().bind(
                    lambda al:
                    freshVar().bind(
                    lambda vb:
                        ( sp := extendSubst(vl, Abs(al, Var(vb)), s),
                          bmlp := extend(al, bml),
                          bmrp := extend(ar, bmr),
                          rhoStep((np, dp, sp), MultiEquation(Closure(Var(vb), bmlp), Closure(br, bmrp)))
                        )[-1]
                ))
        )
    ,   lambda l1, l2: # appElim
        expElim(er
        , lambda _: False
                # raise EEMismatch(str(cl) + "\n" + str(cr))
        , lambda _: # varElim
                rhoStep((np, dp, s), MultiEquation(cr, cl))
        , lambda r1, r2: # appElim
                rhoStep((np, dp, s), MultiEquation(Closure(l1, bml), Closure(r1, bmr))).bind(
                    lambda p: (
                        npp := p[0],
                        dpp := p[1],
                        sp := p[2],
                        rhoStep((npp, dpp, sp), MultiEquation(Closure(l2, bml), Closure(r2, bmr)))
                    )[-1]
                )
        , lambda _, __: False
                    # raise EEMismatch(str(cl) + "\n" + str(cr))
        )
    ,   lambda al, bl: # absElim
        expElim(er
        , lambda _: False
                # raise EEMismatch(str(cl) + "\n" + str(cr))
        , lambda _: # varElim
                rhoStep((np, dp, s), MultiEquation(cr, cl))
        , lambda _, __: False
                # raise EEMismatch(str(cl) + "\n" + str(cr))
        , lambda ar, br: # absElim
                ( bmlp := extend(al, bml),
                  bmrp := extend(ar, bmr),
                  rhoStep((np, dp, s), MultiEquation(Closure(bl, bmlp), Closure(br, bmrp)))
                )[-1]
        )
    )
    
    if res == False:
        raise EEMismatch(str(cl) + "\n" + str(cr))
    return res


In [82]:
# Unify

# unify : Expr -> Expr -> Either UnificationError (Substitution, DeltaProblem)
def unify(l, r):
  return runRhoMachine(
    evalRho([], [], idSubst, [MultiEquation(Closure(l, emptyBinderMap), Closure(r, emptyBinderMap))])).bind(
    lambda p:
      ( np := p[0],
        dp := p[1],
        s1 := p[2],
        runNuMachine(s1, evalNu(np)).bind(
        lambda s2:
            runDeltaMachine(evalDelta(s2, dp, dom(s2)))
      ))
  )

In [30]:
reduce(ex, [1,2,3])

1

In [22]:
def ex(x,y): return x**y

In [19]:
f((1,2))

1

In [6]:
!python --version

Python 3.8.1
