In [1]:
!pip install tenseal

Collecting tenseal
  Downloading tenseal-0.3.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.2 kB)
Downloading tenseal-0.3.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tenseal
Successfully installed tenseal-0.3.14


In [2]:
import torch.nn as nn
import numpy as np
import time
import torch
import copy
import torch.nn.functional as F
from collections import OrderedDict
import matplotlib.pyplot as plt
import tenseal as ts

In [3]:
import tenseal as ts

def context():
    context = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
    context.global_scale = pow(2, 52)
    context.generate_galois_keys()
    return context

context = context()

# Serialize context to a file
with open("context.tenseal", "wb") as f:
    f.write(context.serialize())

# Make context public and serialize the public context
context.make_context_public()
with open("public_context.tenseal", "wb") as f:
    f.write(context.serialize())

# Deserialize context from a file
with open("context.tenseal", "rb") as f:
    context = ts.context_from(f.read())

# Deserialize public context from a file
with open("public_context.tenseal", "rb") as f:
    public_context = ts.context_from(f.read())

# Optionally, you can make the public context again
public_context.make_context_public()

In [4]:
# Encrypt data with the public context
def encrypt_data(public_context, data):
    enc_data = ts.ckks_vector(public_context, data)
    return enc_data

data = [1.0, 2.0, 3.0, 4.0]
enc_data = encrypt_data(public_context, data)

In [5]:
import tenseal as ts
import numpy as np
import pickle

