In [8]:
from syft.core.tensor.passthrough import PassthroughTensor
from typing import Union,Optional,Any,Tuple
import numpy as np
from syft.core.tensor.config import DEFAULT_INT_NUMPY_TYPE
from syft.core.tensor.config import DEFAULT_FLOAT_NUMPY_TYPE


In [20]:
class FixedPrecisionTensor(PassthroughTensor):
    def __init__(
        self,
        value: Union[int, float, np.ndarray] = None,
        base: int = 2,
        precision: int = 16,
    ) -> None:
        self._base = base
        self._precision = precision
        self._scale = base**precision
        if value is not None:
            super().__init__(self.encode(value))
        else:
            super().__init__(None)

    def encode(self, value: Union[int, float, np.ndarray]) -> np.ndarray:
        encoded_value = np.array(self._scale * value, DEFAULT_INT_NUMPY_TYPE)
        return encoded_value

    @property
    def dtype(self) -> np.dtype:
        return getattr(self.child, "dtype", None)

    @property
    def shape(self) -> Optional[Tuple[int, ...]]:
        return getattr(self.child, "shape", None)

    def decode(self) -> Any:
        value = self.child

        correction = (value < 0).astype(DEFAULT_INT_NUMPY_TYPE)

        dividend = np.trunc(value / self._scale - correction)
        remainder = value % self._scale
        remainder += (
            (remainder == 0).astype(DEFAULT_INT_NUMPY_TYPE) * self._scale * correction
        )
        value = (
            dividend.astype(DEFAULT_FLOAT_NUMPY_TYPE)
            + remainder.astype(DEFAULT_FLOAT_NUMPY_TYPE) / self._scale
        )
        return value

    def sanity_check(
        self, other: Union["FixedPrecisionTensor", int, float, np.ndarray]
    ) -> "FixedPrecisionTensor":
        if isinstance(other, "FixedPrecisionTensor"):
            if self.base != other.base or self.precision != other.precision:
                raise ValueError(
                    f"Base:{self.base,other.base} and Precision: "
                    + f"{self.precision, other.precision} should be same for "
                    + "computation on FixedPrecisionTensor"
                )
        elif isinstance(other, (int,float,np.ndarray)):
            other = FixedPrecisionTensor(
                value=other, base=self.base, precision=self.precision
            )
        else:
            raise ValueError(f"Invalid type for FixedPrecisionTensor: {type(other)}")

        return other

    def __add__(self, other: Any) -> "FixedPrecisionTensor":
        res = FixedPrecisionTensor(base=self._base, precision=self._precision)
        other = self.sanity_check(other)
        res.child = self.child + other.child
        return res

    def __sub__(self, other: Any) -> "FixedPrecisionTensor":
        res = FixedPrecisionTensor(base=self._base, precision=self._precision)
        other = self.sanity_check(other)
        res.child = self.child - other.child
        return res

    def __mul__(self, other: Any) -> "FixedPrecisionTensor":
        res = FixedPrecisionTensor(base=self._base, precision=self._precision)
        other = self.sanity_check(other)
        res.child = self.child * other.child
        res = res / self.scale
        return res
    
    def __truediv__(
        self, other: Union[int, np.integer, "FixedPrecisionTensor"]
    ) -> "FixedPrecisionTensor":
        if isinstance(other, FixedPrecisionTensor):
            raise ValueError("We do not support Private Division yet.")

        res = FixedPrecisionTensor(base=self._base, precision=self._precision)
        if isinstance(self.child, np.ndarray) or np.isscalar(self.child):
            res.child = np.trunc(self.child / other).astype(DEFAULT_INT_NUMPY_TYPE)
        else:
            res.child = self.child / other
        return res

        

In [97]:
float_val = 0.5654564
val = np.array([float_val])


In [98]:
a = FixedPrecisionTensor(val)

In [99]:
a

FixedPrecisionTensor(child=[37057])

In [100]:
a.decode()

array([0.56544495])