Fully Homomorphically Encrypted Fashion-MNIST CNN Example
=========================================================

- This example will download Fashion-MNIST (a drop in replacement for MNIST)
- Prepare Fashion-MNIST
- Train a very basic CNN on Fashion-MNIST in plaintext
- Infer using the testing set using both plaintexts and cyphertexts for comparison

Download Fashion-MNIST
----------------------

- Get Fashion-MNIST as a zipped up set of CSVs
- Unizp Fashion-MNIST

In [None]:
import sys
!{sys.executable} -m pip install pyvis seaborn

In [None]:
import os
import requests
import zipfile
import logging
import time
import datetime
import copy
import seaborn as sns
logging.basicConfig(level=logging.ERROR)
sns.set_theme(style="whitegrid")
# logging.basicConfig(level=logging.DEBUG)

In [None]:
cwd = os.getcwd() # current working directory
print(cwd)

In [None]:
data_dir = os.path.join(cwd, "datasets")
if os.path.exists(data_dir):
    pass
else:
    os.mkdir(data_dir)
print(data_dir)

In [None]:
mnist_zip = os.path.join(data_dir, "mnist.zip")
if os.path.exists(mnist_zip):
    print("Skipping mnist download")
else:
    print("Downloading Fashion-MNIST")
    mnist_url = "http://nextcloud.deepcypher.me/s/wjLa6YFw8Bcbra9/download"
    r = requests.get(mnist_url, allow_redirects=True, verify=False)
    with open(mnist_zip, "wb") as f:
        f.write(r.content)

In [None]:
unzip_dir = os.path.join(data_dir, "mnist")
if os.path.exists(unzip_dir):
    pass
else:
    os.mkdir(unzip_dir)
with zipfile.ZipFile(mnist_zip, "r") as zip_ref:
    zip_ref.extractall(unzip_dir)

"Wrangle"/ prepare Fashion-MNIST
--------------------------------

- Read in the Fashion-MNIST CSVs
- Split training and testing features (x) from target (y)
- Normalise x and y (in the range 0-1 to prevent infinite numbers when using our approximations)

In [None]:
import pandas as pd
import numpy as np
import tqdm

In [None]:
train_file = os.path.join(unzip_dir, "fashion-mnist_train.csv") 
test_file = os.path.join(unzip_dir, "fashion-mnist_test.csv")
train = pd.read_csv(train_file)
test = pd.read_csv(test_file)
# train

In [None]:
train_y = train.iloc[:, 0]
train_x = train.iloc[:, 1:]/255 # normalise to 0-1 preventing explosion
test_x = test.iloc[:, 1:]/255 # normalise to 0-1 preventing explosion
test_y = test.iloc[:, 0]
train_x = train_x.to_numpy()
train_y = train_y.to_numpy()
test_x = test_x.to_numpy()
test_y = test_y.to_numpy()
print(train_x.shape)
print(train_y.shape)

In [None]:
# train_x[0]

In [None]:
train_x[0].shape

Define Neural Network
---------------------