# Function to create context with secret key
def create_context():
    context = ts.context(ts.SCHEME_TYPE.CKKS, 8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
    context.global_scale = pow(2, 52)
    context.generate_galois_keys()
    return context

# Create and save context
context = create_context()

# Serialize and save context with secret key using pickle
with open("context.pkl", "wb") as f:
    pickle.dump(context.serialize(save_secret_key=True), f)

# Make the context public and save it using pickle
context.make_context_public()
with open("public_context.pkl", "wb") as f:
    pickle.dump(context.serialize(), f)

# Function to encrypt data
def encrypt_data(context, data):
    return ts.ckks_vector(context, data)

# Function to decrypt data
def decrypt_data(context, enc_data):
    enc_data.link_context(context)
    dec_data = enc_data.decrypt()
    return dec_data

# Encrypt some data
data = [1.0, 2.0, 3.0, 4.0]
enc_data = encrypt_data(context, data)

# Save the encrypted data using pickle
with open("encrypted_data.pkl", "wb") as f:
    pickle.dump(enc_data.serialize(), f)

# Load context with secret key using pickle
with open("context.pkl", "rb") as f:
    context = ts.context_from(pickle.load(f))

# Load encrypted data using pickle
with open("encrypted_data.pkl", "rb") as f:
    enc_data = ts.ckks_vector_from(context, pickle.load(f))

# Decrypt the data
dec_data = decrypt_data(context, enc_data)
print("Decrypted data:", dec_data)


Decrypted data: [0.999999999999791, 2.0000000000001656, 3.0000000000006466, 3.999999999999131]


In [6]:
from typing import Dict
class PartialModelHandler:
    def __init__(self, num_segments: int):
        self.num_segments = num_segments

    def flatten_resnet_parameters(self, state_dict):
        """
        Flatten and concatenate all parameters from a state_dict into a single vector.
        """
        flat_params = torch.cat([p.view(-1).float() for p in state_dict.values()])
        return flat_params

    def reconstruct_parameters(self, flat_params, shapes, sizes, trained_weights):
        """
        Reconstruct the original tensors from a flattened parameter tensor.
        """
        reconstructed_params = OrderedDict()
        offset = 0
        for key in shapes:
            num_elements = sizes[key]
            param_slice = flat_params[offset:offset + num_elements]
            reconstructed_params[key] = param_slice.view(shapes[key]).to(trained_weights[key].dtype)
            offset += num_elements
        return reconstructed_params

    def segment_resnet_parameters(self, flat_params):
        """
        Divide the flat parameters into equal segments, ensuring all elements are included.
        """
        total_len = len(flat_params)
        segment_size = total_len // self.num_segments
        remainder = total_len % self.num_segments

        segments = []
        start = 0
        for i in range(self.num_segments):
            end = start + segment_size + (1 if i < remainder else 0)
            segments.append(flat_params[start:end])
            start = end

        return segments

    def preprocess_weights(self, weights: Dict):
        """
        Preprocess the weights of a model to prepare for federated learning.
        """
        flat_params_n = {key: self.flatten_resnet_parameters(value) for key, value in weights.items()}
        segmented_params = {key: self.segment_resnet_parameters(flat_param) for key, flat_param in flat_params_n.items()}
        return flat_params_n, segmented_params

    def extract_shared_segments(self, clients_dict, client_shared_segments):
        """
        Extract segments from each client's data based on the segments they are supposed to share.
        """
        shared_data = {}
        for client_id, segment_id in client_shared_segments.items():
            shared_data[client_id] = {segment_id: clients_dict[client_id][segment_id]}
        return shared_data

    def aggregate_data_by_key(self, shared_data):
        """
        Aggregate data received from clients by key.
        """
        aggregation = {}
        count = {}

        for client_data in shared_data.values():
            for key, values in client_data.items():
                if key in aggregation:
                    aggregation[key] += values
                    count[key] += 1
                else:
                    aggregation[key] = values.clone()
                    count[key] = 1

        for key in aggregation.keys():
            aggregation[key] /= count[key]

        return aggregation

    def handle_partial_updates(self, client_weights, client_segment_map):
        """
        Handle partial updates from clients by averaging the weights.
        """
        segmented_params = self.segment_resnet_parameters(self.flatten_resnet_parameters(client_weights))
        segments_to_send = self.extract_shared_segments(segmented_params, client_segment_map)
        return segments_to_send

    def update_client_models(self, clients_segmented_params, aggregated_segments, client_segment_map):
        """
        Update client models with aggregated segments and average unshared segments.
        """
        updated_params = {}

        for client_id, segments in clients_segmented_params.items():
            updated_segments = []
            shared_segment_index = client_segment_map[client_id]

            for i, segment_data in enumerate(segments):
                if i == shared_segment_index:
                    updated_segments.append(aggregated_segments[shared_segment_index])
                else:
                    updated_segments.append(segment_data)

            updated_params[client_id] = updated_segments

        updated_client_segments = {client_id: torch.cat(segments) for client_id, segments in updated_params.items()}
        return updated_client_segments

    def rotate_shared_segments(self, client_segment_map):
        """
        Rotate the segments that each client shares for the next round.
        """
        for client in client_segment_map:
            client_segment_map[client] = (client_segment_map[client] + 1) % self.num_segments


In [7]:
#ResNet9
def conv_block(in_channels, out_channels, pool=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

class ResNet9(nn.Module):
    def __init__(
            self,
            input_dim,
            hidden_dims,
            num_classes):
        super(ResNet9, self).__init__()

        self.conv1 = conv_block(input_dim, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))

        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))

        self.classifier = nn.Sequential(nn.AdaptiveMaxPool2d((1,1)),
                                        nn.Flatten(),
                                        nn.Dropout(0.2),
                                        nn.Linear(512, num_classes))

    def forward(self, xb):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        return out


In [8]:
import torch

def segment_resnet_parameters(flat_params, num_segments):
    """
    Divide the flat parameters into equal segments, ensuring all elements are included.
    """
    total_len = len(flat_params)
    segment_size = total_len // num_segments
    remainder = total_len % num_segments

    segments = []
    start = 0
    for i in range(num_segments):
        end = start + segment_size + (1 if i < remainder else 0)
        segments.append(flat_params[start:end])
        start = end

    return segments

