## Imports

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.cryptography.optimized_cnn_ipfe import IPFE, decrypt_patches_batch
from src.utils.notebook_helper import encrypt_test_data, test_ipfe_cnn, test_regular_ipfe_cnn, load_data

In [2]:
base_model = 1
model_path = f"models/cnn_model_{base_model}.pth"

## Define Model

In [3]:
class IPFECNN(nn.Module):
    def __init__(self, num_classes=10, prime=6898777):
        super(IPFECNN, self).__init__()
        self.prime = prime
        self.ipfe = IPFE(prime)
        self.encryption_length = 9 # 3x3 filter size flattened

        self.ipfe.setup(self.encryption_length)
        print("IPFE setup done, with length:", self.encryption_length)

        # First convolutional block - this will be used with IPFE
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=3, padding=1) # stride = 2, padding = 0
        self.bn1 = nn.BatchNorm2d(16)
        self.pool1 = nn.MaxPool2d(2, 2)

        # Second convolutional block
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool2 = nn.MaxPool2d(2, 2)

        # Third convolutional block
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.pool3 = nn.MaxPool2d(2, 2)

        # Fully connected layers
        self.fc1 = nn.Linear(64 * 1 * 1, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, num_classes)

        #copy weights from the trained model
        self.load_state_dict(torch.load(model_path, map_location=device))
        print("weights copied from trained model")

        self.weights = self.conv1.weight.data
        self.y_array = torch.round(self.weights.view(self.weights.size(0), -1).squeeze(1).view(self.weights.size(0), -1) * 10000).long().tolist()
        print("weights converted to y vectors")
        self.biases = self.conv1.bias
        print("biases saved")
        self.sk_y_array = [self.ipfe.key_derive(y) for y in self.y_array]
        print("sk_ys created")

    def encrypt_data(self, test_set):
        unfold = nn.Unfold(kernel_size=3, stride=3, padding=1)
        patches = unfold(test_set)
        B, patch_size, num_patches = patches.shape

        encrypted_batch= []

        for b in range(B):
            patches_b = patches[b].T  # (H*W, patch_size)
            ct0_array = np.zeros(num_patches, dtype=np.int64)
            cts_array = np.zeros((num_patches, patch_size), dtype=np.int64)

            for p_idx in range(num_patches):
                patch = patches_b[p_idx]
                patch_int = np.array([int(val.item()) % (self.prime - 1) for val in patch], dtype=np.int64)
                ct0, ct = self.ipfe.encrypt(patch_int)

                ct0_array[p_idx] = ct0
                cts_array[p_idx, :] = np.array(ct, dtype=np.int64)

            encrypted_batch.append((ct0_array, cts_array))

        return encrypted_batch

    def first_conv_forward(self, encrypted_image, H, W):
        num_kernels = len(self.sk_y_array)
        device = next(self.parameters()).device
        ct0_array, cts_array = encrypted_image
        num_patches = ct0_array.shape[0]

        decrypted_maps = torch.zeros(num_kernels, num_patches, device=device)

        # Loop over kernels
        for k in range(num_kernels):
            sk_y = int(self.sk_y_array[k])
            y_vec = np.array(self.y_array[k], dtype=np.int64)
            bias = float(self.biases[k].item())

            # Batch decrypt all patches using Numba
            decrypted_vals = decrypt_patches_batch(ct0_array, cts_array, sk_y, y_vec, self.ipfe.g, self.prime)

            # Scale and add bias
            decrypted_maps[k, :] = torch.tensor(decrypted_vals / 10000.0 + bias, device=device)

        # Reshape to (1, num_kernels, H, W)
        return decrypted_maps.view(1, num_kernels, H, W)

    def forward(self, x, H, W, encrypted=False):
        if encrypted:
            outputs = []
            for sample in x:  # x = [ [patches_img1], [patches_img2], ... ]
                feat = self.first_conv_forward(sample, H, W)
                feat = self.pool1(F.relu(self.bn1(feat)))
                feat = self.pool2(F.relu(self.bn2(self.conv2(feat))))
                feat = self.pool3(F.relu(self.bn3(self.conv3(feat))))
                feat = feat.view(feat.size(0), -1)
                feat = F.relu(self.fc1(feat))
                feat = self.dropout(feat)
                feat = self.fc2(feat)
                outputs.append(feat)
            return torch.cat(outputs, dim=0)
        else:
            x = self.conv1(x)
            x = self.pool1(F.relu(self.bn1(x)))
            x = self.pool2(F.relu(self.bn2(self.conv2(x))))
            x = self.pool3(F.relu(self.bn3(self.conv3(x))))
            x = x.view(x.size(0), -1)
            x = F.relu(self.fc1(x))
            x = self.dropout(x)
            x = self.fc2(x)
            return x



## Initialize Model

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ipfe_model = IPFECNN(num_classes=10, prime=6898777).to(device)

print(f"IPFE-CNN model created on device: {device}")

IPFE setup done, with length: 9
weights copied from trained model
weights converted to y vectors
biases saved
sk_ys created
IPFE-CNN model created on device: cpu


In [5]:
test_loader = load_data()
encrypted_data, labels = encrypt_test_data(ipfe_model, test_loader, device, num_samples=5)

Test samples: 10000
Encrypted 5 samples.


## Test model

In [6]:
print("Testing IPFE-CNN functionality...")
test_ipfe_cnn(ipfe_model, encrypted_data, labels, H=10, W=10, device=device)


Testing IPFE-CNN functionality...
Testing IPFE-CNN forward pass on encrypted data...
Labels of test samples: [7 2 1 0 4]
Predictions on encrypted data: [7 2 1 0 4]
Accuracy on encrypted samples: 100.00% (5/5)


In [7]:
test_regular_ipfe_cnn(ipfe_model, test_loader, device, num_samples=5)

Testing IPFE-CNN forward pass on encrypted data...
Predictions on encrypted data: [7 2 1 0 4]
Accuracy on encrypted samples: 100.00% (5/5)
