In [6]:
from gbmi.utils import ein

In [7]:
import torch

A = torch.rand(4, 5)
B = torch.rand(5, 6)

In [8]:
from typing import List

graph = None


def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    global graph
    print("custom backend called with FX graph:")
    graph = gm
    return gm.forward


# Reset since we are using a different backend.
torch._dynamo.reset()

In [9]:
def model(x):
    return A @ x


opt_model = torch.compile(model, backend=custom_backend)

In [10]:
opt_model(B)

custom backend called with FX graph:


tensor([[0.7709, 1.8965, 1.5195, 1.8628, 1.0387, 1.9042],
        [0.6951, 1.5685, 0.9689, 1.2673, 0.9519, 1.4195],
        [0.6140, 1.7981, 1.1005, 1.3455, 1.0335, 1.4959],
        [0.9975, 1.7626, 1.0959, 1.4170, 0.8062, 1.4095]])

In [11]:
graph.graph.node

AttributeError: 'Graph' object has no attribute 'node'

In [None]:
ein.array(lambda i: torch.exp(i) + B[i])

In [12]:
from torch import Tensor

In [13]:
class ConstraintTrackingTensor(torch.Tensor):
    @staticmethod
    def add_constraint(tensor, size):
        if hasattr(tensor, "_constraints"):
            tensor._constraints.add(size)
        else:
            tensor._constraints = {size}

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if func.__name__ == "__getitem__":
            for size, index in zip(args[0].shape, args[1]):
                ConstraintTrackingTensor.add_constraint(index, size)
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

In [14]:
idx = ConstraintTrackingTensor(torch.tensor(0))
(lambda j)(idx)
idx._constraints

SyntaxError: invalid syntax (3027892808.py, line 2)

In [15]:
idx._constraints

NameError: name 'idx' is not defined

In [16]:
class MetadataTensor(object):
    def __init__(self, data, metadata=None, **kwargs):
        self._t = torch.as_tensor(data, **kwargs)
        self._metadata = metadata

    def __repr__(self):
        return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        print("---")
        print(cls, func)
        print("===")
        for i in args:
            print("+++")
            print(i)
        if kwargs is None:
            kwargs = {}
        args_flat = torch.utils._pytree.tree_flatten(args)[0]
        print("flat", args_flat)
        metadatas = tuple(a._metadata for a in args_flat if hasattr(a, "_metadata"))
        print("m", metadatas)
        args = torch.utils._pytree.tree_map(lambda x: getattr(x, "_t", x), args)
        ret = func(*args, **kwargs)
        return MetadataTensor(ret, metadata=metadatas[0])

In [17]:
metadata = {"owner": "Ministry of Silly Walks"}
m = MetadataTensor(1, metadata=metadata)

In [147]:
(lambda i: torch.exp(i) + B[i])(MetadataTensor(torch.tensor(0), metadata="x"))

---
<class '__main__.MetadataTensor'> <built-in method exp of type object at 0x118d3a9b8>
===
+++
Metadata:
x

data:
0
flat [Metadata:
x

data:
0]
m ('x',)
---
<class '__main__.MetadataTensor'> <slot wrapper '__getitem__' of 'torch._C.TensorBase' objects>
===
+++
tensor([[4.3073e-01, 1.0403e-01, 2.3399e-01, 6.8633e-01, 5.2392e-01, 4.5291e-01],
        [1.5897e-01, 9.5502e-01, 4.0753e-01, 2.8613e-01, 4.4343e-01, 8.0154e-01],
        [3.4549e-01, 3.7140e-01, 1.0447e-01, 8.2847e-01, 3.9709e-04, 9.5640e-01],
        [7.1253e-01, 3.9568e-01, 2.2937e-01, 4.5929e-01, 5.2071e-01, 8.7025e-01],
        [2.3285e-01, 6.2969e-02, 8.1768e-01, 3.2309e-01, 6.9715e-03, 9.8854e-01]])