- Use [Networkx](https://networkx.org/) to construct a **multi-directed-graph** as a neural network
- Nodes for this graph are abstractions of neural network components with forward, backward (backpropogation), update (weight update/ optimisation), and costs (computational depth of traversal to the node)
- We use Nodes that inherit from the abstract base class [fhez.nn.graph.node.Node](https://python-fhez.readthedocs.io/en/latest/nodes/node.html#node) so if you need to define your own type of node inherit from this to match the API the network traverser expects

In [None]:
import networkx as nx
from fhez.nn.graph.prefab import orbweaver
graph = orbweaver()
print(graph) # you can modify this graph like any other networkx graph using our existing/ ready made nodes like adding a new CNN layer for instance

In [None]:
# # optionally modify the graph
# # here we replace RELU with sigmoid activation for comparison only, ReLU is almost certainly better
# from fhez.nn.activation.sigmoid import Sigmoid
# graph.nodes(data=True)["CNN-RELU"]["node"] = Sigmoid()
# print(graph.nodes(data=True)["CNN-RELU"])
# for i in range(10):
#     graph.nodes(data=True)["Dense-RELU-{}".format(i)]["node"] = Sigmoid()
#     print(graph.nodes(data=True)["Dense-RELU-{}".format(i)])


Visualise the graph
-------------------

In [None]:
import copy
def strip(graph):
    g = copy.deepcopy(graph)
    for node in g.nodes(data=True):
        try:
            # node[1]["title"] = "{}:\n{}".format(type(node[1]["node"]), repr(node[1]["node"]))
            del node[1]["node"]
        except KeyError:
            pass
    return g
    
print(graph)

In [None]:
from pyvis.network import Network
stripped = strip(graph)
print(stripped)

from pyvis.network import Network
net = Network('700px', '700px', bgcolor='#222222', font_color='white', notebook=True)
net.from_nx(stripped)
# net.show_buttons(filter_="physics")
net.show("graph.html")

Train Using Plaintext Data
--------------------------

- Instantiate our neural networks
- Compute the forward pass of our neural networks
- Compute the backward pass of our neural networks

I would like to stress that FHE is not a panacea.
You may be wondering, why dont we train the neural network using cyphertexts? The simple answer is, *where/ when do we stop?*
This statement refers to two *stops* in particular, when do we stop the training when we cannot see the loss, and where does the cyphertext *stop* for instance do we carry the cyphertext all the way through which means our neural network weights are encrypted.
The solution to the first *stop* is both simple but expensive, the answer is we compute the training-test divergence on the client side where the keys exist so that we can find the optimal *training stop* point but this requires us to have a continued connection to the client.
There are many answers to when we might figurativeley stop the cyphertext, but if privacy is of pivital concern then the only real answer is never, since any plaintext weights could be used in theorey to reconstruct the data that was used to train it, which means if we do the forward pass in cyphertext but do the backward pass in plaintext we dont gain any privacy since the data is then known to the data processor. However if we stick to cyphertexts throught all the way upto and including the weight update that means naturally when the weights are updated by the gradients which themselves come from the inputs thus all cyphertexts, that the weights after the first iteration of the network will become encrypted, taking *significantly* (orders of magnitude) longer to calculate since cyphertext + cyphertext operations take much longer even than cyphertext + plaintext calculations. This is not to mention the lack of compatibility of loss functions with FHE since many require some form of division which must be approximated.

Thus we think the optimal solution is actually transfer learning. Where you train on a similar dataset and try to transfer the understanding to a similar problem, but infer using cyphertexts only. That way privacy is maintained since the plaintext weights are untouched by the clients data, and we can still create encrypted inference albeit with lower accuracy, but not incurring the cyphertext-cyphertext cost of encrypted weights.

In [None]:
from fhez.nn.graph.utils import train, infer
from fhez.nn.loss.cce import CCE

In [None]:
# cnn = Layer_CNN(weights=( 1, 6, 6 ), stride=[ 1, 4, 4 ], bias=0)
# dense = None
# for cyphertext in row_encrypted_generator(data=train_x, shape=( 1, 28, 28 )):
#     cnn_acti = cnn.forward(cyphertext)
#     if dense is None:
#         dense = Layer_ANN(weights=(len(cnn_acti),), bias=0)
#     dense.forward(cnn_acti)

In [None]:
train_dict = {
        "x": [],
        "y": [],
    }

for i in zip(train_x, train_y):
    train_dict["x"].append(np.reshape(i[0], (28,28)))
    train_dict["y"].append(i[1])
    
# # for i in train_x:
# #     train_dict["x"].append(np.reshape(i, (28, 28)))
    
# print(train_y, type(train_y), train_y.shape)
# print(train_x, type(train_x), train_x.shape)

In [None]:
tt = time.time()
output = train(graph=graph, inputs=train_dict, batch_size=5, debug=False)
tt = time.time() - tt


In [None]:
trained_graph = copy.deepcopy(graph)

In [None]:
output

Plaintext Inference
-------------------

- Find accuracy against testing set in plaintext space for comparison

In [None]:
test_dict = {
        "x": [],
        "y": [],
    }
for i in zip(test_x, test_y):
    test_dict["x"].append(np.reshape(i[0], (28,28)))
    test_dict["y"].append(i[1])
    
pi = time.time()
y_hats = infer(graph=graph, inputs={key:value for key,value in test_dict.items() if key in ["x"]})["y_hat"] 
pi = time.time() - pi

In [None]:
sample=y_hats[20:40]
sample

In [None]:
true_sample = test_dict["y"][20:40]
true_sample

In [None]:
accurates = np.where(np.isclose(y_hats, test_dict["y"]))
len(accurates[0])

In [None]:
accuracy = len(accurates[0])/len(test_dict["y"])
print("Accuracy {}".format(accuracy))

In [None]:
csv_path = "fashion_MNIST_results.csv"
current_result = pd.DataFrame({"accuracy": [accuracy], 
                               "training_time": [tt], 
                               "plain_inference_time": [pi], 
                               "datetime": [datetime.datetime.now()], 
                               "y_hat_sample": [sample], 
                               "y_sample": [true_sample], 
                               "inference_size": [len(test_dict["y"])],
                               "activation": ["Sigmoid"] if isinstance(graph.nodes(data=True)["CNN-RELU"]["node"], Sigmoid) else ["ReLU"],
                              })
try:
    all_results = pd.read_csv(csv_path, index_col=False)
    all_results = all_results.append(current_result)
except FileNotFoundError:
    all_results = current_result
all_results.to_csv(csv_path, index=False)
all_results

In [None]:
ax = sns.boxplot(y="accuracy", x="activation", data=all_results)
ax = sns.swarmplot(y="accuracy", x="activation", data=all_results, color=".25")
ax.set(title="Model Accuracy by Activation")
fig = ax.get_figure()
fig.savefig("fashion-mnist-swarm.png") 

Not great, not terrible (3.6 roentgen). Absolute network performance can always be improved by using newer/ better architectures, and more epochs if it has not learnt what it can from the training set. we use a simple 1 CNN layer + 10 Dense layers + softmax into categorical cross entropy, there are much better architectures to use but we are concerned with the encryption here so we dont want to overcomplicate things. This is more of a means to an end of comparison.

Encrypted Inference
-------------------

- find accuracy against testing set again but this time in encrypted space

Parameterise Encoding/ Encryption and Create an Encrypted Generator
-------------------------------------------------------------------

- Parameterise our neural network graph encryption nodes
- Automatically set parameterisation using AtoFHE

In [None]:
import seal
encryption_parameters = {
            "scheme": seal.scheme_type.CKKS,
            "poly_modulus_degree": 8192*2,
            "coefficient_modulus":
                [45, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 45],
            "scale": pow(2.0, 30),
            "cache": True,
}

In [None]:
# # Generate Encrypted data peace-meal (as it can get very large)
# def row_encrypted_generator(data: np.ndarray, shape: tuple):
#     """Generate encrypted data of desired shape from rows."""
#     for row in data:
#         row = np.reshape(row, newshape=shape) / 255 # reshape to image shape and normalise between 0-1
#         yield ReArray(row, **encryption_parameters)