## Loading the shared context

In [1]:
import pickle
import dill
import tenseal as ts
import torch
from torch import nn
import socket
import struct
from torchvision import models
import numpy as np

In [2]:
with open('shared_context.pkl', 'rb') as inp:
    shared_context_bin = pickle.load(inp)

In [3]:
shared_context = ts.context_from(shared_context_bin)

In [4]:
sk = shared_context.secret_key()

In [5]:
print(sk)

<tenseal.enc_context.SecretKey object at 0x7779a0679000>


In [6]:
model = models.squeezenet1_1(weights=True)



In [7]:
num_classes = 3
in_ftrs = model.classifier[1].in_channels
features = list(model.classifier.children())[:-3] # Remove last 3 layers
features.extend([nn.Conv2d(in_ftrs, num_classes, kernel_size=1)]) # Add
features.extend([nn.ReLU(inplace=True)]) # Add
features.extend([nn.AdaptiveAvgPool2d(output_size=(1,1))]) # Add
model.classifier = nn.Sequential(*features)

In [8]:
last_layer_list = [model.state_dict()['classifier.1.weight'], model.state_dict()['classifier.1.bias']]
last_layer_list

[tensor([[[[-0.0351]],
 
          [[ 0.0412]],
 
          [[ 0.0206]],
 
          ...,
 
          [[-0.0113]],
 
          [[ 0.0220]],
 
          [[ 0.0247]]],
 
 
         [[[-0.0062]],
 
          [[ 0.0232]],
 
          [[ 0.0230]],
 
          ...,
 
          [[ 0.0226]],
 
          [[ 0.0009]],
 
          [[-0.0377]]],
 
 
         [[[ 0.0094]],
 
          [[-0.0060]],
 
          [[ 0.0355]],
 
          ...,
 
          [[-0.0034]],
 
          [[ 0.0231]],
 
          [[-0.0250]]]]),
 tensor([ 0.0307, -0.0359, -0.0383])]

In [9]:
%%time
encrypted_lll = []
ten_shapes = []
for param in last_layer_list:
    plain_ten = ts.plain_tensor(param)
    encrypted_ten = ts.ckks_tensor(shared_context, plain_ten)
    encrypted_lll.append(encrypted_ten)

CPU times: user 12.5 s, sys: 1.34 s, total: 13.8 s
Wall time: 1.52 s


In [10]:
decrypted_lll = []
for param in encrypted_lll:
    decrypted_lll.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_lll)

[tensor([[[[-0.0351]],

         [[ 0.0412]],

         [[ 0.0206]],

         ...,

         [[-0.0113]],

         [[ 0.0220]],

         [[ 0.0247]]],


        [[[-0.0062]],

         [[ 0.0232]],

         [[ 0.0230]],

         ...,

         [[ 0.0226]],

         [[ 0.0009]],

         [[-0.0377]]],


        [[[ 0.0094]],

         [[-0.0060]],

         [[ 0.0355]],

         ...,

         [[-0.0034]],

         [[ 0.0231]],

         [[-0.0250]]]]), tensor([ 0.0307, -0.0359, -0.0383])]


In [11]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = msg.serialize()
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)
    
def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    
    msglen = struct.unpack('>I', raw_msglen)[0]

    # read the message data
    msg =  recvall(sock, msglen)
    msg = ts.ckks_tensor_from(shared_context, msg)
    
    return msg

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
  
    return data

In [12]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.0.116


In [13]:
client_soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_soc.connect((host, port))

In [14]:
for tens in encrypted_lll:
    send_msg(client_soc, tens)
avg_weights = []
for i in range(2):
    weight = recv_msg(client_soc)
    avg_weights.append(weight)
client_soc.close()

NameError: name 'public_context' is not defined

In [None]:
decrypted_avg_weights = []
for param in avg_weights:
    decrypted_avg_weights.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_avg_weights)