# Testing Individual components of the FV HE scheme

In [21]:
import random

from syft.frameworks.torch.he.fv.modulus import CoeffModulus
from syft.frameworks.torch.he.fv.encryption_params import EncryptionParams
from syft.frameworks.torch.he.fv.context import Context
from syft.frameworks.torch.he.fv.integer_encoder import IntegerEncoder
from syft.frameworks.torch.he.fv.key_generator import KeyGenerator
from syft.frameworks.torch.he.fv.encryptor import Encryptor
from syft.frameworks.torch.he.fv.decryptor import Decryptor
from syft.frameworks.torch.he.fv.integer_encoder import IntegerEncoder
from syft.frameworks.torch.he.fv.modulus import SeqLevelType
from syft.frameworks.torch.he.fv.evaluator import Evaluator

## Keygeneration

In [22]:
poly_modulus = 64
bit_sizes= [40]
plain_modulus = 64
ctx = Context(EncryptionParams(poly_modulus, CoeffModulus().create(poly_modulus, bit_sizes), plain_modulus))
keygenerator = KeyGenerator(ctx)
sk, pk = keygenerator.keygen()

In [23]:
print(ctx.param.coeff_modulus)

[1099511623297]


In [24]:
# print(len(sk.data))
print('secret key values : ', sk.data)

secret key values :  [[0, 1099511623296, 1, 1, 1099511623296, 1, 0, 1099511623296, 1, 1099511623296, 1099511623296, 1099511623296, 1099511623296, 1099511623296, 0, 0, 1, 1099511623296, 0, 0, 0, 1099511623296, 1, 1099511623296, 1, 1, 1, 0, 1099511623296, 1099511623296, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1099511623296, 1099511623296, 1, 1, 1, 1, 1099511623296, 1099511623296, 0, 0, 1099511623296, 1, 1, 0, 1, 1099511623296, 1, 0, 1, 1099511623296, 0, 0, 1099511623296, 0, 1099511623296]]


In [25]:
# print(pk.data)
# print('public key values : ', pk.data)

## Integer Encoder
Encodes Integer values to Plaintext object

In [26]:
int_encoder = IntegerEncoder(ctx)
ri1 = random.randint(0,10)
ri2 = random.randint(0,10)
ri3 = random.randint(0,10)
pt1 = int_encoder.encode(ri1)
pt2 = int_encoder.encode(ri2)
pt3 = int_encoder.encode(ri3)
print(pt1.data,"   ", pt2.data, "   ", pt3.data)
# print('plaintext data',plaintext.data)

[0, 1]     [0, 1, 0, 1]     [0, 0, 1]


### Decodes back to Integer

In [27]:
print(int_encoder.decode(pt1))
print(int_encoder.decode(pt2))
print(int_encoder.decode(pt3))

2
10
4


## Encrypter
Encrypt Plaintext to ciphertext using public_key

In [28]:
encrypter = Encryptor(ctx, pk)

In [29]:
ct1 = encrypter.encrypt(pt1)
ct2 = encrypter.encrypt(pt2)
ct3 = encrypter.encrypt(pt3)

Encrypt Plaintext to ciphertext using secret_key

## Decryptor
Decrypt Ciphertext to Plaintext using secret_key

In [30]:
decrypter = Decryptor(ctx, sk)

In [31]:
dec1 = decrypter.decrypt(ct1)
dec2 = decrypter.decrypt(ct2)
dec3 = decrypter.decrypt(ct3)

In [32]:
print(int_encoder.decode(dec1), "   ", int_encoder.decode(dec2), "   ", int_encoder.decode(dec3))

2     10     4


## Evaluator

In [33]:
eval = Evaluator(ctx)

In [34]:
cc12 = eval.add(ct1, ct2)
cc12 = decrypter.decrypt(cc12)
print(int_encoder.decode(cc12))

12


In [35]:
pc12 = eval.add(pt1, ct2)
pc12 = decrypter.decrypt(pc12)
print(int_encoder.decode(pc12))

12


In [36]:
pp12 = eval.add(pt1, pt2)
print(int_encoder.decode(pp12))

12


### Verify result

In [37]:
assert int_encoder.decode(cc12) == int_encoder.decode(pc12) == int_encoder.decode(pp12) == ri1+ri2

In [38]:
result = eval._mul_cipher_cipher(ct1, ct2)
print("\n\nct1 :",ct1.data)
print("\n\nct2 :",ct2.data)
print('\n\n')

result = decrypter.decrypt(result)
result = int_encoder.decode(result)

print('final result: ', result)



ct1 : [[[1042665945075, 916558144260, 1044174421481, 210629152726, 847821787156, 534894345004, 763475050322, 1033922928488, 532772495248, 351443634171, 1093570136397, 1006841470961, 377936268283, 147046430775, 970183333355, 847329897793, 667471202702, 176662347509, 919648546066, 796441831579, 275081728359, 407687223886, 344513077501, 899118994863, 860389120643, 467753443738, 119725953496, 894760788805, 851090554632, 823548090674, 351534884471, 50687760242, 595463970022, 434506787772, 586276619829, 337725279644, 735579946766, 511011770164, 450013503941, 542766713743, 311432052482, 979771104334, 194974774539, 760908664549, 229372331842, 437253620955, 435574411594, 937959646039, 982777761329, 770521129800, 6214731693, 76094852900, 131833917135, 649660765117, 1200241420, 428350364244, 920846207780, 192005048959, 603237170, 574887659510, 310476840958, 316729173050, 720651025894, 57302224595]], [[787708916333, 612305453753, 503055916075, 265330060304, 967037136757, 1084209027336, 852537584

In [39]:
print(ri1 * ri2, "    ", result)
assert ri1 * ri2 == result

20      20
