## Loading the shared context

In [1]:
import pickle
import tenseal as ts
import torch
from torch import nn
import socket
import struct
from torchvision import models
import numpy as np

In [2]:
with open('shared_context.pkl', 'rb') as inp:
    shared_context_bin = pickle.load(inp)

In [3]:
shared_context = ts.context_from(shared_context_bin)

In [4]:
sk = shared_context.secret_key()

In [5]:
print(sk)

<tenseal.enc_context.SecretKey object at 0x730de8f84820>


In [6]:
model = models.squeezenet1_1(weights=True)



In [7]:
num_classes = 3
in_ftrs = model.classifier[1].in_channels
features = list(model.classifier.children())[:-3] # Remove last 3 layers
features.extend([nn.Conv2d(in_ftrs, num_classes, kernel_size=1)]) # Add
features.extend([nn.ReLU(inplace=True)]) # Add
features.extend([nn.AdaptiveAvgPool2d(output_size=(1,1))]) # Add
model.classifier = nn.Sequential(*features)

In [8]:
last_layer_list = [model.state_dict()['classifier.1.weight'], model.state_dict()['classifier.1.bias']]
last_layer_list

[tensor([[[[ 0.0228]],
 
          [[ 0.0031]],
 
          [[-0.0245]],
 
          ...,
 
          [[-0.0427]],
 
          [[-0.0325]],
 
          [[ 0.0223]]],
 
 
         [[[ 0.0228]],
 
          [[-0.0250]],
 
          [[-0.0148]],
 
          ...,
 
          [[-0.0296]],
 
          [[-0.0300]],
 
          [[ 0.0411]]],
 
 
         [[[ 0.0044]],
 
          [[-0.0181]],
 
          [[ 0.0328]],
 
          ...,
 
          [[-0.0306]],
 
          [[ 0.0186]],
 
          [[-0.0204]]]]),
 tensor([ 0.0204, -0.0089, -0.0177])]

In [9]:
for tens in last_layer_list:
    print(tens.dtype)

torch.float32
torch.float32


In [10]:
%%time
encrypted_lll = []
ten_shapes = []
for param in last_layer_list:
    plain_ten = ts.plain_tensor(param)
    encrypted_ten = ts.ckks_tensor(shared_context, plain_ten)
    encrypted_lll.append(encrypted_ten)

CPU times: user 8 s, sys: 644 ms, total: 8.65 s
Wall time: 842 ms


In [11]:
decrypted_lll = []
for param in encrypted_lll:
    decrypted_lll.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_lll)

[tensor([[[[ 0.0227]],

         [[ 0.0032]],

         [[-0.0246]],

         ...,

         [[-0.0426]],

         [[-0.0325]],

         [[ 0.0224]]],


        [[[ 0.0228]],

         [[-0.0249]],

         [[-0.0150]],

         ...,

         [[-0.0297]],

         [[-0.0299]],

         [[ 0.0411]]],


        [[[ 0.0045]],

         [[-0.0181]],

         [[ 0.0328]],

         ...,

         [[-0.0307]],

         [[ 0.0185]],

         [[-0.0203]]]]), tensor([ 0.0205, -0.0091, -0.0177])]


In [12]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = msg.serialize()
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)
    
def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    
    msglen = struct.unpack('>I', raw_msglen)[0]

    # read the message data
    msg =  recvall(sock, msglen)
    msg = ts.ckks_tensor_from(shared_context, msg)
    
    return msg

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
  
    return data

In [13]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.0.116


In [14]:
client_soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_soc.connect((host, port))

In [15]:
for tens in encrypted_lll:
    send_msg(client_soc, tens)
avg_weights = []
for i in range(2):
    weight = recv_msg(client_soc)
    avg_weights.append(weight)
client_soc.close()

In [16]:
decrypted_avg_weights = []
for param in avg_weights:
    decrypted_avg_weights.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_avg_weights)

[tensor([[[[ 0.0160]],

         [[ 0.0025]],

         [[-0.0252]],

         ...,

         [[-0.0331]],

         [[-0.0214]],

         [[ 0.0309]]],


        [[[ 0.0309]],

         [[-0.0063]],

         [[ 0.0139]],

         ...,

         [[-0.0302]],

         [[ 0.0053]],

         [[ 0.0249]]],


        [[[-0.0136]],

         [[-0.0204]],

         [[ 0.0294]],

         ...,

         [[ 0.0040]],

         [[ 0.0173]],

         [[-0.0326]]]]), tensor([ 0.0036,  0.0034, -0.0304])]
