Experiment of Encryption for Brain Tumor using CKKS and BFV algorithms

In [3]:
# pip install tenseal

In [4]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import tenseal as ts
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import time

In [6]:
import os
DATA_DIR = r"/home/amint/HEP/Brain Tumor MRI"

# If you're running inside WSL, use the Linux mount path instead:
# DATA_DIR = "/mnt/d/Ascl_Mimic_Data/Brain Tumor MRI"

train_dir = os.path.join(DATA_DIR, "Training")
test_dir  = os.path.join(DATA_DIR, "Testing")


In [7]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

# Split training data into 3 clients
total_size = len(train_dataset)
client_size = total_size // 3
client_datasets = random_split(train_dataset, [client_size, client_size, total_size - 2*client_size])

print(f"Total samples: {total_size}")
print(f"Client 1: {len(client_datasets[0])}, Client 2: {len(client_datasets[1])}, Client 3: {len(client_datasets[2])}")


Total samples: 5712
Client 1: 1904, Client 2: 1904, Client 3: 1904


Define CNN model before starting the training

In [8]:
class BrainTumorCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(BrainTumorCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*16*16, 128)
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (32,32,32)
        x = self.pool(F.relu(self.conv2(x)))  # (64,16,16)
        x = x.view(-1, 64*16*16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


Create Homomorphic Encryption Contexts (CKKS & BFV)

In [9]:
# CKKS (supports floats, good for NN weights)
ckks_context = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=8192, coeff_mod_bit_sizes=[60, 40, 40, 60])
ckks_context.global_scale = 2**40
ckks_context.generate_galois_keys()

# BFV (supports integers)
bfv_context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=8192, plain_modulus=1032193)
bfv_context.generate_galois_keys()
bfv_context.generate_relin_keys()


##Federated Client Definition

###Clients train locally, encrypt model updates, and send to the server.

In [10]:
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_data, test_data, encryption_scheme="ckks"):
        self.model = model
        self.train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
        self.test_loader = DataLoader(test_data, batch_size=32)
        self.encryption_scheme = encryption_scheme
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def get_parameters(self, config):
        params = [val.cpu().numpy() for val in self.model.state_dict().values()]
        # Encrypt parameters
        if self.encryption_scheme == "ckks":
            return [ts.ckks_vector(ckks_context, p.flatten().tolist()).serialize() for p in params]
        else:  # BFV
            return [ts.bfv_vector(bfv_context, p.flatten().astype(int).tolist()).serialize() for p in params]

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        new_state_dict = {}
        for (k, v), param in zip(state_dict.items(), parameters):
            arr = np.array(param).reshape(v.shape)
            new_state_dict[k] = torch.tensor(arr, dtype=v.dtype)
        self.model.load_state_dict(new_state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        for epoch in range(1):  # 1 epoch per round
            for data, target in self.train_loader:
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
        return self.get_parameters(config={}), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        loss, correct = 0, 0
        with torch.no_grad():
            for data, target in self.test_loader:
                output = self.model(data)
                loss += self.criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
        accuracy = correct / len(self.test_loader.dataset)
        return float(loss), len(self.test_loader.dataset), {"accuracy": accuracy}


Server Strategy with Decryption and Aggregation

In [11]:
class EncryptedFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, encryption_scheme="ckks"):
        super().__init__()
        self.encryption_scheme = encryption_scheme

    def aggregate_fit(self, rnd, results, failures):
        aggregated = []
        for serialized_params, num_examples, _ in results:
            if self.encryption_scheme == "ckks":
                decrypted_params = [ts.ckks_vector_from(ckks_context, p).decrypt() for p in serialized_params]
            else:
                decrypted_params = [ts.bfv_vector_from(bfv_context, p).decrypt() for p in serialized_params]
            aggregated.append((decrypted_params, num_examples))
        # Standard FedAvg aggregation
        return super().aggregate_fit(rnd, results, failures)


Run Federated Learning Simulation

In [13]:
# # Partition test dataset equally among clients
# test_size = len(test_dataset) // 3
# test_clients = random_split(test_dataset, [test_size, test_size, len(test_dataset) - 2*test_size])

# def client_fn(cid: str):
#     model = BrainTumorCNN()
#     return FLClient(model, client_datasets[int(cid)], test_clients[int(cid)], encryption_scheme="ckks")

# # Run Flower simulation
# strategy = EncryptedFedAvg(encryption_scheme="ckks")

# fl.simulation.start_simulation(
#     client_fn=client_fn,
#     num_clients=3,
#     config=fl.server.ServerConfig(num_rounds=5),
#     strategy=strategy,
# )


Simulation without Flower

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import tenseal as ts
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import time
import os
DATA_DIR = r"/home/amint/HEP/Brain Tumor MRI"

In [12]:
train_dir = os.path.join(DATA_DIR, "Training")
test_dir  = os.path.join(DATA_DIR, "Testing")

In [13]:
# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset (adjust path to your dataset)
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

# Split into 3 clients
total_size = len(train_dataset)
client_size = total_size // 3
client_datasets = random_split(train_dataset, [client_size, client_size, total_size - 2*client_size])

print(f"Total samples: {total_size}")
print(f"Client 1: {len(client_datasets[0])}, Client 2: {len(client_datasets[1])}, Client 3: {len(client_datasets[2])}")


