## Imports

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from cryptography.cnn_ipfe import IPFE

In [2]:
base_model = 1
model_path = f"models/cnn_model_{base_model}.pth"

## Define Model

In [16]:
class IPFECNN(nn.Module):
    def __init__(self, num_classes=10, prime=4590007):
        super(IPFECNN, self).__init__()
        self.prime = prime
        self.ipfe = IPFE(prime)
        self.encryption_length = 9 # 3x3 filter size flattened

        self.ipfe.setup(self.encryption_length)
        print("IPFE setup done, with length:", self.encryption_length)

        # First convolutional block - this will be used with IPFE
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) # stride = 2, padding = 0
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2, 2)

        # Second convolutional block
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2, 2)

        # Third convolutional block
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2, 2)

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 3 * 3, 128) # 64 * 1 * 1, 128
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)

        #copy weights from the trained model
        self.load_state_dict(torch.load(model_path, map_location=device))
        print("weights copied from trained model")

        self.weights = self.conv1.weight.data
        self.y_array = torch.round(self.weights.view(self.weights.size(0), -1).squeeze(1).view(self.weights.size(0), -1) * 10000).long().tolist()
        print("weights converted to y vectors")
        self.biases = self.conv1.bias
        print("biases saved")
        self.sk_y_array = [self.ipfe.key_derive(y) for y in self.y_array]
        print("sk_ys created")

    def encrypt_data(self, test_set):
        unfold = nn.Unfold(kernel_size=3, stride=1, padding=1)
        patches = unfold(test_set)
        B, patch_size, num_patches = patches.shape

        encrypted_patches = []

        for b in range(B):
            patches_b = patches[b].T  # (H*W, patch_size)
            encrypted_image = []
            for p in range(num_patches):
                patch = patches_b[p]
                patch_int = [(int(val.item()) % (self.prime - 1)) for val in patch]
                encrypted = self.ipfe.encrypt(patch_int)  # could be tuple
                encrypted_image.append(encrypted)
            encrypted_patches.append(encrypted_image)

        return encrypted_patches


    def first_conv_forward(self, x, H, W):
        num_patches = len(x)
        num_kernels = len(self.sk_y_array)
        device = next(self.parameters()).device

        decrypted_maps = torch.zeros(num_kernels, num_patches, device=device)

        for k in range(num_kernels):
            for p in range(num_patches):
                decrypted_scaled = self.ipfe.decrypt(
                    x[p],
                    self.sk_y_array[k],
                    self.y_array[k],
                )
                decrypted = (decrypted_scaled / 10000) + self.biases[k].item()
                decrypted_maps[k, p] = decrypted
        return torch.stack([decrypted_maps.view(num_kernels, H, W)], dim=0) # H and W


    def forward(self, x, H, W, encrypted=False):
        if encrypted:
            outputs = []
            for sample in x:  # x = [ [patches_img1], [patches_img2], ... ]
                feat = self.first_conv_forward(sample, H, W)
                feat = self.pool1(F.relu(self.bn1(feat)))
                feat = self.pool2(F.relu(self.bn2(self.conv2(feat))))
                feat = self.pool3(F.relu(self.bn3(self.conv3(feat))))
                feat = feat.view(feat.size(0), -1)
                feat = F.relu(self.fc1(feat))
                feat = self.dropout(feat)
                feat = self.fc2(feat)
                outputs.append(feat)
            return torch.cat(outputs, dim=0)
        else:
            x = self.conv1(x)
            x = self.pool1(F.relu(self.bn1(x)))
            x = self.pool2(F.relu(self.bn2(self.conv2(x))))
            x = self.pool3(F.relu(self.bn3(self.conv3(x))))
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            x = self.dropout(x)
            x = self.fc2(x)
            return x



## Load Data

In [17]:
transform = transforms.Compose([
    transforms.Lambda(lambda pic: torch.tensor(np.array(pic), dtype=torch.float32).unsqueeze(0))
])
batch_size = 64
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(f"Test samples: {len(test_dataset)}")

Test samples: 10000


## Initialize Model

In [18]:
# Initialize IPFE-enhanced CNN
# bei n=5