In [9]:
net = ResNet9(3, [], 10)

In [10]:
from prettytable import PrettyTable
def count_parameters(model):
    table = PrettyTable(["SI No", "Layer Name", "Parameters Listed"])
    t_params = 0
    for si_no, (name, parameter) in enumerate(model.named_parameters()):
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([si_no, name, param])
        t_params+=param
    print(table)
    print(f"Sum of trained paramters: {t_params}")
    return t_params

In [11]:
count_parameters(net)

+-------+---------------------+-------------------+
| SI No |      Layer Name     | Parameters Listed |
+-------+---------------------+-------------------+
|   0   |    conv1.0.weight   |        1728       |
|   1   |     conv1.0.bias    |         64        |
|   2   |    conv1.1.weight   |         64        |
|   3   |     conv1.1.bias    |         64        |
|   4   |    conv2.0.weight   |       73728       |
|   5   |     conv2.0.bias    |        128        |
|   6   |    conv2.1.weight   |        128        |
|   7   |     conv2.1.bias    |        128        |
|   8   |   res1.0.0.weight   |       147456      |
|   9   |    res1.0.0.bias    |        128        |
|   10  |   res1.0.1.weight   |        128        |
|   11  |    res1.0.1.bias    |        128        |
|   12  |   res1.1.0.weight   |       147456      |
|   13  |    res1.1.0.bias    |        128        |
|   14  |   res1.1.1.weight   |        128        |
|   15  |    res1.1.1.bias    |        128        |
|   16  |   

6575370

In [12]:
# Calculate the total number of parameters
total_params = sum(p.numel() for p in net.parameters())
total_params

6575370

In [13]:
# Calculate the total number of parameters, including BatchNorm running mean and variance
state_dict = net.state_dict()
total_params = sum(p.numel() for p in state_dict.values())
print(f"Total number of parameters in ResNet9 (including BatchNorm running mean and variance): {total_params}")

Total number of parameters in ResNet9 (including BatchNorm running mean and variance): 6579858


In [14]:
params = net.state_dict()

In [15]:
def calculate_computation(num_partition):
  handler = PartialModelHandler(num_partition)
  num_parameters = len(handler.segment_resnet_parameters(handler.flatten_resnet_parameters(params))[0])


  time_init_s = time.time()
  enc_data = encrypt_data(context, handler.segment_resnet_parameters(handler.flatten_resnet_parameters(params))[0])
  dec_data = decrypt_data(context, enc_data)
  # Save the encrypted data using pickle
  with open(f'resnet9_seg{num_partition}_encrypted_data.pkl', "wb") as f:
      pickle.dump(enc_data.serialize(), f)
  time_init_e = time.time()
  computation_time = time_init_e-time_init_s

  return num_parameters, computation_time

In [16]:
def calculate_communication_bytes(file_path):

  # Check the size of the file
  file_size = os.path.getsize(file_path)

  return file_size/(1024*1024)

In [18]:
import os

partition_sizes = [1,2,3,4,5]
for num in partition_sizes:
  num_params, comp_time = calculate_computation(num)
  print(f'For ParMS {num}, \n \t The total number of parameters is: {num_params}, \n\t The computation time is: {comp_time} seconds')
  cipher_text_size = calculate_communication_bytes(f'resnet9_seg{num}_encrypted_data.pkl')
  print(f'\t The cipher text size is: {cipher_text_size}')
  print(f'\t The plain text size is {25.08/num}')

The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
For ParMS 1, 
 	 The total number of parameters is: 6579858, 
	 The computation time is: 25.21634602546692 seconds
	 The cipher text size is: 512.2990703582764
	 The plain text size is 25.08
The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
For ParMS 2, 
 	 The total number of parameters is: 3289929, 
	 The computation time is: 12.560456037521362 seconds
	 The cipher text size is: 256.30858612060547
	 The plain text size is 12.54
The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your in