In [83]:
class Term:
    def __str__(self):
        pass

    def debug_str(self):
        pass

    def is_var(self):
        return False

    def is_lambda(self):
        return False

    def is_app(self):
        return False

    def is_lambdaDB(self):
        return False


class Var(Term):
    def __init__(self, var):
        self.var = var

    def __str__(self):
        return str(self.var)

    def is_var(self):
        return True

    def debug_str(self):
        return f"Var {self.var}"


class Lambda(Term):
    def __init__(self, var: str, t: Term):
        self.var = var
        self.t = t

    def __str__(self):
        return f"λ{self.var}. {self.t}"

    def debug_str(self):
        return f"Lambda {self.var}. ({self.t.debug_str()})"

    def is_lambda(self):
        return True


class LambdaDB(Term):
    def __init__(self, t: Term):
        self.t = t

    def __str__(self):
        return f"λ. {str(self.t)}"

    def debug_str(self):
        return f"Lambda . ({self.t.debug_str()})"

    def is_lambdaDB(self):
        return True

    def is_lambda(self):
        return True


class App(Term):
    def __init__(self, t1: Term, t2: Term):
        self.t1 = t1
        self.t2 = t2

    def __str__(self):
        # return f"({str(self.t1)}) ({str(self.t2)})"
        res = ""
        if self.t1.is_var():
            res += f"{self.t1}"
        else:
            res += f"({self.t1})"

        if self.t2.is_var():
            res += f" {self.t2}"
        else:
            res += f" ({self.t2})"

        return res

    def debug_str(self):
        return f"App ({self.t1.debug_str()}), ({self.t2.debug_str()})"

    def is_app(self):
        return True


In [85]:
def deBruijn(term: Term):
    free_vars = {}

    def convert(t: Term, bound_vars):
        if t.is_var():
            try:
                return Var(bound_vars.index(t.var))
            except:
                if t.var not in free_vars:
                    free_var_index = len(free_vars)
                    free_vars[t.var] = free_var_index
                return Var(free_vars[t.var] + len(bound_vars))
        elif t.is_app():
            return App(convert(t.t1, bound_vars), convert(t.t2, bound_vars))
        elif t.is_lambda():
            return LambdaDB(convert(t.t, [t.var] + bound_vars))

    return convert(term, []), free_vars


term = parse_term("lambda k. k x y z")
db_term, free_vars = deBruijn(term)

In [86]:
def is_free_var(var):
    return isinstance(var, str)


def eval_krivine(term: Term):
    stack = []
    env = []

    while True:
        if term.is_app():
            stack.append((term.t2, env))
            term = term.t1

        elif term.is_lambdaDB():
            if len(stack) == 0:
                return term, env

            stack_top = stack.pop()
            env = [stack_top] + env
            term = term.t

        elif term.is_var():
            if is_free_var(term.var):
                if len(stack) == 0:
                    return term, env
                else:
                    raise Exception("Not implemented")
            else:
                index = term.var
                term, env = env[index]


In [87]:
# [n -> s]t
def substitution(t,n,s,):
    if t.is_var():
        if t.var == n:
            return s
        else:
            return t
    elif t.is_app():
        return App(substitution(t.t1, n, s), substitution(t.t2, n, s))
    elif t.is_lambdaDB():
        return LambdaDB(substitution(t.t, n + 1, s))


def normalize_subst_cbn(t):
    if t.is_app():
        t1_norm = normalize_subst_cbn(t.t1)
        if t1_norm.is_lambdaDB():
            beta_reduced = substitution(t1_norm.t, 0, t.t2)
            return normalize_subst_cbn(beta_reduced)
        return App(t1_norm, t.t2)

    return t


def normalize_subst_no(t):
    t_normalized_cbn = normalize_subst_cbn(t)

    if t_normalized_cbn.is_lambdaDB():
        return LambdaDB(normalize_subst_no(t_normalized_cbn.t))
    elif t_normalized_cbn.is_app():
        return App(
            normalize_subst_no(t_normalized_cbn.t1),
            normalize_subst_no(t_normalized_cbn.t2),
        )

    return t_normalized_cbn

In [89]:
class Nat(Term):
    def __init__(self, n):
        self.n = int(n)

    def __str__(self):
        return str(self.n)

class Add(Term):
    def __init__(self, n, m):
        self.n = n
        self.m = m

    def __str__(self):
        return f"add ({str(self.n)}) ({str(self.m)})"

class Mul(Term):
    def __init__(self, n, m):
        self.n = n
        self.m = m

    def __str__(self):
        return f"mul ({str(self.n)}) ({str(self.m)})"

class Sub(Term):
    def __init__(self, n, m):
        self.n = n
        self.m = m

    def __str__(self):
        return f"sub ({str(self.n)}) ({str(self.m)})"

class Eq(Term):
    def __init__(self, n, m):
        self.n = n
        self.m = m

    def __str__(self):
        return f"eq ({str(self.n)}) ({str(self.m)})"

class Tru(Term):
    def __str__(self):
        return "true"

class Fls(Term):
    def __str__(self):
        return "false"

