In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
#INSTALL AND IMPORT DEPENDENCIES
from IPython.display import clear_output
!pip install dask[dataframe]
clear_output()

import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "hummingbird-ml"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "xgboost"])

import json
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier as Gbc
import torch
import ezkl
import os
from torch import nn
from hummingbird.ml import convert
import pandas as pd, gc

In [4]:
BUILD95 = True
BUILD96 = True



# COLUMNS WITH STRINGS
str_type = ['ProductCD', 'card4', 'card6', 'P_emaildomain', 'R_emaildomain','M1', 'M2', 'M3', 'M4','M5',
            'M6', 'M7', 'M8', 'M9', 'id_12', 'id_15', 'id_16', 'id_23', 'id_27', 'id_28', 'id_29', 'id_30',
            'id_31', 'id_33', 'id_34', 'id_35', 'id_36', 'id_37', 'id_38', 'DeviceType', 'DeviceInfo']
str_type += ['id-12', 'id-15', 'id-16', 'id-23', 'id-27', 'id-28', 'id-29', 'id-30',
            'id-31', 'id-33', 'id-34', 'id-35', 'id-36', 'id-37', 'id-38']

# FIRST 53 COLUMNS
cols = ['TransactionID', 'TransactionDT', 'TransactionAmt',
       'ProductCD', 'card1', 'card2', 'card3', 'card4', 'card5', 'card6',
       'addr1', 'addr2', 'dist1', 'dist2', 'P_emaildomain', 'R_emaildomain',
       'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11',
       'C12', 'C13', 'C14', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8',
       'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'M1', 'M2', 'M3', 'M4',
       'M5', 'M6', 'M7', 'M8', 'M9']

# V COLUMNS TO LOAD DECIDED BY CORRELATION EDA
# https://www.kaggle.com/cdeotte/eda-for-columns-v-and-id
v =  [1, 3, 4, 6, 8, 11]
v += [13, 14, 17, 20, 23, 26, 27, 30]
v += [36, 37, 40, 41, 44, 47, 48]
v += [54, 56, 59, 62, 65, 67, 68, 70]
v += [76, 78, 80, 82, 86, 88, 89, 91]

#v += [96, 98, 99, 104] #relates to groups, no NAN
v += [107, 108, 111, 115, 117, 120, 121, 123] # maybe group, no NAN
v += [124, 127, 129, 130, 136] # relates to groups, no NAN

# LOTS OF NAN BELOW
v += [138, 139, 142, 147, 156, 162] #b1
v += [165, 160, 166] #b1
v += [178, 176, 173, 182] #b2
v += [187, 203, 205, 207, 215] #b2
v += [169, 171, 175, 180, 185, 188, 198, 210, 209] #b2
v += [218, 223, 224, 226, 228, 229, 235] #b3
v += [240, 258, 257, 253, 252, 260, 261] #b3
v += [264, 266, 267, 274, 277] #b3
v += [220, 221, 234, 238, 250, 271] #b3

v += [294, 284, 285, 286, 291, 297] # relates to grous, no NAN
v += [303, 305, 307, 309, 310, 320] # relates to groups, no NAN
v += [281, 283, 289, 296, 301, 314] # relates to groups, no NAN
#v += [332, 325, 335, 338] # b4 lots NAN

cols += ['V'+str(x) for x in v]
dtypes = {}
for c in cols+['id_0'+str(x) for x in range(1,10)]+['id_'+str(x) for x in range(10,34)]+\
    ['id-0'+str(x) for x in range(1,10)]+['id-'+str(x) for x in range(10,34)]:
        dtypes[c] = 'float32'
for c in str_type: dtypes[c] = 'category'

In [5]:
#LOAD AND MERGE THE TWO PARTS OF THE DATASET
transact_train = pd.read_csv('/content/drive/MyDrive/ieee-fraud-detection/train_transaction.csv',index_col='TransactionID', dtype=dtypes, usecols=cols+['isFraud'])
train_id = pd.read_csv('/content/drive/MyDrive/ieee-fraud-detection/train_identity.csv',index_col='TransactionID', dtype=dtypes)
X = transact_train.merge(train_id, how='left', left_index=True, right_index=True)

