# EZKL Workflow

To prove the inference of a trained model using EZKL, we need to follow the steps below. As an example to illustrate the process, let's consider that we have just trained a simple perceptron model using PyTorch:

In [1]:
#!pip install --upgrade --force-reinstall -r requirements.txt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision

# MNIST dataset
train, test = (torchvision.datasets.MNIST(
    './data', 
    train=is_train,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,)),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ]),
    download=True
) for is_train in [True, False])

input_size, output_size = 28 * 28, 10

# Define the model
perceptron = nn.Sequential(
    nn.Linear(input_size, output_size),
)

# Create a dataset and data loader
train_loader, test_loader = (DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True
) for dataset in [train, test])

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(perceptron.parameters(), lr=0.01)

# Train the model
perceptron.train()
for data, label in train_loader:
    output = perceptron(data)
    loss = criterion(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [2]:
# Test the model
perceptron.eval()
with torch.no_grad():
    correct, total = 0, 0
    for data, label in test_loader:
        output = perceptron(data)
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
    print(f'Accuracy: {100 * correct / total:.2f}%')

Accuracy: 85.69%


1. **Model Conversion**

Convert the trained model to the ONNX format. In this case, PyTorch provides the function torch.onnx.export. Other frameworks also have similar functions or external tools to convert models to ONNX (e.g., TensorFlow's tf2onnx). Nevertheless, Sklearn models are slighly more complicated to convert to suitable ONNX format, so we must first convert the model to a PyTorch using hummingbird.ml and then convert it to ONNX. We won't cover this process in this article but you can find more information in one of EZKL's notebooks.

Before converting our model to ONNX format, we need to tell the converter the input shape of the model. This can be done by passing a dummy input tensor.

In [7]:
# Choose any valid input tensor (1st input of the test dataset)
input_sample = next(iter(test_loader))[0][0].unsqueeze(0)

torch.onnx.export(
    perceptron,
    input_sample,
    "data/models/perceptron.onnx",
    opset_version=10,
    export_params=True,                # Store the trained parameter weights inside the model file
    do_constant_folding=True,          # Optimize constant values in the model graph
    input_names = ['input'],             # Input and output labels to appear in the ONNX graph 
    output_names = ['output'],
    dynamic_axes={
        'input' : {0 : 'batch_size'},    # Variable length axes
        'output' : {0 : 'batch_size'}
    }
)

verbose: False, log level: Level.ERROR



2. **Setup**

EZKL has several setup functions in their exposed API, namely gen_settings, calibrate_settings, compile_circuit, get_srs, setup, and gen_witness, we've group them together in this bullet point to describe the high level setup process. The signature of each function should be self-explanatory.

In [8]:
import os
import ezkl
import json

def create_file(filename: str) -> str:
    # If the file already exists, clear it
    if os.path.exists(filename):
        os.remove(filename)

    open(filename, 'w').close()
    return filename

# We have to create empty files manually before running the setup
INPUT = create_file("data/ezkl/input_data.json")
SETTINGS = create_file("data/ezkl/settings.json")
CALIBRATION = create_file("data/ezkl/calibration.json")
WITNESS = create_file("data/ezkl/witness.json")
COMPILED_MODEL = create_file("data/ezkl/model.compiled")
VK = create_file("data/ezkl/vk.json")
PK = create_file("data/ezkl/pk.json")
PROOF = create_file("data/ezkl/proof.json")

def setup(onnx_file: str, model: nn.Module, input_sample: torch.Tensor):

    # Create empty files for each of the required inputs
    for filename in [INPUT, SETTINGS, CALIBRATION, WITNESS, COMPILED_MODEL, VK, PK, PROOF]:
        create_file(filename)

    # Save the input data to a file in the expected format
    input_data = {
        'input_shapes': list(input_sample.shape),
        'input_data': input_sample.detach().numpy().tolist(),
        "output_data": model(input_sample).detach().numpy().tolist()
    }

    json.dump(
        input_data,
        open(INPUT, 'w')
    )

    # Run each setup function and verify that it succeeded
    assert ezkl.gen_settings(
        onnx_file,
        SETTINGS
    )

    calibration_data = {
        'input_data': torch.randn(20, input_sample.shape[1]).numpy().tolist()
    }

    json.dump(
        calibration_data,
        open(CALIBRATION, 'w')
    )

    assert ezkl.calibrate_settings(
        INPUT,
        onnx_file,
        SETTINGS, 
        "resources"
    )

    assert ezkl.compile_circuit(
        onnx_file,
        COMPILED_MODEL,
        SETTINGS
    )

    assert ezkl.get_srs(
        SETTINGS
    )

    ezkl.gen_witness(
        "data/ezkl/input_data.json",
        COMPILED_MODEL,
        WITNESS
    )

    assert ezkl.setup(
        COMPILED_MODEL,
        VK,
        PK
    )

setup("data/models/perceptron.onnx", perceptron, input_sample)

Using 6 columns for non-linearity table.
Using 11 columns for non-linearity table.
calibration failed extended k is too large to accommodate the quotient polynomial with logrows 6
Using 11 columns for non-linearity table.
calibration failed extended k is too large to accommodate the quotient polynomial with logrows 6
calibration failed max lookup input (-478801996, 243305849) is too large
calibration failed max lookup input (-478877341, 243282263) is too large
calibration failed max lookup input (-957655723, 486630371) is too large
calibration failed max lookup input (-1915080799, 973077036) is too large


 <------------- Numerical Fidelity Report (input_scale: 13, param_scale: 13, scale_input_multiplier: 10) ------------->

+---------------+--------------+-------------+---------------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| mean_error    | median_error | max_error   | min_error     | mean

3. **Proof Generation**: Generate the proof using the `gen_proof` function. This function takes the arithmetization of the model, the witness, the public key, and the proof file as inputs and writes the proof to the specified file.

In [9]:
import pprint

proof = ezkl.prove(
    WITNESS,
    COMPILED_MODEL,
    PK,
    PROOF,
    "single",
)

pprint.pprint(proof)

{'instances': [['2ae4a17d93f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                '2108b8e093f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                'ffebd7c293f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                '015c5e1600000000000000000000000000000000000000000000000000000000',
                'a50ed01f00000000000000000000000000000000000000000000000000000000',
                '0b92b0e693f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                '79034d9193f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                'e9f13a2300000000000000000000000000000000000000000000000000000000',
                '44202f0d00000000000000000000000000000000000000000000000000000000',
                'd0cfdb3900000000000000000000000000000000000000000000000000000000']],
 'proof': '0x2427cae5f23ce51de1000fdbcd3b4c931d04e172947d227ce649344cd21eaabd1399dfd586f1b1f8d6f179b28bc2b6a3eefa3e5d0474d805cd55ed36c706f9ee034824fe1e70cbabb

4. **Proof Verification**: ...

In [11]:
assert ezkl.verify(
    PROOF,
    SETTINGS,
    VK
) == True

# Token Trend Forecasting

We adapted code from one of [GIZA's examples](https://github.com/gizatechxyz/Giza-Hub/tree/token_trend_action/awesome-giza-actions/trend_token_prediction) the idea is to train multiple models with different accuracies and then compare the costs of proving the inference of each model. We first 
explain the feature extraction process in detail, which is not explained in the original example.

### Data

We will use the [Giza's dataset hub](https://github.com/gizatechxyz/datasets), which contains a collection of datasets that are relevant for blockchain applications. These datasets are publicly available and can be loaded using the `DatasetsLoader` class from the `giza_datasets` package, like so:

In [15]:
from giza.datasets import DatasetsLoader
import polars as pl

# Load the desired dataset
DatasetsLoader().load("tokens-daily-prices-mcap-volume")

# For pretty printing
pl.Config.set_tbl_hide_column_data_types(True)

Dataset tokens-daily-prices-mcap-volume not found in cache. Downloading from GCS.
Dataset read from cache.


polars.config.Config

##### Token Daily Price Data
Contains daily price data (price, market capitalization, and volume) for a set of tokens (e.g., WBTC, WETH, etc.).

In [16]:
print("First few rows of the dataset:")
print(DatasetsLoader().load('tokens-daily-prices-mcap-volume').head(n = 3))

First few rows of the dataset:
Dataset read from cache.
Loading dataset tokens-daily-prices-mcap-volume from cache.
shape: (3, 5)
┌────────────┬─────────────┬────────────┬──────────────────┬───────┐
│ date       ┆ price       ┆ market_cap ┆ volumes_last_24h ┆ token │
╞════════════╪═════════════╪════════════╪══════════════════╪═══════╡
│ 2019-02-01 ┆ 3438.360403 ┆ 0.0        ┆ 20589.040403     ┆ WBTC  │
│ 2019-02-02 ┆ 3472.243307 ┆ 0.0        ┆ 12576.723906     ┆ WBTC  │
│ 2019-02-03 ┆ 3461.058341 ┆ 0.0        ┆ 1852.526033      ┆ WBTC  │
└────────────┴─────────────┴────────────┴──────────────────┴───────┘


##### Top APY per protocol
Contains the top Annual Percentage Yield (APY) for each protocol in the dataset.

In [17]:
print("First few rows of the dataset:")
print(DatasetsLoader().load('top-pools-apy-per-protocol').head(n = 3))

First few rows of the dataset:
Dataset top-pools-apy-per-protocol not found in cache. Downloading from GCS.
Dataset read from cache.
shape: (3, 6)
┌────────────┬──────────┬─────┬─────────┬──────────────────┬──────────┐
│ date       ┆ tvlUsd   ┆ apy ┆ project ┆ underlying_token ┆ chain    │
╞════════════╪══════════╪═════╪═════════╪══════════════════╪══════════╡
│ 2022-02-28 ┆ 12808    ┆ 0.0 ┆ aave-v2 ┆ STETH            ┆ Ethereum │
│ 2022-03-01 ┆ 46045250 ┆ 0.0 ┆ aave-v2 ┆ STETH            ┆ Ethereum │
│ 2022-03-02 ┆ 90080754 ┆ 0.0 ┆ aave-v2 ┆ STETH            ┆ Ethereum │
└────────────┴──────────┴─────┴─────────┴──────────────────┴──────────┘


##### TVL per project tokens
Contains the Total Value Locked (TVL) for each project in the dataset.

In [18]:
print("First few rows of the dataset:")
print(DatasetsLoader().load('tvl-per-project-tokens').head(n = 3))

First few rows of the dataset:
Dataset tvl-per-project-tokens not found in cache. Downloading from GCS.
Dataset read from cache.
shape: (3, 47)
┌───────┬──────┬────────┬──────┬───┬──────┬──────┬────────────┬─────────┐
│ 1INCH ┆ AAVE ┆ AAVE.E ┆ AMPL ┆ … ┆ YFI  ┆ ZRX  ┆ date       ┆ project │
╞═══════╪══════╪════════╪══════╪═══╪══════╪══════╪════════════╪═════════╡
│ null  ┆ null ┆ null   ┆ null ┆ … ┆ null ┆ null ┆ 2020-11-29 ┆ aave-v2 │
│ null  ┆ null ┆ null   ┆ null ┆ … ┆ null ┆ null ┆ 2020-11-30 ┆ aave-v2 │
│ null  ┆ null ┆ null   ┆ null ┆ … ┆ null ┆ null ┆ 2020-12-01 ┆ aave-v2 │
└───────┴──────┴────────┴──────┴───┴──────┴──────┴────────────┴─────────┘


In [20]:
import itertools

TOKEN = "WETH"
LAG = 1
DAYS = [1, 3, 7, 15, 30]
START_DATE = pl.datetime(2022, 6, 1)

token_data = DatasetsLoader().load('tokens-daily-prices-mcap-volume')

# Filter the dataset to only include the target token
# Add a column with the labels for the target token
# Add columns with the price difference over the specified DAYS
# Expand the date column into day_of_week, month_of_year, and year
target_token_price_trend = token_data \
    .filter(pl.col("token") == TOKEN) \
    .with_columns(
        ((pl.col("price").shift(-LAG) - pl.col("price")) > 0).cast(pl.Int8).alias("label")
    ) \
    .with_columns(
        pl.col("price").diff(n = days).alias(f"price_diff_{days}_days")
        for days in DAYS
    ) \
    .with_columns(
        (pl.col("price") - pl.col("price").shift(days) > 0).cast(pl.Int8).alias(f"trend_{days}_days")
        for days in DAYS
    ) \
    .with_columns([
        pl.col("date").dt.weekday().alias("day"),
        pl.col("date").dt.month().alias("month"),
        pl.col("date").dt.year().alias("year")
    ])

print("First few rows of the dataset:")
print(target_token_price_trend.head(n = 3))

Dataset read from cache.
Loading dataset tokens-daily-prices-mcap-volume from cache.
First few rows of the dataset:
shape: (3, 19)
┌────────────┬─────────┬────────────┬──────────────────┬───┬───────────────┬─────┬───────┬──────┐
│ date       ┆ price   ┆ market_cap ┆ volumes_last_24h ┆ … ┆ trend_30_days ┆ day ┆ month ┆ year │
╞════════════╪═════════╪════════════╪══════════════════╪═══╪═══════════════╪═════╪═══════╪══════╡
│ 2018-02-14 ┆ 839.535 ┆ 0.0        ┆ 54776.5          ┆ … ┆ null          ┆ 3   ┆ 2     ┆ 2018 │
│ 2018-02-15 ┆ 947.358 ┆ 0.0        ┆ 111096.0         ┆ … ┆ null          ┆ 4   ┆ 2     ┆ 2018 │
│ 2018-02-16 ┆ 886.961 ┆ 0.0        ┆ 57731.7          ┆ … ┆ null          ┆ 5   ┆ 2     ┆ 2018 │
└────────────┴─────────┴────────────┴──────────────────┴───┴───────────────┴─────┴───────┴──────┘


In [21]:
token_data = DatasetsLoader().load('tokens-daily-prices-mcap-volume')
correlations = {}

# List all tokens in the dataset
tokens = token_data.get_column("token").unique().to_list()

# Calculate the correlation between the target token and all other tokens
for token_1, token_2 in itertools.permutations(tokens, r=2):
    
    # Filter the dataset and get the price and date columns
    token_1_data = token_data.filter(pl.col("token") == token_1) \
        .select(["date", "price"])
    token_2_data = token_data.filter(pl.col("token") == token_2) \
        .select(["date", "price"])
    
    # Join the datasets on the date column
    joined_data = token_1_data.join(token_2_data, on="date", suffix="_compare")

    # Nested dictionary to store the correlation between the two tokens
    correlations[token_1] = correlations.get(token_1, {}) 
    correlations[token_1][token_2] = correlations[token_1] \
        .get(token_2, {
            day: joined_data \
                    .with_columns(pl.col("price_compare").shift(day)) \
                    .select(pl.corr("price", "price_compare").alias("correlation")) \
                    .get_column("correlation")[0]
            for day in DAYS
        })


#pprint.pprint(correlations)

Dataset read from cache.
Loading dataset tokens-daily-prices-mcap-volume from cache.


In [22]:
K = 5

# Dataframe to store the final results
price_dataset = target_token_price_trend

# Retrive the relevant data from the nested dictionary
target_token_correlations = correlations[TOKEN]

# Get the top K correlated tokens for each lag
top_k_correlated_tokens_by_lag = {
    lag: sorted(target_token_correlations.items(), key=lambda x: x[1][lag], reverse=True)[:K]
    for lag in DAYS
}

top_k_correlated_tokens_15_days = top_k_correlated_tokens_by_lag[15]

for token, _ in top_k_correlated_tokens_15_days:

    # Column names for the price differences
    price_diff_columns = [f"price_diff_{token}_{days}" for days in DAYS]
    price_trend_columns = [f"price_trend_{token}_{days}" for days in DAYS]

    # Filter the dataset to only include the correlated token
    token_prices = token_data.filter(pl.col("token") == token)

    # Add columns with the price differences for each day
    token_prices = token_prices \
        .with_columns(
            pl.col("price").diff(n = days).alias(tag)
            for days, tag in zip(DAYS, price_diff_columns)
        ) \
        .with_columns([
            (pl.col("price") - pl.col("price").shift(days) > 0).cast(pl.Int8).alias(tag)
            for days, tag in zip(DAYS, price_trend_columns)
        ]) \
        .select(["date"] + price_diff_columns + price_trend_columns)

    # Join the dataset with the target token dataset
    price_dataset = price_dataset.join(token_prices, on="date", how="left")

print("First few rows of the dataset:")
print(price_dataset.head(n = 3))

First few rows of the dataset:
shape: (3, 69)
┌────────────┬─────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬───────────┐
│ date       ┆ price   ┆ market_ca ┆ volumes_l ┆ … ┆ price_tre ┆ price_tre ┆ price_tre ┆ price_tre │
│            ┆         ┆ p         ┆ ast_24h   ┆   ┆ nd_GNO_3  ┆ nd_GNO_7  ┆ nd_GNO_15 ┆ nd_GNO_30 │
╞════════════╪═════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡
│ 2018-02-14 ┆ 839.535 ┆ 0.0       ┆ 54776.5   ┆ … ┆ 0         ┆ 1         ┆ 0         ┆ 0         │
│ 2018-02-15 ┆ 947.358 ┆ 0.0       ┆ 111096.0  ┆ … ┆ 1         ┆ 1         ┆ 0         ┆ 0         │
│ 2018-02-16 ┆ 886.961 ┆ 0.0       ┆ 57731.7   ┆ … ┆ 1         ┆ 1         ┆ 0         ┆ 0         │
└────────────┴─────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴───────────┘


  price_dataset = price_dataset.join(token_prices, on="date", how="left")


In [23]:
top_apy_per_protocol = DatasetsLoader().load("top-pools-apy-per-protocol")

# Filter the dataset to only include protocols with the target token
""" unique_token_projects = top_apy_per_protocol \
    .filter(pl.col("underlying_token").str.contains(TOKEN)) \
    .filter(pl.col("date") > START_DATE) \
    .unique("project") \
    .pivot(index="date", columns="project", values=["apy", "tvlUsd"]) """

apy_df = top_apy_per_protocol \
    .filter(pl.col("underlying_token").str.contains(TOKEN)) \
    .with_columns(
            pl.col("project") + "_" + pl.col("chain") +  pl.col("underlying_token")
    ) \
    .drop(["underlying_token", "chain"])

unique_projects = apy_df \
    .filter(pl.col("date") <= START_DATE) \
    .select("project") \
    .unique()

apy_df_token = apy_df.join(
    unique_projects, 
    on="project", 
    how="inner"
)

unique_token_projects = apy_df_token.pivot(
    index="date",
    columns="project",
    values=["tvlUsd", "apy"]
)

print("First few rows of the dataset:")
print(unique_token_projects.head(n = 3))
print("Number of rows in the dataset:", len(unique_token_projects))

Dataset read from cache.
Loading dataset top-pools-apy-per-protocol from cache.
First few rows of the dataset:
shape: (3, 91)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ date      ┆ tvlUsd_pr ┆ tvlUsd_pr ┆ tvlUsd_pr ┆ … ┆ apy_proje ┆ apy_proje ┆ apy_proje ┆ apy_proj │
│           ┆ oject_aav ┆ oject_aav ┆ oject_aav ┆   ┆ ct_uniswa ┆ ct_uniswa ┆ ct_uniswa ┆ ect_year │
│           ┆ e-v2_Ethe ┆ e-v2_Poly ┆ e-v2_Aval ┆   ┆ p-v3_Arbi ┆ p-v3_Arbi ┆ p-v3_Ethe ┆ n-financ │
│           ┆ reu…      ┆ gon…      ┆ anc…      ┆   ┆ tru…      ┆ tru…      ┆ reu…      ┆ e_Ethe…  │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ 2022-02-1 ┆ 246215633 ┆ 560180650 ┆ 719972444 ┆ … ┆ null      ┆ null      ┆ null      ┆ null     │
│ 1         ┆ 5         ┆           ┆           ┆   ┆           ┆           ┆           ┆          │
│ 2022-02-1 ┆ 246420416 ┆ 537846447 ┆ 672831429 ┆ … ┆ null      ┆ 

In [24]:
tvl_df = DatasetsLoader().load("tvl-per-project-tokens") \
    .unique(subset=["date", "project"]) \
    .filter(pl.col("date") > START_DATE) 

tvl_per_projects_token = tvl_df[[TOKEN, "project", "date"]].pivot(
    index="date",
    columns="project",
    values=TOKEN
)

print("First few rows of the dataset:")
print(tvl_per_projects_token.head(n = 3))

Dataset read from cache.
Loading dataset tvl-per-project-tokens from cache.
First few rows of the dataset:
shape: (3, 20)
┌────────────┬──────────┬──────────┬───────────┬───┬───────────┬───────────┬───────────┬───────────┐
│ date       ┆ aave-v2  ┆ aura     ┆ pancakesw ┆ … ┆ rocket-po ┆ uniswap-v ┆ balancer- ┆ uniswap-v │
│            ┆          ┆          ┆ ap-amm    ┆   ┆ ol        ┆ 2         ┆ v2        ┆ 3         │
╞════════════╪══════════╪══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡
│ 2023-05-14 ┆ 4.1997e8 ┆ 1.6514e8 ┆ null      ┆ … ┆ 1.1390e9  ┆ null      ┆ 2.4431e8  ┆ 7.7659e8  │
│ 2022-12-07 ┆ 3.8163e8 ┆ 1.1824e8 ┆ null      ┆ … ┆ 4.1207e8  ┆ null      ┆ 1.9926e8  ┆ 6.0181e8  │
│ 2022-07-31 ┆ 9.0687e8 ┆ 6.2815e6 ┆ null      ┆ … ┆ 3.5818e8  ┆ null      ┆ 1.6390e8  ┆ 9.0439e8  │
└────────────┴──────────┴──────────┴───────────┴───┴───────────┴───────────┴───────────┴───────────┘


In [25]:
# Join the datasets by the date column to create the final dataset
final_dataset = price_dataset \
    .join(tvl_per_projects_token, on="date", how="inner") \
    .join(unique_token_projects, on="date", how="inner")

# Drop unnecessary columns and rows with irrelevant data
# - columns with token, market_cap , date and current price: not relevant for the model
# - rows with year < 2022: historical data is not relevant
final_dataset = final_dataset \
    .filter(pl.col("year") >= 2022) \
    .drop(["token", "market_cap", "date", "price", "month"])
final_dataset = final_dataset.slice(0, len(final_dataset) - 1)
# Drop columns if there are lots of missing values
THRESHOLD = 0.2
max_nulls = THRESHOLD * final_dataset.shape[0]
columns_to_keep = [
        col_name for col_name in final_dataset.columns if final_dataset[col_name].null_count() <= max_nulls
]
final_dataset = final_dataset.select(columns_to_keep    )

# Split the dataset into features and labels
features = final_dataset.drop("label")
labels = final_dataset["label"]

# Normalize the training dataset and fill missing values
for col in features.columns:
    mean_val = features[col].mean()
    std_dev = features[col].std() if features[col].std() != 0 else 1
    features = features.with_columns(((features[col].fill_null(mean_val) - mean_val) / std_dev).alias(col))

print("First few rows of the dataset:")
print(features.head(n = 3))

First few rows of the dataset:
shape: (3, 162)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ volumes_l ┆ price_dif ┆ price_dif ┆ price_dif ┆ … ┆ apy_proje ┆ apy_proje ┆ apy_proje ┆ apy_proj │
│ ast_24h   ┆ f_1_days  ┆ f_3_days  ┆ f_7_days  ┆   ┆ ct_uniswa ┆ ct_uniswa ┆ ct_uniswa ┆ ect_year │
│           ┆           ┆           ┆           ┆   ┆ p-v3_Arbi ┆ p-v3_Arbi ┆ p-v3_Ethe ┆ n-financ │
│           ┆           ┆           ┆           ┆   ┆ tru…      ┆ tru…      ┆ reu…      ┆ e_Ethe…  │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ 1.024704  ┆ -2.197109 ┆ 0.11474   ┆ -0.843464 ┆ … ┆ -0.576543 ┆ -0.54936  ┆ -0.044265 ┆ -0.92143 │
│           ┆           ┆           ┆           ┆   ┆           ┆           ┆           ┆ 2        │
│ 0.498189  ┆ 0.064754  ┆ -1.764524 ┆ 0.120361  ┆ … ┆ -0.678317 ┆ -0.580351 ┆ -0.044265 ┆ -0.94637 │
│           ┆           ┆           ┆       

## Model Training

##### Splitting the dataset

We need to convert the dataframes to torch tensors and split the dataset into training and testing sets. As usual, we choose a reasonable split ratio (e.g., 80% training and 20% testing) and shuffle the data before splitting it: 

In [26]:
# Convert the dataset to a PyTorch tensor
features_tensor = torch.tensor(features.to_numpy(), dtype=torch.float32)
labels_tensor = torch.tensor(labels.to_numpy(), dtype=torch.float32)

# Get a random permutation of the indices
indices = torch.randperm(len(features_tensor))
train_indices = indices[:int(0.75 * len(features_tensor))]
test_indices = indices[int(0.75 * len(features_tensor)):]
train_features, train_labels = features_tensor[train_indices], labels_tensor[train_indices]
test_features, test_labels = features_tensor[test_indices], labels_tensor[test_indices]

#### Model Definition

For the sake of simplicity, we illustrate the training process using a simple perceptron model no hidden layers and a single output neuron. We use the Sigmoid activation function to output the probability of the token price increasing:

In [27]:
model = nn.Sequential(
    nn.Linear(len(features.columns), 1),
    nn.Sigmoid()
)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#### Training the Model

We then train the model on the Binary Cross-Entropy loss function and the Adam optimizer. For the hyperparameters, we use a learning rate of 0.01 and 500 epochs:

In [28]:
def train_and_test_model(model, train_features, train_labels, test_features, test_labels, criterion, optimizer):
    model.train()
    for _ in range(1000):
        optimizer.zero_grad()
        output = model(train_features)
        loss = criterion(output, train_labels.unsqueeze(1))
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        correct, total = 0, 0
        output = model(test_features)
        predicted = torch.tensor([1 if x > 0.5 else 0 for x in output])
        total += test_labels.size(0)
        correct += (predicted == test_labels).sum().item()
    return 100 * correct / total
acc = train_and_test_model(model, train_features, train_labels, test_features, test_labels, criterion, optimizer)
print(f'Accuracy: {acc:.2f}%')

Accuracy: 66.67%


This simple Logistic Regression model obtains an accuracy of 0.70 on the test set. In the original example from GIZA, the authors trained a Multilayer Perceptron model with 2 hidden layers of decreasing input size (64 and 32). Using their feature extraction process, their MLP achieved an accuracy of around 0.65. In the final part of this article, we will show how the accuracy can be further improved by training more complex models.

#### Benchmarking Proof Generation

Once the model is trained, we can generate a proof of its inference using EZKL. We first convert the model to the ONNX format and set up the proof generation process as described in the previous sections. To simplify the process, let us write a few helper functions to convert the model to ONNX format and time the proof generation process:

In [103]:
import math
import os
from typing import Tuple, Union
import time
import torch
import sklearn as sk

def to_onnx(model, input_sample, onnx_file):
    # if the file already exists, delete it
    if os.path.exists(onnx_file):
        os.remove(onnx_file)
    torch.onnx.export(
        model,
        input_sample,
        onnx_file,
        input_names = ['input'],             # Input and output labels to appear in the ONNX graph 
        output_names = ['output'],
        opset_version=10,
        do_constant_folding=True,
        export_params=True, 
        dynamic_axes={
            'input' : {0 : 'batch_size'},    # Variable length axes
            'output' : {0 : 'batch_size'}
        }
    )

def prove():
    _= ezkl.prove(
        WITNESS,
        COMPILED_MODEL,
        PK,
        PROOF,
        "single",
    )

def verify():
    assert ezkl.verify(
        PROOF,
        SETTINGS,
        VK,
    ) == True

import contextlib

def bench_ezkl_single_round(
    model: Union[torch.nn.Module, sk.base.BaseEstimator],
    sample: torch.Tensor, 
) -> Tuple[float, float, float, int]:
    
    setup_time = -time.time()
    with contextlib.redirect_stderr(None):
        setup("data/models/perceptron.onnx", model, sample)
    setup_time += time.time()

    logrows = json.load(open(SETTINGS, 'r'))["run_args"]["logrows"]
    
    # Sleep for 1 second to make sure Rust has enough time to write the files
    time.sleep(1) 

    prove_time = -time.time()
    prove()
    prove_time += time.time()

    time.sleep(1) 

    verify_time = -time.time()
    verify()
    verify_time += time.time()

    return setup_time, prove_time, verify_time, logrows

def bench_ezkl(
    model: Union[torch.nn.Module, sk.base.BaseEstimator],
    test_features: torch.Tensor,
    rounds: int = 1,
) -> Tuple[float, float, float]:
    
    # Convert the model to ONNX and calibrate it
    to_onnx(model, test_features[0].unsqueeze(0), "data/models/perceptron.onnx")  

    setup_time, prove_time, verify_time, logrows = [], [], [], []
    for _ in range(rounds):
        # Reload the module to avoid any caching issues

        from importlib import reload
        import ezkl
        reload(ezkl)

        # randomly sample a feature from the test dataset
        sample = test_features[torch.randint(0, len(test_features), (1,))]
        s, p, v, l = bench_ezkl_single_round(model, sample)
        setup_time.append(s), prove_time.append(p), verify_time.append(v), logrows.append(l)

    # Calculate the average and standard deviation of the timings
    avg_setup, avg_prove, avg_verify, avg_logrows = (
        sum(setup_time) / rounds, 
        sum(prove_time) / rounds, 
        sum(verify_time) / rounds,
        sum(logrows) / rounds
    )

    std_setup, std_prove, std_verify, std_logrows = (
        (sum((s - avg_setup) ** 2 for s in setup_time) / rounds) ** 0.5,
        (sum((p - avg_prove) ** 2 for p in prove_time) / rounds) ** 0.5,
        (sum((v - avg_verify) ** 2 for v in verify_time) / rounds) ** 0.5,
        (sum((l - avg_logrows) ** 2 for l in logrows) / rounds) ** 0.5
    )

    print(f"Setup time: {str(avg_setup)[:5]} ± {str(std_setup/math.sqrt(rounds))[:5]} [s]")
    print(f"Prover time: {str(avg_prove)[:5]} ± {str(std_prove/math.sqrt(rounds))[:5]} [s]")
    print(f"Verifier time: {str(avg_verify)[:5]} ± {str(std_verify/math.sqrt(rounds))[:5]} [s]")
    print(f"Logrows: {str(avg_logrows)[:5]} ± {str(std_logrows/math.sqrt(rounds))[:5]}")
    
    return setup_time, prove_time, verify_time, logrows

We can now time the `setup`, `prove`, and `verify` functions by calling the `bench_ezkl` function, which allows us to obtain average times with error margins by specifying the number of `rounds`. Let's benchmark the proof generation process for the simple perceptron model:

In [104]:
_ = bench_ezkl(
    model,
    test_features,
    rounds=1
)

Setup time: 0.491 ± 0.0 [s]
Prover time: 0.487 ± 0.0 [s]
Verifier time: 0.015 ± 0.0 [s]
Logrows: 12.0 ± 0.0


## Accuracy vs. Proving Costs

For the main part of this article, we will compare the accuracy of the model with the cost of proving its inference. On the one hand, we increase the number of hidden layers and neurons of the perceptron model to show how a linear increase yields a linear increase in proof cost but a diminishing return in accuracy. On the other hand, we show how different architectures (e.g., Decision Trees, Random Forests, and SVMs) can obtain similar accuracies with varying proof costs.

#### Increasing Model Complexity

Let's start by increasing the complexity of the perceptron model. We evaluate perceptrons with one, two and three hidden layers for which we vary the number of neurons per layer as follows. We define the possible number of neurons per layer to be one of the following: [4, 8, 16, 32, 64, 128]. In addition, for any two consecutive layers, the outermost layer must have striclty less neurons than the inner one. We then train the model for each configuration and obtain the accuracies:

In [None]:
from typing import List

# Function to create a MLP model with the specified information
def create_mlp_model(layer_info: List[Tuple[int, int]]) -> nn.Module:
    layers = (
        nn.Linear(in_size, out_size)
        for in_size, out_size in layer_info
    )
    return nn.Sequential(*layers, nn.Sigmoid())


# Function to train and return the model accuracy
def train_and_return_model(model: nn.Module) -> Tuple[nn.Module, float]:
    # Get a random permutation of the indices
    indices = torch.randperm(len(features_tensor))
    train_indices = indices[:int(0.75 * len(features_tensor))]
    test_indices = indices[int(0.75 * len(features_tensor)):]
    train_features, train_labels = features_tensor[train_indices], labels_tensor[train_indices]
    test_features, test_labels = features_tensor[test_indices], labels_tensor[test_indices]

    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0012)
    acc = train_and_test_model(model, train_features, train_labels, test_features, test_labels, criterion, optimizer)
    return model, acc

def get_num_params(name: str) -> int:
    numbers = [len(features.columns)] + [
        int(s) for s in name.split('_')[3:] if s.isdigit()] + [1]
    weights = sum([s1 * s2 for s1, s2 in zip(numbers, numbers[1:])])
    biases = sum(numbers[1:])
    return weights + biases

# Given the number of layers, return all possible decreasing configurations
# Where the number of neurons is in [4, 8, 16, 32, 64]
def get_all_configurations(n_layers: int, in_features: int = len(features.columns)) -> List[List[int]]:
    if n_layers == 0:
        return [[(in_features, 1)]]

    in_sizes = [128, 64, 32, 16, 8, 4]
    combinations = list(itertools.combinations_with_replacement(in_sizes, n_layers))
    increasing_combinations = [
        c for c in combinations
        if all(c[i] > c[i + 1] for i in range(len(c) - 1))
    ]

    sizes =[ 
        [(in_features, c[0])] + [
            (c1, c2) for c1, c2 in zip(c, c[1:])
        ] + [(c[-1], 1)]
        for c in increasing_combinations
    ]
                                
    return sizes

# Dictionary to store the accuracy of each model
acc = {}
ROUNDS = 100

for _ in range(ROUNDS):
    for layers in range(0, 4):
        for layer_info in get_all_configurations(layers):
            model, accuracy = train_and_return_model(
                create_mlp_model(layer_info)
            )
            tag = f"MLP_{layers}_layers_{'_'.join(str(x[0]) for x in layer_info[1:])}"
            acc[tag] = acc.get(tag, []) + [accuracy]

# Print the average accuracy for each model
sorted_acc = sorted(acc.items(), key=lambda x: get_num_params(x[0]))
for name, accuracies in sorted_acc:
    print(f"{name}: {str(sum(accuracies) / len(accuracies))[:5]}")

As we can see, the accuracy of the models correlates with the number of neurons per layer. However, at some point, the increase in accuracy becomes marginal or even stagnates. This is due to overfitting, as the model becomes too complex and starts to memorize the training data instead of generalizing well to unseen data. We will now observe how the proof costs increase with the number of neurons per layer:

In [None]:
def bench_configuration(layer_info):
    
    layers = len(layer_info) - 1
    tag = f"MLP_{layers}_layers_{'_'.join(str(x[0]) for x in layer_info[1:])}"
    print(f"{tag} ({get_num_params(tag)} params)\n")
    
    bench_ezkl(
        train_and_return_model(create_mlp_model(layer_info))[0],
        test_features,
        rounds=50,
    )

    print("\n")

for layers in range(0, 4):
    for layer_info in get_all_configurations(layers):
        # bench_configuration(layer_info)

In [183]:
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import RidgeClassifier
from sklearn.tree import DecisionTreeClassifier

def train_sklearn_model(model: sk.base.BaseEstimator) -> float:
    # Convert the dataset to a PyTorch tensor
    features_tensor = torch.tensor(features.to_numpy(), dtype=torch.float32)
    labels_tensor = torch.tensor(labels.to_numpy(), dtype=torch.float32)

    # Get a random permutation of the indices
    indices = torch.randperm(len(features_tensor))
    train_indices = indices[:int(0.75 * len(features_tensor))]
    test_indices = indices[int(0.75 * len(features_tensor)):]
    train_features, train_labels = features_tensor[train_indices], labels_tensor[train_indices]
    test_features, test_labels = features_tensor[test_indices], labels_tensor[test_indices]
    model.fit(train_features, train_labels)
    return model, model.score(test_features, test_labels)
max_iter = 50
# Train and test the model
for model in [SVC(kernel="linear", tol=0.01), LogisticRegression(tol=0.01), RandomForestClassifier(), RidgeClassifier(), DecisionTreeClassifier()]:
    _, acc = train_sklearn_model(model)
    print(f'Accuracy: {acc:.2f}%')


Accuracy: 0.77%
Accuracy: 0.72%
Accuracy: 0.61%
Accuracy: 0.70%
Accuracy: 0.60%


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

accuracies = open("benches/accuracy.txt", "r").read()
accuracies = accuracies.split("\n")[1:-1]
accuracies = [line.split(": ") for line in accuracies]
accuracies = {name: float(acc) for name, acc in accuracies}

def get_num_params(name: str) -> int:
    numbers = [162] + [
        int(s) for s in name.split('_')[3:] if s.isdigit()] + [1]
    weights = sum([s1 * s2 for s1, s2 in zip(numbers, numbers[1:])])
    biases = sum(numbers[1:])
    return weights + biases

sorted_acc = sorted(accuracies.items(), key=lambda x: get_num_params(x[0]))

for name, acc in sorted_acc:
    num_layers = int(name.split('_')[1])
    color = {0: 'black', 1: 'blue', 2: 'red', 3: 'green'}[num_layers]
    plt.scatter(
        get_num_params(name),
        acc,
        color=color,
        marker='x',
    )

plt.figure(figsize=(10, 6))
plt.grid(True)
plt.scatter([], [], color='black', label='No hidden layers', marker='x')
plt.scatter([], [], color='blue', label='1 hidden layer', marker='x')
plt.scatter([], [], color='red', label='2 hidden layers', marker='x')
plt.scatter([], [], color='green', label='3 hidden layers', marker='x')
plt.xscale('log')
plt.legend()
plt.xlabel("Number of parameters")
plt.ylabel("Accuracy")


for name, acc in sorted_acc:
    num_layers = int(name.split('_')[1])
    color = {0: 'black', 1: 'blue', 2: 'red', 3: 'green'}[num_layers]
    sns.regplot(
        x=[get_num_params(name)],
        y=[acc],
        color=color,
        marker='x',
    )

sns.set_style("whitegrid")
sns.despine()

# Save without borders
plt.savefig("plots/accuracy_vs_params.png", bbox_inches='tight', pad_inches=0)

In [None]:
timings = open("benches/time.txt", "r").read()

import matplotlib.pyplot as plt

# Extract the average times from the benchmark results
lines_with_text = lambda text: [line for line in timings.split("\n") if text in line]
layers = [int(line.split("_")[1]) for line in lines_with_text("layers")]
params = [get_num_params(line.split()[0]) for line in lines_with_text("params")]
setup_times = [(float(line.split()[2]), float(line.split()[4])) for line in lines_with_text("Setup time")]
prove_times = [(float(line.split()[2]), float(line.split()[4])) for line in lines_with_text("Prover time")]
verify_times = [(float(line.split()[2]), float(line.split()[4])) for line in lines_with_text("Verifier time")]
logrows = [(float(line.split()[1]), float(line.split()[3])) for line in lines_with_text("Logrows")]


# Plot the four metrics
fig, axs = plt.subplots(2, 2, figsize=(15, 10))

# Grid lines
for ax in axs.flatten():
    ax.grid(True)

# Plot the setup times
for (i, j), l, times in zip([(0, 0), (0, 1), (1, 0), (1, 1)], ["Setup", "Prover", "Verifier", "Logrows"],  [setup_times, prove_times, verify_times, logrows]):
    axs[i, j].set_title(l)
    axs[i, j].set_ylabel("Time $[s]$")
    if l == "Logrows":
        axs[i, j].set_ylabel("Logrows")

    if l == "Verifier":
        axs[i, j].set_ylabel("Time $[ms]$")
    colors = {
        0: 'black',
        1: 'blue',
        2: 'red',
        3: 'green',
    }

    # Log scale for the x-axis
    axs[i, j].set_xscale('log')

    for la, param, (time, std) in zip(layers, params, times):
        axs[i, j].errorbar(
            param,
            [time * 1000 if l == "Verifier" else time] ,
            marker='x',
            yerr=std,
            # For thinner lines
            elinewidth=0.3,
            capsize=1,
            color=colors[la],
            label=f"{la} layers",
        )

        # Add x-lims
        axs[i, j].set_xlim([min(params) / 1.5, max(params) * 1.5])

        # Save without borders
plt.savefig("plots/times_vs_params.png", bbox_inches='tight', pad_inches=0)
