<a href="https://colab.research.google.com/github/Rhicarde/CECS-574---ZKML/blob/main/ZKML_for_Privacy_Preservation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install skl2onnx onnx onnxruntime



# Dataset
View dataset from kaggle [here](https://www.kaggle.com/datasets/ankushpanday2/heart-attack-risk-and-prediction-dataset-in-india/data).

The dataset used is representing medical and lifestyle risk factors that may lead to heart diseases. We will be using it to train a model that can ideally predict early signs of heart diseases based on the given factors.

In [None]:
# Downloading Dataset
import kagglehub

path = kagglehub.dataset_download("ankushpanday2/heart-attack-risk-and-prediction-dataset-in-india")

print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/ankushpanday2/heart-attack-risk-and-prediction-dataset-in-india/versions/1


In [None]:
import pandas as pd

df = pd.read_csv(path + '/heart_attack_prediction_india.csv')
df.head()

Unnamed: 0,Patient_ID,State_Name,Age,Gender,Diabetes,Hypertension,Obesity,Smoking,Alcohol_Consumption,Physical_Activity,...,Diastolic_BP,Air_Pollution_Exposure,Family_History,Stress_Level,Healthcare_Access,Heart_Attack_History,Emergency_Response_Time,Annual_Income,Health_Insurance,Heart_Attack_Risk
0,1,Rajasthan,42,Female,0,0,1,1,0,0,...,119,1,0,4,0,0,157,611025,0,0
1,2,Himachal Pradesh,26,Male,0,0,0,0,1,1,...,115,0,0,7,0,0,331,174527,0,0
2,3,Assam,78,Male,0,0,1,0,0,1,...,117,0,1,10,1,0,186,1760112,1,0
3,4,Odisha,58,Male,1,0,1,0,0,1,...,65,0,0,1,1,1,324,1398213,0,0
4,5,Karnataka,22,Male,0,0,0,0,0,1,...,109,0,0,9,0,0,209,97987,0,1


In [None]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 26 columns):
 #   Column                   Non-Null Count  Dtype 
---  ------                   --------------  ----- 
 0   Patient_ID               10000 non-null  int64 
 1   State_Name               10000 non-null  object
 2   Age                      10000 non-null  int64 
 3   Gender                   10000 non-null  object
 4   Diabetes                 10000 non-null  int64 
 5   Hypertension             10000 non-null  int64 
 6   Obesity                  10000 non-null  int64 
 7   Smoking                  10000 non-null  int64 
 8   Alcohol_Consumption      10000 non-null  int64 
 9   Physical_Activity        10000 non-null  int64 
 10  Diet_Score               10000 non-null  int64 
 11  Cholesterol_Level        10000 non-null  int64 
 12  Triglyceride_Level       10000 non-null  int64 
 13  LDL_Level                10000 non-null  int64 
 14  HDL_Level                10000 non-null

# Training Model

The model that is used is a simple logistic regression model from Sklearn. A logistic regression model estimates the probability of an event, in this case heart disease, occuring using the given data.

In [None]:
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler , OneHotEncoder

# Split Dataset
x = df.drop('Heart_Attack_Risk', axis=1)
y = df['Heart_Attack_Risk']

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=2)

ct = ColumnTransformer(
    transformers= [('onehot', OneHotEncoder(drop='first'), ['State_Name', 'Gender']),
                   ('normal', StandardScaler(),
                   ['Diastolic_BP', 'Annual_Income', 'Emergency_Response_Time', 'Systolic_BP', 'Cholesterol_Level', 'Triglyceride_Level', 'LDL_Level', 'HDL_Level'])
                  ], remainder='passthrough')

x_train_ct = ct.fit_transform(x_train)
x_test_ct = ct.transform(x_test)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix

# Logistic Regression Model
model = LogisticRegression(max_iter=2000)

# Train Model
model.fit(x_train_ct, y_train)

# Test Model Accuracy
y_pred = model.predict(x_test_ct)
print(f'Accuracy Score: {accuracy_score(y_test, y_pred) * 100}%')

Accuracy Score: 71.75%


In [None]:
import torch
import torch.nn as nn


