## Loading the shared context

In [1]:
import bz2
import lzma
import zlib
import sys

# Sample data (replace with your actual data)
data = b"This is some data that we want to compress significantly."

# Original data size
original_size = sys.getsizeof(data)
original_size = len(data)


# Compress using bz2
compressed_bz2 = bz2.compress(data)

# Compress using lzma (LZMA algorithm)
compressed_lzma = lzma.compress(data)

compressed_zlib = zlib.compress(data)

# Compressed data size
# compressed_size_bz2 = sys.getsizeof(compressed_bz2)
# compressed_size_lzma = sys.getsizeof(compressed_lzma)
# compressed_size_zlib = sys.getsizeof(compressed_zlib)
compressed_size_bz2 = len(compressed_bz2)
compressed_size_lzma = len(compressed_lzma)
compressed_size_zlib = len(compressed_zlib)


# Calculate compression ratios
compression_ratio_bz2 = (original_size - compressed_size_bz2) / original_size * 100
compression_ratio_lzma = (original_size - compressed_size_lzma) / original_size * 100
compression_ratio_zlib = (original_size - compressed_size_zlib) / original_size * 100


print("Original data size:", original_size, "bytes")
print("Compressed size (bz2):", compressed_size_bz2, "bytes")
print("Compression ratio (bz2):", compression_ratio_bz2, "%")
print("Compressed size (lzma):", compressed_size_lzma, "bytes")
print("Compression ratio (lzma):", compression_ratio_lzma, "%")
print("Compressed size (zlib):", compressed_size_zlib, "bytes")
print("Compression ratio (zlib):", compression_ratio_zlib, "%")

Original data size: 57 bytes
Compressed size (bz2): 86 bytes
Compression ratio (bz2): -50.877192982456144 %
Compressed size (lzma): 116 bytes
Compression ratio (lzma): -103.50877192982458 %
Compressed size (zlib): 61 bytes
Compression ratio (zlib): -7.017543859649122 %


In [1]:
import pickle
import tenseal as ts
import torch
from torch import nn
import socket
import struct
from torchvision import models
import numpy as np
import zlib
import gzip
import sys

In [2]:
with open('shared_context.pkl', 'rb') as inp:
    shared_context_bin = pickle.load(inp)

In [3]:
shared_context = ts.context_from(shared_context_bin)

In [4]:
sk = shared_context.secret_key()

In [5]:
print(sk)

<tenseal.enc_context.SecretKey object at 0x778010250340>


In [6]:
model = models.squeezenet1_1(weights=True)



In [7]:
num_classes = 3
in_ftrs = model.classifier[1].in_channels
features = list(model.classifier.children())[:-3] # Remove last 3 layers
features.extend([nn.Conv2d(in_ftrs, num_classes, kernel_size=1)]) # Add
features.extend([nn.ReLU(inplace=True)]) # Add
features.extend([nn.AdaptiveAvgPool2d(output_size=(1,1))]) # Add
model.classifier = nn.Sequential(*features)

In [8]:
last_layer_list = [model.state_dict()['classifier.1.weight'], model.state_dict()['classifier.1.bias']]
last_layer_list

[tensor([[[[-0.0006]],
 
          [[ 0.0043]],
 
          [[ 0.0222]],
 
          ...,
 
          [[-0.0114]],
 
          [[ 0.0281]],
 
          [[-0.0097]]],
 
 
         [[[-0.0382]],
 
          [[-0.0212]],
 
          [[-0.0065]],
 
          ...,
 
          [[ 0.0197]],
 
          [[-0.0307]],
 
          [[ 0.0123]]],
 
 
         [[[-0.0120]],
 
          [[ 0.0236]],
 
          [[-0.0177]],
 
          ...,
 
          [[ 0.0085]],
 
          [[ 0.0441]],
 
          [[-0.0287]]]]),
 tensor([-0.0339,  0.0431,  0.0423])]

In [9]:
for tens in last_layer_list:
    print(tens.dtype)

