## Define an LTSF-Linear Model

In [1]:
import argparse
import torch
import random
import numpy as np
import onnx

import sys
sys.path.append("..")
from models.DLinear import Model as DLinear

from data_provider.data_factory import data_provider

parser = argparse.ArgumentParser(description='Autoformer & Transformer family for Time Series Forecasting')

# basic config
parser.add_argument('--is_training', type=int, default=1, help='status')
parser.add_argument('--train_only', type=bool, default=False, help='perform training on full input dataset without validation and testing')
parser.add_argument('--model_id', type=str, default='test', help='model id')
parser.add_argument('--model', type=str, default='Autoformer',
                    help='model name, options: [Autoformer, Informer, Transformer]')

# data loader
parser.add_argument('--data', type=str, default='custom', help='dataset type')
parser.add_argument('--root_path', type=str, default='../dataset/', help='root path of the data file')
parser.add_argument('--data_path', type=str, default='electricity.csv', help='data file')
parser.add_argument('--features', type=str, default='S',
                    help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
parser.add_argument('--freq', type=str, default='h',
                    help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')

# forecasting task
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
parser.add_argument('--label_len', type=int, default=48, help='start token length')
parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')


# DLinear
parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually')
# Formers 
parser.add_argument('--embed_type', type=int, default=0, help='0: default 1: value embedding + temporal embedding + positional embedding 2: value embedding + temporal embedding 3: value embedding + positional embedding 4: value embedding')
parser.add_argument('--enc_in', type=int, default=21, help='encoder input size') # DLinear with --individual, use this hyperparameter as the number of channels
parser.add_argument('--dec_in', type=int, default=21, help='decoder input size')
parser.add_argument('--c_out', type=int, default=21, help='output size')
parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
parser.add_argument('--factor', type=int, default=1, help='attn factor')
parser.add_argument('--distil', action='store_false',
                    help='whether to use distilling in encoder, using this argument means not using distilling',
                    default=True)
parser.add_argument('--dropout', type=float, default=0.05, help='dropout')
parser.add_argument('--embed', type=str, default='timeF',
                    help='time features encoding, options:[timeF, fixed, learned]')
parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data')

# optimization
parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers')
parser.add_argument('--itr', type=int, default=1, help='experiments times')
parser.add_argument('--train_epochs', type=int, default=100, help='train epochs')
parser.add_argument('--batch_size', type=int, default=32, help='batch size of train input data')
parser.add_argument('--patience', type=int, default=3, help='early stopping patience')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
parser.add_argument('--des', type=str, default='test', help='exp description')
parser.add_argument('--loss', type=str, default='mse', help='loss function')
parser.add_argument('--lradj', type=str, default='type1', help='adjust learning rate')
parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)

# GPU
parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu')
parser.add_argument('--gpu', type=int, default=0, help='gpu')
parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False)
parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus')
parser.add_argument("--seed", type=int, default=2021, help="random seed")
parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage')

args = parser.parse_args([])

random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model = onnx.load('../checkpoints/test_DLinear_custom_ftS_sl96_ll48_pl96_dm512_nh8_el2_dl1_df2048_fc1_ebtimeF_dtTrue_test_0/checkpoint.onnx')

# Check that the model is well-formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)


'graph main_graph (\n  %input[FLOAT, batch_sizex96x1]\n) initializers (\n  %Linear_Seasonal.bias[FLOAT, 96]\n  %Linear_Trend.bias[FLOAT, 96]\n  %onnx::MatMul_52[FLOAT, 96x96]\n  %onnx::MatMul_53[FLOAT, 96x96]\n) {\n  %/decompsition/moving_avg/Constant_output_0 = Constant[value = <Tensor>]()\n  %/decompsition/moving_avg/Constant_1_output_0 = Constant[value = <Tensor>]()\n  %/decompsition/moving_avg/Constant_2_output_0 = Constant[value = <Tensor>]()\n  %/decompsition/moving_avg/Constant_3_output_0 = Constant[value = <Tensor>]()\n  %/decompsition/moving_avg/Slice_output_0 = Slice(%input, %/decompsition/moving_avg/Constant_1_output_0, %/decompsition/moving_avg/Constant_2_output_0, %/decompsition/moving_avg/Constant_output_0, %/decompsition/moving_avg/Constant_3_output_0)\n  %onnx::Tile_13 = Constant[value = <Tensor>]()\n  %/decompsition/moving_avg/Constant_4_output_0 = Constant[value = <Tensor>]()\n  %/decompsition/moving_avg/ConstantOfShape_output_0 = ConstantOfShape[value = <Tensor>](%/d

