In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from syft.core.tensor.tensor import Tensor
from syft.core.adp.entity import Entity
import numpy as np
import torch as th
import copy

# data_batch = np.random.rand(3,3)
a = np.array([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
b = np.array(np.flip(a))
sy_a = Tensor(a).autograd(requires_grad=True)
sy_b = Tensor(b).autograd(requires_grad=True)
th_a = th.tensor(a)
th_b = th.tensor(b)

types = ["syft", "np", "torch"]
pairs = [(sy_a, sy_b), (a, b), (th_a, th_b)]

In [4]:
def torch_divmod(t, d):
    q = t.floor_divide(d)
    r = t.remainder(d)
    return (q, r)

In [5]:
def torch_rlshift(t, d):
    return d.__lshift__(t)

In [6]:
def torch_rrshift(t, d):
    return d.__rshift__(t)

In [7]:
def torch_rmatmul(t, d):
    return d.__rmatmul__(t)

In [8]:
# Monkey patch torch giving it lots of missing methods
th.Tensor.__divmod__ = torch_divmod
th.Tensor.__rlshift__ = torch_rlshift
th.Tensor.__rrshift__ = torch_rrshift
th.Tensor.__rmatmul__ = torch_rmatmul

In [9]:
# handle converting collections of th.Tensor to numpy
def th_as_np(t):
    if isinstance(t, (int, bool, float)):
        return t
    if isinstance(t, th.Tensor):
        return t.numpy()
    if isinstance(t, tuple):
        return tuple([x.numpy() for x in t])
    if isinstance(t, set):
        return set([x.numpy() for x in t])
    if isinstance(t, list):
        return list([x.numpy() for x in t])

    raise Exception(f"unknown type {type(t)} {t}")

In [10]:
# todo move pairs a, b into seperate inputs / args for each test
# todo flatten args are slightly different in pytorch
# repeat needs two inputs for torch
# todo what to do about resize and references?
# todo fix torch sort, compare against .values
# todo np.expand_dims for squeeze
# todo transpose needs both args in torch
test_params = {
    "__getitem__": 0, "argmax": 0, "argmin": 0,
    "argsort": -1, "clip":2, "cumprod": 0, "cumsum": 0, "diagonal":0,
    "min": None, "max": None,
    "repeat": 1, "reshape": -1, "resize": 1, "take": 0
}

In [11]:
methods_close = ["__truediv__", "__rtruediv__"]

In [12]:
def test_op(op, pairs, zero_args, alt_name, properties, test_params, methods_close):
    results = []
    for i, (a, b) in enumerate(pairs):
        a = copy.copy(a)
        b = copy.copy(b)
        if op in test_params:
            b = test_params[op]
        real_op = getattr(a, op, None)
        if real_op is None:
            if op in alt_name:
                real_op = getattr(a, alt_name[op], None)
            if real_op is None:
                raise Exception(f"Op doesnt exist on {types[i]}", op)
        if op not in zero_args:
            res = real_op(b)
        else:
            if op in properties:
                res = real_op
            else:
                res = real_op()
        results.append(res)
    
    # unpack the sy tensor to the bottom data level
    data_tensor = results[0] 
    while hasattr(data_tensor, "child"):
        data_tensor = data_tensor.child

    # sy is the same as numpy
    comp_method = np.array_equal
    if op in methods_close:
        comp_method = np.allclose # some results arent identical
    assert comp_method(data_tensor, results[1])
    # sy is the same as torch
    print()
    assert comp_method(data_tensor, th_as_np(results[2]))

In [13]:
desired_ops = [
    "__abs__",
    "__add__",
    "__divmod__",
    "__eq__",
    "__floordiv__",
    "__ge__",
    "__getitem__",
    "__gt__",
    "__index__",
    "__invert__",
    "__iter__",
    "__le__",
    "__len__",
    "__lshift__",
    "__lt__",
    "__matmul__",
    "__mul__",
    "__ne__",
    "__neg__",
    "__pow__",
    "__radd__",
    "__repr__",
    "__rfloordiv__",
    "__rlshift__",
    "__rmatmul__",
    "__rmul__",
    "__rpow__",
    "__rrshift__",
    "__rshift__",
    "__rsub__",
    "__rtruediv__",
    "__sizeof__",
    "__str__",
    "__sub__",
    "__truediv__",
    "argmax",
    "argmin",
    "argsort",
    "choose",
    "clip",
    "copy",
    "cumprod",
    "cumsum",
    "diagonal",
    "dot",
    "flat",
    "flatten",
    "item",
    "itemset",
    "itemsize",
    "max",
    "mean",
    "min",
    "ndim",
    "prod",
    "repeat",
    "reshape",
    "resize",
    "sort",
    "squeeze",
    "std",
    "sum",
    "swapaxes",
    "T",
    "take",
    "transpose",
]

In [14]:
properties = ["T", "ndim"]
zero_args = ["copy", "__neg__", "__abs__", "T", "__index__", "__invert__", "__len__",
             "flatten", "mean", "min", "max", "ndim", "prod",
             "sort", "std", "sum"]
alt_name = {"copy":"clone"}

In [17]:
# test_op("__getitem__", pairs, zero_args, alt_name, properties, test_params, methods_close)




In [18]:
working = []
for op in desired_ops:
    print(op)
    try:
        test_op(op, pairs, zero_args, alt_name, properties, test_params, methods_close)
        working.append(op)
    except Exception as e:
        print(f"{op} Failed.", e)

__abs__

__add__

__divmod__

__eq__

__floordiv__

__ge__

__getitem__

__gt__

__index__
__index__ Failed. only integer scalar arrays can be converted to a scalar index
__invert__
__invert__ Failed. ufunc 'invert' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
__iter__
__iter__ Failed. ('Op doesnt exist on syft', '__iter__')
__le__

__len__

__lshift__
__lshift__ Failed. ufunc 'left_shift' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
__lt__

__matmul__

__mul__

__ne__

__neg__

__pow__

__radd__

__repr__
__repr__ Failed. __repr__() takes 1 positional argument but 2 were given
__rfloordiv__

__rlshift__
__rlshift__ Failed. ufunc 'left_shift' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
__rmatmul__
__rmatm

In [19]:
working

['__abs__',
 '__add__',
 '__divmod__',
 '__eq__',
 '__floordiv__',
 '__ge__',
 '__getitem__',
 '__gt__',
 '__le__',
 '__len__',
 '__lt__',
 '__matmul__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__pow__',
 '__radd__',
 '__rfloordiv__',
 '__rmul__',
 '__rsub__',
 '__sub__',
 'argmax',
 'argmin',
 'argsort',
 'clip',
 'copy',
 'cumprod',
 'cumsum',
 'diagonal',
 'flatten',
 'max',
 'mean',
 'min',
 'ndim',
 'prod',
 'reshape',
 'T']

In [20]:
todo = set(desired_ops) - set(working)

In [19]:
todo

{'__index__',
 '__invert__',
 '__iter__',
 '__lshift__',
 '__repr__',
 '__rlshift__',
 '__rmatmul__',
 '__rpow__',
 '__rrshift__',
 '__rshift__',
 '__sizeof__',
 '__str__',
 'choose',
 'dot',
 'flat',
 'item',
 'itemset',
 'itemsize',
 'repeat',
 'resize',
 'sort',
 'squeeze',
 'std',
 'sum',
 'swapaxes',
 'take',
 'transpose'}