# Dependent Types Implementation

In [4]:
from dataclasses import dataclass

### Defining classes

@dataclass
class Var():
  name: str

@dataclass
class Star():
  pass

@dataclass
class Pi():
  name: str
  t1: object
  t2: object

@dataclass
class Lambda():
  name: str
  t1: object
  e2: object

@dataclass
class App():
  e1: object
  e2: object

@dataclass
class Natural():
  pass

@dataclass
class Zero():
  pass

@dataclass
class Succ():
  e: object

@dataclass
class ElimNat():
  e1: object
  e2: object
  e3: object
  e4: object


### definition of Expr
Expr = Var | Star | Pi | Lambda | App | Natural | Zero | Succ | ElimNat


def alpha_equivalence(expr1: Expr, expr2: Expr) -> bool:
  """checks if two expressions are equivalent under alpha-conversion

    Args:
        expr1: first expression
        expr2: second expression

    Returns:
        true if equivalent, false otherwise
  """
  if isinstance(expr1, Var) and isinstance(expr2, Var):
    return expr1.name == expr2.name
  elif isinstance(expr1, Star) and isinstance(expr2, Star):
    return True
  elif isinstance(expr1, Pi) and isinstance(expr2, Pi):
    return expr1.name == expr2.name and alpha_equivalence(expr1.t1, expr2.t1) and alpha_equivalence(expr1.t2, expr2.t2)
  elif isinstance(expr1, Lambda) and isinstance(expr2, Lambda):
    return alpha_equivalence(expr1.t1, expr2.t1) and alpha_equivalence(rename(expr1.name, expr2.name, expr1.e2), expr2.e2)
  elif isinstance(expr1, App) and isinstance(expr2, App):
    return alpha_equivalence(expr1.e1, expr2.e1) and alpha_equivalence(expr1.e2, expr2.e2)
  elif isinstance(expr1, Natural) and isinstance(expr2, Natural):
    return True
  elif isinstance(expr1, Zero) and isinstance(expr2, Zero):
    return True
  elif isinstance(expr1, Succ) and isinstance(expr2, Succ):
    return alpha_equivalence(expr1.e, expr2.e)
  elif isinstance(expr1, ElimNat) and isinstance(expr2, ElimNat):
    return alpha_equivalence(expr1.e1, expr2.e1) and alpha_equivalence(expr1.e2, expr2.e2) and alpha_equivalence(expr1.e3, expr2.e3) and alpha_equivalence(expr1.e4, expr2.e4)
  else:
    return False


def rename(old_name: str, new_name: str, expr: Expr) -> Expr:
  """renames all occurrences of old_name in an expression with new_name
     i think this is a little bit broken? i'm not sure i handle the edge cases correctly
     future work would to be fix this w de bruijn indicies

    Args:
        old_name: name to replace
        new_name: name to replace with
        expr: expression to rename

    Returns:
        renamed expression
  """
  match expr:
    case Var(name):
      if name == old_name:
        return Var(new_name)
      else:
        return Var(name)
    case Star():
      return Star()
    case Pi(name, t1, t2):
      if name == old_name:
        return Pi(new_name, t1, t2)
      else:
        return Pi(name, t1, t2)
    case Lambda(name, t1, e2):
      if name == old_name:
        return Lambda(new_name, t1, rename(old_name, new_name, e2))
      else:
        return Lambda(name, t1, rename(old_name, new_name, e2))
    case App(e1, e2):
      return App(rename(old_name, new_name, e1), rename(old_name, new_name, e2))
    case Natural():
      return Natural()
    case Zero():
      return Zero()
    case Succ(e):
      return Succ(rename(old_name, new_name, e))
    case ElimNat(e1, e2, e3, e4):
      return ElimNat(
          rename(old_name, new_name, e1),
          rename(old_name, new_name, e2),
          rename(old_name, new_name, e3),
          rename(old_name, new_name, e4)
      )

  return expr


def free_variables(expr: Expr) -> set:
  """returns the set of free variables in an expression

    Args:
        expr: expression to check

    Returns:
        set of free variables
  """
  match expr:
    case Var(name):
      return set(name)
    case Pi(name, t1, t2):
      return free_variables(t1).union(free_variables(t2) - set(name))
    case Lambda(name, t1, e2):
      return free_variables(t1).union(free_variables(e2) - set(name))
    case App(e1, e2):
      return free_variables(e1).union(free_variables(e2))
    case Succ(e):
      return free_variables(e)
    case ElimNat(e1, e2, e3, e4):
      return free_variables(e1).union(free_variables(e2)).union(free_variables(e3)).union(free_variables(e4))

  return set()


