In [26]:
import torch as th
from functools import reduce

prod = lambda xs: reduce(lambda x, y: x * y, xs)

def _egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = _egcd(b % a, a)
        return (g, x - (b // a) * y, y)
    
def _inverse(a, m):
    _, b, _ = _egcd(a, m)
    return b % m

moduli = [1999703, 1990007, 1996949, 1925899, 1816117]
modulus = prod(moduli)
moduli_inverses = [_inverse(modulus // mi, mi) for mi in moduli]

class CrtTensor(object):

    def __init__(self, values, residues=None):
        if values is not None:
            residues = [ values % mi for mi in moduli ]
        self.residues = residues

    @staticmethod
    def sample_uniform(shape):
        return CrtTensor(None, [
            th.randint(0, mi, shape).type(th.LongTensor)
            for mi in moduli
        ])

    def recombine(self, bound=2**31):
        return self._explicit_crt(bound)

    def __add__(self, other):
        return CrtTensor(None, [
            (xi + yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])
    
    def __sub__(self, other):
        return CrtTensor(None, [
            (xi - yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])

    def __mul__(self, other):
        return CrtTensor(None, [
            (xi * yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])

    def matmul(self, other):
        return CrtTensor(None, [
            th.matmul(xi, yi) % mi
            for xi, yi, mi in zip(self.residues, other.residues, moduli)
        ])

    def __mod__(self, k):
        return CrtTensor(self._explicit_crt(k))

    def _explicit_crt(self, bound):
        def sum(xs):
            return th.cat(xs).view(len(moduli), *xs[0].shape).sum(0)

        t = [
            th.fmod((xi * qi), mi)
            for xi, qi, mi in zip(self.residues, moduli_inverses, moduli)
        ]
        alpha = sum(tuple(
            ti.type(th.DoubleTensor) / float(mi)
            for ti, mi in zip(t, moduli)
        ))

        b = [(modulus // mi) % bound for mi in moduli]
        u = sum(tuple(
            ti * bi
            for ti, bi in zip(t, b)
        ))

        B = modulus % bound
        v = th.round(alpha).type(th.LongTensor) * B
        w = u.type(th.LongTensor) - v

        return th.fmod(w,bound)


x = CrtTensor(th.LongTensor([100000, 200000, 300000, 400000]).view(2,2))
y = CrtTensor(th.LongTensor([100000, 200000, 300000, 400000]).view(2,2))
z = x.matmul(y)
print(z.recombine(2**40))


 7.0000e+10  1.0000e+11
 1.5000e+11  2.2000e+11
[torch.LongTensor of size 2x2]



In [2]:
bound = 2**40

t = [
            (xi * qi) % mi
            for xi, qi, mi in zip(z.residues, moduli_inverses, moduli)
        ]

alpha = sum(tuple(
            ti.type(th.DoubleTensor) / float(mi)
            for ti, mi in zip(t, moduli)
        ))

b = [(modulus // mi) % bound for mi in moduli]

u = sum(tuple(
            ti * bi
            for ti, bi in zip(t, b)
        ))

B = modulus % bound

v = th.round(alpha).type(th.LongTensor) * B

w = u.type(th.LongTensor) - v

for each in w.view(-1):
    print(each)
print()
for each in (w % 2**40).view(-1):
    print(each)

2480347669169650688
3836811895822018560
3788346572781280256
3586661025875073024

70000000000
100000000000
1249511627776
220000000000


In [3]:
import torch as th

In [4]:
x = th.LongTensor([5])

In [5]:
y = th.LongTensor([3])

In [10]:
x = th.LongTensor([3788346572781280256])
y = th.LongTensor([2**40])

In [16]:
th.fmod(x,y)


 1.5000e+11
[torch.LongTensor of size 1]

In [12]:
2**40

1099511627776

In [46]:
w


 2.4803e+18  3.8368e+18
 3.7883e+18  3.5867e+18
[torch.LongTensor of size 2x2]

In [22]:
alpha


 2.0000  3.0000
 3.0000  3.0000
[torch.DoubleTensor of size 2x2]

In [20]:
def tolist(t):

    vs = list()
    for v in th.cat(t).view(-1):
        vs.append(v)
    return vs

In [21]:
tolist(alpha)

[2.0, 2.9999999999999996, 3.0, 3.0]