In [35]:
# Quantization code from Song Han's TinyML class
import torch

import matplotlib.pyplot as plt

import numpy as np

from PIL import Image

def get_quantized_range(bitwidth):
    quantized_max = (1 << (bitwidth - 1)) - 1
    quantized_min = -(1 << (bitwidth - 1))
    return quantized_min, quantized_max

def get_quantization_scale_and_zero_point(fp_tensor, bitwidth):
    """
    get quantization scale for single tensor
    :param fp_tensor: [torch.(cuda.)Tensor] floating tensor to be quantized
    :param bitwidth: [int] quantization bit width
    :return:
        [float] scale
        [int] zero_point
    """
    quantized_min, quantized_max = get_quantized_range(bitwidth)
    fp_max = fp_tensor.max()
    fp_min = fp_tensor.min()

    # scale
    scale = (fp_max - fp_min) / (quantized_max - quantized_min)
    # zero_point
    zero_point = ((quantized_min - fp_min / scale))

    # clip the zero_point to fall in [quantized_min, quantized_max]
    if zero_point < quantized_min:
        zero_point = quantized_min
    elif zero_point > quantized_max:
        zero_point = quantized_max
    else: # convert from float to int using round()
        zero_point = round(zero_point)
    return scale, int(zero_point)

def linear_quantize(fp_tensor, bitwidth, scale, zero_point, dtype=np.int8) -> np.array:
    """
    linear quantization for single fp_tensor
      from
        r = fp_tensor = (quantized_tensor - zero_point) * scale
      we have,
        q = quantized_tensor = int(round(fp_tensor / scale)) + zero_point
    :param tensor: [np.array] floating tensor to be quantized
    :param bitwidth: [int] quantization bit width
    :param scale: [float] scaling factor
    :param zero_point: [int] the desired centroid of tensor values
    :return:
        [np.array] quantized tensor whose values are integers
    """
    # assert(fp_tensor is np.array)
    assert(isinstance(scale, float))
    assert(isinstance(zero_point, int))

    # scale the fp_tensor
    scaled_tensor = fp_tensor/scale
    # round the floating value to integer value
    rounded_tensor = (np.round(scaled_tensor)) #.to(torch.int8)

    # print(rounded_tensor.dtype)

    # shift the rounded_tensor to make zero_point 0
    shifted_tensor = rounded_tensor + zero_point

    # clamp the shifted_tensor to lie in bitwidth-bit range
    quantized_min, quantized_max = get_quantized_range(bitwidth)
    quantized_tensor = shifted_tensor.clip(quantized_min, quantized_max)
    quantized_tensor = quantized_tensor.astype(np.int8)
    return quantized_tensor

def linear_quantize_feature(fp_tensor, bitwidth):
    """
    linear quantization for feature tensor
    :param fp_tensor: [torch.(cuda.)Tensor] floating feature to be quantized
    :param bitwidth: [int] quantization bit width
    :return:
        [torch.(cuda.)Tensor] quantized tensor
        [float] scale tensor
        [int] zero point
    """
    scale, zero_point = get_quantization_scale_and_zero_point(fp_tensor, bitwidth)
    quantized_tensor = linear_quantize(fp_tensor, bitwidth, scale, zero_point)
    return quantized_tensor, scale, zero_point

checkpoint = torch.load("cpu/NN/mnist_100_10.pt")
# checkpoint = torch.load("./NN/MNIST_12_layers.pt")

weights_biases = {}

for name, param in checkpoint.items():
    weights_biases[name] = param.cpu().numpy()  # Convert to numpy array and store

int8_quant = {}

for key in weights_biases:
    print(key)
    int8_quant[key] = linear_quantize_feature(weights_biases[key], 3)[0]

fc1.weight
fc1.bias


  checkpoint = torch.load("cpu/NN/mnist_100_10.pt")


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

use_cuda = torch.cuda.is_available()


