In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from time import time
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt 
import torch

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch import nn
import torch.nn.functional as F

from fase.nn.conv import *

In [3]:
num_workers = 0
batch_size = 32
valid_size = 0.2


## Scale 
transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

train_data = datasets.CIFAR10('data', train=True,
                              download=True,
                              transform=transform
                             )
test_data = datasets.CIFAR10('data', train=False,
                             download=True, 
                             transform=transform
                            )

num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 
    sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, 
    num_workers=num_workers)

# specify the image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

Files already downloaded and verified
Files already downloaded and verified


In [4]:
from approximate import approx_relu, approx_sign
from fase.nn.models import ConvNeuralNet

xfactor = 40
activation = lambda x : xfactor * approx_relu(x/xfactor, degree = 5, repeat=3)

org_model = ConvNeuralNet(num_classes=10, activation=activation)

# FHE

In [None]:
#import fase
from fase.core import seal_ckks
from fase.core.seal_ckks import SEALContext
from fase.seal import Ciphertext

In [5]:
poly_modulus_degree = 2**15
scale_bit = 50
coeff_moduli = [60] + [scale_bit] * 14 + [60]

sec = SEALContext(poly_modulus_degree=poly_modulus_degree,
                             coeff_moduli=coeff_moduli,
                             scale_bit=scale_bit)

SEAL CKKS scheme is ready


## Load image


In [7]:
import torchvision.transforms as transforms

img = np.array(Image.open("./bird6.png"))
to_tensor = transforms.ToTensor() # [n_channel, nh, nw]
img_tensor = to_tensor(img).unsqueeze(0) # [n_batch, n_channel, nh, nw]
n_batch, n_channel, nh, nw = img_tensor.shape

print(img_tensor.shape)

torch.Size([1, 3, 32, 32])


## Load trained parameters

In [8]:
fn_param = "SimpleCNN_ReLU_minimax_v2.pt"
trained_param = torch.load(fn_param)
trained_param = {k: v.cpu() for k, v in trained_param.items()} # to cpu()
org_model.load_state_dict(trained_param)
org_model.eval() ## If not eval(), running_mean and running_var of batch_norm changes

# To numpy
params_np = {k: v.numpy() for k, v in trained_param.items()}

In [9]:
import fase.nn.utils as utils
util = utils.Seal_checker(sec)

In [10]:
# 테스트용 이미지 (새_)

img_this_example = img_tensor[0] # Assume batch size = 1
img_enc = [sec.encrypt(this_channel.ravel()) for this_channel in img_this_example]


#util.check_decrypt(img_enc[0])

print(img_tensor.min(), img_tensor.max())

tensor(0.0824) tensor(1.)


# Evaluation

In [None]:
tmp1, _nh2, _nw2 = my_conv2D_FHE(sec, img_enc, nh, nw, org_model.conv_layer1.weight) # list of ctxts
tmp2, _nh2, _nw2 = my_conv2D_FHE(sec, tmp1, nh, nw, org_model.conv_layer2.weight) # list of ctxts
tmp3, _nh2, _nw2 = fhe_avg_pool(sec, tmp2, nh, nw, 
                                kernel_size=org_model.pool.kernel_size, 
                                stride_in=1)
tmp4, _nh2, _nw2 = my_conv2D_FHE(sec, tmp3, nh, nw, org_model.conv_layer3.weight,
                                stride_in=2) 



