In [1]:
import pickle
import tenseal as ts
import socket
import struct
from collections import OrderedDict
import copy
import torch

In [2]:
with open('public_context.pkl', 'rb') as inp:
    public_context_bin = pickle.load(inp)

In [3]:
public_context = ts.context_from(public_context_bin)

In [4]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = msg.serialize()
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)
    
def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    
    msglen = struct.unpack('>I', raw_msglen)[0]
    print(f'msglen: {msglen}')
    # read the message data
    msg =  recvall(sock, msglen)
    msg = ts.ckks_tensor_from(public_context, msg)
    
    return msg

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
  
    return data

In [5]:
client_num = 2
client_list = []
client_weight_list = []

In [6]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.0.116


In [7]:
server_soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_soc.bind((host, port))
server_soc.listen()
print("Listening")
for i in range(client_num):    
    client, addr = server_soc.accept()
    client_list.append(client)
    print('Conntected with', addr)

Listening
Conntected with ('192.168.0.116', 50282)
Conntected with ('192.168.0.116', 43288)


In [8]:
for client in client_list:
    weights = []
    for i in range(2):
        weight = recv_msg(client)
        weights.append(weight)
    print(f"weights: {weights}")
    client_weight_list.append(weights)
    # client.close()
    
server_soc.close()

msglen: 571637985
msglen: 1116735
weights: [<tenseal.tensors.ckkstensor.CKKSTensor object at 0x79a28c857910>, <tenseal.tensors.ckkstensor.CKKSTensor object at 0x79a28c8576d0>]
msglen: 571643625
msglen: 1116055
weights: [<tenseal.tensors.ckkstensor.CKKSTensor object at 0x79a28c857cd0>, <tenseal.tensors.ckkstensor.CKKSTensor object at 0x79a3903a6200>]


In [9]:
def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """
    for i, data in enumerate(datasize):
        for j in range(len(w[i])):
            w[i][j] *= float(data)
    
    w_avg = copy.deepcopy(w[0])
    
    # when client use only one kinds of device

    for i in range(len(w_avg)):
        for j in range(1, len(w)):
            w_avg[i] += w[j][i]
            #eval add
        w_avg[i] = w_avg[i]* (1/float(sum(datasize)))

    return w_avg

In [10]:
avg_weight = average_weights(client_weight_list, [87, 87])

In [11]:
for client in client_list:
    for weight in avg_weight:    
        send_msg(client, weight)
    client.close()

In [12]:
server_soc.close()