## Loading the shared context

In [1]:
# import bz2
# import lzma
# import zlib
# import sys

# # Sample data (replace with your actual data)
# data = b"This is some data that we want to compress significantly."

# # Original data size
# original_size = sys.getsizeof(data)
# original_size = len(data)


# # Compress using bz2
# compressed_bz2 = bz2.compress(data)

# # Compress using lzma (LZMA algorithm)
# compressed_lzma = lzma.compress(data)

# compressed_zlib = zlib.compress(data)

# # Compressed data size
# # compressed_size_bz2 = sys.getsizeof(compressed_bz2)
# # compressed_size_lzma = sys.getsizeof(compressed_lzma)
# # compressed_size_zlib = sys.getsizeof(compressed_zlib)
# compressed_size_bz2 = len(compressed_bz2)
# compressed_size_lzma = len(compressed_lzma)
# compressed_size_zlib = len(compressed_zlib)


# # Calculate compression ratios
# compression_ratio_bz2 = (original_size - compressed_size_bz2) / original_size * 100
# compression_ratio_lzma = (original_size - compressed_size_lzma) / original_size * 100
# compression_ratio_zlib = (original_size - compressed_size_zlib) / original_size * 100


# print("Original data size:", original_size, "bytes")
# print("Compressed size (bz2):", compressed_size_bz2, "bytes")
# print("Compression ratio (bz2):", compression_ratio_bz2, "%")
# print("Compressed size (lzma):", compressed_size_lzma, "bytes")
# print("Compression ratio (lzma):", compression_ratio_lzma, "%")
# print("Compressed size (zlib):", compressed_size_zlib, "bytes")
# print("Compression ratio (zlib):", compression_ratio_zlib, "%")

In [2]:
import pickle
import tenseal as ts
import torch
from torch import nn
import socket
import struct
from torchvision import models
import numpy as np
import zlib
import gzip
import sys

In [3]:
with open('shared_context.pkl', 'rb') as inp:
    shared_context_bin = pickle.load(inp)

In [4]:
shared_context = ts.context_from(shared_context_bin)

In [5]:
sk = shared_context.secret_key()

In [6]:
print(sk)

<tenseal.enc_context.SecretKey object at 0x7a19b3a630d0>


In [7]:
model = models.squeezenet1_1(weights=True)



In [8]:
num_classes = 2
in_ftrs = model.classifier[1].in_channels
features = list(model.classifier.children())[:-3] # Remove last 3 layers
features.extend([nn.Conv2d(in_ftrs, num_classes, kernel_size=1)]) # Add
features.extend([nn.ReLU(inplace=True)]) # Add
features.extend([nn.AdaptiveAvgPool2d(output_size=(1,1))]) # Add
model.classifier = nn.Sequential(*features)

In [9]:
last_layer_list = [model.state_dict()['classifier.1.weight'], model.state_dict()['classifier.1.bias']]
last_layer_list

[tensor([[[[-0.0134]],
 
          [[ 0.0313]],
 
          [[-0.0142]],
 
          ...,
 
          [[-0.0437]],
 
          [[ 0.0002]],
 
          [[ 0.0244]]],
 
 
         [[[ 0.0061]],
 
          [[-0.0251]],
 
          [[ 0.0036]],
 
          ...,
 
          [[-0.0420]],
 
          [[-0.0269]],
 
          [[ 0.0079]]]]),
 tensor([0.0145, 0.0152])]

In [10]:
for tens in last_layer_list:
    print(tens.dtype)

torch.float32
torch.float32


In [11]:
%%time
encrypted_lll = []

for param in last_layer_list:
    plain_ten = ts.plain_tensor(param)
    encrypted_ten = ts.ckks_tensor(shared_context, plain_ten)
    encrypted_lll.append(encrypted_ten)

CPU times: user 5.4 s, sys: 405 ms, total: 5.81 s
Wall time: 562 ms


In [12]:
def print_tens_size(tens_list):
    total_size = 0
    for param in tens_list:
        ten_size = sys.getsizeof(param)
        # ten_size = len(param) # TypeError: object of type 'CKKSTensor' has no len()
        print(f'size of tensor: {ten_size}')
        total_size += ten_size
    
    print(f'total size of list: {total_size}')

In [13]:
type(encrypted_lll[0])

tenseal.tensors.ckkstensor.CKKSTensor

In [14]:
print_tens_size(encrypted_lll)

size of tensor: 48
size of tensor: 48
total size of list: 96


In [15]:
# serialized_list = []
# for param in encrypted_lll:
#     serialized_list.append(param.serialize())

In [16]:
# type(serialized_list[0])

In [17]:
# print_tens_size(serialized_list) # using len()

In [18]:
# print_tens_size(serialized_list) # using sys.getsizeof()

In [19]:
# %%time
# compressed_list = []
# for param in serialized_list:
#     compressed_list.append(zlib.compress(param))
    

In [20]:
# print_tens_size(compressed_list)

In [21]:
decrypted_lll = []
for param in encrypted_lll:
    decrypted_lll.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_lll)

[tensor([[[[-0.0136]],

         [[ 0.0306]],

         [[-0.0152]],

         ...,

         [[-0.0438]],

         [[-0.0008]],

         [[ 0.0243]]],


        [[[ 0.0059]],

         [[-0.0252]],

         [[ 0.0039]],

         ...,

         [[-0.0427]],

         [[-0.0266]],

         [[ 0.0073]]]]), tensor([0.0135, 0.0147])]


In [22]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = msg.serialize()
    msg = zlib.compress(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)
    
def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    
    msglen = struct.unpack('>I', raw_msglen)[0]

    # read the message data
    msg =  recvall(sock, msglen)
    msg = zlib.decompress(msg)
    msg = ts.ckks_tensor_from(shared_context, msg)
    
    return msg

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
  
    return data

In [23]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.0.245


In [24]:
client_soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_soc.connect((host, port))

In [25]:
for tens in encrypted_lll:
    send_msg(client_soc, tens)
avg_weights = []
for i in range(2):
    weight = recv_msg(client_soc)
    avg_weights.append(weight)
client_soc.close()

In [26]:
decrypted_avg_weights = []
for param in avg_weights:
    decrypted_avg_weights.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_avg_weights)

[tensor([[[[-0.0168]],

         [[ 0.0446]],

         [[-0.0057]],

         ...,

         [[-0.0585]],

         [[-0.0152]],

         [[ 0.0124]]],


        [[[ 0.0136]],

         [[ 0.0088]],

         [[ 0.0309]],

         ...,

         [[-0.0260]],

         [[-0.0102]],

         [[ 0.0132]]]]), tensor([ 0.0213, -0.0144])]