y = X['isFraud']


In [6]:
#DIVIDE THE DATA INTO TRAINING AND TESTING SET
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
X_train.shape, y_train.shape, X_test.shape,  y_test.shape

In [8]:
# NORMALIZE D COLUMNS
for i in range(1,16):
    if i in [1,2,3,5,9]: continue
    X_train['D'+str(i)] =  X_train['D'+str(i)] - X_train.TransactionDT/np.float32(24*60*60)
    X_test['D'+str(i)] = X_test['D'+str(i)] - X_test.TransactionDT/np.float32(24*60*60)

In [9]:
# LABEL ENCODE AND MEMORY REDUCE
for i,f in enumerate(X_train.columns):
    # FACTORIZE CATEGORICAL VARIABLES
    if (str(X_train[f].dtype)=='category')|(X_train[f].dtype=='object'):
        df_comb = pd.concat([X_train[f],X_test[f]],axis=0)
        df_comb,_ = df_comb.factorize(sort=True)
        if df_comb.max()>32000: print(f,'needs int32')
        X_train[f] = df_comb[:len(X_train)].astype('int16')
        X_test[f] = df_comb[len(X_train):].astype('int16')
    # SHIFT ALL NUMERICS POSITIVE. SET NAN to -1
    elif f not in ['TransactionAmt','TransactionDT']:
        mn = np.min((X_train[f].min(),X_test[f].min()))
        X_train[f] -= np.float32(mn)
        X_test[f] -= np.float32(mn)

        X_train[f] = X_train[f].fillna(-1)
        X_test[f] = X_test[f].fillna(-1)

In [10]:
# FREQUENCY ENCODE TOGETHER
def encode_FE(df1, df2, cols):
    for col in cols:
        df = pd.concat([df1[col],df2[col]])
        vc = df.value_counts(dropna=True, normalize=True).to_dict()
        vc[-1] = -1
        nm = col+'_FE'
        df1[nm] = df1[col].map(vc)
        df1[nm] = df1[nm].astype('float32')
        df2[nm] = df2[col].map(vc)
        df2[nm] = df2[nm].astype('float32')
        print(nm,', ',end='')

# LABEL ENCODE
def encode_LE(col,train=X_train,test=X_test,verbose=True):
    df_comb = pd.concat([train[col],test[col]],axis=0)
    df_comb,_ = df_comb.factorize(sort=True)
    nm = col
    if df_comb.max()>32000:
        train[nm] = df_comb[:len(train)].astype('int32')
        test[nm] = df_comb[len(train):].astype('int32')
    else:
        train[nm] = df_comb[:len(train)].astype('int16')
        test[nm] = df_comb[len(train):].astype('int16')
    del df_comb; x=gc.collect()
    if verbose: print(nm,', ',end='')

# GROUP AGGREGATION MEAN AND STD
# https://www.kaggle.com/kyakovlev/ieee-fe-with-some-eda
def encode_AG(main_columns, uids, aggregations=['mean'], train_df=X_train, test_df=X_test,
              fillna=True, usena=False):
    # AGGREGATION OF MAIN WITH UID FOR GIVEN STATISTICS
    for main_column in main_columns:
        for col in uids:
            for agg_type in aggregations:
                new_col_name = main_column+'_'+col+'_'+agg_type
                temp_df = pd.concat([train_df[[col, main_column]], test_df[[col,main_column]]])
                if usena: temp_df.loc[temp_df[main_column]==-1,main_column] = np.nan
                temp_df = temp_df.groupby([col])[main_column].agg([agg_type]).reset_index().rename(
                                                        columns={agg_type: new_col_name})

                temp_df.index = list(temp_df[col])
                temp_df = temp_df[new_col_name].to_dict()

                train_df[new_col_name] = train_df[col].map(temp_df).astype('float32')
                test_df[new_col_name]  = test_df[col].map(temp_df).astype('float32')

                if fillna:
                    train_df[new_col_name].fillna(-1,inplace=True)
                    test_df[new_col_name].fillna(-1,inplace=True)

                print("'"+new_col_name+"'",', ',end='')

