## Simple implementation of a Paillier cryptosystem

See: https://en.wikipedia.org/wiki/Paillier_cryptosystem 

Code cribbed from (in no particular order):

- https://github.com/data61/python-paillier [a]
- https://github.com/cgshep/paillier-lib [b]
- https://erev0s.com/blog/paillier-cryptosystem-python/ [c]



In [5]:
!pip install pycryptodome



In [6]:
# Note: "Crypto" here is pycryptodome
from Crypto import Random
from Crypto.Random import random
from Crypto.Util import number
import math

In [7]:
# math.lcm() requires python 3.9
try:
  l =  math.lcm(10, 2)
except AttributeError:
  def lcm(x, y):
     return (x*y)//number.GCD(x,y)
  math.lcm = lcm

## Utils

In [8]:
def getRandomBitsOfLength(bits):
    '''
    Get random val EXACTLY bits long
    This means that the high bit must be set
    '''
    actual_bits = 0
    while actual_bits != bits:
        r = random.getrandbits(bits)
        #print(r, type(r))       
        actual_bits = number.size(r) # higest bit IS set for size_in_bits()
        # print("actual: {}, desired: {}".format(actual_bits, bits))         
    return r


## Pailler Code

In [9]:
class PaillierPublicKey(object):
    """
    In all the papers, (g,n) is the public key, but since we define  g === n + 1 it's really just n 
    """
    def __init__(self, n):
        self.n = n
        self.n2 = n**2
        self.g = n + 1  # From Damgard,Jurik 2000.  Both [a] and [c] use this value. [b] does NOT
        
    def __repr__(self):
        return "<PaillierPublicKey(n, g): ({}, {})>".format(self.n,self.g)        
        
    def encrypt(self, m, r=None):
        '''
        source [a], [c]
        '''
        
        if r == None:
            r = self.n
            while r >= self.n-1:
                r = random.getrandbits(number.size(self.n))           
        
         # Comment from [a]:
        #  "we chose g = n + 1, so that we can exploit the fact that 
        # (n+1)^plaintext = n*plaintext + 1 mod n^2"
        # I'm certain this optimization is from  Damgard-Jurik, but I don;t feel like reading 
        # the paper again just to find it.
        x = pow(r,self.n,self.n2)
        
        cipher = pow(self.g,m,self.n2) * x % self.n2
        return cipher
        
    def encrypt_2(self, msg):
        '''
        from [b]
        '''
        r = self.n
        while r >= self.n-1:
            r = random.getrandbits(number.size(self.n))

        print("r = {}, n = {}".format(r,self.n))            
            
        a = pow(self.g, msg, self.n2)
        b = pow(r, self.n, self.n2)
        c = (a * b) % self.n2            
            
        return c   
        

In [10]:

class PaillierPrivateKey(object):
    """
    Source [a] uses the Chinese remainder theorem in decrypt, the other 2 do not.
    I think that is because [a] and [b] always set g = n+1 per Damgard and Jurik 2000
    """
    def __init__(self, p, q, pubkey):
        self.n = pubkey.n
        self.n2 = pubkey.n2
        self.l = math.lcm(p-1,q-1)         
        self.u = number.inverse(self.l,self.n)           
        
    def __repr__(self):
        hsh = hex(hash(self))[2:]
        return "<PaillierPrivateKey(l, u): ({}{})>".format(self.l,self.u)         

    def _lFunc(self,x):
        return (x-1) // self.n
    
    def decrypt(self, cipher):    
        '''
        Source: [b]
        '''
        a = pow(cipher, self.l, self.n2)
        return self._lFunc(a) * self.u % self.n       
    
    def decrypt_2(self, cipher):
        '''
        Source [c]
        '''
        x =  pow(cipher, self.l, self.n2) - 1
        msg =  ((x // self.n) * self.u) % self.n
        if msg >= self.n // 2:
            msg = msg - self.n
        return msg        
    

In [11]:

def GeneratePaillierKeys(bits):
    '''
    Generate random public `(n, g)` and private `lambda`
    keys of bitlength `bits`
    
    returns (pubkey, privkey)

    From Damgard,Jurik: "A Generalisation, a Simplification and Some Applications of Paillier’s Probabilistic Public-Key System", 2000
    In the first section of Chapter 3:
        "In particular one may choose g = n + 1 always without degrading security. We do this in the following for simplicity, 
        so that a public key consists only of the modulus n."
    '''
    actual_bits = 0
    while actual_bits != bits:        
        p = number.getPrime( bits // 2);
        q = number.getPrime( bits // 2);
        n = (p * q)
        actual_bits =  number.size(n)
         # print("n actual: {}, desired: {}".format(actual_bits, bits))        

   # print("n = {}".format(n))
    
    pubkey = PaillierPublicKey(n)
    privkey = PaillierPrivateKey(p, q, pubkey)
    return (pubkey, privkey)
    

In [12]:
def bin_str(x):
    y = int(x) # in case it's an mpz or Integer...
    bits = ''
    while y:
        bits += str(y&1)
        y >>= 1
    return bits[::-1]

In [13]:
def simple_paillier_test(msg, key_bits=256):
    print( f"Msg: {msg}" )
    (pubk, privk) = GeneratePaillierKeys(key_bits)
    print(pubk, privk)
    c = pubk.encrypt(msg)
    print( f"Cipher: {c}")
    d = privk.decrypt(c)
    print( f"Decrypted: {d}")

In [14]:
print("BacisPaillier notebook loaded")

BacisPaillier notebook loaded


In [15]:
# simple_paillier_test(234321,192)