class If(Term):
    def __init__(self, cond, if_true, if_false):
        self.cond = cond
        self.if_true = if_true
        self.if_false = if_false

    def __str__(self):
        return f"if ({str(self.cond)}) then ({str(self.if_true)}) else ({str(self.if_false)})"
    

class Fix(Term):
    def __init__(self, f):
        self.f = f

    def __str__(self):
        return f"fix ({self.f})"

class Pair(Term):
    def __init__(self, first, second):
        self.first = first
        self.second = second

    def __str__(self):
        return f"pair ({self.first}) ({self.second})"

class Fst(Term):
    def __init__(self, pair):
        self.pair = pair

    def __str__(self):
        return f"fst ({self.pair})"

class Snd(Term):
    def __init__(self, pair):
        self.pair = pair

    def __str__(self):
        return f"snd ({self.pair})"

class Nil(Term):
    def __str__(self):
        return f"nil"

class Cons(Term):
    def __init__(self, head, tail):
        self.head = head
        self.tail = tail

    def __str__(self):
        return f"cons ({self.head}) ({self.tail})"

class Head(Term):
    def __init__(self, t):
        self.list = t

    def __str__(self):
        return f"head ({self.list})"

class Tail(Term):
    def __init__(self, t):
        self.list = t

    def __str__(self):
        return f"tail ({self.list})"

class IsNil(Term):
    def __init__(self, t):
        self.list = t

    def __str__(self):
        return f"isnil ({self.list})"


In [90]:
import re


def is_var(token):
    return re.match("^[a-z]+$", token)


def is_nat(token):
    return re.match("^\d+$", token)


def is_lambda(token):
    return re.match("^lambda [a-z]+[.]$", token)


def is_term(token):
    return isinstance(token, Term)


def lambda_var(token):
    return token[7:-1]

def tail(xs):
    if len(xs) == 0:
        return []
    return xs[1:]

def parse_brackets(tokens):
    stack = []
    for token in tokens:
        if token == ")":
            error = True
            for i in range(len(stack) - 1, -1, -1):
                if stack[i] == "(":
                    error = False
                    stack = stack[:i] + [stack[i + 1 :]]
                    break
            if error:
                raise Exception("Parse error")
        else:
            stack.append(token)

    if "(" in stack:
        raise Exception("Parse error")

    def remove_brackets(xs):
        while len(xs) == 1 and not isinstance(xs, str):
            xs = xs[0]

        if isinstance(xs, str):
            return xs
        return [remove_brackets(x) for x in xs]

    return remove_brackets(stack)


def parse_term(term_str: str):
    term_str = " ".join(term_str.split())

    regex = "lambda [a-z]+[.]|[a-z]+|\d+|[(]|[)]"

    tokens = re.findall(regex, term_str)

    tokens = parse_brackets(tokens)

    term, list_tail = parse([tokens])

    return term


def parse(tokens):
    try:
        hd = tokens[0]
    except:
        raise Exception("Parse error")
    if isinstance(hd, list):
        term, l = parse(hd)

        while l != []:
            t, l = parse(l)
            term = App(term, t)

        return term, tail(tokens)

    if is_lambda(hd):
        var = lambda_var(hd)
        term, l = parse(tail(tokens))

        while l != []:
            t, l = parse(l)
            term = App(term, t)

        return Lambda(var, term), l

    elif hd in ['add', 'mul', 'sub', 'eq', 'cons', 'pair']:
        t1, l1 = parse(tail(tokens))
        t2, l2 = parse(l1)

        if hd == 'add':
            return Add(t1, t2), l2
        elif hd == 'mul':
            return Mul(t1,t2), l2
        elif hd == 'sub':
            return Sub(t1,t2), l2
        elif hd == 'eq':
            return Eq(t1,t2), l2
        elif hd == 'cons':
            return Cons(t1,t2), l2
        elif hd == 'pair':
            return Pair(t1,t2), l2
    
    elif hd == "if":
        t1, l1 = parse(tail(tokens))
        t2, l2 = parse(l1)
        t3, l3 = parse(l2)

        return If(t1,t2,t3), l3

    elif hd == "true":
        return Tru(), tail(tokens)
    elif hd == "false":
        return Fls(), tail(tokens)
    elif hd == "nil":
        return Nil(), tail(tokens)

    elif hd in ["fst", "snd", "head", "tail", "isnil", "fix"]:
        t, l = parse(tail(tokens))

        if hd == "fst":
            return Fst(t), l
        elif hd == "snd":
            return Snd(t), l
        elif hd == "head":
            return Head(t), l
        elif hd == "tail":
            return Tail(t), l
        elif hd == "isnil":
            return IsNil(t), l
        elif hd == "fix":
            return Fix(t), l

    elif is_var(hd):
        return Var(hd), tail(tokens)
    elif is_nat(hd):
        return Nat(hd), tail(tokens)

term_str = "if 1 2 3 4"
t = parse_term(term_str)