def subsitution(expr1: Expr, expr2: Expr, x: str) -> Expr:
  """substitutes all occurrences of x in expr1 with expr2

    Args:
        expr1: expression to substitute in
        expr2: expression to substitute with
        x: variable to substitute

    Returns:
        substituted expression
  """
  match expr1:
    case Var(name):
      if name == x:
        return expr2
    case Pi(name, t1, t2):
      if name not in free_variables(expr2).union(set(x)):
        return Pi(name, subsitution(t1, expr2, x), subsitution(t2, expr2, x))
    case Lambda(name, t1, e3):
      if name not in free_variables(expr2).union(set(x)):
        return Lambda(name, subsitution(t1, expr2, x), subsitution(e3, expr2, x))
    case App(e3, e4):
      return App(subsitution(e3, expr2, x), subsitution(e4, expr2, x))
    case Succ(e3):
      return Succ(subsitution(e3, expr2, x))
    case ElimNat(e3, e4, e5, e6):
      return ElimNat(
          subsitution(e3, expr2, x),
          subsitution(e4, expr2, x),
          subsitution(e5, expr2, x),
          subsitution(e6, expr2, x)
      )

  return expr1


def eval_step(expr: Expr) -> Expr:
  """evaluates one step of the algorithm
     (it steps through multiple congruence rules at once when possible)

    Args:
        expr: expression to evaluate

    Returns:
        evaluated expression
  """
  match expr:
    case App(e1, e2):
      # eval-app
      if isinstance(e1, Lambda):
        return subsitution(e1.e2, e2, e1.name)
      else:
        # eval-rator and eval-rand
        e3 = eval_step(e1)
        e4 = eval_step(e2)
        return App(e3, e4)
    case ElimNat(e1, e2, e3, e4):
      # eval-elimnat-0
      if isinstance(e4, Zero):
        return e2
      # eval-elimnat-succ
      elif isinstance(e4, Succ):
        return App(App(e3, e4.e), ElimNat(e1, e2, e3, e4.e))
      else:
        # eval-elimnat-something-something
        e5 = eval_step(e1)
        e6 = eval_step(e2)
        e7 = eval_step(e3)
        e8 = eval_step(e4)
        return ElimNat(e5, e6, e7, e8)
    case Pi(name, t1, t2):
      # eval-pi-domain and eval-pi-codomain
      t3 = eval_step(t1)
      t4 = eval_step(t2)
      return Pi(name, t3, t4)
    case Lambda(name, t1, e2):
      # eval-lambda-body
      # t2 = eval_step(t1)
      e3 = eval_step(e2)
      return Lambda(name, t1, e3)
    case Succ(e1):
      # eval-succ
      e2 = eval_step(e1)
      return Succ(e2)

  return expr


def eval_to_completion(expr: Expr) -> Expr:
  '''evaluates until completion

    Args:
        expr: expression to evaluate

    Returns:
        evaluated expression
  '''
  expr1 = eval_step(expr)
  if expr1 != expr:
    return eval_to_completion(expr1)
  else:
    return expr1


def get_type(expr: Expr, env: dict) -> Expr:
  '''returns the type of an expression

    Args:
        expr: expression to check
        env: environment

    Returns:
        type of expression if no errors, raises TypeError otherwise
  '''
  match expr:
    case Var(name):
      # type-var-ref
      if name not in env:
        raise TypeError(f'Error: variable {name} not defined')

      return env[name]

    case Star():
      # type-star
      return Star()

    case Pi(name, t1, t2):
      # type-pi
      type_1 = get_type(t1, env)
      if type_1 != Star():
        raise TypeError('Error: t1 in Pi is not a type (Star())')

      new_env = env.copy()
      new_env[name] = type_1

      type_2 = get_type(t2, new_env)
      if type_2 != Star():
        raise TypeError('Error: t2 in Pi is not a type (Star())')
      return Star()

    case Lambda(name, t1, e2):
      # type-lambda
      type_1 = eval_to_completion(get_type(t1, env))
      if type_1 != Star():
        raise TypeError('Error: t1 in Lambda is not a type (Star())')

      new_env = env.copy()
      new_env[name] = t1
      e2_type = get_type(e2, new_env)
      type_2 = get_type(e2_type, new_env)
      if type_2 != Star():
        raise TypeError('Error: e2 in Lambda is not typed correctly')

      return Pi(name, t1, e2_type)

    case App(e1, e2):
      # type-app
      e1_type = get_type(e1, env)
      if not isinstance(e1_type, Pi):
        raise TypeError('Error: e1 in App is not of type Pi()')

      e2_type = get_type(e2, env)
      if e2_type != e1_type.t1:
        raise TypeError('Error: e2 in App is not the type of the input of e1')

      return subsitution(e1_type.t2, e2, e1_type.name)

    case Natural():
      # type-N
      return Star()

    case Zero():
      # type-0
      return Natural()

    case Succ(e):
      e_type = get_type(e, env)
      if e_type != Natural():
        raise TypeError('Error: Succ() should take an argument of Natural() type as an input')

      return Natural()

    case ElimNat(e1, e2, e3, e4):
      type_1 = get_type(e1, env)
      if not isinstance(type_1, Pi) or type_1.t1 != Natural() or type_1.t2 != Star():
        raise TypeError('Error: e1 in ElimNat() has incorrect type')

      type_2 = get_type(e2, env)
      if not isinstance(type_2, App) or type_2.e1 != e1 or type_2.e2 != Zero() :
        raise TypeError('Error: e2 in ElimNat() has incorrect type')

      type_3 = get_type(e3, env)
      e3_expected_type = Pi('x', Natural(), Pi('y', App(e1, Var('x')), App(e1, Succ(Var('x')))))
      if not alpha_equivalence(type_3, e3_expected_type):
        raise TypeError('Error: e3 in ElimNat() has incorrect type')

      type_4 = get_type(e4, env)
      if type_4 != Natural():
        raise TypeError('Error: e4 in ElimNat() should be an argument of Natural() type')

      return App(e1, e4)