# COMBINE FEATURES
def encode_CB(col1,col2,df1=X_train,df2=X_test):
    nm = col1+'_'+col2
    df1[nm] = df1[col1].astype(str)+'_'+df1[col2].astype(str)
    df2[nm] = df2[col1].astype(str)+'_'+df2[col2].astype(str)
    encode_LE(nm,verbose=False)
    print(nm,', ',end='')

# GROUP AGGREGATION NUNIQUE
def encode_AG2(main_columns, uids, train_df=X_train, test_df=X_test):
    for main_column in main_columns:
        for col in uids:
            comb = pd.concat([train_df[[col]+[main_column]],test_df[[col]+[main_column]]],axis=0)
            mp = comb.groupby(col)[main_column].agg(['nunique'])['nunique'].to_dict()
            train_df[col+'_'+main_column+'_ct'] = train_df[col].map(mp).astype('float32')
            test_df[col+'_'+main_column+'_ct'] = test_df[col].map(mp).astype('float32')
            print(col+'_'+main_column+'_ct, ',end='')

In [11]:
X_train.info()

<class 'pandas.core.frame.DataFrame'>
Index: 531486 entries, 3029238.0 to 3108958.0
Columns: 214 entries, isFraud to DeviceInfo
dtypes: float32(182), float64(1), int16(31)
memory usage: 406.5 MB


In [12]:
cols = list( X_train.columns )
cols.remove('TransactionDT')
for c in ['D6','D7','D8','D9','D12','D13','D14']:
    cols.remove(c)

# FAILED TIME CONSISTENCY TEST
for c in ['C3','M5','id_08','id_33']:
    cols.remove(c)
for c in ['card4','id_07','id_14','id_21','id_30','id_32','id_34']:
    cols.remove(c)
for c in ['id_'+str(x) for x in range(22,28)]:
    cols.remove(c)

In [13]:
cols = list( X_train.columns )
cols.remove('TransactionDT')
for c in ['D6','D7','D8','D9','D12','D13','D14']:
    cols.remove(c)

# FAILED TIME CONSISTENCY TEST
for c in ['C3','M5','id_08','id_33']:
    cols.remove(c)
for c in ['card4','id_07','id_14','id_21','id_30','id_32','id_34']:
    cols.remove(c)
for c in ['id_'+str(x) for x in range(22,28)]:
    cols.remove(c)

In [14]:
print('NOW USING THE FOLLOWING',len(cols),'FEATURES.')
np.array(cols)

NOW USING THE FOLLOWING 189 FEATURES.


