# Testing Individual components of the FV HE scheme

In [81]:
import random

from syft.frameworks.torch.he.fv.modulus import CoeffModulus
from syft.frameworks.torch.he.fv.encryption_params import EncryptionParams
from syft.frameworks.torch.he.fv.context import Context
from syft.frameworks.torch.he.fv.integer_encoder import IntegerEncoder
from syft.frameworks.torch.he.fv.key_generator import KeyGenerator
from syft.frameworks.torch.he.fv.encryptor import Encryptor
from syft.frameworks.torch.he.fv.decryptor import Decryptor
from syft.frameworks.torch.he.fv.integer_encoder import IntegerEncoder
from syft.frameworks.torch.he.fv.modulus import SeqLevelType
from syft.frameworks.torch.he.fv.evaluator import Evaluator
from syft.frameworks.torch.he.fv.relin_keys import RelinKeys

## Keygeneration

In [82]:
poly_modulus = 128
bit_sizes= [30]
plain_modulus = 64
ctx = Context(EncryptionParams(poly_modulus, CoeffModulus().create(poly_modulus, bit_sizes), plain_modulus))
keygenerator = KeyGenerator(ctx)
sk, pk = keygenerator.keygen()

In [83]:
# print(ctx.param.coeff_modulus)

In [84]:
# print(pk.data)
# print('public key values : ', pk.data)

## Integer Encoder
Encodes Integer values to Plaintext object

In [85]:
int_encoder = IntegerEncoder(ctx)
ri1 = random.randint(0,10)
ri2 = random.randint(0,10)
ri3 = random.randint(0,10)
pt1 = int_encoder.encode(ri1)
pt2 = int_encoder.encode(ri2)
pt3 = int_encoder.encode(ri3)
# print(pt1.data,"   ", pt2.data, "   ", pt3.data)
# print('plaintext data',plaintext.data)

### Decodes back to Integer

In [86]:
# print(int_encoder.decode(pt1))
# print(int_encoder.decode(pt2))
# print(int_encoder.decode(pt3))

## Encrypter
Encrypt Plaintext to ciphertext using public_key

In [87]:
encrypter = Encryptor(ctx, pk)

In [88]:
ct1 = encrypter.encrypt(pt1)
ct2 = encrypter.encrypt(pt2)
ct3 = encrypter.encrypt(pt3)

Encrypt Plaintext to ciphertext using secret_key

## Decryptor
Decrypt Ciphertext to Plaintext using secret_key

In [89]:
decrypter = Decryptor(ctx, sk)

In [90]:
dec1 = decrypter.decrypt(ct1)
dec2 = decrypter.decrypt(ct2)
dec3 = decrypter.decrypt(ct3)

In [91]:
# print(int_encoder.decode(dec1), "   ", int_encoder.decode(dec2), "   ", int_encoder.decode(dec3))

## Evaluator

In [92]:
eval = Evaluator(ctx)

In [93]:
cc12 = eval.add(ct1, ct2)
cc12 = decrypter.decrypt(cc12)
# print(int_encoder.decode(cc12))

In [94]:
pc12 = eval.add(pt1, ct2)
pc12 = decrypter.decrypt(pc12)
# print(int_encoder.decode(pc12))

In [95]:
pp12 = eval.add(pt1, pt2)
# print(int_encoder.decode(pp12))

### Verify result

In [96]:
assert int_encoder.decode(cc12) == int_encoder.decode(pc12) == int_encoder.decode(pp12) == ri1+ri2

In [97]:
first_prod = eval._mul_cipher_cipher(ct1, ct2)

In [98]:

relin_prod = eval.relin(first_prod, keygenerator.get_relin_keys())


In [99]:
final_prod = eval._mul_cipher_cipher(relin_prod, ct3)
result = decrypter.decrypt(final_prod)
result = int_encoder.decode(result)

In [100]:
print(ri1 * ri2 * ri3, "    ", result)
assert ri1 * ri2 == result

0      -4313233534013225645650638308439700547724


AssertionError: 