# mit kernel 3x3, stride 1, padding 1
# 1721257 mit 16 373ms scheint sicher zu sein
# 2300003 mit 21 721ms sollte sehr wahrscheinlich sicher zu sein
# 4590007 mit 29 470ms muss mathematisch sicher sein

# mit kernel 3x3, stride 2, padding 0
# 1721257 nur 60%
# 2300003 nur 80%
# 4590007 mit 5 708ms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ipfe_model = IPFECNN(num_classes=10, prime=4590007).to(device)
print(f"IPFE-CNN model created on device: {device}")

IPFE setup done, with length: 9
weights copied from trained model
weights converted to y vectors
biases saved
sk_ys created
IPFE-CNN model created on device: cpu


In [19]:
def encrypt_test_data(model, test_loader, device, num_samples=5):
    """Encrypt a batch of test data"""
    model.eval()
    with torch.no_grad():

        # Get a batch of test data
        data_iter = iter(test_loader)
        images, labels = next(data_iter)
        images, labels = images.to(device), labels.to(device)

        # Encrypt only a subset
        images_subset = images[:num_samples]
        labels_subset = labels[:num_samples]

        # Encrypt the data
        encrypted_data = model.encrypt_data(images_subset)
        print(f"Encrypted {num_samples} samples.")

        H, W = images.size(2), images.size(3)

    return encrypted_data, labels_subset, H, W

In [20]:
encrypted_data, labels, H, W = encrypt_test_data(ipfe_model, test_loader, device, num_samples=5)

Encrypted 5 samples.


## Test Model

In [21]:
# Test IPFE functionality with a sample
def test_ipfe_cnn(model, encrypted_data, labels, H, W, device):
    """Test the IPFE-CNN with a sample query vector"""
    model.eval()

    with torch.no_grad():
        print("Testing IPFE-CNN forward pass on encrypted data...")
        print(f"Labels of test samples: {labels.cpu().numpy()}")

        try:
            outputs = model.forward(encrypted_data, encrypted=True, H=28, W=28)
            _, predicted = outputs.max(1)

            print(f"Predictions on encrypted data: {predicted.cpu().numpy()}")

            correct = (predicted == labels).sum().item()
            total = labels.size(0)
            print(f"Accuracy on encrypted samples: {100 * correct / total:.2f}% ({correct}/{total})")


        except Exception as e:
            print(f"Encrypted IPFE forward pass failed: {e}")


In [22]:
def test_regular_ipfe_cnn(model, test_loader, device, num_samples=5):
    """Test the IPFE-CNN with a sample query vector"""
    model.eval()

    with torch.no_grad():
        print("Testing IPFE-CNN forward pass on encrypted data...")

        data_iter = iter(test_loader)
        images, labels = next(data_iter)
        images, labels = images.to(device), labels.to(device)

        # Encrypt only a subset
        images_subset = images[:num_samples]
        labels_subset = labels[:num_samples]
        try:
            outputs = model.forward(images_subset, encrypted=False, H=28, W=28)
            _, predicted = outputs.max(1)

            print(f"Predictions on encrypted data: {predicted.cpu().numpy()}")

            correct = (predicted == labels_subset).sum().item()
            total = labels_subset.size(0)
            print(f"Accuracy on encrypted samples: {100 * correct / total:.2f}% ({correct}/{total})")


        except Exception as e:
            print(f"Encrypted IPFE forward pass failed: {e}")

In [23]:
print("Testing IPFE-CNN functionality...")
test_ipfe_cnn(ipfe_model, encrypted_data, labels, H, W, device)


Testing IPFE-CNN functionality...
Testing IPFE-CNN forward pass on encrypted data...
Labels of test samples: [7 2 1 0 4]
Predictions on encrypted data: [7 2 1 0 4]
Accuracy on encrypted samples: 100.00% (5/5)


In [24]:
test_regular_ipfe_cnn(ipfe_model, test_loader, device, num_samples=5)

Testing IPFE-CNN forward pass on encrypted data...
Predictions on encrypted data: [7 2 1 0 4]
Accuracy on encrypted samples: 100.00% (5/5)