array(['isFraud', 'TransactionAmt', 'ProductCD', 'card1', 'card2',
       'card3', 'card5', 'card6', 'addr1', 'addr2', 'dist1', 'dist2',
       'P_emaildomain', 'R_emaildomain', 'C1', 'C2', 'C4', 'C5', 'C6',
       'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'D1', 'D2',
       'D3', 'D4', 'D5', 'D10', 'D11', 'D15', 'M1', 'M2', 'M3', 'M4',
       'M6', 'M7', 'M8', 'M9', 'V1', 'V3', 'V4', 'V6', 'V8', 'V11', 'V13',
       'V14', 'V17', 'V20', 'V23', 'V26', 'V27', 'V30', 'V36', 'V37',
       'V40', 'V41', 'V44', 'V47', 'V48', 'V54', 'V56', 'V59', 'V62',
       'V65', 'V67', 'V68', 'V70', 'V76', 'V78', 'V80', 'V82', 'V86',
       'V88', 'V89', 'V91', 'V107', 'V108', 'V111', 'V115', 'V117',
       'V120', 'V121', 'V123', 'V124', 'V127', 'V129', 'V130', 'V136',
       'V138', 'V139', 'V142', 'V147', 'V156', 'V160', 'V162', 'V165',
       'V166', 'V169', 'V171', 'V173', 'V175', 'V176', 'V178', 'V180',
       'V182', 'V185', 'V187', 'V188', 'V198', 'V203', 'V205', 'V207',
       'V209'

Ezkl and Halo2 were generating huge values. I figured out from the ezkl documentation that including calibration data was optional so I first removed the calibration data lines of code which was also generating huge values beyond the maximum required by ezkl. Then I scaled down the non-categorical features into a range of 0 to 1 using min-max normalization. This solved the problem.

In [15]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

X_train = pd.DataFrame(X_train)

# Identify numerical columns including other numerical types
numerical_cols = X_train.select_dtypes(include=['float64', 'int64', 'float32', 'int32']).columns

if len(numerical_cols) > 0:
    # Initialize the MinMaxScaler
    scaler = MinMaxScaler()

    # Fit the scaler on numerical data and transform it
    X_train[numerical_cols] = scaler.fit_transform(X_train[numerical_cols])
else:
    print("No numerical columns to scale.")


In [16]:
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

X_test = pd.DataFrame(X_test)

# Identify numerical columns including other numerical types
numerical_cols = X_test.select_dtypes(include=['float64', 'int64', 'float32', 'int32']).columns

if len(numerical_cols) > 0:
    # Initialize the MinMaxScaler
    scaler = MinMaxScaler()

    # Fit the scaler on numerical data and transform it
    X_test[numerical_cols] = scaler.fit_transform(X_test[numerical_cols])
else:
    print("No numerical columns to scale in X_test.")


In [17]:
#Divide the training data into training and validation set
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

In [18]:
X_train.info()

<class 'pandas.core.frame.DataFrame'>
Index: 478337 entries, 3548099.0 to 3020988.0
Columns: 214 entries, isFraud to DeviceInfo
dtypes: float64(183), int16(31)
memory usage: 698.0 MB


In [19]:
X_train.head()

Unnamed: 0_level_0,isFraud,TransactionDT,TransactionAmt,ProductCD,card1,card2,card3,card4,card5,card6,...,id_31,id_32,id_33,id_34,id_35,id_36,id_37,id_38,DeviceType,DeviceInfo
TransactionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3548099.0,0.0,0.939438,0.00184,4,0.918544,0.383234,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3519483.0,0.0,0.887201,0.003906,2,0.459991,0.706587,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3034032.0,0.0,0.065605,0.003372,4,0.782076,0.856287,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3500145.0,1.0,0.849384,0.002341,1,0.392389,0.762475,0.386364,2,0.905797,1,...,103,0.757576,25,2,1,0,1,1,0,208
3570044.0,0.0,0.986237,0.00137,4,0.103645,0.001996,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1


In [20]:
X_val.head()

Unnamed: 0_level_0,isFraud,TransactionDT,TransactionAmt,ProductCD,card1,card2,card3,card4,card5,card6,...,id_31,id_32,id_33,id_34,id_35,id_36,id_37,id_38,DeviceType,DeviceInfo
TransactionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3509589.0,0.0,0.867468,0.001807,4,0.414233,0.91018,0.386364,3,0.92029,0,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3258470.0,0.0,0.412741,0.000741,0,0.855656,0.89022,0.651515,3,0.282609,1,...,90,0.0,-1,-1,0,0,1,0,1,12
3333800.0,0.0,0.538127,0.000805,4,0.890147,0.149701,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3206881.0,0.0,0.32307,0.006566,4,0.108301,0.780439,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3013389.0,0.0,0.037402,0.001003,0,0.072201,0.89022,0.651515,3,0.92029,0,...,51,0.0,-1,-1,0,0,1,1,0,581


In [21]:
X_test.head()

Unnamed: 0_level_0,isFraud,TransactionDT,TransactionAmt,ProductCD,card1,card2,card3,card4,card5,card6,...,id_31,id_32,id_33,id_34,id_35,id_36,id_37,id_38,DeviceType,DeviceInfo
TransactionID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3457624.0,0.0,0.767428,0.022659,4,0.392212,0.762475,0.386364,2,0.905797,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3552820.0,0.0,0.948825,0.003386,4,0.66358,0.443114,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3271083.0,0.0,0.437781,0.00149,4,0.482745,0.023952,0.386364,2,0.905797,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3226689.0,0.0,0.355327,0.003139,0,0.855746,0.89022,0.651515,3,0.282609,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1
3268855.0,0.0,0.432477,0.003369,4,0.833429,0.780439,0.386364,3,0.92029,1,...,-1,0.0,-1,-1,-1,-1,-1,-1,-1,-1


**XGBOOST TRAINING**

In [22]:
# CONVERT THE PANDAS DATAFRAME TO NUMPY ARRAYS WITH FLOAT32 DATATYPE AS IN THE OFFICIAL EZKL XGBOOST CODE
X_train = np.array(X_train, dtype=np.float32)
y_train = np.array(y_train, dtype=np.float32)
X_val = np.array(X_val, dtype=np.float32)
y_val = np.array(y_val, dtype=np.float32)
X_test = np.array(X_test, dtype=np.float32)
y_test = np.array(y_test, dtype=np.float32)

In [23]:
X_val

array([[ 0.0000000e+00,  8.6746782e-01,  1.8066427e-03, ...,
        -1.0000000e+00, -1.0000000e+00, -1.0000000e+00],
       [ 0.0000000e+00,  4.1274107e-01,  7.4129994e-04, ...,
         0.0000000e+00,  1.0000000e+00,  1.2000000e+01],
       [ 0.0000000e+00,  5.3812677e-01,  8.0467446e-04, ...,
        -1.0000000e+00, -1.0000000e+00, -1.0000000e+00],
       ...,
       [ 0.0000000e+00,  9.3356955e-01,  1.8066427e-03, ...,
        -1.0000000e+00, -1.0000000e+00, -1.0000000e+00],
       [ 0.0000000e+00,  6.1563271e-01,  4.7884689e-04, ...,
         0.0000000e+00,  0.0000000e+00,  5.1800000e+02],
       [ 0.0000000e+00,  9.1070735e-01,  7.7492849e-04, ...,
        -1.0000000e+00, -1.0000000e+00, -1.0000000e+00]], dtype=float32)

When 2000, 1000, 800 estimators are used in XGBoost, the code that SETS UP THE ZK PROOF CIRCUIT PARAMS
throws this
Error --> *(PanicException: dynamic lookup or shuffle should only have one block)*

However, 500 estimators works without any error

In [25]:
clf = Gbc(n_estimators=500,
          max_depth=12,
          learning_rate=0.02,
          subsample=0.8
          )

In [26]:
#clf.fit(X_train, y_train, eval_set=[(X_val,y_val)], Verbose=50)

In [27]:
clf.fit(X_train, y_train)

#import xgboost as xgb
#print("XGBoost version:", xgb.__version__)

**USING EZKL FOR XGBOOST INFERENCING ON THE FRAUD DETECTION DATA**

1. Convert the Scikit-Learn's XGBoost model to a PyTorch model using the HummingBird *convert* method
2.  

In [28]:
# Convert to torch
torch_gbt = convert(clf, 'torch')

In [29]:
# Convert the entire test set to a tensor
X_test_tensor = torch.tensor(X_test)

# Perform batch prediction using the entire test set
torch_preds = torch_gbt.predict(X_test_tensor)

# Use sklearn for batch prediction as well
sk_preds = clf.predict(X_test)

# Compare predictions and calculate differences
diffs = torch_preds != sk_preds

# Count the number of differences
num_diff = diffs.sum().item()  # Use .item() to get a Python number from a tensor

print("num diff: ", num_diff)

num diff:  0


In [40]:
torch_preds.shape, sk_preds.shape

((59054,), (59054,))

In [39]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, confusion_matrix

# Assuming 'preds' contains predicted probabilities from clf.predict_proba
# Thresholding at 0.5 to get binary predictions (class labels)
pred_labels = (sk_preds >= 0.5).astype(int)

# Accuracy
accuracy = accuracy_score(y_test, pred_labels)

# Precision
precision = precision_score(y_test, pred_labels)

# Recall
recall = recall_score(y_test, pred_labels)

# Print the results
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

Accuracy: 1.0000
Precision: 1.0000
Recall: 1.0000


In [41]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, confusion_matrix

# Assuming 'preds' contains predicted probabilities from clf.predict_proba
# Thresholding at 0.5 to get binary predictions (class labels)
pred_labels = (torch_preds >= 0.5).astype(int)

# Accuracy
accuracy = accuracy_score(y_test, pred_labels)

# Precision
precision = precision_score(y_test, pred_labels)

# Recall
recall = recall_score(y_test, pred_labels)

# Print the results
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

Accuracy: 1.0000
Precision: 1.0000
Recall: 1.0000


In [30]:
model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')

witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')

In [31]:
# !!!!!!!!!!!!!!!!! This cell will flash a warning about onnx runtime compat but it is fine !!!!!!!!!!!!!!!!!!!!!


# export to onnx format


# Input to the model
shape = X_train.shape[1:]
x = torch.rand(1, *shape, requires_grad=False)
torch_out = torch_gbt.predict(x)
# Export the model
torch.onnx.export(torch_gbt.model,               # model being run
                  # model input (or a tuple for multiple inputs)
                  x,
                  # where to save the model (can be a file or file-like object)
                  "network.onnx",
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=18,          # the ONNX version to export the model to
                  input_names=['input'],   # the model's input names
                  output_names=['output'],  # the model's output names
                  dynamic_axes={'input': {0: 'batch_size'},    # variable length axes
                                'output': {0: 'batch_size'}})

d = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_shapes=[shape],
            input_data=[d],
            output_data=[(o).reshape([-1]).tolist() for o in torch_out])

# Serialize data into file:
json.dump(data, open("input.json", 'w'))


In [32]:
run_args = ezkl.PyRunArgs()
run_args.variables = [("batch_size", 1)]

# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True


In [33]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [34]:
# srs path
res = await ezkl.get_srs( settings_path)

In [35]:
# now generate the witness file

res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [36]:
# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK



res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,

    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [37]:
# GENERATE A PROOF
%%time

proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,

        "single",
    )

