In [14]:
import torch as th
from functools import reduce

prod = lambda xs: reduce(lambda x, y: x * y, xs)

def _egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = _egcd(b % a, a)
        return (g, x - (b // a) * y, y)
    
def _inverse(a, m):
    _, b, _ = _egcd(a, m)
    return b % m

moduli = [1999703, 1990007, 1996949, 1925899, 1816117]
modulus = prod(moduli)
moduli_inverses = [_inverse(modulus // mi, mi) for mi in moduli]

class CrtTensor(object):

    def __init__(self, values, residues=None):
        if values is not None:
            residues = [ values % mi for mi in moduli ]
        self.residues = residues

    @staticmethod
    def sample_uniform(shape):
        return CrtTensor(None, [
            th.randint(0, mi, shape).type(th.LongTensor)
            for mi in moduli
        ])

    def recombine(self, bound=2**31):
        return self._explicit_crt(bound)

    def __add__(self, other):
        return CrtTensor(None, [
            (xi + yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])
    
    def __sub__(self, other):
        return CrtTensor(None, [
            (xi - yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])

    def __mul__(self, other):
        return CrtTensor(None, [
            (xi * yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])

    def matmul(self, other):
        return CrtTensor(None, [
            th.matmul(xi, yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])

    def __mod__(self, k):
        return CrtTensor(self._explicit_crt(k))

    def _explicit_crt(self, bound):
        def sum(xs):
            return th.cat(xs).view(len(moduli), *xs[0].shape).sum(0)

        t = [
            (xi * qi) % mi
            for xi, qi, mi in zip(self.residues, moduli_inverses, moduli)
        ]
        alpha = sum(tuple(
            ti.type(th.DoubleTensor) / float(mi)
            for ti, mi in zip(t, moduli)
        ))

        b = [(modulus // mi) % bound for mi in moduli]
        u = sum(tuple(
            ti * bi
            for ti, bi in zip(t, b)
        ))

        B = modulus % bound
        v = th.round(alpha).type(th.LongTensor) * B
        w = u.type(th.LongTensor) - v

        return w % bound


x = CrtTensor(th.LongTensor([100000, 200000, 300000, 400000]).view(2,2))
y = CrtTensor(th.LongTensor([100000, 200000, 300000, 400000]).view(2,2))
z = x.matmul(y)
print(z.recombine(2**62))

tensor([[ 70000000000, 100000000000],
        [150000000000, 220000000000]])


In [51]:
bound = 2**40

t = [
            (xi * qi) % mi
            for xi, qi, mi in zip(z.residues, moduli_inverses, moduli)
        ]

alpha = sum(tuple(
            ti.type(th.DoubleTensor) / float(mi)
            for ti, mi in zip(t, moduli)
        ))

b = [(modulus // mi) % bound for mi in moduli]

u = sum(tuple(
            ti * bi
            for ti, bi in zip(t, b)
        ))

B = modulus % bound

v = th.round(alpha).type(th.LongTensor) * B

w = u.type(th.LongTensor) - v

for each in w.view(-1):
    print(each)
print()
for each in (w % 2**40).view(-1):
    print(each)

tensor(2480347669169650688)
tensor(3836811895822018560)
tensor(3788346572781280256)
tensor(3586661025875073024)

tensor(70000000000)
tensor(100000000000)
tensor(150000000000)
tensor(220000000000)


In [49]:
w

tensor([[2480347669169650688, 3836811895822018560],
        [3788346572781280256, 3586661025875073024]])

In [38]:
alpha

tensor([[2.0000, 3.0000],
        [3.0000, 3.0000]], dtype=torch.float64)

In [36]:
tolist(alpha)

TypeError: cat(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

In [31]:
th.FloatTensor(tolist(alpha)) == \
th.FloatTensor([680188,
 1257369,
 886202,
 1566390,
 1602043,
 1151486,
 1727229,
 1339265,
 733872,
 1333667,
 1002026,
 1735898,
 355455,
 1883435,
 1862203,
 291759,
 549828,
 266578,
 399867,
 949695])

NameError: name 'alpha' is not defined

In [28]:
x2 = th.FloatTensor([680188,
 1257369,
 886202,
 1566390,
 1602043,
 1151486,
 1727229,
 1339265,
 733872,
 1333667,
 1002026,
 1735898,
 355455,
 1883435,
 1862203,
 291759,
 549828,
 266578,
 399867,
 949695])

In [25]:
def tolist(t):

    vs = list()
    for v in th.cat(t).view(-1):
        vs.append(int(v))
    return vs

In [3]:
a = th.LongTensor([100000, 200000, 300000, 400000]).view(2, 2)
b = th.LongTensor([100000, 200000, 300000, 400000]).view(2, 2)

In [4]:
a.mm(b)

tensor([[ 70000000000, 100000000000],
        [150000000000, 220000000000]])

In [None]:
print(z.recombine(2 ** 40))

x = CrtTensor.sample_uniform((2, 2))
print(x.residues)