In [1]:
import numpy as np
import phe as paillier
import pickle, socket, threading, time

MAX_INPUT = 128

ALICE_SIZE = (5, 8)
BOB_SIZE = (8, 4)

HOST = "127.0.0.1"
PORT = 65432


class User:
    __is_sender: bool = None
    pub_key: paillier.PaillierPublicKey = None
    priv_key: paillier.PaillierPrivateKey = None
    priv_matrix: np.ndarray = None
    enc_matrix: list[list[paillier.EncryptedNumber]] = None
    ext_matrix: list[list[paillier.EncryptedNumber]] = None

    def __init__(self, mat_size: tuple[int, int], is_sender: bool = False, key_size: int = 512) -> None:

        self.__is_sender = is_sender
        # np.random.seed(PORT)
        self.priv_matrix = np.random.randint(MAX_INPUT, size=mat_size)

        if is_sender:
            self.pub_key, self.priv_key = paillier.generate_paillier_keypair(n_length=key_size)
            self.__loadKey()

    def __loadKey(self):
        if self.__is_sender:
            self.enc_matrix = [[self.pub_key.encrypt(i) for i in r] for r in self.priv_matrix.tolist()]
        else:
            self.enc_matrix = [[paillier.encoding.EncodedNumber.encode(self.pub_key, i) for i in r] for r in self.priv_matrix.tolist()]

    def ciphertext(self) -> str:
        return str([[i.ciphertext() for i in r] for r in self.enc_matrix])

    def __ciphertext(self) -> str:
        if self.__is_sender:
            return str([[i.ciphertext() % 100 for i in r] for r in self.enc_matrix])
        else:
            return str([[i.ciphertext() % 100 for i in r] for r in self.ext_matrix])

    def decrypt(self) -> list[list[int]]:
        if self.__is_sender and self.ext_matrix:
            return [[self.priv_key.decrypt(i) for i in r] for r in self.ext_matrix]

    def dump_matrix(self, dump_ext: bool = False) -> bytes:
        return pickle.dumps(self.ext_matrix if dump_ext else self.enc_matrix)

    def load_matrix(self, bytes: bytes):
        self.ext_matrix = pickle.loads(bytes)

    def dump_pub(self) -> bytes:
        return pickle.dumps(self.pub_key)

    def load_pub(self, bytes: bytes):
        if not self.__is_sender:
            self.pub_key = pickle.loads(bytes)
            self.__loadKey()

    def __str__(self) -> str:
        return f"{'Sender and ' if self.priv_key else ''}Receiver\n{self.ciphertext()}\n"


class Alice(User):
    def __init__(self, key_size: int) -> None:
        super().__init__(ALICE_SIZE, True, key_size)


class Bob(User):
    def __init__(self) -> None:
        super().__init__(BOB_SIZE)


def mat_mult(a, b):
    zip_b = zip(*b)
    zip_b = list(zip_b)
    return [[sum(ele_a * ele_b for ele_a, ele_b in zip(row_a, col_b)) for col_b in zip_b] for row_a in a]


In [2]:
def local_test():
    """Just a local test with no sockets"""
    # Initialize both users
    alice = Alice(512)
    bob = Bob()

    # bob receives initial data
    bob.load_pub(alice.dump_pub())
    bob.load_matrix(alice.dump_matrix())

    # bob performs calculation
    bob.ext_matrix = mat_mult(bob.ext_matrix, bob.enc_matrix)

    # alice receives the data back
    alice.load_matrix(bob.dump_matrix(True))

    # alice decrypts the data
    print(alice.decrypt())
    print(mat_mult(alice.priv_matrix.tolist(), bob.priv_matrix.tolist()))
    # print(alice.priv_matrix.tolist())
    # print(bob.priv_matrix.tolist())


local_test()


[[46908, 31768, 23393, 38100], [40768, 21478, 14525, 32899], [38043, 28440, 19323, 37125], [45199, 28599, 25852, 36922], [33377, 26998, 27210, 35091]]
[[46908, 31768, 23393, 38100], [40768, 21478, 14525, 32899], [38043, 28440, 19323, 37125], [45199, 28599, 25852, 36922], [33377, 26998, 27210, 35091]]


In [3]:
# socket testing
def run_alice(bits: int):
    alice = Alice(bits)

    print(f"Alice Private Matrix\n{alice.priv_matrix}")

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind((HOST, PORT))
        s.listen()
        conn, addr = s.accept()
        with conn:
            conn.recv(1)  # Only used to ensure print does not overlap
            print("Alice: Sending input")
            pub = alice.dump_pub()
            mat = alice.dump_matrix()
            conn.sendall(len(pub).to_bytes(128, "little"))
            conn.sendall(pub)
            conn.sendall(len(mat).to_bytes(128, "little"))
            conn.sendall(mat)
            alice.load_matrix(conn.recv(len(mat)))
            conn.recv(1)  # Only used to ensure print does not overlap
            print("Alice: Final Ciphertext")
            print(alice.ciphertext())
            print("Alice: Final Result")
            print(alice.decrypt())
            print("Alice: Finished")


def run_bob():
    bob = Bob()

    print(f"Bob Private Matrix\n{bob.priv_matrix}")

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.connect((HOST, PORT))
        print("Bob: Receiving input")
        s.send(bytes(1))  # Only used to ensure print does not overlap
        sz = int.from_bytes(s.recv(128), "little")
        print(f"Bob: Pub size: {sz}")
        bob.load_pub(s.recv(sz))
        sz = int.from_bytes(s.recv(128), "little")
        print(f"Bob: Matrix size: {sz}")
        bob.load_matrix(s.recv(sz))
        print("Bob: Computing result")
        bob.ext_matrix = mat_mult(bob.ext_matrix, bob.enc_matrix)
        print("Bob: Sending result")
        s.sendall(bob.dump_matrix(True))
        print("Bob: Finished")
        s.send(bytes(1))  # Only used to ensure print does not overlap


def socket_test(bits: int):
    """Test with socket communication"""
    print(f"\n---[ Socket Test | Bits:{bits} ]---\n")
    ra = threading.Thread(target=run_alice, args=[bits])
    rb = threading.Thread(target=run_bob)
    ra.start()
    rb.start()
    ra.join()
    rb.join()


socket_test(512)
socket_test(1024)



---[ Socket Test | Bits:512 ]---

Bob Private Matrix
[[105 126  60 123]
 [ 82  63  19  24]
 [ 19  79  98  46]
 [ 75 121  91 115]
 [ 49  23  68  33]
 [ 16  11   5  85]
 [ 83  33  26  70]
 [ 98  20  13   4]]
Alice Private Matrix
[[108   6  35  83 109  52  23  59]
 [  5  47 111  64  32 127  70  75]
 [ 21  63 110  39  73   1  21  24]
 [ 95  29  51  80  70  64   2   9]
 [123 127  68  98  73  92  70 104]]
Bob: Receiving input
Alice: Sending input
Bob: Pub size: 415
Bob: Matrix size: 6658
Bob: Computing result
Bob: Sending result
Bob: Finished
Alice: Final Ciphertext
[[7569420083534309627912391436277885549318437507884768651601723056385143712743027002097158532706296532861643535088722533899178176652626034179100819166826834269940279366685131495777799074675442377685972687962895306175422831798030858972230732547451077499621471853534369824020476301939407853885249835743783468950, 1617146420359752273134618913667365262295118241907567878227396321641825804637742720501966745069361518959368543386297540061