# Evaluate Python UDF Performance (Iris)

In [1]:
import numpy as np
import pandas as pd

import time
import json

import duckdb
from duckdb.typing import *
from duckdb import typing

import torch
import torch.nn as nn
import pyarrow as pa

from pprint import pprint

### Config

In [2]:
times = {}

In [3]:
table_name = 'iris'
times['dataset'] = 'iris'
times['pipeline'] = 'python_udf_vec|d5-w512'

ori_workload = False
drop_table = False

if not ori_workload:
    workload = 20000000

table_name

'iris'

### Load iris data

In [4]:
times['workload'] = workload

### Load the compiled model

In [5]:
con = duckdb.connect("../test.db")

In [6]:
def load_iris():
    model_path = "/Users/udeshuk/Developer/mldb/models/iris_d5_w512.pt"
    load_iris.model = torch.jit.load(model_path)
    return True

load_iris.model = None

st = time.perf_counter_ns()

con.create_function("load_iris", load_iris, [], BOOLEAN, side_effects=True)
res = con.sql("SELECT load_iris()").show()

et = time.perf_counter_ns()
times["load"] = (et - st)/1000

┌─────────────┐
│ load_iris() │
│   boolean   │
├─────────────┤
│ true        │
└─────────────┘



### Predict

In [7]:
# @udf.scalar.pyarrow
# def predict_fare(x: dt.float64) -> dt.float32:
def predict_iris(a, b, c, d):
    a = torch.from_numpy(a.to_numpy()[:, None]).float()
    b = torch.from_numpy(b.to_numpy()[:, None]).float()
    c = torch.from_numpy(c.to_numpy()[:, None]).float()
    d = torch.from_numpy(d.to_numpy()[:, None]).float()

    # print(a.shape)
    
    # x = torch.Tensor([a, b, c, d])
    x = torch.cat([a, b, c, d], -1)
    # print(x)

    out = load_iris.model(x).detach().argmax(dim=1)

    # print(out)
    
    return pa.array(out.numpy())

con.sql("PRAGMA enable_profiling='json'")
con.create_function("predict_iris", predict_iris, [FLOAT, FLOAT, FLOAT, FLOAT], FLOAT, side_effects=True, type='arrow')
res = con.sql(f"EXPLAIN ANALYZE SELECT *, predict_iris(sepal_length, sepal_width, sepal_width, sepal_width) FROM iris_20_0 LIMIT {workload}").fetchall()

con.close()

  a = torch.from_numpy(a.to_numpy()[:, None]).float()


In [8]:
# pprint(res)

In [9]:
prediction = json.loads(res[0][1])['children'][0]['children'][0]['children'][0]['children'][0]
times["predict"] = prediction['timing'] * 1000000

move = prediction['children'][0]
times["move"] = move['timing'] * 1000000

In [10]:
prediction

{'name': 'PROJECTION',
 'timing': 550.454093,
 'cardinality': 20000000,
 'extra_info': 'sepal_length\nsepal_width\npetal_length\npetal_width\npredict_iris(sepal_length, sepal_width, sepal_width, sepal_width)\n',
 'timings': [],
 'children': [{'name': 'SEQ_SCAN ',
   'timing': 0.891964,
   'cardinality': 20000000,
   'extra_info': 'iris_20_0\n[INFOSEPARATOR]\nsepal_length\nsepal_width\npetal_length\npetal_width\n[INFOSEPARATOR]\nEC: 20000000',
   'timings': [],
   'children': []}]}

In [11]:
times_df = pd.DataFrame.from_records([times]).loc[:, ['dataset', 'pipeline', 'workload', 'move', 'load', 'predict']]

con = duckdb.connect("../test.db")
tables = con.sql(f"SHOW TABLES").df()
has_table = tables[tables.name == 'times'].shape[0] == 1

if drop_table:
    con.sql(f"DROP TABLE IF EXISTS times")
    
if not has_table:
    con.sql(f"CREATE TABLE times (dataset varchar, workload integer, move integer, load integer, predict integer, ts timestamp DEFAULT current_timestamp)")

con.sql(f"INSERT INTO times (dataset, pipeline, workload, move, load, predict) SELECT * FROM times_df")
con.sql(f"SELECT * FROM times").show()
con.close()

┌─────────┬──────────┬─────────┬───────┬────────────┬─────────────────────────┬────────────────────────┐
│ dataset │ workload │  move   │ load  │  predict   │           ts            │        pipeline        │
│ varchar │  int32   │  int32  │ int32 │   int32    │        timestamp        │        varchar         │
├─────────┼──────────┼─────────┼───────┼────────────┼─────────────────────────┼────────────────────────┤
│ iris    │  1000000 │   66407 │ 11404 │     644862 │ 2024-04-16 15:39:06.832 │ duckdb_python          │
│ iris    │  1000000 │   69000 │ 14205 │     711004 │ 2024-04-16 15:39:13.889 │ duckdb_python          │
│ iris    │  1000000 │   87882 │ 15013 │     714544 │ 2024-04-16 15:39:37.287 │ duckdb_python          │
│ iris    │  1000000 │   66269 │ 14332 │     696981 │ 2024-04-16 15:39:42.694 │ duckdb_python          │
│ iris    │  1000000 │   75381 │ 14991 │     691012 │ 2024-04-16 15:39:54.801 │ duckdb_python          │
│ iris    │  5000000 │  328913 │ 15558 │    5179009 │ 2

In [12]:
con.close()