# EZKL Workflow

To prove the inference of a trained model using EZKL, we need to follow the steps below. As an example to illustrate the process, let's consider that we have just trained a simple perceptron model using PyTorch:

In [24]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision

# MNIST dataset
train, test = (torchvision.datasets.MNIST(
    './data', 
    train=is_train,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,)),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
) for is_train in [True, False])

input_size, output_size = 28 * 28, 10

# Define the model
perceptron = nn.Sequential(
    nn.Linear(input_size, output_size),
)

# Create a dataset and data loader
train_loader, test_loader = (DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True
) for dataset in [train, test])

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(perceptron.parameters(), lr=0.01)

# Train the model
perceptron.train()
for data, label in train_loader:
    output = perceptron(data)
    loss = criterion(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [25]:
# Test the model
perceptron.eval()
with torch.no_grad():
    correct, total = 0, 0
    for data, label in test_loader:
        output = perceptron(data)
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
    print(f'Accuracy: {100 * correct / total:.2f}%')

Accuracy: 87.16%


1. **Model Conversion**

Convert the trained model to the ONNX format. In this case, PyTorch provides the function torch.onnx.export. Other frameworks also have similar functions or external tools to convert models to ONNX (e.g., TensorFlow's tf2onnx). Nevertheless, Sklearn models are slighly more complicated to convert to suitable ONNX format, so we must first convert the model to a PyTorch using hummingbird.ml and then convert it to ONNX. We won't cover this process in this article but you can find more information in one of EZKL's notebooks.

Before converting our model to ONNX format, we need to tell the converter the input shape of the model. This can be done by passing a dummy input tensor.

In [26]:
# Choose any valid input tensor (1st input of the test dataset)
input_sample = next(iter(test_loader))[0][0].unsqueeze(0)

torch.onnx.export(
    perceptron,
    input_sample,
    "perceptron.onnx",
    input_names = ['input'],             # Input and output labels to appear in the ONNX graph 
    output_names = ['output'],
    dynamic_axes={
        'input' : {0 : 'batch_size'},    # Variable length axes
        'output' : {0 : 'batch_size'}
    }
)

2. **Setup**

EZKL has several setup functions in their exposed API, namely gen_settings, calibrate_settings, compile_circuit, get_srs, setup, and gen_witness, we've group them together in this bullet point to describe the high level setup process. The signature of each function should be self-explanatory.

In [28]:
import ezkl
import json

def create_file(filename: str) -> str:
    open(filename, 'w').close()
    return filename

# We have to create empty files manually before running the setup
INPUT = create_file("input_data.json")
SETTINGS = create_file("settings.json")
CALIBRATION = create_file("calibration.json")
WITNESS = create_file("witness.json")
COMPILED_MODEL = create_file("compiled_model.json")
VK = create_file("vk.json")
PK = create_file("pk.json")
PROOF = create_file("proof.pf")

def setup(model, onnx_file, input_sample):

    # Save the input data to a file in the expected format
    input_data = {
        'input_shapes': list(input_sample.shape),
        'input_data': input_sample.detach().numpy().tolist(),
        "output_data": model(input_sample).detach().numpy().tolist()
    }

    json.dump(
        input_data,
        open(INPUT, 'w')
    )

    # Run each setup function and verify that it succeeded
    assert ezkl.gen_settings(
        onnx_file,
        SETTINGS
    )


    json.dump(
        input_data,
        open(CALIBRATION, 'w')
    )

    assert ezkl.calibrate_settings(
        INPUT,
        onnx_file,
        SETTINGS, 
        "resources"
    )

    assert ezkl.compile_circuit(
        onnx_file,
        COMPILED_MODEL,
        SETTINGS
    )

    assert ezkl.get_srs(
        SETTINGS
    )

    ezkl.gen_witness(
        "input_data.json",
        COMPILED_MODEL,
        WITNESS
    )

    assert ezkl.setup(
        COMPILED_MODEL,
        VK,
        PK
    )

setup(perceptron, "perceptron.onnx", input_sample)

Using 6 columns for non-linearity table.
Using 12 columns for non-linearity table.
calibration failed extended k is too large to accommodate the quotient polynomial with logrows 6
Using 12 columns for non-linearity table.
calibration failed extended k is too large to accommodate the quotient polynomial with logrows 6
calibration failed max lookup input (-430708640, 367947557) is too large
calibration failed max lookup input (-430782343, 367916286) is too large
calibration failed max lookup input (-861471861, 735938776) is too large
calibration failed max lookup input (-1722893789, 1471769531) is too large


 <------------- Numerical Fidelity Report (input_scale: 13, param_scale: 13, scale_input_multiplier: 10) ------------->

+---------------+--------------+--------------+---------------+----------------+------------------+---------------+---------------+--------------------+--------------------+------------------------+
| mean_error    | median_error | max_error    | min_error     | m

3. **Proof Generation**: Generate the proof using the `gen_proof` function. This function takes the arithmetization of the model, the witness, the public key, and the proof file as inputs and writes the proof to the specified file.

In [29]:
import pprint

proof = ezkl.prove(
    WITNESS,
    COMPILED_MODEL,
    PK,
    PROOF,
    "single",
)

pprint.pprint(proof)

{'instances': [['73c5f88893f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                '0554ec0000000000000000000000000000000000000000000000000000000000',
                '8d62cde493f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                'e42d5bee93f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                '2a8d65e693f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                '5007cf0600000000000000000000000000000000000000000000000000000000',
                '8b978faf93f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                'fbe4369793f5e1439170b97948e833285d588181b64550b829a031e1724e6430',
                'dbb7bf5700000000000000000000000000000000000000000000000000000000',
                '03fda0d793f5e1439170b97948e833285d588181b64550b829a031e1724e6430']],
 'proof': '0x121bde1ec665884902396b9e7fd2c7cbc76a90fab3920ba3d3bdcfc3889ac11b2ff83062ec0faa43cef4ea45614698ca1ac79430f201ab539e61c175b2ec801602f01ff501e1aff40

4. **Proof Verification**: ...

In [30]:
assert ezkl.verify(
    PROOF,
    SETTINGS,
    VK
) == True

# Token Trend Forecasting

We adapted code from one of [GIZA's examples](https://github.com/gizatechxyz/Giza-Hub/tree/token_trend_action/awesome-giza-actions/trend_token_prediction) the idea is to train multiple models with different accuracies and then compare the costs of proving the inference of each model. We first 
explain the feature extraction process in detail, which is not explained in the original example.

### Data

We will use the [Giza's dataset hub](https://github.com/gizatechxyz/datasets), which contains a collection of datasets that are relevant for blockchain applications. These datasets are publicly available and can be loaded using the `DatasetsLoader` class from the `giza_datasets` package, like so:

In [8]:
from giza_datasets import DatasetsLoader
import polars as pl

# Load the desired dataset
DatasetsLoader().load("tokens-daily-prices-mcap-volume")

# For pretty printing
pl.Config.set_tbl_hide_column_data_types(True)

Dataset read from cache.
Loading dataset tokens-daily-prices-mcap-volume from cache.


polars.config.Config

##### Token Daily Price Data
Contains daily price data (price, market capitalization, and volume) for a set of tokens (e.g., WBTC, WETH, etc.).

In [9]:
print("First few rows of the dataset:")
print(DatasetsLoader().load('tokens-daily-prices-mcap-volume').head(n = 3))

First few rows of the dataset:
Dataset read from cache.
Loading dataset tokens-daily-prices-mcap-volume from cache.
shape: (3, 5)
┌────────────┬─────────────┬────────────┬──────────────────┬───────┐
│ date       ┆ price       ┆ market_cap ┆ volumes_last_24h ┆ token │
╞════════════╪═════════════╪════════════╪══════════════════╪═══════╡
│ 2019-02-01 ┆ 3438.360403 ┆ 0.0        ┆ 20589.040403     ┆ WBTC  │
│ 2019-02-02 ┆ 3472.243307 ┆ 0.0        ┆ 12576.723906     ┆ WBTC  │
│ 2019-02-03 ┆ 3461.058341 ┆ 0.0        ┆ 1852.526033      ┆ WBTC  │
└────────────┴─────────────┴────────────┴──────────────────┴───────┘


##### Top APY per protocol
Contains the top Annual Percentage Yield (APY) for each protocol in the dataset.

In [10]:
print("First few rows of the dataset:")
print(DatasetsLoader().load('top-pools-apy-per-protocol').head(n = 3))

First few rows of the dataset:
Dataset read from cache.
Loading dataset top-pools-apy-per-protocol from cache.
shape: (3, 6)
┌────────────┬──────────┬─────┬─────────┬──────────────────┬──────────┐
│ date       ┆ tvlUsd   ┆ apy ┆ project ┆ underlying_token ┆ chain    │
╞════════════╪══════════╪═════╪═════════╪══════════════════╪══════════╡
│ 2022-02-28 ┆ 12808    ┆ 0.0 ┆ aave-v2 ┆ STETH            ┆ Ethereum │
│ 2022-03-01 ┆ 46045250 ┆ 0.0 ┆ aave-v2 ┆ STETH            ┆ Ethereum │
│ 2022-03-02 ┆ 90080754 ┆ 0.0 ┆ aave-v2 ┆ STETH            ┆ Ethereum │
└────────────┴──────────┴─────┴─────────┴──────────────────┴──────────┘


##### TVL per project tokens
Contains the Total Value Locked (TVL) for each project in the dataset.

In [11]:
print("First few rows of the dataset:")
print(DatasetsLoader().load('tvl-per-project-tokens').head(n = 3))

First few rows of the dataset:
Dataset read from cache.
Loading dataset tvl-per-project-tokens from cache.
shape: (3, 47)
┌───────┬──────┬────────┬──────┬───┬──────┬──────┬────────────┬─────────┐
│ 1INCH ┆ AAVE ┆ AAVE.E ┆ AMPL ┆ … ┆ YFI  ┆ ZRX  ┆ date       ┆ project │
╞═══════╪══════╪════════╪══════╪═══╪══════╪══════╪════════════╪═════════╡
│ null  ┆ null ┆ null   ┆ null ┆ … ┆ null ┆ null ┆ 2020-11-29 ┆ aave-v2 │
│ null  ┆ null ┆ null   ┆ null ┆ … ┆ null ┆ null ┆ 2020-11-30 ┆ aave-v2 │
│ null  ┆ null ┆ null   ┆ null ┆ … ┆ null ┆ null ┆ 2020-12-01 ┆ aave-v2 │
└───────┴──────┴────────┴──────┴───┴──────┴──────┴────────────┴─────────┘


In [12]:
import itertools

TOKEN = "WETH"
LAG = 1
DAYS = [1, 3, 7, 15, 30]

token_data = DatasetsLoader().load('tokens-daily-prices-mcap-volume')

# Filter the dataset to only include the target token
# Add a column with the labels for the target token
# Add columns with the price difference over the specified DAYS
# Expand the date column into day_of_week, month_of_year, and year
target_token_price_trend = token_data \
    .filter(pl.col("token") == TOKEN) \
    .with_columns(
        ((pl.col("price").shift(-LAG) - pl.col("price")) > 0).cast(pl.Int8).alias("label")
    ) \
    .with_columns(
        pl.col("price").diff(n = days).alias(f"price_diff_{days}_days")
        for days in DAYS
    ) \
    .with_columns(
        (pl.col("price") - pl.col("price").shift(days) > 0).cast(pl.Int8).alias(f"trend_{days}_days")
        for days in DAYS
    ) \
    .with_columns([
        pl.col("date").dt.weekday().alias("day"),
        pl.col("date").dt.month().alias("month"),
        pl.col("date").dt.year().alias("year")
    ])

print("First few rows of the dataset:")
print(target_token_price_trend.head(n = 3))
print("Number of rows in the dataset:", len(target_token_price_trend))

Dataset read from cache.
Loading dataset tokens-daily-prices-mcap-volume from cache.
First few rows of the dataset:
shape: (3, 19)
┌────────────┬─────────┬────────────┬──────────────────┬───┬───────────────┬─────┬───────┬──────┐
│ date       ┆ price   ┆ market_cap ┆ volumes_last_24h ┆ … ┆ trend_30_days ┆ day ┆ month ┆ year │
╞════════════╪═════════╪════════════╪══════════════════╪═══╪═══════════════╪═════╪═══════╪══════╡
│ 2018-02-14 ┆ 839.535 ┆ 0.0        ┆ 54776.5          ┆ … ┆ null          ┆ 3   ┆ 2     ┆ 2018 │
│ 2018-02-15 ┆ 947.358 ┆ 0.0        ┆ 111096.0         ┆ … ┆ null          ┆ 4   ┆ 2     ┆ 2018 │
│ 2018-02-16 ┆ 886.961 ┆ 0.0        ┆ 57731.7          ┆ … ┆ null          ┆ 5   ┆ 2     ┆ 2018 │
└────────────┴─────────┴────────────┴──────────────────┴───┴───────────────┴─────┴───────┴──────┘
Number of rows in the dataset: 2173


In [None]:
token_data = DatasetsLoader().load('tokens-daily-prices-mcap-volume')
correlations = {}

# List all tokens in the dataset
tokens = token_data.get_column("token").unique().to_list()

# Calculate the correlation between the target token and all other tokens
for token_1, token_2 in itertools.permutations(tokens, r=2):
    
    # Filter the dataset and get the price and date columns
    token_1_data = token_data.filter(pl.col("token") == token_1) \
        .select(["date", "price"])
    token_2_data = token_data.filter(pl.col("token") == token_2) \
        .select(["date", "price"])
    
    # Join the datasets on the date column
    joined_data = token_1_data.join(token_2_data, on="date", suffix="_compare")

    # Nested dictionary to store the correlation between the two tokens
    correlations[token_1] = correlations.get(token_1, {}) 
    correlations[token_1][token_2] = correlations[token_1] \
        .get(token_2, {
            day: joined_data \
                    .with_columns(pl.col("price_compare").shift(day)) \
                    .select(pl.corr("price", "price_compare").alias("correlation")) \
                    .get_column("correlation")[0]
            for day in DAYS
        })


#pprint.pprint(correlations)

In [14]:
K = 10

# Dataframe to store the final results
price_dataset = target_token_price_trend

# Retrive the relevant data from the nested dictionary
target_token_correlations = correlations[TOKEN]

# Get the top K correlated tokens for each lag
top_k_correlated_tokens_by_lag = {
    lag: sorted(target_token_correlations.items(), key=lambda x: x[1][lag], reverse=True)[:K]
    for lag in DAYS
}

top_k_correlated_tokens_15_days = top_k_correlated_tokens_by_lag[15]

for token, _ in top_k_correlated_tokens_15_days:

    # Column names for the price differences
    price_diff_columns = [f"price_diff_{token}_{days}" for days in DAYS]
    price_trend_columns = [f"price_trend_{token}_{days}" for days in DAYS]

    # Filter the dataset to only include the correlated token
    token_prices = token_data.filter(pl.col("token") == token)

    # Add columns with the price differences for each day
    token_prices = token_prices \
        .with_columns(
            pl.col("price").diff(n = days).alias(tag)
            for days, tag in zip(DAYS, price_diff_columns)
        ) \
        .with_columns([
            (pl.col("price") - pl.col("price").shift(days) > 0).cast(pl.Int8).alias(tag)
            for days, tag in zip(DAYS, price_trend_columns)
        ]) \
        .select(["date"] + price_diff_columns + price_trend_columns)

    # Join the dataset with the target token dataset
    price_dataset = price_dataset.join(token_prices, on="date", how="left")
    break

print("First few rows of the dataset:")
print(price_dataset.head(n = 3))

First few rows of the dataset:
shape: (3, 29)
┌────────────┬─────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬───────────┐
│ date       ┆ price   ┆ market_ca ┆ volumes_l ┆ … ┆ price_tre ┆ price_tre ┆ price_tre ┆ price_tre │
│            ┆         ┆ p         ┆ ast_24h   ┆   ┆ nd_ETH_3  ┆ nd_ETH_7  ┆ nd_ETH_15 ┆ nd_ETH_30 │
╞════════════╪═════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡
│ 2018-02-14 ┆ 839.535 ┆ 0.0       ┆ 54776.5   ┆ … ┆ 0         ┆ 1         ┆ 0         ┆ 0         │
│ 2018-02-15 ┆ 947.358 ┆ 0.0       ┆ 111096.0  ┆ … ┆ 1         ┆ 1         ┆ 0         ┆ 0         │
│ 2018-02-16 ┆ 886.961 ┆ 0.0       ┆ 57731.7   ┆ … ┆ 1         ┆ 1         ┆ 0         ┆ 0         │
└────────────┴─────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴───────────┘


In [15]:
top_apy_per_protocol = DatasetsLoader().load("top-pools-apy-per-protocol")

START_DATE = pl.datetime(2022, 6, 1)

# Filter the dataset to only include protocols with the target token
""" unique_token_projects = top_apy_per_protocol \
    .filter(pl.col("underlying_token").str.contains(TOKEN)) \
    .filter(pl.col("date") > START_DATE) \
    .unique("project") \
    .pivot(index="date", columns="project", values=["apy", "tvlUsd"]) """

apy_df = top_apy_per_protocol \
    .filter(pl.col("underlying_token").str.contains(TOKEN)) \
    .with_columns(
            pl.col("project") + "_" + pl.col("chain") +  pl.col("underlying_token")
    ) \
    .drop(["underlying_token", "chain"])

unique_projects = apy_df \
    .filter(pl.col("date") <= START_DATE) \
    .select("project") \
    .unique()

apy_df_token = apy_df.join(
    unique_projects, 
    on="project", 
    how="inner"
)

unique_token_projects = apy_df_token.pivot(
    index="date",
    columns="project",
    values=["tvlUsd", "apy"]
)

print("First few rows of the dataset:")
print(unique_token_projects.head(n = 3))
print("Number of rows in the dataset:", len(unique_token_projects))

Dataset read from cache.
Loading dataset top-pools-apy-per-protocol from cache.
First few rows of the dataset:
shape: (3, 91)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ date      ┆ tvlUsd_pr ┆ tvlUsd_pr ┆ tvlUsd_pr ┆ … ┆ apy_proje ┆ apy_proje ┆ apy_proje ┆ apy_proj │
│           ┆ oject_aav ┆ oject_aav ┆ oject_aav ┆   ┆ ct_uniswa ┆ ct_uniswa ┆ ct_uniswa ┆ ect_year │
│           ┆ e-v2_Ethe ┆ e-v2_Poly ┆ e-v2_Aval ┆   ┆ p-v3_Arbi ┆ p-v3_Arbi ┆ p-v3_Ethe ┆ n-financ │
│           ┆ reumW…    ┆ gonWE…    ┆ anche…    ┆   ┆ trumW…    ┆ trumW…    ┆ reumW…    ┆ e_Ethere │
│           ┆           ┆           ┆           ┆   ┆           ┆           ┆           ┆ …        │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ 2022-02-1 ┆ 246215633 ┆ 560180650 ┆ 719972444 ┆ … ┆ null      ┆ null      ┆ null      ┆ null     │
│ 1         ┆ 5         ┆           ┆           ┆   ┆           ┆ 

In [16]:
tvl_df = DatasetsLoader().load("tvl-per-project-tokens") \
    .unique(subset=["date", "project"]) \
    .filter(pl.col("date") > START_DATE) 

tvl_per_projects_token = tvl_df[[TOKEN, "project", "date"]].pivot(
    index="date",
    columns="project",
    values=TOKEN
)

print("First few rows of the dataset:")
print(tvl_per_projects_token.head(n = 3))

Dataset read from cache.
Loading dataset tvl-per-project-tokens from cache.
First few rows of the dataset:
shape: (3, 20)
┌────────────┬──────────┬──────────┬──────┬───┬───────────┬────────────┬────────────┬──────────────┐
│ date       ┆ aave-v2  ┆ aave-v3  ┆ aura ┆ … ┆ sushiswap ┆ uniswap-v2 ┆ uniswap-v3 ┆ yearn-financ │
│            ┆          ┆          ┆      ┆   ┆           ┆            ┆            ┆ e            │
╞════════════╪══════════╪══════════╪══════╪═══╪═══════════╪════════════╪════════════╪══════════════╡
│ 2022-06-12 ┆ 1.1741e9 ┆ 3.1182e8 ┆ null ┆ … ┆ null      ┆ null       ┆ 1.0583e9   ┆ null         │
│ 2022-07-06 ┆ 6.5735e8 ┆ 1.2864e8 ┆ null ┆ … ┆ 3.9497e7  ┆ null       ┆ 7.8917e8   ┆ null         │
│ 2022-07-07 ┆ 6.3625e8 ┆ 1.3310e8 ┆ null ┆ … ┆ 4.1564e7  ┆ null       ┆ 7.7908e8   ┆ null         │
└────────────┴──────────┴──────────┴──────┴───┴───────────┴────────────┴────────────┴──────────────┘


In [17]:
# Join the datasets by the date column to create the final dataset
final_dataset = price_dataset \
    .join(tvl_per_projects_token, on="date", how="inner") \
    .join(unique_token_projects, on="date", how="inner")

# Drop unnecessary columns and rows with irrelevant data
# - columns with token, market_cap , date and current price: not relevant for the model
# - rows with year < 2022: historical data is not relevant
final_dataset = final_dataset \
    .filter(pl.col("year") >= 2022) \
    .drop(["token", "market_cap", "date", "price", "month"])
final_dataset = final_dataset.slice(0, len(final_dataset) - 1)
# Drop columns if there are lots of missing values
THRESHOLD = 0.2
max_nulls = THRESHOLD * final_dataset.shape[0]
columns_to_keep = [
        col_name for col_name in final_dataset.columns if final_dataset[col_name].null_count() <= max_nulls
]
final_dataset = final_dataset.select(columns_to_keep    )

# Split the dataset into features and labels
features = final_dataset.drop("label")
labels = final_dataset["label"]

# Normalize the training dataset and fill missing values
for col in features.columns:
    mean_val = features[col].mean()
    std_dev = features[col].std() if features[col].std() != 0 else 1
    features = features.with_columns(((features[col].fill_null(mean_val) - mean_val) / std_dev).alias(col))

print("First few rows of the dataset:")
print(features.head(n = 3))

First few rows of the dataset:
shape: (3, 138)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ date      ┆ price     ┆ market_ca ┆ volumes_l ┆ … ┆ apy_proje ┆ apy_proje ┆ apy_proje ┆ apy_proj │
│           ┆           ┆ p         ┆ ast_24h   ┆   ┆ ct_uniswa ┆ ct_uniswa ┆ ct_uniswa ┆ ect_year │
│           ┆           ┆           ┆           ┆   ┆ p-v3_Arbi ┆ p-v3_Arbi ┆ p-v3_Ethe ┆ n-financ │
│           ┆           ┆           ┆           ┆   ┆ trumW…    ┆ trumW…    ┆ reumW…    ┆ e_Ethere │
│           ┆           ┆           ┆           ┆   ┆           ┆           ┆           ┆ …        │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ 2022-06-0 ┆ 1828.2324 ┆ 0.0       ┆ 1.6046e9  ┆ … ┆ 4.87082   ┆ 27.33455  ┆ 0.0       ┆ 0.55534  │
│ 2         ┆ 82        ┆           ┆           ┆   ┆           ┆           ┆           ┆          │
│ 2022-06-0 ┆ 1832.6794 ┆ 0.0       ┆ 1.2320

## Model Training

##### Splitting the dataset

We need to convert the dataframes to torch tensors and split the dataset into training and testing sets. As usual, we choose a reasonable split ratio (e.g., 80% training and 20% testing) and shuffle the data before splitting it: 

In [81]:
# Convert the dataset to a PyTorch tensor
features_tensor = torch.tensor(features.to_numpy(), dtype=torch.float32)
labels_tensor = torch.tensor(labels.to_numpy(), dtype=torch.float32)

# Get a random permutation of the indices
indices = torch.randperm(len(features_tensor))
train_indices = indices[:int(0.75 * len(features_tensor))]
test_indices = indices[int(0.75 * len(features_tensor)):]
train_features, train_labels = features_tensor[train_indices], labels_tensor[train_indices]
test_features, test_labels = features_tensor[test_indices], labels_tensor[test_indices]


#### Model Definition

Initially we use a multi-layer perceptron (MLP) model with three hidden layers and ReLU activation functions. For the final layer, we use a sigmoid activation function to output the probability of the token price increasing. We also define the loss function (mean squared error) and the optimizer (Adam) for training the model:

In [55]:
model = nn.Sequential(
    nn.Linear(len(features.columns), 32),
    nn.ReLU(),
    nn.Linear(32, 1),
    nn.Sigmoid()
)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

#### Training the Model

We train the model using the training dataset and evaluate its performance on the testing dataset:

In [56]:
model.train()
for epoch in range(500):
    optimizer.zero_grad()
    output = model(train_features)
    loss = criterion(output, train_labels.unsqueeze(1))
    loss.backward()
    optimizer.step()

model.eval()
with torch.no_grad():
    correct, total = 0, 0
    output = model(test_features)
    predicted = torch.tensor([1 if x > 0.5 else 0 for x in output])
    total += test_labels.size(0)
    correct += (predicted == test_labels).sum().item()
    print(f'Accuracy: {100 * correct / total:.2f}%')

Accuracy: 67.33%


## Benchmarking Proof Generation

Now that we have trained the model, we can generate a proof of its inference using EZKL. We first convert the model to the ONNX format and set up the proof generation process as described in the previous sections:

In [131]:
def to_onnx(model, input_sample, onnx_file):
    torch.onnx.export(
        model,
        input_sample,
        onnx_file,
        input_names = ['input'],             # Input and output labels to appear in the ONNX graph 
        output_names = ['output'],
        dynamic_axes={
            'input' : {0 : 'batch_size'},    # Variable length axes
            'output' : {0 : 'batch_size'}
        }
    )

input_sample = next(iter(test_features)).unsqueeze(0)
to_onnx(model, input_sample, "perceptron.onnx")

We also create several methods to help us with the benchmarking process:

In [59]:
from typing import Tuple, Union
import time
import torch
import sklearn as sk

def prove():
    _= ezkl.prove(
        WITNESS,
        COMPILED_MODEL,
        PK,
        PROOF,
        "single",
    )

def verify():
    assert ezkl.verify(
        PROOF,
        SETTINGS,
        VK,
    ) == True

def bench_ezkl_single_round(
    model: Union[torch.nn.Module, sk.base.BaseEstimator],
    model_onnx_file: str, 
    sample: torch.Tensor, 
) -> Tuple[float, float, float]:
    setup_time = -time.time()
    setup(model, model_onnx_file, sample)
    setup_time += time.time()
    
    prove_time = -time.time()
    prove()
    prove_time += time.time()

    verify_time = -time.time()
    verify()
    verify_time += time.time()

    return setup_time, prove_time, verify_time

import contextlib

def bench_ezkl(
    model: Union[torch.nn.Module, sk.base.BaseEstimator],
    model_onnx_file: str,
    sample: torch.Tensor,
    rounds: int = 1
) -> Tuple[float, float, float]:
    setup_time, prove_time, verify_time = [], [], []
    for _ in range(rounds):
        with contextlib.redirect_stderr(None):
            s, p, v = bench_ezkl_single_round(model, model_onnx_file, sample)
        setup_time.append(s), prove_time.append(p), verify_time.append(v)
    avg_setup, avg_prove, avg_verify = (
        sum(setup_time) / rounds, 
        sum(prove_time) / rounds, 
        sum(verify_time) / rounds
    )

    std_setup, std_prove, std_verify = (
        (sum((s - avg_setup) ** 2 for s in setup_time) / rounds) ** 0.5,
        (sum((p - avg_prove) ** 2 for p in prove_time) / rounds) ** 0.5,
        (sum((v - avg_verify) ** 2 for v in verify_time) / rounds) ** 0.5
    )

    print(f"Setup time: {str(avg_setup)[:5]} ± {str(std_setup)[:5]} [s]")
    print(f"Prover time: {str(avg_prove)[:5]} ± {str(std_prove)[:5]} [s]")
    print(f"Verifier time: {str(avg_verify)[:5]} ± {str(std_verify)[:5]} [s]")
    
    return setup_time, prove_time, verify_time

Let's benchmark the proof generation process for the MLP model we trained earlier:

In [60]:
_ = bench_ezkl(
    model,
    "perceptron.onnx",
    input_sample,
    rounds=10
)

Setup time: 6.129 ± 0.456 [s]
Prover time: 8.216 ± 0.429 [s]
Verifier time: 0.014 ± 0.001 [s]


#### Accuracy vs. Proof Cost

Let us now define a more complex model that slighly increases the accuracy of the predictions at the cost of a longer proof generation time. We will use a deeper MLP model with more hidden layers and neurons, as shown below:

In [130]:
model = nn.Sequential(
    nn.Linear(len(features.columns), 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 1),
    nn.Sigmoid()
)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(500):
    optimizer.zero_grad()
    output = model(train_features)
    loss = criterion(output, train_labels.unsqueeze(1))
    loss.backward()
    optimizer.step()

model.eval()
with torch.no_grad():
    correct, total = 0, 0
    output = model(test_features)
    predicted = torch.tensor([1 if x > 0.5 else 0 for x in output])
    total += test_labels.size(0)
    correct += (predicted == test_labels).sum().item()
    print(f'Accuracy: {100 * correct / total:.2f}%')

Accuracy: 73.33%


In [132]:
to_onnx(model, input_sample, "perceptron.onnx")

_ = bench_ezkl(
    model,
    "perceptron.onnx",
    input_sample,
    rounds=10
)

Setup time: 11.49 ± 0.727 [s]
Prover time: 15.09 ± 0.341 [s]
Verifier time: 0.021 ± 0.000 [s]
