# Monkey Patching Tensor

I want tf tensors to know whether they have a special first "batch" dimension, and I want every tensor to have a flag that tells me if it as hit. I will write special functions to set it to true/false, and I want every operation to preserve this flag correctly.

In [1]:
class Tensor:
    
    def __init__(self, v):
        self.value = v
        
    def __add__(self, other):
        print('adding')
        out = Tensor(0)
        out.value = self.value + other.value
        return out
    
    def __gt__(self, other):
        return self.value > other.value    
    
    def speak(self):
        return "hi!"
    
    
    
    def __repr__(self):
        return str(self.value)

In [2]:
# now just adding .is_batched doesn't solve it
Tensor.is_batched = False
# the desired behavior is that any operation consisting of any batched tensors
# will naivly return a batched tensor.
# The only way to not input a batched tensor and get a non-batched tensor out is to call
# remove_batch
excluded_symbols = ['__class__', '__init__', '__new__', '__repr__', '__str__', '__getattribute__', '__setattr__']
for symbol_name in dir(Tensor):
    original_symbol = getattr(Tensor, symbol_name)
    if hasattr(original_symbol, '__call__') and symbol_name not in excluded_symbols:
        def _intercept(n, s):
            print(type(n))
            def intercept(*args, **kwargs):
                print(n, args)
                result = s(*args, **kwargs)
                any_args_are_batched = False
                for a in args:
                    if getattr(a, 'is_batched', False):
                        any_args_are_batched = True
                for a in kwargs.values():
                    if getattr(a, 'is_batched', False):
                        any_args_are_batched = True
                print(f'any batched args = {any_args_are_batched}')
                if any_args_are_batched:
                    try:
                        result.is_batched = True
                    except AttributeError:
                        # results must be a python native type, which you can't add attribtutes to
                        # so instead we just do nothing
                        pass
                return result
            return intercept
        print(f"will intercept calls to {symbol_name}")
        setattr(Tensor, symbol_name, (lambda n, s: _intercept(n, s))(symbol_name, original_symbol))

def add_batch(t: Tensor):
    t.is_batched = True
    
def remove_batch(t: Tensor):
    t.is_batched = False
    
    

In [3]:
dir(Tensor)

In [4]:
a = Tensor(2)
b = Tensor(3)

In [5]:
dir(Tensor)

In [6]:
print(a, a.is_batched)
print(b, b.is_batched)

add_batch(a)

print(a, a.is_batched)
print(b, b.is_batched)

2 False
3 False
2 True
3 False


In [7]:
a + b

__add__ (2, 3)
adding
any batched args = True


5

In [8]:
d=a+b

__add__ (2, 3)
adding
any batched args = True


In [9]:
d.is_batched

True

In [10]:
d.speak()

speak (5,)
any batched args = True


'hi!'

In [12]:
a > b

__gt__ (2, 3)
any batched args = True


False