In [44]:
from concrete import fhe
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import time

from sklearn.model_selection import train_test_split

In [45]:
#device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [46]:
#hyperparams
no_epochs = 2
batch_size = 32
learning_rate = 0.001

In [47]:
#transform the dataset into tensors normalized range [-1, 1]
transform = transforms.Compose(
            [transforms.ToTensor(),
            transforms.Normalize((0.5),(0.5))     
        ])

In [48]:
#data sets downloading and reading
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                        train=True,
                                        download=True,
                                        transform=transform
                                        )

test_dataset = torchvision.datasets.MNIST(root='data',
                                        train=False,
                                        download=True,
                                        transform=transform
                                        )

In [49]:
# Extract features (images) and labels from MNIST dataset
mnist_features = train_dataset.data.numpy().reshape(-1, 28, 28)
mnist_labels = train_dataset.targets.numpy()

# Reshape and expand dimensions to match the structure of load_digits dataset
x_train_mnist = np.expand_dims(mnist_features, 1)

# Split the MNIST data into train and test sets
x_train, x_test, y_train, y_test = train_test_split(
    x_train_mnist, mnist_labels,  train_size=5000, test_size=100, shuffle=True, random_state=42
)

#x_train = x_train.astype('float64')
print(x_train.dtype)
print(x_train.shape)

print(x_train)
# plt.imshow(x_train[0,0], cmap='grey')
# plt.show()
# Verify the shapes
print("Shape of x_train from MNIST:", x_train.shape)
print("Shape of x_test from MNIST:", x_test.shape)
print("Shape of y_train from MNIST:", y_train.shape)
print("Shape of y_test from MNIST:", y_test.shape)

uint8
(5000, 1, 28, 28)
[[[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 ...


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]


 [[[0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   ...
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]
   [0 0 0 ... 0 0 0]]]]
Shape of x_train from MNIST: (5000, 1, 28, 28)
Shape of x_test from MNIST: (100, 1, 28, 28)
Shape of y_train from MNIST: (5000,)
Shape of y_test fro

In [50]:
@fhe.compiler({"image": "encrypted"})
def to_grayscale(image):
    with fhe.tag("scaling.r"):
        r = image[:, :, 0]
        r = (r * 0.30).astype(np.int64)

    with fhe.tag("combining.rgb"):
        gray = r 
        
    with fhe.tag("creating.result"):
        gray = np.expand_dims(gray, axis=2)
        result = np.concatenate((gray, gray, gray), axis=2)
    
    return result


In [51]:
configuration = fhe.Configuration(
    enable_unsafe_features=True,
    use_insecure_key_cache=True,
    insecure_key_cache_location=".keys",

    # To enable displaying progressbar
    show_progress=True,
    # To enable showing tags in the progressbar (does not work in notebooks)
    progress_tag=True,
    # To give a title to the progressbar
    progress_title="Evaluation:",
)

In [52]:
inputset = x_train
print(f"Compilation started @ {time.strftime('%H:%M:%S', time.localtime())}")
start = time.time()
circuit = to_grayscale.compile(inputset, configuration)
end = time.time()
print(f"(took {end - start:.3f} seconds)")

Compilation started @ 19:20:35
(took 0.969 seconds)


In [53]:
print(f"Key generation started @ {time.strftime('%H:%M:%S', time.localtime())}")
start = time.time()
circuit.keygen()
end = time.time()
print(f"(took {end - start:.3f} seconds)")

print()

Key generation started @ 19:20:40
(took 1.587 seconds)



In [54]:
image_data_b = x_train[0]
print(f"Evaluation started @ {time.strftime('%H:%M:%S', time.localtime())}")
start = time.time()
enc_image = circuit.encrypt(image_data_b)
enc_image = circuit.run(enc_image)
end = time.time()
print(f"(took {end - start:.3f} seconds)")

Evaluation started @ 19:20:46
           __________________________________________________
Evaluation:████ 100%
(took 15.184 seconds)


In [72]:
#print(x_train)
x_train_a = x_train
sum = x_train_a[0][0][0][0] + x_train_a[0][0][0][1]
print(f'sum {sum}')

multiply = x_train_a[0][0][0][0] * x_train_a[0][0][0][1]
print(f'multi {multiply}')

sum 0
multi 0


In [76]:
@fhe.compiler({"x":"encrypted"})
def add(x):
    sum =  x[0,0,0] + x[0,0,0]
    return sum

In [77]:
x_train_a = x_train
circuit_add = add.compile(x_train_a)

enc_image_data = circuit_add.encrypt(x_train_a[0])
enc_image_add = circuit_add.run(enc_image_data)
dec_image = circuit_add.decrypt(enc_image)
print(dec_image)

Tracer<output=EncryptedScalar<uint8>>
150


In [62]:
@fhe.compiler({"x":"encrypted"})
def multiply(x):
    sum =  x[0,0,0] * x[0,0,0]
    return sum

In [63]:
circuit_multiply = multiply.compile(x_train_a)

#enc_image = circuit_multiply.encrypt(x_train_a[0])
enc_image_multiply = circuit_multiply.run(enc_image_data)
dec_image = circuit_multiply.decrypt(enc_image_multiply)
print(dec_image)

0


In [64]:
#lookup_table = tuple(range(256))

lookup_table = (0,) * 128 + (1,) * 128

lookup_table 
print(lookup_table)

(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)


In [65]:
table = fhe.LookupTable(lookup_table)

In [66]:
@fhe.compiler({"x":"encrypted"})
def spike_function(x):
    return table[x]

In [67]:
inputset = range(0,255)
print(inputset)
circuit_spiking = spike_function.compile(inputset)

range(0, 255)


In [68]:
value = 10
enc_value = circuit_spiking.encrypt(value)
enc_value_a = circuit_spiking.run(enc_value)

value = 140
enc_value = circuit_spiking.encrypt(value)
enc_value_b = circuit_spiking.run(enc_value)

enc_value_c = circuit_spiking.run(enc_image_add)
enc_value_d = circuit_spiking.run(enc_image_multiply)

print(enc_value_a)
print(enc_value_b)
print(enc_value_c)
print(enc_value_c)

<concrete.fhe.compilation.value.Value object at 0x7efcd4c3ceb0>
<concrete.fhe.compilation.value.Value object at 0x7efcd4c3cac0>
<concrete.fhe.compilation.value.Value object at 0x7efcd4dccf40>
<concrete.fhe.compilation.value.Value object at 0x7efcd4dccf40>


In [70]:
dec_value = circuit_spiking.decrypt(enc_value_a)
print(f'decrypted enc_value_a {dec_value}')

dec_value = circuit_spiking.decrypt(enc_value_b)
print(f'decrypted enc_value_b {dec_value}')

dec_value = circuit_spiking.decrypt(enc_value_c)
print(f'decrypted enc_value_c {dec_value}')

dec_value = circuit_spiking.decrypt(enc_value_d)
print(f'decrypted enc_value_d {dec_value}')

decrypted enc_value_a 0
decrypted enc_value_b 1
decrypted enc_value_c 0
decrypted enc_value_d 0
