In [3]:
from gbmi.utils import ein

In [5]:
import torch

A = torch.rand(4, 5)
B = torch.rand(5, 6)

In [7]:
from typing import List

graph = None


def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    global graph
    print("custom backend called with FX graph:")
    graph = gm
    return gm.forward


# Reset since we are using a different backend.
torch._dynamo.reset()

In [8]:
def model(x):
    return A @ x


opt_model = torch.compile(model, backend=custom_backend)

In [9]:
opt_model(B)

custom backend called with FX graph:


tensor([[1.3877, 1.6426, 1.3159, 1.8250, 1.5599, 1.2425],
        [0.8879, 0.8428, 0.5298, 0.8087, 0.6971, 0.9310],
        [0.5792, 0.7911, 0.6740, 0.8534, 0.6334, 0.5789],
        [1.3163, 1.2260, 1.3549, 1.6138, 1.0785, 1.0069]])

In [10]:
graph.graph.node

AttributeError: 'Graph' object has no attribute 'node'

In [11]:
ein.array(lambda i: torch.exp(i) + B[i])

ValueError: dimension dim is unbound

In [12]:
from torch import Tensor

In [13]:
import logging


def add_constraint(tensor, size):
    if hasattr(tensor, "_constraints"):
        tensor._constraints.add(size)
    else:
        tensor._constraints = {size}


class ConstraintTrackingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # NOTE: Logging calls Tensor.__repr__, so we can't log __repr__ without infinite recursion
        if func.__name__ == "__getitem__":
            for size, index in zip(args[0].shape, args[1]):
                add_constraint(index, size)
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

In [15]:
idx = ConstraintTrackingTensor(torch.tensor(0))
(lambda i: torch.exp(i) + B[i])(idx)
idx._constraints

{5}

In [194]:
idx._constraints

{5}

In [145]:
class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        print("---")
        print(cls, func)
        print("===")
        for i in args:
            print("+++")
            print(i)
        if kwargs is None:
            kwargs = {}
        args_flat = torch.utils._pytree.tree_flatten(args)[0]
        print("flat", args_flat)
        metadatas = tuple(a._metadata for a in args_flat if hasattr(a, "_metadata"))
        print("m", metadatas)
        args = torch.utils._pytree.tree_map(lambda x: getattr(x, "_t", x), args)
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0])

In [146]:
metadata = {"owner": "Ministry of Silly Walks"}
m = MetadataTensor(1, metadata=metadata)

In [147]:
(lambda i: torch.exp(i) + B[i])(MetadataTensor(torch.tensor(0), metadata="x"))

---
<class '__main__.MetadataTensor'> <built-in method exp of type object at 0x118d3a9b8>
===
+++
Metadata:
x

data:
0
flat [Metadata:
x

data:
0]
m ('x',)
---
<class '__main__.MetadataTensor'> <slot wrapper '__getitem__' of 'torch._C.TensorBase' objects>
===
+++
tensor([[4.3073e-01, 1.0403e-01, 2.3399e-01, 6.8633e-01, 5.2392e-01, 4.5291e-01],
        [1.5897e-01, 9.5502e-01, 4.0753e-01, 2.8613e-01, 4.4343e-01, 8.0154e-01],
        [3.4549e-01, 3.7140e-01, 1.0447e-01, 8.2847e-01, 3.9709e-04, 9.5640e-01],
        [7.1253e-01, 3.9568e-01, 2.2937e-01, 4.5929e-01, 5.2071e-01, 8.7025e-01],
        [2.3285e-01, 6.2969e-02, 8.1768e-01, 3.2309e-01, 6.9715e-03, 9.8854e-01]])
+++
(Metadata:
x

data:
0,)
flat [tensor([[4.3073e-01, 1.0403e-01, 2.3399e-01, 6.8633e-01, 5.2392e-01, 4.5291e-01],
        [1.5897e-01, 9.5502e-01, 4.0753e-01, 2.8613e-01, 4.4343e-01, 8.0154e-01],
        [3.4549e-01, 3.7140e-01, 1.0447e-01, 8.2847e-01, 3.9709e-04, 9.5640e-01],
        [7.1253e-01, 3.9568e-01, 2.2937e-01, 

TypeError: unsupported operand type(s) for +: 'MetadataTensor' and 'MetadataTensor'

In [122]:
fun.__name__

'__getitem__'

In [61]:
import torch


def f(i):
    x = torch.exp(i) + B[i]
    return x, i


#  To avoid dealing with prim::Bailout stuff
torch._C._jit_set_profiling_executor(False)

trace = torch.jit.trace(f, torch.tensor(1, dtype=torch.int))

In [62]:
trace.graph

graph(%i : Int(requires_grad=0, device=cpu)):
  %2 : int = aten::Int(%i)
  %1 : Float(requires_grad=0, device=cpu) = aten::exp(%i) # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %3 : Float(5, 6, strides=[6, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]() # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %4 : int = prim::Constant[value=0]() # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %5 : Float(6, strides=[1], requires_grad=0, device=cpu) = aten::select(%3, %4, %2) # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %6 : int = prim::Constant[value=1]() # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %7 : Float(6, strides=[1], requires_grad=0, device=cpu) = aten::add(%1, %5, %6) # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %8 : (Float(6, stri

In [44]:
import torch

conv = torch.nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3)

#  To avoid dealing with prim::Bailout stuff
torch._C._jit_set_profiling_executor(False)

inp = torch.rand(1, 3, 224, 224)
trace = torch.jit.trace(conv, inp)

RuntimeError: example_kwarg_inputs should be a dict

In [42]:
traced.graph

graph(%self : __torch__.torch.nn.modules.conv.Conv2d,
      %input : Float(1, 3, 224, 224, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu)):
  %bias : Tensor = prim::GetAttr[name="bias"](%self)
  %weight : Tensor = prim::GetAttr[name="weight"](%self)
  %6 : int = prim::Constant[value=1]() # /Users/eo/Library/Caches/pypoetry/virtualenvs/gbmi-UvbjekAV-py3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:456:0
  %7 : int = prim::Constant[value=1]() # /Users/eo/Library/Caches/pypoetry/virtualenvs/gbmi-UvbjekAV-py3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:456:0
  %8 : int[] = prim::ListConstruct(%6, %7)
  %9 : int = prim::Constant[value=0]() # /Users/eo/Library/Caches/pypoetry/virtualenvs/gbmi-UvbjekAV-py3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:456:0
  %10 : int = prim::Constant[value=0]() # /Users/eo/Library/Caches/pypoetry/virtualenvs/gbmi-UvbjekAV-py3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py:456:0
  %11 : int[] = pri