This notebook is exploring the use the Chinese remainder theorem (CRT) for operation on multi-precision integers.

Note that throughout we don't perform the modulus reduction by default since it'll often be interesting to do so only in a lazy fashion. In fact, below we'll pick our moduli such that we can do a dot product between two vectors of dimension ~1000 and with ~120 bit numbers, while only performing a reduction at the end.

TODOs (besides those mentioned in the code):
- anything to gain from Barrett reductions?

In [1]:
import random
from functools import reduce
import numpy as np
from datetime import datetime
from math import log, ceil

log2 = lambda x: log(x)/log(2)
prod = lambda xs: reduce(lambda x, y: x * y, xs)

# Number theory

All we really need here is the ability to finding (ring) inverses.

In [2]:
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)
    
def gcd(a, b):
    g, _, _ = egcd(a, b)
    return g

def inverse(a, m):
    _, b, _ = egcd(a, m)
    return b % m

# Ring

Define the ring in which we're working. For the numbers in the CRT to fit into 64 bit signed words we hence need each modulus to be ~26 bits.

In [3]:
ms = [89702869, 78489023, 69973811, 70736797, 79637461]
for mi in ms: assert 2 * log2(mi) + log2(1024) < 63, mi

M = prod(ms)
assert log2(M) >= 120

We also fix a truncation amount in anticipation of fixedpoint arithmetic.

In [17]:
K = 2**16

# Scalars

Introducing these mostly because the ideas are perhaps easier to follow here. Not sure about neither utility nor speed-ups.

## Natural number representation

Natural (built-in) multi-precision integers mod `N` for performance comparisons.

In [5]:
class NaturalScalar:
    """ Uses the typical built-in representation of numbers """
    
    def __init__(self, value):
        self.value = value
        
    def __repr__(self):
        return 'NaturalScalar({})'.format(self.unwrap())
    
    def unwrap(self):
        return self.value
    
    def __add__(x, y):
        return NaturalScalar(x.value + y.value)
    
    def __sub__(x, y):
        return NaturalScalar(x.value - y.value)
    
    def __mul__(x, y):
        return NaturalScalar(x.value * y.value)
    
    def reduce(x):
        return NaturalScalar(x.value % M)
    
    def mod(x):
        return NaturalScalar(x.value % K)
    
    @staticmethod
    def sample():
        return NaturalScalar(random.randrange(M))


a = 2**120
b = 2**110
x = NaturalScalar(a); print(x)
y = NaturalScalar(b); print(y)
z = x + y; z = z.reduce(); assert z.unwrap() == (a+b) % M, z
z = x - y; z = z.reduce(); assert z.unwrap() == (a-b) % M, z
z = x * y; z = z.reduce(); assert z.unwrap() == (a*b) % M, z
z = x.mod(); assert z.unwrap() == a % K, z

NaturalScalar(1329227995784915872903807060280344576)
NaturalScalar(1298074214633706907132624082305024)


## CRT number representation

Alternative representation where numbers are split into several parts that each fit into a 64 bit signed word.