print(res)
assert os.path.isfile(proof_path)

{'instances': [['0000000000000000000000000000000000000000000000000000000000000000', '4000000000000000000000000000000000000000000000000000000000000000', '4000000000000000000000000000000000000000000000000000000000000000']], 'proof': '0x15fd47986f7e4e6b75cf69f2714b2d20a194feb3ea9f0c5754f524f191c6734326103251d0557504ce81bef2d64f8a4eff34f1a0749a1bf95e73b2ed0da2482c20dfb89ebaa747355e9e84659c4cabc19b5e0c73a70eb9aeb4b330b991232e8000ad8fdb5307d8c0d7be49c5b7e2d69a18c788787213cbe333a1f9e56ee9591f09f56e53278366f6ffb3d003a751dfd3c810ba913c668ce275768151048a2a00010718a35b5e5f5da8fef6546b5c9d5df9fed62438fb1e54245821f77c57832d27514ed4319f6bf09960c54ef02c68553dfd94eda772f159653df17942d8610f03d6c46bf517bab747edacac59c192a825cd5f787d8c18c33d9ef94009a4ab770087c665037c68e9e350590b103a5a972eb3fe017528c1ebffa3caf46eb6d814128f917fb2c86890a6c7ed7a56e497a4d37ccf81ed47b1de1aadca11552aacd003e54f066cd1e5dd7f3d6556e6df99f4aa0f6998abc36f5fed6c8419bdbfeb30132c463155bd264e6a14c33c8f4edd77db1ca826f58546564ef7ca2271a676

In [38]:
# VERIFY IT
%%time

res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,

    )

assert res == True
print("verified")

verified
CPU times: user 312 ms, sys: 8.51 ms, total: 320 ms
Wall time: 108 ms