# e ::= x | λx.e | e e |
#     n | add e e | mul e e | sub e e | eq e e | (liczby naturalne)
#     true | false | if e e e | (warto´sci logiczne)
#     fix e | (rekursja)
#     pair e e | fst e | snd e | (pary)
#     nil | cons e e | head e | tail e | isnil e (listy)


In [106]:
def nat_to_lambda(n):
    app = Var("z")

    for _ in range(n):
        app = App(Var("s"), app)

    return Lambda("s", (Lambda("z", app)))


def desugar(term: Term):
    if isinstance(term, Var):
        return term
    elif isinstance(term, Lambda):
        return Lambda(term.var, desugar(term.t))
    elif isinstance(term, App):
        return App(desugar(term.t1), desugar(term.t2))
    elif isinstance(term, Nat):
        return nat_to_lambda(term.n)
    elif isinstance(term, Add):
        add = parse_term("lambda m. lambda n. lambda s. lambda z. m s (n s z)")
        return App(App(add, desugar(term.n)), desugar(term.m))
    elif isinstance(term, Mul):
        plus = "(lambda m. lambda n. lambda s. lambda z. m s (n s z))"
        times = parse_term(f"lambda m. lambda n. m ({plus} n) (lambda s. lambda n. n)")
        return App(App (times, desugar(term.n)), desugar(term.m))
    elif isinstance(term, Sub):
        zz = Pair(Nat(0), Nat(0))
        ss = Lambda("p", Pair(Snd(Var("p")), Add(Nat(1), Snd(Var("p")))))
        prd = Lambda("m", Fst(App(App(Var("m"), ss), zz)))
        return desugar(Lambda("m", Lambda("n", App(App(Var("n"), prd), Var("m")))))
    elif isinstance(term, Eq):
        isZero = parse_term("lambda n. n (lambda x. false) true")
        leq = Lambda("m", Lambda("n", App(isZero, Sub(Var("m"), Var("n")))))
        lambda_and = parse_term("lambda p. lambda q. p q p")
        m_leq_n = App(App(leq, Var("m")), Var("n"))
        n_leq_m = App(App(leq, Var("n")), Var("m"))
        return Lambda("m", Lambda("n", App(App(lambda_and, m_leq_n), n_leq_m)))
    elif isinstance(term, Tru):
        return Lambda("t", (Lambda("f", Var("t"))))
    elif isinstance(term, Fls):
        return Lambda("t", (Lambda("f", Var("f"))))
    elif isinstance(term, If):
        return App(App(desugar(term.cond), desugar(term.if_true)), desugar(term.if_false))
    elif isinstance(term, Fix):
        fix = parse_term("lambda f. (lambda x. f (lambda y. x x y)) (lambda x. f (lambda y. x x y))")
        return App(fix, desugar(term.f))
    elif isinstance(term, Pair):
        pair =  parse_term("lambda f. lambda s. lambda b. b f s")
        return App(App(pair, desugar(term.first)), desugar(term.second))
    elif isinstance(term, Fst):
        fst = Lambda("p", App(Var("p"), Lambda("t", (Lambda("f", Var("t"))))))
        return App(fst, desugar(term.pair))
    elif isinstance(term, Snd):
        snd = Lambda("p", App(Var("p"), Lambda("t", (Lambda("f", Var("f"))))))
        return App(snd, desugar(term.pair))
    elif isinstance(term, Nil):
        return parse_term("lambda c. lambda n. n")
    elif isinstance(term, Cons):
        cons = parse_term("lambda h. lambda t. lambda c. lambda n. c h (t c n)")
        return App(App(cons, desugar(term.head)), desugar(term.tail))
    elif isinstance(term, Head):
        head = parse_term("lambda l. l (lambda h. lambda t. h) (lambda t. lambda f. f)")
        return desugar(App(head, term.list))
    elif isinstance(term, Tail):
        tru = "(lambda t. lambda f. t)"
        fls = "(lambda t. lambda f. f)"

        pair = "(lambda f. lambda s. lambda b. b f s)"
        fst = f"(lambda p. p {tru})"
        snd = f"(lambda p. p {fls})"

        tail = parse_term(f"lambda l. {fst} (l (lambda x. lambda p. pair ({snd} p) (cons x ({snd} p))) ({pair} nil nil))")

        return desugar(App(tail, term.list))

    elif isinstance(term, IsNil):
        tru = "(lambda t. lambda f. t)"
        fls = "(lambda t. lambda f. f)"

        isnil = parse_term(f"lambda l. l (lambda h. lambda t. {fls}) {tru}")
        return App(isnil, desugar(term.list))


t = parse_term("cons 5 nil")
desugar_t = desugar(t)
print(desugar_t)

db_t, frees = deBruijn(desugar_t)
print(db_t)
normal_t = normalize_subst_no(db_t)
print(normal_t)


((λh. λt. λc. λn. (c h) ((t c) n)) (λs. λz. s (s (s (s (s z)))))) (λc. λn. n)
((λ. λ. λ. λ. (1 3) ((2 1) 0)) (λ. λ. 1 (1 (1 (1 (1 0)))))) (λ. λ. 0)
λ. λ. (1 (λ. λ. 1 (1 (1 (1 (1 0)))))) 0