## ZK Inference

### Prepare Model

#### Define Files Path

In [2]:
import ezkl
import tracemalloc
import os
import json
from timeit import default_timer as timer

model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.ezkl')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')
proof_path = os.path.join('test.pf')
sol_code_path = os.path.join('verify.sol')
abi_path = os.path.join('verify.abi')

#### Convert Model to ONNX

In [39]:
# Model was trained by '/mnt/LTSF-Linear/scripts/EXP-LongForecasting/Linear/electricity.sh' and 
# stored into the checkpoint state 'checkpoint.pth' in the '/mnt/LTSF-Linear/checkpoints' folder.
# Now we need to export the onnx file from this state file with model inputs.

seq_len = 720
pred_len = 24
target = 40

configs = Configs(seq_len)
circuit = Model(configs)
check_point_model = '../checkpoints/checkpoint_{}_24_tg{}.pth'.format(seq_len, target)
state_dict = torch.load(check_point_model)
circuit.load_state_dict(state_dict)
print(circuit)


# After training, export to onnx (network.onnx) and create a data file (input.json)
x = 10*torch.rand(1,*[seq_len, 1], requires_grad=True)

# Flips the neural net into inference mode
circuit.eval()

# Export the model
torch.onnx.export(circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump( data, open(data_path, 'w' ))

input_size = os.stat(data_path).st_size / 1024
onnx_size = os.stat(model_path).st_size / 1024
print("Input.json size: {}KB".format(input_size))
print("network.onnx size: {}KB".format(onnx_size))

Model(
  (Linear): Linear(in_features=720, out_features=24, bias=True)
)
Input.json size: 13.5732421875KB
network.onnx size: 68.064453125KB


### Setup

In [40]:
!ezkl table -M network.onnx

[1;34m[[0m[1;34m*[0m[1;34m][0m [0s, ezkl] - [1;37m
[1;37m | [0m 
[1;37m | [0m         ███████╗███████╗██╗  ██╗██╗
[1;37m | [0m         ██╔════╝╚══███╔╝██║ ██╔╝██║
[1;37m | [0m         █████╗    ███╔╝ █████╔╝ ██║
[1;37m | [0m         ██╔══╝   ███╔╝  ██╔═██╗ ██║
[1;37m | [0m         ███████╗███████╗██║  ██╗███████╗
[1;37m | [0m         ╚══════╝╚══════╝╚═╝  ╚═╝╚══════╝
[1;37m | [0m 
[1;37m | [0m         -----------------------------------------------------------
[1;37m | [0m         Easy Zero Knowledge for Layers.
[1;37m | [0m         -----------------------------------------------------------
[1;37m | [0m 
[1;37m | [0m         [0m
[1;34m[[0m[1;34m*[0m[1;34m][0m [0s, ezkl] - [1;37mcommand: 
[1;37m | [0m  [1m{[0m[1;37m
[1;37m | [0m   [1;34m"[0m[1;37m[1;34mcommand[0m[1;37m[1;34m"[0m[1;37m: [1m{[0m[1;37m
[1;37m | [0m     [1;34m"[0m[1;37m[1;34mTable[0m[1;37m[1;34m"[0m[1;37m: [1m{[0m[1;37m
[1;37m | [0m       [1;34m"[

In [41]:
# Setup is performed by the application developer, who then deploys the resulting artifacts to production.

!RUST_LOG=trace
# TODO: Dictionary outputs
# Before setup can run, the settings need to be generated with gen-settings
#  and optionally calibrate-settings, and the model must be compiled.

tracemalloc.start()
start = timer()

res = ezkl.gen_settings(model_path, settings_path)
assert res == True

res = await ezkl.calibrate_settings(data_path, model_path, settings_path, "resources")
assert res == True

res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

# srs path
res = ezkl.get_srs(srs_path, settings_path)

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)

end = timer()
print("time used: {} seconds.".format(end - start))

snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("[ Top 10 ]")
for stat in top_stats[:10]:
    print(stat)

spawning module 2
spawning module 2


time used: 16.07956880517304 seconds.
[ Top 10 ]
/usr/lib/python3.10/ast.py:50: size=2278 KiB, count=46892, average=50 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:265: size=429 KiB, count=5833, average=75 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:236: size=340 KiB, count=3347, average=104 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:263: size=231 KiB, count=690, average=343 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:227: size=151 KiB, count=5, average=30.3 KiB
/usr/lib/python3.10/tracemalloc.py:558: size=81.7 KiB, count=1544, average=54 B
/usr/lib/python3.10/sre_compile.py:804: size=75.4 KiB, count=125, average=617 B
/usr/lib/python3.10/tracemalloc.py:67: size=54.6 KiB, count=874, average=64 B
/mnt/zkpet/venv/lib/python3.10/site-packages/IPython/core/compilerop.py:174: size=43.5 KiB, count=450, average=99 B
/usr/lib/python3.10/tracemalloc.py:505: size=41.0 KiB, count=741, average=5

### Prove

In [42]:
# Prove, invoked with ezkl prove at the cli or ezkl.prove() in Python, is called by the prover, often on the client.

# the witness data for the claim: an (input, output) pair (x,y) such that model(input) = output.
# this pair can be produced from x using the gen-witness command.
# now generate the witness file 

tracemalloc.start()
start = timer()
res = ezkl.gen_witness(
        data_path, 
        compiled_model_path, 
        witness_path
      )
# assert os.path.isfile(witness_path)

# GENERATE A PROOF

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        srs_path,
        "single"
    )

# print(res)
# assert os.path.isfile(proof_path)

end = timer()
prove_time = end - start
print("time used: {} seconds".format(prove_time))

snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("[ Top 10 ]")
for stat in top_stats[:10]:
    print(stat)

proof_size = os.stat(proof_path).st_size / 1024
print("{} size: {}KB".format(proof_path, proof_size))

spawning module 2


time used: 2.98009037040174 seconds
[ Top 10 ]
/usr/lib/python3.10/ast.py:50: size=2278 KiB, count=46892, average=50 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:265: size=429 KiB, count=5833, average=75 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:236: size=340 KiB, count=3347, average=104 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:263: size=231 KiB, count=690, average=343 B
/mnt/zkpet/venv/lib/python3.10/site-packages/executing/executing.py:227: size=151 KiB, count=5, average=30.3 KiB
/usr/lib/python3.10/tracemalloc.py:558: size=81.5 KiB, count=1536, average=54 B
/usr/lib/python3.10/sre_compile.py:804: size=75.4 KiB, count=125, average=617 B
/usr/lib/python3.10/tracemalloc.py:67: size=54.8 KiB, count=876, average=64 B
/mnt/zkpet/venv/lib/python3.10/site-packages/IPython/core/compilerop.py:174: size=43.5 KiB, count=450, average=99 B
/usr/lib/python3.10/tracemalloc.py:505: size=42.8 KiB, count=774, average=57 

### Verify

#### VERIFY off-chain

In [43]:
start = timer()
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        srs_path,
    )

assert res == True
print("verified")
end = timer()
print("time used: {} seconds".format(end - start))

verified
time used: 0.02656630612909794 seconds


#### VERIFY on-chain

In [44]:
# Create verifier contract
res = ezkl.create_evm_verifier(
        vk_path,
        srs_path,
        settings_path,
        sol_code_path,
        abi_path,
    )
verifier_size = os.stat(sol_code_path).st_size / 1024
print("{} size: {}KB".format(sol_code_path, verifier_size))

verify.sol size: 53.421875KB


In [45]:
# Deploy the verifier contract onchain
sol_code_path = os.path.join("verify.sol")
address_path = os.path.join('contractAddr.txt')
# assuming anvil is running
res = ezkl.deploy_evm(
    address_path,
    sol_code_path
)

In [46]:
with open(address_path, 'r') as f:
  addr = f.readline()

In [47]:
# # Verify proof onchain
# res = ezkl.verify_evm(
#     proof_path,
#     addr
# )

## Stats

In [48]:
input_size = os.stat(data_path).st_size / 1024
onnx_size = os.stat(model_path).st_size / 1024
print("{} size: {}KB".format(data_path, input_size))
print("{} size: {}KB".format(model_path, onnx_size))
proof_size = os.stat(proof_path).st_size / 1024
print("{} size: {}KB".format(proof_path, proof_size))
print("prove time used: {} seconds".format(prove_time))
verifier_size = os.stat(sol_code_path).st_size / 1024
print("{} size: {}KB".format(sol_code_path, verifier_size))

input.json size: 13.5732421875KB
network.onnx size: 68.064453125KB
test.pf size: 20.458984375KB
prove time used: 2.98009037040174 seconds
verify.sol size: 53.421875KB