Total samples: 5712
Client 1: 1904, Client 2: 1904, Client 3: 1904


Define the CNN model

In [14]:
class BrainTumorCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(BrainTumorCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*16*16, 128)
        self.fc2 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (32,32,32)
        x = self.pool(F.relu(self.conv2(x)))  # (64,16,16)
        x = x.view(-1, 64*16*16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


Create Homomorphic Encryption context

In [15]:
# CKKS context (for floats, good for CNN weights)
ckks_context = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=8192,
                          coeff_mod_bit_sizes=[60, 40, 40, 60])
ckks_context.global_scale = 2**40
ckks_context.generate_galois_keys()

# BFV context (for integers)
bfv_context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=8192, plain_modulus=1032193)
bfv_context.generate_galois_keys()


Local Traininig Function

In [None]:
def train_local(model, dataset, epochs=10, lr=0.001):
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    for _ in range(epochs):
        for data, target in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
    return model


Encrypt Parameters

In [17]:
def encrypt_parameters(params, scheme="ckks"):
    encrypted = []
    for p in params:
        flat = p.flatten()
        if scheme == "ckks":
            encrypted.append(ts.ckks_vector(ckks_context, flat).serialize())
        else:
            encrypted.append(ts.bfv_vector(bfv_context, flat.astype(int)).serialize())
    return encrypted

def decrypt_parameters(encrypted, shapes, scheme="ckks"):
    decrypted = []
    for enc, shape in zip(encrypted, shapes):
        if scheme == "ckks":
            dec = ts.ckks_vector_from(ckks_context, enc).decrypt()
        else:
            dec = ts.bfv_vector_from(bfv_context, enc).decrypt()
        decrypted.append(np.array(dec).reshape(shape))
    return decrypted


Federated Averaging

In [18]:
def federated_averaging(encrypted_updates, shapes, scheme="ckks"):
    num_clients = len(encrypted_updates)
    
    # Decrypt all client updates
    all_updates = []
    for update in encrypted_updates:
        dec_params = decrypt_parameters(update, shapes, scheme)
        all_updates.append(dec_params)
    
    # Average parameters
    avg_params = []
    for i in range(len(all_updates[0])):
        stacked = np.stack([client[i] for client in all_updates])
        avg_params.append(np.mean(stacked, axis=0))
    
    return avg_params


Run Mannual Federated Learning 

In [20]:
start_time = time.time()
# Initialize global model
global_model = BrainTumorCNN()
shapes = [p.shape for p in global_model.state_dict().values()]

# Convert model params to numpy
def get_model_params(model):
    return [val.detach().cpu().numpy() for val in model.state_dict().values()]

def set_model_params(model, params):
    state_dict = model.state_dict()
    new_state_dict = {}
    for (k, v), p in zip(state_dict.items(), params):
        new_state_dict[k] = torch.tensor(p, dtype=v.dtype)
    model.load_state_dict(new_state_dict, strict=True)

# Run federated rounds
num_rounds = 5
scheme = "ckks"  # or "bfv"

for rnd in range(1, num_rounds+1):
    print(f"\n--- Federated Round {rnd} ---")
    client_updates = []
    
    for cid, dataset in enumerate(client_datasets):
        local_model = BrainTumorCNN()
        set_model_params(local_model, get_model_params(global_model))
        
        # Train locally
        local_model = train_local(local_model, dataset, epochs=10, lr=0.01)
        
        # Encrypt update
        params = get_model_params(local_model)
        encrypted = encrypt_parameters(params, scheme=scheme)
        client_updates.append(encrypted)
    
    # Aggregate on server
    avg_params = federated_averaging(client_updates, shapes, scheme=scheme)
    set_model_params(global_model, avg_params)
    
    print(f"Round {rnd} completed.")
end_time = time.time()
elapsed_time = end_time - start_time
print(f"\nTotal training time: {elapsed_time:.2f} seconds")



--- Federated Round 1 ---
The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, try increasing the poly_modulus parameter, to fit your input.
The following operations are disabled in this setup: matmul, matmul_plain, enc_matmul_plain, conv2d_im2col.
If you need to use those operations, 

In [29]:
test_loader = DataLoader(test_dataset, batch_size=32)
criterion = nn.CrossEntropyLoss()

global_model.eval()
loss, correct = 0, 0
all_preds, all_targets = [], []

with torch.no_grad():
    for data, target in test_loader:
        output = global_model(data)
        loss += criterion(output, target).item()
        preds = output.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(target.cpu().numpy())
        correct += preds.eq(target).sum().item()

accuracy = correct / len(test_loader.dataset)
print(f"Test Loss: {loss/len(test_loader):.4f}, Accuracy: {accuracy:.4f}")
print("\nClassification Report:\n", classification_report(all_targets, all_preds))


Test Loss: 0.4835, Accuracy: 0.8108

Classification Report:
               precision    recall  f1-score   support

           0       0.91      0.59      0.72       300
           1       0.65      0.69      0.67       306
           2       0.87      0.95      0.91       405
           3       0.83      0.97      0.90       300

    accuracy                           0.81      1311
   macro avg       0.81      0.80      0.80      1311
weighted avg       0.82      0.81      0.81      1311