if use_cuda:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

train_kwargs = {'batch_size': 64}
test_kwargs = {'batch_size': 64}
if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


transform=transforms.Compose([
        transforms.Resize((10, 10)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        transforms.Lambda(lambda x: torch.where(x > 0.7, torch.tensor(1.0), torch.tensor(0.0))),
        transforms.Lambda(lambda x: x.view(-1))
        ])

def filter_0_and_1(dataset):
    indices = [i for i, target in enumerate(dataset.targets) if target == 0 or target == 1]
    dataset.targets = dataset.targets[indices]
    dataset.data = dataset.data[indices]
    return dataset

test_dataset = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
# test_dataset = filter_0_and_1(test_dataset)

dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
# dataset2 = filter_0_and_1(dataset2)

# Create the DataLoader
train_loader = torch.utils.data.DataLoader(test_dataset)
test_loader = torch.utils.data.DataLoader(dataset2)

In [36]:
weights = int8_quant["fc1.weight"]
m = test_dataset[0][0].numpy().astype(int)
m

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [62]:
k = 500
N = 100

q = 2**16
p = 2**10

A = [np.random.randint(0, q, N) for _ in range(k)]
s = [np.random.randint(0, 2, k)] * N
s = np.array(s).T

def polynomial_mult(s0, s1, size=N, base=q):
    result = [0] * (size)

    # Multiply the coefficients
    for i in range(len(s0)):
        for j in range(len(s1)):
            if i + j < size:
                result[i + j] += s0[i] * s1[j]

    for i in range(len(result)):
        result[i] = result[i] % base

    return result

def enc():
    # E = [1, 0, 1, -1] # \in q
    delta = q / p
    # print(delta, LAMBDA, delta/LAMBDA)
    # E = np.random.randint(-delta/LAMBDA, delta/LAMBDA, N)
    delta_m = np.array(m) * delta

    # B = np.array(delta_m) + np.array(E)
    B = np.array(delta_m)

    # breakpoint()
    for idx in range(len(A)):
        B += np.array(polynomial_mult(A[idx], s[idx], N, q))
        B %= q

    # B %= q
    # print(B)
    return B

def dec(B, c = 1):
    B_res = (np.array(B) * c) % q
    for idx in range(len(A)):
        B_res = (B_res - np.array(polynomial_mult((c**2 * A[idx]) % q, (s[idx]) % q, N, q))) % q
    # can check bottom bits and add one if needed
    return np.round(B_res / (q / p)) / c

def dec2(A, B, s, N = 100):
    B_res = (np.array(B)) % q
    for idx in range(len(A)):
        poly_m = np.array(polynomial_mult((A[idx]) % q, (s[idx]) % q, N, q))
        # print(poly_m.shape)
        B_res = (B_res - poly_m) % q
    # can check bottom bits and add one if needed
    return np.round(B_res / (q / p))%q

"""
Operations on ciphertext
"""

def add_ct(ct1, ct2):
    return ct1 + ct2

def add_constant(ct, c):
    return ct + c * q/p

def mul_constant(ct, A, c):
    return ((c * ct) % (q)), ((c * A) % (q))

In [73]:
enc()
print(np.array(dec(enc())).astype(int))
print(m)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1 1 0 0
 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1
 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1 1 0 0
 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1
 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [54]:
newA = array1 = np.zeros((10, 500))

newB = array2 = np.zeros((10, 1))

newS = np.array([s.T[0]]*10).T

In [78]:
A = np.array(A)
type(A)
print(A.shape)
print(A.T.shape)
print(weights.shape)

newA = np.dot(weights, A.T)%q
newA.shape

print(A.T)
print(weights)
print(newA)

(500, 100)
(100, 500)
(10, 100)
[[47319 24744 34242 ... 57927 36854 57258]
 [ 3060  3425   922 ... 11973  8537 36519]
 [ 7508 10707 17614 ... 31409 35953 22040]
 ...
 [35984 22449 31734 ... 17440 14156 16343]
 [53874 62325 28935 ... 58392 20035 33686]
 [41437 58154  9353 ...  4619 55655 63223]]
[[ 2  2  2  2  2  2  2  2  2  2  2  2  2  1  1  1  1  2  2  2  2  2  1  1
   1  1  1  1  1  2  2  2  1  1  1  0  1  2  2  2  2  2  1  1  0 -2 -1  2
   3  2  2  2  2  1 -2 -4 -1  2  2  2  2  2  2  1 -1 -1  1  1  2  2  2  2
   1  2  1  1  1  1  2  2  2  2  2  1  1  1  1  2  2  2  2  2  2  2  2  2
   2  2  2  2]
 [ 2  2  2  2  2  2  2  2  2  2  2  2  2  1  1  1  1  2  2  2  2  2  1  0
   0  1  0  1  1  2  2  2  1 -1 -1  2  0 -1  1  2  2  2  1 -2  0  3 -1  0
   2  2  2  2  1 -2  1  2 -3  0  1  2  2  2  0  0  2  1 -2  0  1  2  2  2
   0  0  1  0 -1  1  2  2  2  2  1  1  0  1  1  2  2  2  2  2  2  2  2  2
   2  2  2  2]
 [ 2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  1  2  2  2  2  2  2  2
   1  1 

In [77]:
b = enc().reshape((100, 1))
print(b.shape)
newB = np.dot(weights, b)
print(newB.shape)
print(newS.shape)

print(newS)
print(newB)

(100, 1)
(10, 1)
(500, 10)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [1 1 1 ... 1 1 1]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [1 1 1 ... 1 1 1]]
[[4757037.]
 [4170844.]
 [4821920.]
 [4644913.]
 [4221736.]
 [4557902.]
 [4494688.]
 [4166521.]
 [4304344.]
 [4133439.]]


In [63]:
print(newA.T.shape)
print(newB.T[0].shape)
dec2(newA.T, newB.T[0], newS, N=10)

(500, 10)
(10,)


array([  67., 1007.,  598.,  699.,  491.,  420.,  779.,  427.,  361.,
        378.])

In [65]:
print(m.reshape((100,1)).shape)
print(weights.shape)
print(m)
print(weights)
np.dot(weights, m.reshape((100,1)))

(100, 1)
(10, 100)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 1 1 0 0
 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 1
 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[[ 2  2  2  2  2  2  2  2  2  2  2  2  2  1  1  1  1  2  2  2  2  2  1  1
   1  1  1  1  1  2  2  2  1  1  1  0  1  2  2  2  2  2  1  1  0 -2 -1  2
   3  2  2  2  2  1 -2 -4 -1  2  2  2  2  2  2  1 -1 -1  1  1  2  2  2  2
   1  2  1  1  1  1  2  2  2  2  2  1  1  1  1  2  2  2  2  2  2  2  2  2
   2  2  2  2]
 [ 2  2  2  2  2  2  2  2  2  2  2  2  2  1  1  1  1  2  2  2  2  2  1  0
   0  1  0  1  1  2  2  2  1 -1 -1  2  0 -1  1  2  2  2  1 -2  0  3 -1  0
   2  2  2  2  1 -2  1  2 -3  0  1  2  2  2  0  0  2  1 -2  0  1  2  2  2
   0  0  1  0 -1  1  2  2  2  2  1  1  0  1  1  2  2  2  2  2  2  2  2  2
   2  2  2  2]
 [ 2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  2  1  2  2  2  2  2  2  2
   1  1  1  1  1  2  2  2  1 -1 -2  0  1  1  1  2  2  2  0 -3 -2  0  1  0
   1  2

array([[10],
       [ 1],
       [ 6],
       [13],
       [ 1],
       [14],
       [ 5],
       [ 8],
       [13],
       [ 5]])