# Custom Tensor Example

In this example, we see how to create a custom tensor type called PlusIsMinusTensor, which chnages the addition operation and makes it actually do subtraction. Note that any function not directly addressed in the PlusIsMinusTensor class simply falls back to the default behavior in the Torch type

In [1]:
import syft as sy
from syft.core.frameworks.torch import utils

hook = sy.TorchHook(verbose=True)

me = hook.local_worker
me.is_client_worker = False

bob = sy.VirtualWorker(id="bob", hook=hook, is_client_worker=False)
alice = sy.VirtualWorker(id="alice", hook=hook, is_client_worker=False)
james = sy.VirtualWorker(id="james", hook=hook, is_client_worker=False)

me.add_workers([bob, alice, james])
bob.add_workers([me, alice, james])
alice.add_workers([me, bob, james])
james.add_workers([me, bob, alice])

In [2]:
x = sy.FloatTensor([5, 6])
y = sy.FloatTensor([3, 4])

In [3]:
utils.chain_print(x)

FloatTensor > _LocalTensor


In [4]:
class _PlusIsMinusTensor(sy._SyftTensor):
    """
    Example of a custom overloaded _SyftTensor

    Role:
    Converts all add operations into sub/minus ones.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    # The table of command you want to replace
    substitution_table = {
        'torch.add': 'torch.add'
    }

    class overload_functions:
        """
        Put here the functions you want to overload
        Beware of recursion errors.
        """
        @staticmethod
        def add(x, y):
            return x.add(y)

        @staticmethod
        def get(attr):
            attr = attr.split('.')[-1]
            return getattr(sy._PlusIsMinusTensor.overload_functions, attr)

    # Put here all the methods you want to overload
    def add(self, arg):
        """
        Overload the add method and execute another function or method with the provided args
        """
        _response = self.sub(arg)

        return _response

    def abs(self):
        """
        Overload the abs() method and execute another function
        """
        return torch.abs(self)

In [5]:
x = _PlusIsMinusTensor().on(x)
y = _PlusIsMinusTensor().on(y)

In [6]:
utils.chain_print(x)

FloatTensor > _PlusIsMinusTensor > _LocalTensor


In [7]:
x

[Head of chain]

 5
 6
[syft.core.frameworks.torch.tensor.FloatTensor of size 2]