def natural_to_int(expr: Expr) -> int:
  '''converts a natural to an integer

    Args:
        expr: natural to convert

    Returns:
        integer representation of natural
  '''
  match expr:
    case Zero():
      return 0
    case Succ(e):
      return 1 + natural_to_int(e)


def int_to_natural(i: int) -> Expr:
  '''converts an integer to a natural

    Args:
        i: integer to convert

    Returns:
        natural representation of integer
  '''
  if i == 0:
    return Zero()
  return Succ(int_to_natural(i - 1))

# Alpha equivalence examples

In [5]:
# inspired by lecture on de bruijn indices
expr_1 = Lambda('x', Natural(), App(Lambda('y', Natural(), App(Var('x'), Var('y'))), App(Var('x'), Var('y'))))
expr_2 = Lambda('x', Natural(), App(Lambda('z', Natural(), App(Var('x'), Var('z'))), App(Var('x'), Var('z'))))
expr_3 = Lambda('x', Natural(), App(Lambda('z', Natural(), App(Var('x'), Var('z'))), App(Var('x'), Var('y'))))

print(alpha_equivalence(expr_1, expr_2))
print(alpha_equivalence(expr_1, expr_3))

False
True


# Typechecking examples

In [6]:
print(get_type(Var('x'), {}))

TypeError: Error: variable x not defined

In [7]:
print(get_type(Var('x'), {'x': Natural()}))

Natural()


In [8]:
print(get_type(Var('x'), {'x': Star()}))

Star()


In [9]:
print(get_type(Succ(Succ(Zero())), {}))

Natural()


In [10]:
print(get_type(Lambda('x', Natural(), Succ(Var('x'))), {'x': Natural()}))

Pi(name='x', t1=Natural(), t2=Natural())


In [11]:
print(get_type(Lambda('x', Star(), Succ(Var('x'))), {'x': Star()}))

TypeError: Error: Succ() should take an argument of Natural() type as an input

# Mathematical operator exampels

In [12]:
add = Lambda('x', Natural(),
             Lambda('y', Natural(),
                    ElimNat(Lambda('_', Natural(), Natural()),
                            Var('y'),
                            Lambda('_', Natural(),
                                   Lambda('rec', Natural(), Succ(Var('rec')))),
                            Var('x'))))

In [13]:
num_4, num_5, num_6 = int_to_natural(4), int_to_natural(5), int_to_natural(6)

### commutativity of addition
# 4 + 5
comm_1 = natural_to_int(eval_to_completion(App(App(add, num_4), num_5)))
# 5 + 4
comm_2 = natural_to_int(eval_to_completion(App(App(add, num_5), num_4)))

assert comm_1 == 9
assert comm_2 == 9
assert comm_1 == comm_2


### associativity of addition
# (4 + 5) + 6
assoc_1 = natural_to_int(eval_to_completion(App(App(add, num_4), App(App(add, num_5), num_6))))
# 4 + (5 + 6)
assoc_2 = natural_to_int(eval_to_completion(App(App(add, App(App(add, num_4), num_5)), num_6)))

assert assoc_1 == 15
assert assoc_2 == 15
assert assoc_1 == assoc_2