+++
(Metadata:
x

data:
0,)
flat [tensor([[4.3073e-01, 1.0403e-01, 2.3399e-01, 6.8633e-01, 5.2392e-01, 4.5291e-01],
        [1.5897e-01, 9.5502e-01, 4.0753e-01, 2.8613e-01, 4.4343e-01, 8.0154e-01],
        [3.4549e-01, 3.7140e-01, 1.0447e-01, 8.2847e-01, 3.9709e-04, 9.5640e-01],
        [7.1253e-01, 3.9568e-01, 2.2937e-01, 

TypeError: unsupported operand type(s) for +: 'MetadataTensor' and 'MetadataTensor'

In [122]:
fun.__name__

'__getitem__'

In [61]:
import torch


def f(i):
    x = torch.exp(i) + B[i]
    return x, i


#  To avoid dealing with prim::Bailout stuff
torch._C._jit_set_profiling_executor(False)

trace = torch.jit.trace(f, torch.tensor(1, dtype=torch.int))

In [62]:
trace.graph

graph(%i : Int(requires_grad=0, device=cpu)):
  %2 : int = aten::Int(%i)
  %1 : Float(requires_grad=0, device=cpu) = aten::exp(%i) # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %3 : Float(5, 6, strides=[6, 1], requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]() # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %4 : int = prim::Constant[value=0]() # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %5 : Float(6, strides=[1], requires_grad=0, device=cpu) = aten::select(%3, %4, %2) # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %6 : int = prim::Constant[value=1]() # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %7 : Float(6, strides=[1], requires_grad=0, device=cpu) = aten::add(%1, %5, %6) # /var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_21548/588144388.py:4:0
  %8 : (Float(6, stri

In [None]:
lambda f_plus: f_plus(5, 4)

In [60]:
import cloudpickle
import hashlib

In [69]:
f = lambda x, y: y

In [70]:
import functools

In [80]:
lambda_hash = lambda x: hashlib.md5(cloudpickle.dumps(x)).hexdigest()

In [151]:
lambda_hash(functools.partial(f, 2))

'a48b469abf481a30a9506aa7425741ea'

In [73]:
hashlib.md5(cloudpickle.dumps((lambda x: lambda y: 5)(4))).hexdigest()

'a4a99bc6fd85f16d5f49d1249a87d04f'

In [89]:
from typing import Generic, TypeVar

T = TypeVar("T")
from contextlib import contextmanager


class ContextualGlobal(Generic[T]):
    def __init__(self, val: T):
        self.vals = [val]

    @contextmanager
    def set(self, val: T):
        self.vals.append(val)
        yield
        self.vals.pop()

    def get(self) -> T:
        return self.vals[-1]

In [90]:
g = ContextualGlobal(1)

In [95]:
with g.set(2):
    with g.set(3):
        print(g.get())
    print(g.get())
print(g.get())

3
2
1


In [135]:
def function_contents(func):
    closure = (
        tuple(cell.cell_contents for cell in func.__closure__)
        if getattr(func, "__closure__", None)
        else ()
    )
    return (
        func.__name__,
        func.__defaults__,
        closure,
        func.__code__.co_code,
        func.__code__.co_consts,
    )

In [38]:
import dill

In [39]:
import dill
import pickle

In [112]:
import dill
import copyreg
import functorch.dim


class Pickler(dill.Pickler):
    def reducer_override(self, obj):
        if isinstance(obj, torch.Tensor):
            if "BatchedTensor" in repr(obj):
                return pickle.loads, (
                    pickle.dumps(torch._C._functorch.get_unwrapped(obj)),
                )
            return NotImplemented
        elif isinstance(obj, functorch.dim.Tensor):
            return pickle.loads, (
                pickle.dumps(
                    obj.order(*obj.dims),
                )
            )
        return NotImplemented


def dumps(obj):
    f = io.BytesIO()
    p = Pickler(f)
    p.dump(obj)
    return f.getvalue()

In [116]:
v = dumps(torch.rand(100000000))

tensor([0.8632, 0.3194, 0.7590,  ..., 0.3651, 0.4767, 0.4451])


In [90]:
from functorch.dim import dims

dim1, dim2 = dims(2)

In [100]:
t = torch.tensor([[[[1, 2]], [[3, 4]]]])[dim1, dim2]

In [96]:
type(t)

functorch.dim.Tensor

In [94]:
t.order(*t.dims)

tensor([[[[1, 2]],

         [[3, 4]]]])

In [67]:
torch._C._functorch.get_unwrapped(torch.tensor([1]))

RuntimeError: No wrappers present!

In [63]:
dumps(((lambda x: x + u), u))

b'\x80\x04\x95$\x01\x00\x00\x00\x00\x00\x00\x8c\ndill._dill\x94\x8c\x10_create_function\x94\x93\x94(h\x00\x8c\x0c_create_code\x94\x93\x94(C\x00\x94K\x01K\x00K\x00K\x01K\x02KCC\x08|\x00t\x00\x17\x00S\x00\x94N\x85\x94\x8c\x01u\x94\x85\x94\x8c\x01x\x94\x85\x94\x8cN/var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_42541/3750004189.py\x94\x8c\x08<lambda>\x94K\x01C\x02\x08\x00\x94))t\x94R\x94c__builtin__\n__main__\nh\rNNt\x94R\x94}\x94}\x94\x8c\x0f__annotations__\x94}\x94s\x86\x94b\x8c\x05torch\x94\x8c\x06Tensor\x94\x93\x94\x8a\x05 >\x04\x8b\x02\x85\x94R\x94\x86\x94.'

In [56]:
cloudpickle.Pickler

cloudpickle.cloudpickle.Pickler

In [57]:
cloudpickle.Pickler.dispatch_table[torch.Tensor] = lambda x: (torch.Tensor, (hash(x),))

In [58]:
x = torch.tensor([])

In [59]:
import io

with io.BytesIO() as file:
    cp = cloudpickle.Pickler(file, protocol=None, buffer_callback=None)
    cp.dump(lambda y: x + y)
    print(file.getvalue())

b'\x80\x05\x95\xf9\x01\x00\x00\x00\x00\x00\x00\x8c\x17cloudpickle.cloudpickle\x94\x8c\x0e_make_function\x94\x93\x94(h\x00\x8c\r_builtin_type\x94\x93\x94\x8c\x08CodeType\x94\x85\x94R\x94(K\x01K\x00K\x00K\x01K\x02KCC\x08t\x00|\x00\x17\x00S\x00\x94N\x85\x94\x8c\x01x\x94\x85\x94\x8c\x01y\x94\x85\x94\x8cN/var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_42541/2147222178.py\x94\x8c\x08<lambda>\x94K\x04C\x02\x08\x00\x94))t\x94R\x94}\x94(\x8c\x0b__package__\x94N\x8c\x08__name__\x94\x8c\x08__main__\x94uNNNt\x94R\x94h\x00\x8c\x12_function_setstate\x94\x93\x94h\x18}\x94}\x94(h\x15h\x0f\x8c\x0c__qualname__\x94h\x0f\x8c\x0f__annotations__\x94}\x94\x8c\x0e__kwdefaults__\x94N\x8c\x0c__defaults__\x94N\x8c\n__module__\x94h\x16\x8c\x07__doc__\x94N\x8c\x0b__closure__\x94N\x8c\x17_cloudpickle_submodules\x94]\x94\x8c\x0b__globals__\x94}\x94h\n\x8c\x05torch\x94\x8c\x06Tensor\x94\x93\x94\x8a\x05\x80)L\x8b\x02\x85\x94R\x94su\x86\x94\x86R0.'


In [241]:
cloudpickle._BUILTIN_TYPE_NAMES

AttributeError: module 'cloudpickle' has no attribute '_BUILTIN_TYPE_NAMES'

In [239]:
cp._BUILTIN_TYPE_NAMES

AttributeError: 'Pickler' object has no attribute '_BUILTIN_TYPE_NAMES'

In [19]:
import functorch

In [21]:
u = None


def func(x: torch.Tensor) -> torch.Tensor:
    global u
    # some function where we want to debug closely here
    y = 2 * x
    print(x, y)  # to show x & y are BatchedTensors
    print(x.__class__.__name__)
    u = x
    # try to save the tensors (the error occurs here)
    torch.save((x, y), "somefile.pt")
    return y


x = torch.randn((4, 5))
y = functorch.vmap(func)(x)

BatchedTensor(lvl=1, bdim=0, value=
    tensor([[ 0.5271,  0.5215,  0.9051,  0.2090, -1.5623],
            [ 0.3213, -0.0021,  0.1317, -0.9832, -0.6461],
            [-0.8264,  0.8084, -0.2842, -0.2385, -1.9357],
            [-0.1334,  0.9392,  0.5812, -1.5355, -1.1561]])
) BatchedTensor(lvl=1, bdim=0, value=
    tensor([[ 1.0543,  1.0431,  1.8101,  0.4179, -3.1245],
            [ 0.6425, -0.0042,  0.2634, -1.9664, -1.2922],
            [-1.6528,  1.6167, -0.5684, -0.4770, -3.8714],
            [-0.2669,  1.8783,  1.1624, -3.0711, -2.3121]])
)
Tensor
BatchedTensor(lvl=1, bdim=0, value=
    tensor([[ 0.5271,  0.5215,  0.9051,  0.2090, -1.5623],
            [ 0.3213, -0.0021,  0.1317, -0.9832, -0.6461],
            [-0.8264,  0.8084, -0.2842, -0.2385, -1.9357],
            [-0.1334,  0.9392,  0.5812, -1.5355, -1.1561]])
)


RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

In [22]:
u

BatchedTensor(lvl=1, bdim=0, value=
    tensor([[ 0.5271,  0.5215,  0.9051,  0.2090, -1.5623],
            [ 0.3213, -0.0021,  0.1317, -0.9832, -0.6461],
            [-0.8264,  0.8084, -0.2842, -0.2385, -1.9357],
            [-0.1334,  0.9392,  0.5812, -1.5355, -1.1561]])
)

In [23]:
u[1]

RuntimeError: Either your tensor may have escaped from inside a function being vmapped and this is a user error (see https://pytorch.org/functorch/stable/ux_limitations.html), or there is an internal functorch error in `gen_vmap_plumbing` Please file an issue if it looks like the latter

In [177]:
cp.dumps(lambda y: x + y)

AttributeError: 'Pickler' object has no attribute 'dumps'

In [150]:
cloudpickle.dumps(torch.tensor([1, 2, 3]))

b'\x80\x05\x95\x1f\x00\x00\x00\x00\x00\x00\x00\x8c\x05torch\x94\x8c\x06Tensor\x94\x93\x94\x8a\x05P\x18\x12l\x01\x85\x94R\x94.'

In [212]:
function_contents(lambda i: 1 + ein.array(lambda j: A[i, j]))

('<lambda>',
 None,
 (),
 b'd\x01t\x00\xa0\x01\x87\x00f\x01d\x02d\x03\x84\x08\xa1\x01\x17\x00S\x00',
 (None,
  1,
  <code object <lambda> at 0x16b294030, file "/var/folders/y1/33lbjdps12lf6x0csm_rvb5m0000gn/T/ipykernel_95829/1144719411.py", line 1>,
  '<lambda>.<locals>.<lambda>'))

In [139]:
ein.array(lambda i: 1 + ein.array(lambda j: A[i, j]))

TypeError: 'type' object is not subscriptable

In [140]:
f = lambda: ein.

ein.array(lambda i: f())

SyntaxError: invalid syntax (1606187086.py, line 1)

In [148]:
isinstance(functools.partial(lambda x: lambda y: x, 1), functools.partial)

True

In [146]:
functools.partial(lambda x: lambda y: x, 1)

functools.partial(<function <lambda> at 0x16c789990>, 1)

In [132]:
getattr(hash, "bleh", None)

In [44]:
import torch

conv = torch.nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3)

#  To avoid dealing with prim::Bailout stuff
torch._C._jit_set_profiling_executor(False)

inp = torch.rand(1, 3, 224, 224)
trace = torch.jit.trace(conv, inp)

RuntimeError: example_kwarg_inputs should be a dict