In [6]:
def gen_crt():
    
    # make sure all values in ms are coprime
    for i, mi in enumerate(ms):
        for j, mj in enumerate(ms[i+1:]):
            assert gcd(mi, mj) == 1, '{} and {} are not coprime'.format(mi, mj)
    
    def decompose(x):
        return [ x % mi for mi in ms ]
    
    # precomputation for recombine
    Mis = ( M // mi for mi in ms )
    ls = [ Mi * inverse(Mi, mi) % M for Mi, mi in zip(Mis, ms) ]
    
    def recombine(xs):
        return sum( xi * li for xi, li in zip(xs, ls) ) % M
    
    return decompose, recombine

decompose, recombine = gen_crt()

assert recombine(decompose(123456789)) == 123456789

In [7]:
def gen_mod():
    
    # precomputation for mod
    qs = [ inverse(M//mi, mi) for mi in ms ]
    B = M % K
    bs = [ (M//mi) % K for mi in ms ]

    def mod(xs):
        ts = [ (xi * qi) % mi for xi, qi, mi in zip(xs, qs, ms) ]
        alpha = round(sum( float(ti) / float(mi) for ti, mi in zip(ts, ms) ))
        v = int( sum( ti * bi for ti, bi in zip(ts, bs) ) - B * alpha )
        
        assert abs(v) < K * sum(ms) # TODO express in bit length
        
        return decompose(v % K) # TODO inline decompose?
    
    return mod

mod = gen_mod()

assert mod(decompose(123456789)) == decompose(123456789 % K)

In [8]:
class CrtScalar:
    """ Uses the CRT representation of numbers """
    
    def __init__(self, value, parts=None):
        if value is not None:
            parts = decompose(value)
        self.parts = parts
        
    def __repr__(self):
        return 'CrtScalar({})'.format(self.parts)
    
    def unwrap(self):
        return recombine(self.parts)
    
    def __add__(x, y):
        # component-wise operation that can be done in parallel
        return CrtScalar(None, [ 
            (xi + yi) for xi, yi in zip(x.parts, y.parts)
        ])
    
    def __sub__(x, y):
        # component-wise operation that can be done in parallel
        return CrtScalar(None, [ 
            (xi - yi) for xi, yi in zip(x.parts, y.parts)
        ])

    def __mul__(x, y):
        # component-wise operation that can be done in parallel
        return CrtScalar(None, [ 
            (xi * yi) for xi, yi in zip(x.parts, y.parts)
        ])
    
    def reduce(x):
        return CrtScalar(None, [
            xi % mi for xi, mi in zip(x.parts, ms)
        ])
    
    def mod(x):
        return CrtScalar(None, mod(x.parts))
    
    @staticmethod
    def sample():
        return CrtScalar(None, [
            random.randrange(mi) for mi in ms
        ])


a = 2**120
b = 2**110
x = CrtScalar(a); print(x)
y = CrtScalar(b); print(y)
z = x + y; z = z.reduce(); assert z.unwrap() == (a+b) % M, z
z = x - y; z = z.reduce(); assert z.unwrap() == (a-b) % M, z
z = x * y; z = z.reduce(); assert z.unwrap() == (a*b) % M, z
z = x.mod(); assert z.unwrap() == a % K, z

CrtScalar([32730343, 65319507, 11926796, 58726713, 67363725])
CrtScalar([49526222, 70198023, 21605128, 12422474, 62826948])


## Performance tests

Simple benchmarks between `NaturalScalar` and `CrtScalar`. Note that the time for `CrtScalar` is on a single core, meaning it'll ideally be reduced by a factor 5 on a multi-core device. Scalars are not the main focus though, and not sure how relevant this is.

In [9]:
for scalar_type in [NaturalScalar, CrtScalar]:
    
    x = scalar_type(2**120)
    
    start = datetime.now()
    for _ in range(100000):
        y = x
        z = y * y
        #z.reduce() # roughly doubles the execution time for both (bit less for Typical, bit more for Crt)
    end = datetime.now()
    print('{:15}: {}'.format(scalar_type.__name__, end - start))

NaturalScalar  : 0:00:00.086124
CrtScalar      : 0:00:00.173483


# Private scalar

Again mostly because of the simpler scalar setting, we here do secret sharing on top of the two different scalar representations. Note that only multiplication with a constant is given here (to avoid bringing in triples).

In [10]:
def gen_private_scalar(scalar_type):
    
    # precomputation for truncation
    K_inv = scalar_type(inverse(K, M))
    M_wrapped = scalar_type(M)
    def raw_truncate(x):
        y = x - x.mod()
        return y * K_inv
    
    class AbstractPrivateScalar:

        def __init__(self, value, share0=None, share1=None):
            if value is not None:
                value = scalar_type(value)
                share0 = scalar_type.sample()
                share1 = value - share0
            self.share0 = share0
            self.share1 = share1

        def __repr__(self):
            return 'PrivateScalar({})'.format(self.unwrap())
        
        def unwrap(self):
            return self.reconstruct().unwrap()
        
        def reconstruct(self):
            return (self.share0 + self.share1).reduce()
        
        def __add__(x, y):
            # component-wise operation that can be done in parallel
            return AbstractPrivateScalar(None,
                share0 = x.share0 + y.share0,
                share1 = x.share1 + y.share1
            )
        
        def __sub__(x, y):
            # component-wise operation that can be done in parallel
            return AbstractPrivateScalar(None,
                share0 = x.share0 - y.share0,
                share1 = x.share1 - y.share1
            )
        
        def __mul__(x, k):
            # component-wise operation that can be done in parallel
            return AbstractPrivateScalar(None,
                share0 = x.share0 * k,
                share1 = x.share1 * k
            )
        
        def reduce(x):
            return AbstractPrivateScalar(None,
                share0 = x.share0.reduce(),
                share1 = x.share1.reduce()
            )
        
        def truncate(x):
            return AbstractPrivateScalar(None,
                share0 = raw_truncate(x.share0),
                share1 = M_wrapped - raw_truncate((M_wrapped - x.share1).reduce())
            )
            

    return AbstractPrivateScalar

for scalar_type in [NaturalScalar, CrtScalar]:
    
    PublicScalar  = scalar_type
    PrivateScalar = gen_private_scalar(scalar_type)

    a = 2**120
    b = 2**110
    x = PrivateScalar(a)
    y = PrivateScalar(b)
    k = PublicScalar(b)
    
    z = x + y; z = z.reduce(); assert z.unwrap() == (a+b) % M, z
    z = x - y; z = z.reduce(); assert z.unwrap() == (a-b) % M, z
    z = x * k; z = z.reduce(); assert z.unwrap() == (a*b) % M, z
    z = x.truncate(); assert z.unwrap() in [a // K, a // K + 1], z

## Performance tests

A few simple tests. Again the cost for `CrtScalar` would be cut in five on a multi-core device.

In [11]:
for scalar_type in [NaturalScalar, CrtScalar]:
    
    PublicScalar  = scalar_type
    PrivateScalar = gen_private_scalar(scalar_type)
    
    x = PrivateScalar(2**120)
    k = PublicScalar(2**110)
    
    start = datetime.now()
    for _ in range(10000):
        z = x * k
        #z.reduce() # roughly doubles the execution time for both (bit less for Typical, bit more for Crt)
    end = datetime.now()
    print('{:15}: {}'.format(scalar_type.__name__, end - start))

NaturalScalar  : 0:00:00.029406
CrtScalar      : 0:00:00.041221


# Tensors

The ideas above carried over to the tensor setting using NumPy.

## Natural tensors

Tensors backed by a NumPy array of numbers of type `object`.

In [12]:
class NaturalTensor:
    """ Uses the typical built-in representation of numbers """
    
    def __init__(self, values):
        self.values = values
        
    def __repr__(self):
        return 'NaturalTensor({})'.format(self.unwrap())
    
    @property
    def shape(self):
        return self.values.shape
    
    def unwrap(self):
        return self.values
    
    def __add__(x, y):
        return NaturalTensor(x.values + y.values)
    
    def __sub__(x, y):
        return NaturalTensor(x.values - y.values)
    
    def __mul__(x, y):
        return NaturalTensor(x.values * y.values)
    
    def dot(x, y):
        return NaturalTensor(x.values.dot(y.values))
    
    def reduce(x):
        return NaturalTensor(x.values % M)
    
    def mod(x):
        return NaturalTensor(x.values % K)
    
    @staticmethod
    def sample(shape):
        return NaturalTensor(np.array([ random.randrange(M) for _ in range(prod(shape)) ]).reshape(shape))
    
a = np.array([ 2**120 for _ in range(1024) ])
b = np.array([ 2**110 for _ in range(1024) ])
x = NaturalTensor(a)
y = NaturalTensor(b)
z = x + y; z = z.reduce(); assert (z.unwrap() == (a+b) % M).all(), z
z = x - y; z = z.reduce(); assert (z.unwrap() == (a-b) % M).all(), z
z = x * y; z = z.reduce(); assert (z.unwrap() == (a*b) % M).all(), z
z = x.dot(y); z = z.reduce(); assert z.unwrap() == a.dot(b) % M, z
z = y.mod(); assert (z.unwrap() == b % K).all(), z

## CRT tensors

Tensors backed by a NumPy array of numbers of type `int64`.

In [13]:
class CrtTensor:
    """ Uses the CRT representation of numbers """
    
    def __init__(self, values, parts=None):
        if values is not None:
            parts = [ part.astype(np.int64) for part in decompose(values) ]
        self.parts = parts
        
    def __repr__(self):
        return 'CrtTensor({})'.format(self.parts)
    
    @property
    def shape(self):
        return self.parts[0].shape
    
    def unwrap(self):
        return recombine(self.parts)
    
    def __add__(x, y):
        # component-wise operation that can be done in parallel
        return CrtTensor(None, [ 
            (xi + yi) for xi, yi in zip(x.parts, y.parts)
        ])
    
    def __sub__(x, y):
        # component-wise operation that can be done in parallel
        return CrtTensor(None, [ 
            (xi - yi) for xi, yi in zip(x.parts, y.parts)
        ])

    def __mul__(x, y):
        # component-wise operation that can be done in parallel
        return CrtTensor(None, [ 
            (xi * yi) for xi, yi in zip(x.parts, y.parts)
        ])
    
    def dot(x, y):
        return CrtTensor(None, [ 
            xi.dot(yi) for xi, yi in zip(x.parts, y.parts)
        ])
    
    def reduce(x):
        return CrtTensor(None, [
            xi % mi for xi, mi in zip(x.parts, ms)
        ])
    
    # TODO: straight-forward from scalar types, just need to take NumPy into account
#     def mod(x):
#         return CrtTensor(None, mod(x.parts))
    
    @staticmethod
    def sample(shape):
        return CrtTensor(None, [
            np.random.randint(mi, size=shape) for mi in ms
        ])


a = np.array([ 2**120 for _ in range(1024) ])
b = np.array([ 2**110 for _ in range(1024) ])
x = CrtTensor(a)
y = CrtTensor(b)
z = x + y; z = z.reduce(); assert (z.unwrap() == (a+b) % M).all(), z
z = x - y; z = z.reduce(); assert (z.unwrap() == (a-b) % M).all(), z
z = x * y; z = z.reduce(); assert (z.unwrap() == (a*b) % M).all(), z
z = x.dot(y); z = z.reduce(); assert z.unwrap() == a.dot(b) % M, z
# z = y.mod(); assert (z.unwrap() == b % K).all(), z

## Performance tests

We can finally get to do benchmarks on dot products. `CrtTensor` is ~10x faster already on single core device.

In [14]:
for tensor_type in [NaturalTensor, CrtTensor]:
    
    x = tensor_type(np.array([ 2**120 for _ in range(1024) ]))
    k = tensor_type(np.array([ 2**110 for _ in range(1024) ]))
    
    start = datetime.now()
    for _ in range(10000):
        z = x.dot(k)
        z.reduce() # only little effect compared to dot
    end = datetime.now()
    print('{:15}: {}'.format(tensor_type.__name__, end - start))

NaturalTensor  : 0:00:01.153654
CrtTensor      : 0:00:00.116858


# Private tensors

Again bring in secret sharing just to illustrate that it works.

In [15]:
def gen_private_tensor(tensor_type):
    
    class AbstractPrivateTensor:

        def __init__(self, values, shares0=None, shares1=None):
            if values is not None:
                values = tensor_type(values)
                shares0 = tensor_type.sample(values.shape)
                shares1 = values - shares0
            self.shares0 = shares0
            self.shares1 = shares1

        def __repr__(self):
            return 'PrivateTensor({})'.format(self.unwrap())
        
        def unwrap(self):
            return self.reconstruct().unwrap()
        
        def reconstruct(self):
            return (self.shares0 + self.shares1).reduce()
        
        def __add__(x, y):
            # component-wise operation that can be done in parallel
            return AbstractPrivateTensor(None,
                shares0 = x.shares0 + y.shares0,
                shares1 = x.shares1 + y.shares1
            )
        
        def __sub__(x, y):
            # component-wise operation that can be done in parallel
            return AbstractPrivateTensor(None,
                shares0 = x.shares0 - y.shares0,
                shares1 = x.shares1 - y.shares1
            )
        
        def __mul__(x, k):
            # component-wise operation that can be done in parallel
            return AbstractPrivateTensor(None,
                shares0 = x.shares0 * k,
                shares1 = x.shares1 * k
            )
        
        def dot(x, k):
            # component-wise operation that can be done in parallel
            return AbstractPrivateTensor(None,
                shares0 = x.shares0.dot(k),
                shares1 = x.shares1.dot(k)
            )
        
        def reduce(x):
            return AbstractPrivateTensor(None,
                shares0 = x.shares0.reduce(),
                shares1 = x.shares1.reduce()
            )
        
        # TODO need to implement `mod` on tensors first
#         def truncate(x):
#             return AbstractPrivateTensor(None,
#                 shares0 = raw_truncate(x.share0),
#                 shares1 = M_wrapped - raw_truncate(M_wrapped - x.share1)
#             )
            

    return AbstractPrivateTensor

for tensor_type in [NaturalTensor, CrtTensor]:
    
    PublicTensor  = tensor_type
    PrivateTensor = gen_private_tensor(tensor_type)

    a = np.array([ 2**120 for _ in range(1024) ])
    b = np.array([ 2**110 for _ in range(1024) ])
    x = PrivateTensor(a)
    y = PrivateTensor(b)
    k = PublicTensor(b)

    z = x + y; z = z.reduce(); assert (z.unwrap() == (a+b) % M).all(), z
    z = x - y; z = z.reduce(); assert (z.unwrap() == (a-b) % M).all(), z
    z = x * k; z = z.reduce(); assert (z.unwrap() == (a*b) % M).all(), z
    z = x.dot(k); z = z.reduce(); assert z.unwrap() == a.dot(b) % M, z
    # z = y.truncate(); assert z.unwrap() in [b // K, b // K + 1], z

## Performance tests

Also ~10x speedup on single core device here.

In [16]:
for tensor_type in [NaturalTensor, CrtTensor]:
    
    PublicTensor  = tensor_type
    PrivateTensor = gen_private_tensor(tensor_type)
    
    a = np.array([ 2**120 for _ in range(1024) ])
    b = np.array([ 2**110 for _ in range(1024) ])
    x = PrivateTensor(a)
    k = PublicTensor(b)
    
    start = datetime.now()
    for _ in range(10000):
        z = x.dot(k)
        #z.reduce() # only little effect compared to dot
    end = datetime.now()
    print('{:15}: {}'.format(tensor_type.__name__, end - start))

NaturalTensor  : 0:00:02.330989
CrtTensor      : 0:00:00.183384
