## Somewhat Homomorphic Encryption Toy
- SWHE
- Refernce: "Computing Arbitrary Functions of Encrypted Data" by Craig Gentry

In [1]:
import random


In [2]:
def keygen(noise, modulus):
    a_key = random.getrandbits((noise ** 2))
 
    while ((a_key % 2) != 1) and (a_key < (modulus ** (noise ** 2) - 1)):
        a_key = a_key + 1
 
    return a_key

In [3]:
def encrypt(noise, a_key, a_bit, modulus):
    a_mask          = random.getrandbits(noise)
    a_bit_remainder = a_bit % modulus
 
    while ((a_mask % modulus) != a_bit_remainder):
        a_mask = random.getrandbits(noise)
 
    return a_mask + (a_key * random.getrandbits(noise ** 5))

In [4]:
def decrypt(a_key, a_bit, modulus):
    return (a_bit % a_key) % modulus

## Evaluation

In [5]:
def multiplication_netpie(key,c1,c2):
    return  c1 * c2

In [6]:
def additive_netpie(key,c1,c2):
    return c1 + c2

In [7]:
def minus_netpie(key,c1,c2):
    return c1 - c2 

## Test functions

In [8]:
def multiplication_example():
    modulus = 16
    noise   = 5
    a_key   = keygen(noise, modulus=modulus)
    a_p     = random.getrandbits(2)
    b_p     = random.getrandbits(2)
    a_c     = encrypt(noise, a_key, a_p, modulus=modulus)
    b_c     = encrypt(noise, a_key, b_p, modulus=modulus)
    c       = a_c * b_c
    d       = decrypt(a_key, c, modulus=modulus)
    print("multiplication_example()\n-------------------------")
    print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
    print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
    print("c: %d\nd: %d\n\n" % (c, d))

In [9]:
def addition_example():
    modulus = 8
    noise   = 6
    a_key   = keygen(noise, modulus=modulus)
    a_p     = random.getrandbits(2)
    b_p     = random.getrandbits(2)
    a_c     = encrypt(noise, a_key, a_p, modulus=modulus)
    b_c     = encrypt(noise, a_key, b_p, modulus=modulus)
    c       = a_c + b_c
    d       = decrypt(a_key, c, modulus=modulus)
    print("addition_example()\n------------------")
    print("a (p): %d\nb (p): %d\n" % (a_p, b_p))
    print("a (c): %d\nb (c): %d\n" % (a_c, b_c))
    print("c: %d\nd: %d\n\n" % (c, d))

In [10]:
def compair(dec_c,origin_m):
    if dec_c == origin_m:
        print("Correct answer")
    else:
        print("Error")

In [11]:
multiplication_example()

multiplication_example()
-------------------------
a (p): 2
b (p): 3

a (c): 58848653489687450831615301199009589418096803205193849313906531116889216240466535862607789274860924235933239548968316093868452626938659790496087950084256511756072082128277899448667079401493948786673759049443283218558769800708891745746984918918612302917012785728205465087852175264690133531385199049233881242955491606937654486317023183433506451597271166380588418348192060276889762461566482280801528536445149271170254483090956768618041643297479533939000855821711255303958872942764135573906137244652646725650969693909909331516281341969006615959699538102169580061247165341390692865579165116050750084667048989252760263098795503452452214298688975825151893641906332919495767983864031039221098404301573049449508843426144050397818889544531465250161918244295571409779840886023287412274606276137105430193101473511106238666629868191079031999185769002810705049400725223699740625185299622772938946457845947642075643330906773871165848204540

In [12]:
addition_example()

addition_example()
------------------
a (p): 2
b (p): 2

a (c): 194808570994221921231206521100104004341278509023592782079350268951372656005254077218752970352077555027780369802950411386793236844842311658274307742896267534506913095401995040115358424029126059113671102095126787727454985720217836892475141734190329805614855428270040094177083098880433795922938939994451626581477596555987765106445659010786906979746603321750920206895728056008445820819260169708189052409409482337149565059653543120702808550822759714744522109676801360297712646038800316812064135263476739384788535947461289934764503440792744601002846397801147889366725757052531415503429163435153157784763815685053709573766821909808836345142507031375014313037429300235627447598052110152193786447628640317607060460960288286634371903460541920443563373690198083615367959269844797105722765473330887726376579089575789469047667886985211368328143381139737902766834343746429374925646085859796479330151691044053554654105123243489176707425804837088965592

## Example of SHE 

In [13]:
# parameter setup
noise = 2
modulus = 8
# Create a symmetric key
key = keygen(noise, modulus)
print("Initial Key = ", key)
## Check key length:
while key <= modulus:
    key = keygen(noise, modulus=modulus)
    print("New Key = ", key)
    
    
m1 = 3
m2 = 1
c1 = encrypt(noise, key, m1, modulus=modulus)
c2 = encrypt(noise, key, m2, modulus=modulus)

Initial Key =  11


In [14]:
print("Ciphertext : ")
print('c1 =',c1)
print('c2 =', c2)

Ciphertext : 
c1 = 16197670921
c2 = 30763138671


In [15]:
print("Multiplication ciphertext result: ")
c = multiplication_netpie(key,c1,c2)
print('c = ',c)

Multiplication ciphertext result: 
c =  498291196689947285991


In [16]:
print("Decrypted multiplicative result: ")
m = decrypt(key, c, modulus=modulus)
print('m =',m)
compair(m,m1*m2)

Decrypted multiplicative result: 
m = 3
Correct answer


In [17]:
print("Additive ciphertext result: ")
c_add = additive_netpie(key,c1,c2)
print('c = ',c_add)

Additive ciphertext result: 
c =  46960809592


In [18]:
print("Decrypted additive result: ")
m_add = decrypt(key, c_add, modulus=modulus)
print('m_add = ',m_add)
compair(m_add,m1+m2)

Decrypted additive result: 
m_add =  4
Correct answer


In [19]:
print("Minus ciphertext result: ")
c_minus = minus_netpie(key,c1,c2)
print('c = ',c_minus)

Minus ciphertext result: 
c =  -14565467750


In [20]:
print("Decrypted minus result: ")
m_minus = decrypt(key, c_minus, modulus=modulus)
print('m_add = ',m_minus)
compair(m_minus,m1-m2)

Decrypted minus result: 
m_add =  2
Correct answer
