In [1]:
from concrete import fhe
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
from torchsummary import summary

from norse.torch.functional.lif import LIFParameters
import norse.torch as snn

from sklearn.model_selection import train_test_split
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
#hyperparams
no_epochs = 2
batch_size = 32
learning_rate = 0.001

In [4]:
#transform the dataset into tensors normalized range [-1, 1]
transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.5),(0.5))     
        ])

In [5]:
#data sets downloading and reading
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                        train=True,
                                        download=True,
                                        transform=transform
                                        )

test_dataset = torchvision.datasets.MNIST(root='data',
                                        train=False,
                                        download=True,
                                        transform=transform
                                        )


In [6]:
# Extract features (images) and labels from MNIST dataset
mnist_features = train_dataset.data.numpy().reshape(-1, 28, 28)
mnist_labels = train_dataset.targets.numpy()

# Reshape and expand dimensions to match the structure of load_digits dataset
x_train_mnist = np.expand_dims(mnist_features, 1)

# Split the MNIST data into train and test sets
x_train, x_test, y_train, y_test = train_test_split(
    x_train_mnist, mnist_labels,  train_size=5000, test_size=100, shuffle=True, random_state=42
)

#x_train = x_train.astype('float64')
print(x_train.dtype)
print(x_train.shape)

print(x_train)
# plt.imshow(x_train[0,0], cmap='grey')
# plt.show()
# Verify the shapes
print("Shape of x_train from MNIST:", x_train.shape)
print("Shape of x_test from MNIST:", x_test.shape)
print("Shape of y_train from MNIST:", y_train.shape)
print("Shape of y_test from MNIST:", y_test.shape)

uint8
(5000, 1, 28, 28)
[[[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 ...


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]]
Shape of x_train from MNIST: (5000, 1, 28, 28)
Shape of x_test from MNIST: (100, 1, 28, 28)
Shape of y_train from MNIST: (5000,)
Shape of y_test fro

In [7]:
lookup_table = tuple(range(255))
print(lookup_table)

(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,

In [8]:
table = fhe.LookupTable(lookup_table)

In [9]:
@fhe.compiler({"x":"encrypted"})
def ref_function(x):
    return table[x]

In [10]:
inputset = range(0,255)
print(inputset)
circuit = ref_function.compile(inputset)

range(0, 255)


In [11]:

# take an MNIST image then encrypt it
encrypted_x_train = np.empty_like(x_train, dtype=object)

start = time.time()
for i, x_train_row in enumerate(x_train):
    for j, x_train_cell in enumerate(x_train_row):
        for k, row in enumerate(x_train_cell):
            for l, cell in enumerate(row):
                # Encrypt the cell value using circuit.encrypt
                enc_value = circuit.encrypt(cell)
                enc_value = circuit.run(enc_value)
                encrypted_x_train[i, j, k, l] = enc_value
                print(enc_value)
            break
        break
    break


end = time.time()
print(f' value encrypt time : {end - start:.3f} seconds')
print(encrypted_x_train)

<concrete.fhe.compilation.value.Value object at 0x7f85a927a7d0>
<concrete.fhe.compilation.value.Value object at 0x7f85a927be50>
<concrete.fhe.compilation.value.Value object at 0x7f85a927bee0>
<concrete.fhe.compilation.value.Value object at 0x7f85a927ba00>
<concrete.fhe.compilation.value.Value object at 0x7f85a927b910>
<concrete.fhe.compilation.value.Value object at 0x7f85a927b820>
<concrete.fhe.compilation.value.Value object at 0x7f85a9279780>
<concrete.fhe.compilation.value.Value object at 0x7f85a92797b0>
<concrete.fhe.compilation.value.Value object at 0x7f85a927a350>
<concrete.fhe.compilation.value.Value object at 0x7f85a92791e0>
<concrete.fhe.compilation.value.Value object at 0x7f85a927b3a0>
<concrete.fhe.compilation.value.Value object at 0x7f85a927b550>
<concrete.fhe.compilation.value.Value object at 0x7f85a927bb20>
<concrete.fhe.compilation.value.Value object at 0x7f85a927a8c0>
<concrete.fhe.compilation.value.Value object at 0x7f86ddd69ff0>
<concrete.fhe.compilation.value.Value ob

In [13]:
decrypted_x_train = np.empty_like(x_train, dtype=int)

start = time.time()
for i, x_train_row in enumerate(encrypted_x_train):
    for j, x_train_cell in enumerate(x_train_row):
        for k, row in enumerate(x_train_cell):
            for l, cell in enumerate(row):
                # Encrypt the cell value using circuit.encrypt
                enc_value = circuit.run(cell)
                dec_value = circuit.decrypt(enc_value)
                decrypted_x_train[i, j, k, l] = dec_value
                # print(dec_value)
            break
        break
    break

end = time.time()
print(f' value encrypt time : {end - start:.3f} seconds')
print(decrypted_x_train)

 value encrypt time : 12.556 seconds
[[[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 ...


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]]


In [22]:
@fhe.compiler({"x":"encrypted", "y":"encrypted"})
def add_function(x, y):
    return x+y

In [15]:
@fhe.compiler({"x":"encrypted", "y":"encrypted"})
def multiply_function(x, y):
    return x*y

Add two cells in an image.

In [35]:
val_1 = encrypted_x_train[0,0,0,10]
val_2 = encrypted_x_train[0,0,0,11]
# enc_value = np.add(val_1, val_2)
# enc_value = val_1


enc_value = circuit.run(enc_sum)
dec_value = circuit.decrypt(enc_value)
print(f'decrypted sum {dec_value}')

TypeError: unsupported operand type(s) for +: 'Value' and 'Value'

In [164]:
value = -49

enc_value  = circuit.encrypt(value)
enc_value = circuit.run(enc_value)
dec_value = circuit.decrypt(enc_value)
print(dec_value)

1


In [147]:
c_val = circuit.encrypt_run_decrypt(0)

print(c_val)

0
