# Notebook di Inferenza con Homomorphic Encryption per Modelli CNN

Questo notebook dimostra l'implementazione di un sistema di inferenza utilizzando Homomorphic Encryption (HE) su un modello CNN addestrato con NVFlare per la classificazione di immagini radiografiche per la diagnosi della polmonite. Il notebook si articola in diverse sezioni principali:

## Componenti Principali

1. **Architettura del Modello**
   - Implementazione di una CNN moderata (ModerateCNN) con tre blocchi convoluzionali
   - Architettura che include layer di BatchNorm, ReLU, MaxPooling e Dropout
   - Classificatore finale con output binario per la classificazione della polmonite

2. **Gestione dei Dati**
   - Implementazione di un CustomDataset per il caricamento delle immagini radiografiche
   - Pipeline di trasformazione delle immagini con ridimensionamento, conversione in scala di grigi e normalizzazione

3. **Processo di Decriptazione**
   - Utilizzo della libreria TenSEAL per la gestione della crittografia omomorfa
   - Caricamento e decriptazione dei pesi del modello crittografato
   - Salvataggio del modello decriptato per l'inferenza

4. **Valutazione del Modello**
   - Test del modello decriptato su un set di validazione
   - Calcolo dell'accuratezza della classificazione
   - Visualizzazione delle predizioni confrontate con le etichette reali

## Caratteristiche Tecniche
- Utilizzo del framework PyTorch per l'implementazione del modello
- Implementazione del protocollo CKKS per la crittografia omomorfa
- Pipeline di preprocessing delle immagini ottimizzata per radiografie toraciche
- Sistema di logging dettagliato per il monitoraggio del processo di decriptazione

Questo notebook rappresenta un esempio pratico di come implementare l'inferenza sicura utilizzando tecniche di crittografia omomorfa in un contesto medico, garantendo la privacy dei dati durante l'intero processo di classificazione.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ModerateCNN(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(ModerateCNN, self).__init__()
        
        self.features = nn.Sequential(
            # First block
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.2),
            
            # Second block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.3),
            
            # Third block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.4),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 2)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [2]:
from torch.utils.data import Dataset
import os
from PIL import Image

class CustomsDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.data_folder = data_folder
        self.transform = transform

        self.class_names = sorted(os.listdir(data_folder))
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)}
        self.image_paths = []
        self.labels = []
        self.data = []

        for class_name in self.class_names:
            class_folder = os.path.join(data_folder, class_name)
            class_label = self.class_to_idx[class_name]
            for filename in os.listdir(class_folder):
                img_path = os.path.join(class_folder, filename)
                self.image_paths.append(img_path)
                self.labels.append(class_label)
                image = Image.open(img_path)
                self.data.append(np.array(image)) 

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [6]:
from torchvision import transforms

transform_valid = transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),          
            transforms.Resize((224, 224)),    
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485],
                std=[0.229]
            ),
        ])

In [7]:
import torch
import tenseal as ts
import numpy as np
import os

