# Traditional ML Models for ZKML: Decision Tree

*In this series of tutorials, we delve into the world of traditional machine learning models for ZKML. Despite the hype surrounding advanced AI techniques, traditional ML models often offer superior performance or sufficiently robust results for specific applications. This is particularly true for ZKML use cases, where computational proof costs can be a critical factor. We aim to equip you with guides on how to implement machine learning algorithms suitable for Giza platform applications. This includes practical steps for converting your scikit-learn models to the ONNX format, transpiling them to Orion Cairo, and deploying inference endpoints for prediction in AI Action.*

In this tutorial, you will learn how to use the Giza tools through a Decision Tree model.

## Before Starting
Before we start, ensure that you have installed the Giza stack, created a user, and logged-in. 

In [None]:
! giza users login # Login to your account

## Create and Train a Decision Tree Model
We'll start by creating a simple decision tree model using Scikit-Learn and train it on the iris dataset. We will then use the [Hummingbirds](https://github.com/microsoft/hummingbird) library to convert the model to torch graphs.

In [None]:
import json
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier as De
from hummingbird.ml import convert
import torch
import os



iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = De()
clr.fit(X_train, y_train)

model = convert(clr, "torch", X_test[:1]).model

## Convert the Model to ONNX Format
Giza only supports ONNX models so you'll need to convert the model to ONNX format. This can be done post training.

In [None]:
import torch.onnx

input_sample = torch.from_numpy(X_test[:1])

# Specify the path to save the ONNX model
onnx_model_path = "decision_tree.onnx"

# Export the model
torch.onnx.export(model,
                  input_sample,
                  onnx_model_path,     # where to save the model
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=17,    # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names=['input'],   # the model's input names
                  output_names=['output'],  # the model's output names
                  dynamic_axes={'input': {0: 'batch_size'},  # variable length axes
                                'output': {0: 'batch_size'}})

## Transpile your model to Orion Cairo

We will use the Giza-CLI to transpile our ONNX model to Orion Cairo.

In [None]:
! giza transpile decision_tree.onnx --output-path verifiable_dt

## Deploy an inference endpoint

Now that our model is transpiled to Cairo we can deploy an endpoint to run verifiable inferences. We will use Giza CLI again to run and deploy an endpoint.
Ensure to replace `model-id` and `version-id` with your ids provided during transpilation.

In [None]:
! giza endpoints deploy --model-id 569 --version-id 2

## Run a verifiable inference in AI Actions

To streamline a verifiable inference, you might consider using the endpoint URL obtained after transpilation. However, this approach requires manual serialization of the input for the Cairo program and handling the deserialization process. To make this process more user-friendly and keep you within a Python environment, we've introduced AI Actions—a Python SDK designed to facilitate the creation of ML workflows and execution of verifiable predictions. When you initiate a prediction, our system automatically retrieves the endpoint URL you deployed earlier, converts your input into Cairo-compatible format, executes the prediction, and then converts the output back into a numpy object. More info about [AI Actions here.](https://actions.gizatech.xyz/)

First ensure you have an AI Actions workspace created. This step grants access to a user-friendly UI dashboard, enabling you to monitor and manage workflows with ease.

In [None]:
! giza workspaces get

# 🚨 If you haven't set up a workspace yet, you can establish one by executing the command below:
# `! giza workspaces create`

Now let's run a verifiable inference with AI Actions. To design your workflow in AI Actions, you will need to define your task with the `@task` decorator and then action your tasks with the `@action` decorator. You can track the progress of your workflow via the workspace URL previously provided.

In [None]:
from giza_actions.model import GizaModel
from giza_actions.action import action
from giza_actions.task import task

MODEL_ID = 569  # Update with your model ID
VERSION_ID = 2  # Update with your version ID


@task(name="PredictDTModel")
def prediction(input, model_id, version_id):
    model = GizaModel(id=model_id, version=version_id)

    result = model.predict(
        input_feed={'input': input}, 
        custom_output_dtype="(Tensor<i32>, Tensor<FP16x16>)" # Decision Tree will always have this output dtype.
    )
    print(result)

    return result


@action(name="ExectuteCairoDT", log_prints=True)
def execution():
    # The input data type should match the model's expected input
    input = input_sample.numpy()

    result = prediction(input, MODEL_ID, VERSION_ID)

    return result


execution()

## Download the proof

Initiating a verifiable inference sets off a proving job on our server, sparing you the complexities of installing and configuring the prover yourself. Upon completion, you can download your proof.

First, let's check the status of the proving job to ensure that it has been completed. 

🚨 Remember to substitute `endpoint-id` and `proof-id` with the specific IDs assigned to you throughout this tutorial.

In [None]:
! giza endpoints get-proof --endpoint-id 36 --proof-id 3bb53193c43048b7b47abfefe32b569a

Once the proof is ready, you can download it.

In [None]:
! giza endpoints download-proof --endpoint-id 36 --proof-id 3bb53193c43048b7b47abfefe32b569a --output-path zkdt.proof