torch.float32
torch.float32


In [10]:
%%time
encrypted_lll = []

for param in last_layer_list:
    plain_ten = ts.plain_tensor(param)
    encrypted_ten = ts.ckks_tensor(shared_context, plain_ten)
    encrypted_lll.append(encrypted_ten)

CPU times: user 14.4 s, sys: 1.39 s, total: 15.8 s
Wall time: 1.73 s


In [16]:
def print_tens_size(tens_list):
    total_size = 0
    for param in tens_list:
        ten_size = sys.getsizeof(param)
        # ten_size = len(param) # TypeError: object of type 'CKKSTensor' has no len()
        print(f'size of tensor: {ten_size}')
        total_size += ten_size
    
    print(f'total size of list: {total_size}')

In [12]:
type(encrypted_lll[0])

tenseal.tensors.ckkstensor.CKKSTensor

In [19]:
print_tens_size(encrypted_lll)

size of tensor: 48
size of tensor: 48
total size of list: 96


In [13]:
serialized_list = []
for param in encrypted_lll:
    serialized_list.append(param.serialize())

In [14]:
type(serialized_list[0])

bytes

In [15]:
print_tens_size(serialized_list) # using len()

size of tensor: 424486569
size of tensor: 829142
total size of list: 425315711


In [17]:
print_tens_size(serialized_list) # using sys.getsizeof()

size of tensor: 424486602
size of tensor: 829175
total size of list: 425315777


In [31]:
%%time
compressed_list = []
for param in serialized_list:
    compressed_list.append(zlib.compress(param))
    

CPU times: user 21 s, sys: 460 ms, total: 21.4 s
Wall time: 21.4 s


In [33]:
print_tens_size(compressed_list)

size of tensor: 424538367
size of tensor: 829153
total size of list: 425367520


In [12]:
decrypted_lll = []
for param in encrypted_lll:
    decrypted_lll.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_lll)

[tensor([[[[ 0.0058]],

         [[ 0.0290]],

         [[-0.0363]],

         ...,

         [[-0.0277]],

         [[-0.0359]],

         [[-0.0076]]],


        [[[-0.0031]],

         [[ 0.0136]],

         [[ 0.0031]],

         ...,

         [[ 0.0419]],

         [[ 0.0116]],

         [[-0.0305]]],


        [[[-0.0187]],

         [[ 0.0256]],

         [[-0.0127]],

         ...,

         [[ 0.0164]],

         [[-0.0255]],

         [[ 0.0386]]]]), tensor([-0.0091,  0.0322,  0.0281])]


In [13]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = msg.serialize()
    msg = zlib.compress(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)
    
def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    
    msglen = struct.unpack('>I', raw_msglen)[0]

    # read the message data
    msg =  recvall(sock, msglen)
    msg = zlib.decompress(msg)
    msg = ts.ckks_tensor_from(shared_context, msg)
    
    return msg

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
  
    return data

In [14]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.0.245


In [15]:
client_soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client_soc.connect((host, port))

In [16]:
for tens in encrypted_lll:
    send_msg(client_soc, tens)
avg_weights = []
for i in range(2):
    weight = recv_msg(client_soc)
    avg_weights.append(weight)
client_soc.close()

In [17]:
decrypted_avg_weights = []
for param in avg_weights:
    decrypted_avg_weights.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_avg_weights)

[tensor([[[[-0.0037]],

         [[ 0.0061]],

         [[-0.0246]],

         ...,

         [[-0.0045]],

         [[-0.0204]],

         [[ 0.0091]]],


        [[[-0.0063]],

         [[ 0.0285]],

         [[ 0.0082]],

         ...,

         [[ 0.0300]],

         [[ 0.0202]],

         [[-0.0135]]],


        [[[ 0.0027]],

         [[ 0.0333]],

         [[ 0.0139]],

         ...,

         [[ 0.0109]],

         [[-0.0266]],

         [[ 0.0146]]]]), tensor([0.0111, 0.0140, 0.0011])]