try:
    # Load TenSEAL context
    # TenSEAL context is present on the clients, e.g. "/site-1/startup/client_context.tenseal"
    with open("./tenseal_context/client_context.tenseal", "rb") as f:
        context_bytes = f.read()
        tenseal_context = ts.Context.load(context_bytes)
    
    print("TenSEAL context loaded with private key:", tenseal_context.is_private())
    
    # Load encrypted model
    model_path = './models/encrypted_FL_global_model.pt'
    encrypted_state_dict = torch.load(model_path)
    
    # Print model structure info
    print("\nModel structure:")
    for key in encrypted_state_dict['model'].keys():
        value = encrypted_state_dict['model'][key]
        print(f"{key}: {type(value)}")
        if isinstance(value, bytes):
            print(f"  First 20 bytes: {value[:20].hex()}")
    
    print("\nEncryption metadata:", encrypted_state_dict['meta_props'])
    
    # Create model instance to get expected shapes
    model = ModerateCNN()
    expected_shapes = {name: param.shape for name, param in model.state_dict().items()}
    
    # Decrypt each layer
    decrypted_weights = {}
    for key, encrypted_value in encrypted_state_dict['model'].items():
        try:
            if isinstance(encrypted_value, bytes):
                # Create CKKS vector and decrypt
                enc_vector = ts.CKKSVector.load(tenseal_context, encrypted_value)
                dec_data = enc_vector.decrypt()
                
                # Reshape to expected shape
                if key in expected_shapes:
                    dec_array = np.array(dec_data, dtype=np.float32).reshape(expected_shapes[key])
                    decrypted_weights[key] = torch.from_numpy(dec_array)
                    print(f"Successfully decrypted {key}")
            else:
                decrypted_weights[key] = encrypted_value
                print(f"Copied unencrypted {key}")
                
        except Exception as e:
            print(f"Failed to decrypt {key}: {e}")
    
    if decrypted_weights:
        try:
            # Load decrypted weights into model
            model.load_state_dict(decrypted_weights)
            model.eval()
            
            # Save decrypted model
            torch.save({
                'model_weights': model.state_dict(),
                'meta_props': encrypted_state_dict['meta_props'],
                'train_conf': encrypted_state_dict['train_conf']
            }, './models/decrypted_model.pt')
            
            print("\nModel successfully decrypted and saved!")
            
            # Print sample weights
            print("\nSample of decrypted weights:")
            for name, param in list(model.named_parameters())[:3]:
                print(f"\n{name}:")
                print(f"Shape: {param.shape}")
                print(f"First few values: {param.data.flatten()[:5]}")
                
        except Exception as e:
            print(f"\nError loading decrypted weights: {e}")
    else:
        print("\nNo weights were successfully decrypted")

except Exception as e:
    print(f"Error: {e}")
    print("\nFull traceback:")
    import traceback
    traceback.print_exc()

TenSEAL context loaded with private key: True


  encrypted_state_dict = torch.load(model_path)



Model structure:
features.0.weight: <class 'bytes'>
  First 20 bytes: 0a02a002128081085ea110040102000080000200
features.0.bias: <class 'bytes'>
  First 20 bytes: 0a0120128081085ea11004010200008000020000
features.1.weight: <class 'bytes'>
  First 20 bytes: 0a0120128081085ea11004010200008000020000
features.1.bias: <class 'bytes'>
  First 20 bytes: 0a0120128081085ea11004010200008000020000
features.1.running_mean: <class 'bytes'>
  First 20 bytes: 0a0120128081085ea11004010200008000020000
features.1.running_var: <class 'bytes'>
  First 20 bytes: 0a0120128081085ea11004010200008000020000
features.1.num_batches_tracked: <class 'bytes'>
  First 20 bytes: 0a0101128081085ea11004010200008000020000
features.5.weight: <class 'bytes'>
  First 20 bytes: 0a0a80208020802080208010128081085ea11004
features.5.bias: <class 'bytes'>
  First 20 bytes: 0a0140128081085ea11004010200008000020000
features.6.weight: <class 'bytes'>
  First 20 bytes: 0a0140128081085ea11004010200008000020000
features.6.bias: <class 

In [8]:
model.features

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Dropout2d(p=0.2, inplace=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU(inplace=True)
  (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (9): Dropout2d(p=0.3, inplace=False)
  (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU(inplace=True)
  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (14): Dropout2d(p=0.4, inplace=False)
)

In [None]:
from torch.utils.data import DataLoader
import os

test_data_folder = "./test/"
test_dataset = CustomsDataset(test_data_folder, transform=transform_valid)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

model.eval()
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        if i >= 1000:  # Break after processing 1000 images
            break
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        predicted_label = predicted.item()
        predicted_class_name = test_dataset.class_names[predicted_label]
        true_label = labels.item()
        true_class_name = test_dataset.class_names[true_label]

        print(f"predicted: {predicted_label} - {predicted_class_name}, actual: {true_label} - {true_class_name}")

        total_predictions += 1
        if predicted_label == true_label:
            correct_predictions += 1

accuracy = correct_predictions / total_predictions
print(f"\nAccuracy: {accuracy*100:.2f}%")

predicted: 0 - NORMAL, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 0 - NORMAL
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 0 - NORMAL
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 0 - NORMAL, actual: 0 - NORMAL
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 0 - NORMAL, actual: 0 - NORMAL
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 0 - NORMAL, actual: 0 - NORMAL
predicted: 0 - NORMAL, actual: 0 - NORMAL
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA, actual: 1 - PNEUMONIA
predicted: 1 - PNEUMONIA,