In [1]:
import tensorwrap as tf
from tensorflow.experimental import numpy as np

The goal is to transform a function, rather than inputs/outputs by using primitives. It will allow us to escape numerical limits and actually develop a proper system for integrating XLA.

In [2]:
# Defining the primitives:
from typing import NamedTuple

class Primitive(NamedTuple):
    """A template for generating primitive datatypes."""
    name: str

In [3]:
# Basic Primitives
add_p = Primitive("add")
mul_p = Primitive("mul")
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")

In [4]:
# Wrapper Functions:
def bind1(prim, *args, **params):
    out, = bind(prim, *args, **params)
    return out

def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
    axis = None or tuple(range(np.ndim((x))))
    if type(axis) is int:
        axis = (axis,)
    return bind1(reduce_sum_p, x, axis=axis)

So we have creator some wrapper functions that just call bind with the primitive and keyword parameters to intercept the function flow.

Now we create active interpreters as a stack of interpreters, or just a list.

In [None]:
# Creating Interpreter stack:
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Optional, Any

class TraceInterpreter(NamedTuple):
    level: int
    trace_type: type['Trace']
    global_data: Optional[Any]

trace_stack: list[TraceInterpreter] = []
dynamic_trace: Optional[TraceInterpreter] = None # Will be used later

@contextmanager
def new_main(trace_type: type['Trace'], global_data = None):
    level = len(trace_stack)
    main = TraceInterpreter(level, trace_type, global_data)
    trace_stack.append(main)
    
    try:
        yield main
    finally:
        trace_stack.pop()

With new_main, we can push interpreter onto the stack.
After all the interpreters are stacked and optimized, we can just call and eval Interpreter.

In [None]:
class Trace:
    main: TraceInterpreter
    
    def __init__(self, main: TraceInterpreter) -> None:
        self.main = main
    
    # Marked for override:
    def pure(self, val): assert False
    def lift(self, val): assert False
    
    def process_primitive(self, primitive, tracers, params):
        assert False 