class LogisticRegression(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

In [None]:
# Convert to Torch Model for EZKL

# Get sklearn model weights and bias
weights = model.coef_  # Shape: (1, n_features)
bias = model.intercept_  # Shape: (1,)

# Create Torch model
input_dim = x_train_ct.shape[1]
torch_model = LogisticRegression(input_dim)

# Convert sklearn weights to PyTorch tensors and assign them
torch_model.linear.weight = nn.Parameter(torch.tensor(weights, dtype=torch.float32))
torch_model.linear.bias = nn.Parameter(torch.tensor(bias, dtype=torch.float32))

In [None]:
# Testing torch model to ensure same result

# Convert test data to tensor
x_test_torch = torch.tensor(x_test_ct, dtype=torch.float32)
y_torch = torch.tensor(y_test.tolist(), dtype=torch.float32)

torch_model.eval()

with torch.no_grad():
  # Get model predictions and convert them to int (0 or 1 tells us no or yes to heart disease risk)
  y_pred_probs = torch_model(x_test_torch).squeeze()
  y_pred_labels = (y_pred_probs >= 0.5).int()  # Convert to binary labels

# Check accuracy
accuracy = accuracy_score(y_torch, y_pred_labels.numpy())  # Compare with true labels

print(f"PyTorch Accuracy: {accuracy * 100:.2f}%")

PyTorch Accuracy: 71.75%


# Implementing EZKL circuit

In [None]:
# check if notebook is in colab
try:
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

import os
import json
import ezkl

In [None]:
# Required Files for EZKL
model_path = os.path.join('network.onnx')
data_path = os.path.join('input.json')
cal_data_path = os.path.join('calibration.json')

In [None]:
x_test_torch.shape

torch.Size([2000, 51])

In [None]:
# Convert model to ONNX

# create a random input
x = torch.randn(1, x_test_torch.shape[1])

# Export the model
torch.onnx.export(torch_model,                     # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  model_path,                # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump(data, open(data_path, 'w'))

In [None]:
!RUST_LOG=trace
# TODO: Dictionary outputs
res = ezkl.gen_settings()
assert res == True

In [None]:
# use the test set to calibrate the circuit
cal_data = dict(input_data = x_test_torch.flatten().tolist())

# Serialize calibration data into file:
json.dump(data, open(cal_data_path, 'w'))

# Optimize for resources, we cap logrows at 12 to reduce setup and proving time, at the expense of accuracy
# You may want to increase the max logrows if accuracy is a concern
res = await ezkl.calibrate_settings(target = "resources", max_logrows = 50, scales = [14])


 <------------- Numerical Fidelity Report (input_scale: 14, param_scale: 14, scale_input_multiplier: 1) ------------->

+-----------------+-----------------+-----------------+-----------------+----------------+------------------+----------------+----------------+--------------------+--------------------+------------------------+
| mean_error      | median_error    | max_error       | min_error       | mean_abs_error | median_abs_error | max_abs_error  | min_abs_error  | mean_squared_error | mean_percent_error | mean_abs_percent_error |
+-----------------+-----------------+-----------------+-----------------+----------------+------------------+----------------+----------------+--------------------+--------------------+------------------------+
| -0.000019401312 | -0.000019401312 | -0.000019401312 | -0.000019401312 | 0.000019401312 | 0.000019401312   | 0.000019401312 | 0.000019401312 | 0.0000000003764109 | -0.00021295297     | 0.00021295297          |
+-----------------+----------------

In [None]:
res = ezkl.compile_circuit()
assert res == True

In [None]:
res = await ezkl.get_srs()

In [None]:
res = ezkl.setup()
assert res == True

# Create Witness

In [None]:
# Generate the Witness for the proof

# now generate the witness file
witness_path = os.path.join('witness.json')

res = await ezkl.gen_witness()
assert os.path.isfile(witness_path)

In [None]:
# Generate the proof
proof_path = os.path.join('proof.json')

proof = ezkl.prove(proof_type="single", proof_path=proof_path)

print(proof)
assert os.path.isfile(proof_path)

{'instances': [['d505000000000000000000000000000000000000000000000000000000000000']], 'proof': '0x28d3c9eb1436b21114e0c631845006207a8251dbcbdc7492bba1e34bd39d1ecd0ba1210d1c9b60800e01c785af6cff6b8701b32e347d02370ab678cf70c993bf0a25a249d5e84ca2dc7fcd071937e55b55f619341ec89fffad3772e299b20b10168837fac8a03db3cf4bfae81d30371fdf5453bc63c5b1b16d27ec979243b34d0c3dd3d733e40e613899d006dce2e9c771d360cce33183e7a6114e805666f64e0c7eb7ad9fe270c2562b934ac2273b97bdf7d9a0596a1b33f19018346df3ce1a1ae703c3f1e9d89c115474e2df56f0c9bb5e91a7e150309745b7005a8ca804d12c989eb7ab1fdb15aee2229348fe6c317867a1f055b01a3c8efc9de84e6df7770df89d7a1061e8667ab34faefd9ab810a3273341c2b6d0dd17c556e24c9abaa40d50602cb0932c314fc4e75de496c5120c0c605d0738394763d22d99a57f105f0f1707d844771f1fcb739fb73333cde364014356a8a7f21060cdc5e33703afa22c84066503b2749b03b81f868670145daba58adf29d7ea6ddfc3af3cc92d93c81efe61e76ba18f81867e524093084dce151b53f82c4f931753d3978cbdaa35191f31d46738d333142c225f5325273b6a22ead726b44ee175b2b0ddd45072efc421d1fc

In [None]:
# verify our proof
res = ezkl.verify()

assert res == True
print("verified")

verified


# Create Verifier

In [None]:
# check if notebook is in colab
try:
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "solc-select"])
    !solc-select install 0.8.20
    !solc-select use 0.8.20
    !solc --version

# rely on local installation if the notebook is not in colab
except:
    pass

Installing solc '0.8.20'...
Version '0.8.20' installed.
Switched global version to 0.8.20
solc, the solidity compiler commandline interface
Version: 0.8.20+commit.a1b79de6.Linux.g++


In [None]:
sol_code_path = os.path.join('Verifier.sol')
abi_path = os.path.join('Verifier.abi')

res = await ezkl.create_evm_verifier(
        sol_code_path=sol_code_path,
        abi_path=abi_path,
    )

assert res == True
assert os.path.isfile(sol_code_path)

In [None]:
onchain_input_array = []

# using a loop
# avoiding printing last comma
formatted_output = "["
for i, value in enumerate(proof["instances"]):
    for j, field_element in enumerate(value):
        onchain_input_array.append(ezkl.felt_to_big_endian(field_element))
        formatted_output += '"' + str(onchain_input_array[-1]) + '"'
        if j != len(value) - 1:
            formatted_output += ", "
    if i != len(proof["instances"]) - 1:
        formatted_output += ", "
formatted_output += "]"

# This will be the values you use onchain
# copy them over to remix and see if they verify
# What happens when you change a value?
print("pubInputs: ", formatted_output)
print("proof: ", proof["proof"])

pubInputs:  ["0x00000000000000000000000000000000000000000000000000000000000005d5"]
proof:  0x28d3c9eb1436b21114e0c631845006207a8251dbcbdc7492bba1e34bd39d1ecd0ba1210d1c9b60800e01c785af6cff6b8701b32e347d02370ab678cf70c993bf0a25a249d5e84ca2dc7fcd071937e55b55f619341ec89fffad3772e299b20b10168837fac8a03db3cf4bfae81d30371fdf5453bc63c5b1b16d27ec979243b34d0c3dd3d733e40e613899d006dce2e9c771d360cce33183e7a6114e805666f64e0c7eb7ad9fe270c2562b934ac2273b97bdf7d9a0596a1b33f19018346df3ce1a1ae703c3f1e9d89c115474e2df56f0c9bb5e91a7e150309745b7005a8ca804d12c989eb7ab1fdb15aee2229348fe6c317867a1f055b01a3c8efc9de84e6df7770df89d7a1061e8667ab34faefd9ab810a3273341c2b6d0dd17c556e24c9abaa40d50602cb0932c314fc4e75de496c5120c0c605d0738394763d22d99a57f105f0f1707d844771f1fcb739fb73333cde364014356a8a7f21060cdc5e33703afa22c84066503b2749b03b81f868670145daba58adf29d7ea6ddfc3af3cc92d93c81efe61e76ba18f81867e524093084dce151b53f82c4f931753d3978cbdaa35191f31d46738d333142c225f5325273b6a22ead726b44ee175b2b0ddd45072efc421d1fc5594b

In [None]:
# Move On-Chain