A demonstration of converting a [Python vtreat](https://github.com/WinVector/vtreat) transformation into a [data algebra](https://github.com/WinVector/data_algebra) pipeline, which can then in turn be converted to SQL queries.
[R vtreat](https://winvector.github.io/vtreat/) already has similar functionality with [as_rquery_plan()](https://winvector.github.io/vtreat/reference/as_rquery_plan.html).

Let's demonstrate this with a simple problem.

First we import our modules.

In [11]:
import pandas as pd

from data_algebra.data_ops import *
import data_algebra.SQLite
import vtreat
from vtreat_db_adapter import as_data_algebra_pipeline

In [12]:
# Data from:
# https://archive.ics.uci.edu/ml/datasets/Diabetes+130-US+hospitals+for+years+1999-2008

data = pd.read_csv("diabetes_head.csv")
n = data.shape[0]
data['orig_index'] = range(n)

data


Unnamed: 0,encounter_id,patient_nbr,race,gender,age,weight,admission_type_id,discharge_disposition_id,admission_source_id,time_in_hospital,...,glipizide-metformin,glimepiride-pioglitazone,metformin-rosiglitazone,metformin-pioglitazone,change,diabetesMed,readmitted,visit_number,revisit,orig_index
0,2278392,8222157,Caucasian,Female,[0-10),,6,25,1,1,...,No,No,No,No,No,No,False,1,False,0
1,64410,86047875,AfricanAmerican,Female,[20-30),,1,1,7,2,...,No,No,No,No,No,Yes,False,1,False,1
2,500364,82442376,Caucasian,Male,[30-40),,1,1,7,2,...,No,No,No,No,Ch,Yes,False,1,False,2
3,35754,82637451,Caucasian,Male,[50-60),,2,1,2,3,...,No,No,No,No,No,Yes,False,1,False,3
4,55842,84259809,Caucasian,Male,[60-70),,3,1,2,4,...,No,No,No,No,Ch,Yes,False,1,False,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,8860284,94419315,Hispanic,Male,[50-60),,6,1,17,3,...,No,No,No,No,Ch,Yes,False,1,False,995
996,8860944,338247,Caucasian,Female,[60-70),,1,1,7,4,...,No,No,No,No,No,Yes,False,2,True,996
997,8864718,695439,Caucasian,Male,[70-80),,1,1,7,1,...,No,No,No,No,No,Yes,False,1,False,997
998,8866632,103586670,Caucasian,Male,[70-80),[100-125),6,1,17,6,...,No,No,No,No,No,Yes,False,1,False,998


Unnamed: 0,encounter_id,patient_nbr,race,gender,age,weight,admission_type_id,discharge_disposition_id,admission_source_id,time_in_hospital,...,glipizide-metformin,glimepiride-pioglitazone,metformin-rosiglitazone,metformin-pioglitazone,change,diabetesMed,readmitted,visit_number,revisit,orig_index
0,2278392,8222157,Caucasian,Female,[0-10),,6,25,1,1,...,No,No,No,No,No,No,False,1,False,0
1,64410,86047875,AfricanAmerican,Female,[20-30),,1,1,7,2,...,No,No,No,No,No,Yes,False,1,False,1
2,500364,82442376,Caucasian,Male,[30-40),,1,1,7,2,...,No,No,No,No,Ch,Yes,False,1,False,2
3,35754,82637451,Caucasian,Male,[50-60),,2,1,2,3,...,No,No,No,No,No,Yes,False,1,False,3
4,55842,84259809,Caucasian,Male,[60-70),,3,1,2,4,...,No,No,No,No,Ch,Yes,False,1,False,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,8860284,94419315,Hispanic,Male,[50-60),,6,1,17,3,...,No,No,No,No,Ch,Yes,False,1,False,995
996,8860944,338247,Caucasian,Female,[60-70),,1,1,7,4,...,No,No,No,No,No,Yes,False,2,True,996
997,8864718,695439,Caucasian,Male,[70-80),,1,1,7,1,...,No,No,No,No,No,Yes,False,1,False,997
998,8866632,103586670,Caucasian,Male,[70-80),[100-125),6,1,17,6,...,No,No,No,No,No,Yes,False,1,False,998


In [13]:
outcome_name = "readmitted"
cols_to_copy = ["orig_index", "encounter_id", "patient_nbr"] + [outcome_name]
#vars = [c for c in data.columns if c not in cols_to_copy]
vars = ['time_in_hospital', 'weight']
columns = vars + cols_to_copy

data.loc[:, columns]


Unnamed: 0,time_in_hospital,weight,orig_index,encounter_id,patient_nbr,readmitted
0,1,,0,2278392,8222157,False
1,2,,1,64410,86047875,False
2,2,,2,500364,82442376,False
3,3,,3,35754,82637451,False
4,4,,4,55842,84259809,False
...,...,...,...,...,...,...
995,3,,995,8860284,94419315,False
996,4,,996,8860944,338247,False
997,1,,997,8864718,695439,False
998,6,[100-125),998,8866632,103586670,False


Unnamed: 0,time_in_hospital,weight,orig_index,encounter_id,patient_nbr,readmitted
0,1,,0,2278392,8222157,False
1,2,,1,64410,86047875,False
2,2,,2,500364,82442376,False
3,3,,3,35754,82637451,False
4,4,,4,55842,84259809,False
...,...,...,...,...,...,...
995,3,,995,8860284,94419315,False
996,4,,996,8860944,338247,False
997,1,,997,8864718,695439,False
998,6,[100-125),998,8866632,103586670,False


In [14]:
treatment = vtreat.BinomialOutcomeTreatment(
    cols_to_copy=cols_to_copy,
    outcome_name=outcome_name,
    outcome_target=True,
    params=vtreat.vtreat_parameters(
        {"sparse_indicators": False, "filter_to_recommended": False,}
    ),
)
data_treated = treatment.fit_transform(data.loc[:, columns])

data_treated

Unnamed: 0,orig_index,encounter_id,patient_nbr,readmitted,weight_is_bad,time_in_hospital,weight_logit_code,weight_prevalence_code,weight_lev__NA_
0,0,2278392,8222157,False,1.0,1.0,0.005970,0.993,1.0
1,1,64410,86047875,False,1.0,2.0,0.007204,0.993,1.0
2,2,500364,82442376,False,1.0,2.0,0.007204,0.993,1.0
3,3,35754,82637451,False,1.0,3.0,0.005970,0.993,1.0
4,4,55842,84259809,False,1.0,4.0,0.006001,0.993,1.0
...,...,...,...,...,...,...,...,...,...
995,995,8860284,94419315,False,1.0,3.0,0.007162,0.993,1.0
996,996,8860944,338247,False,1.0,4.0,0.007204,0.993,1.0
997,997,8864718,695439,False,1.0,1.0,0.007199,0.993,1.0
998,998,8866632,103586670,False,0.0,6.0,0.000000,0.001,0.0


Unnamed: 0,orig_index,encounter_id,patient_nbr,readmitted,weight_is_bad,time_in_hospital,weight_logit_code,weight_prevalence_code,weight_lev__NA_
0,0,2278392,8222157,False,1.0,1.0,0.004770,0.993,1.0
1,1,64410,86047875,False,1.0,2.0,0.007164,0.993,1.0
2,2,500364,82442376,False,1.0,2.0,0.007201,0.993,1.0
3,3,35754,82637451,False,1.0,3.0,0.004770,0.993,1.0
4,4,55842,84259809,False,1.0,4.0,0.007164,0.993,1.0
...,...,...,...,...,...,...,...,...,...
995,995,8860284,94419315,False,1.0,3.0,0.007204,0.993,1.0
996,996,8860944,338247,False,1.0,4.0,0.007201,0.993,1.0
997,997,8864718,695439,False,1.0,1.0,0.007164,0.993,1.0
998,998,8866632,103586670,False,0.0,6.0,0.000000,0.001,0.0


In [15]:
transform_as_data = treatment.description_matrix()

transform_as_data

Unnamed: 0,treatment_class,treatment,orig_var,variable,value,replacement
0,IndicateMissingTransform,missing_indicator,weight,weight_is_bad,_NA_,1.0
1,CleanNumericTransform,clean_copy,time_in_hospital,time_in_hospital,_NA_,4.802
2,MappedCodeTransform,logit_code,weight,weight_logit_code,[0-25),0.0
3,MappedCodeTransform,logit_code,weight,weight_logit_code,[100-125),0.0
4,MappedCodeTransform,logit_code,weight,weight_logit_code,[50-75),0.0
5,MappedCodeTransform,logit_code,weight,weight_logit_code,[75-100),-2.129774
6,MappedCodeTransform,logit_code,weight,weight_logit_code,_NA_,0.006737
7,MappedCodeTransform,prevalence_code,weight,weight_prevalence_code,[0-25),0.001
8,MappedCodeTransform,prevalence_code,weight,weight_prevalence_code,[100-125),0.001
9,MappedCodeTransform,prevalence_code,weight,weight_prevalence_code,[50-75),0.001


Unnamed: 0,treatment_class,treatment,orig_var,variable,value,replacement
0,IndicateMissingTransform,missing_indicator,weight,weight_is_bad,_NA_,1.0
1,CleanNumericTransform,clean_copy,time_in_hospital,time_in_hospital,_NA_,4.802
2,MappedCodeTransform,logit_code,weight,weight_logit_code,[0-25),0.0
3,MappedCodeTransform,logit_code,weight,weight_logit_code,[100-125),0.0
4,MappedCodeTransform,logit_code,weight,weight_logit_code,[50-75),0.0
5,MappedCodeTransform,logit_code,weight,weight_logit_code,[75-100),-2.129774
6,MappedCodeTransform,logit_code,weight,weight_logit_code,_NA_,0.006737
7,MappedCodeTransform,prevalence_code,weight,weight_prevalence_code,[0-25),0.001
8,MappedCodeTransform,prevalence_code,weight,weight_prevalence_code,[100-125),0.001
9,MappedCodeTransform,prevalence_code,weight,weight_prevalence_code,[50-75),0.001


In [16]:
ops = as_data_algebra_pipeline(
    source=descr(data=data.loc[:, columns]),
    vtreat_descr=transform_as_data,
    treatment_table_name='transform_as_data',
)

print(ops)

(
    TableDescription(
        table_name="data",
        column_names=[
            "time_in_hospital",
            "weight",
            "orig_index",
            "encounter_id",
            "patient_nbr",
            "readmitted",
        ],
    )
    .extend(
        {
            "weight_is_bad": "(weight.is_bad()).if_else(1.0, 0.0)",
            "weight_lev__NA_": "(weight.coalesce('_NA_') == '_NA_').if_else(1.0, 0.0)",
        }
    )
    .extend({"weight": "weight.coalesce('_NA_')"})
    .natural_join(
        b=TableDescription(
            table_name="transform_as_data",
            column_names=[
                "treatment_class",
                "treatment",
                "orig_var",
                "variable",
                "value",
                "replacement",
            ],
        )
        .select_rows(
            "(treatment_class == 'MappedCodeTransform') and (orig_var == 'weight') and (variable == 'weight_logit_code')"
        )
        .extend({"weight": "v

In [17]:
transformed = ops.eval({'data': data.loc[:, columns], 'transform_as_data': transform_as_data})

transformed

Unnamed: 0,time_in_hospital,orig_index,encounter_id,patient_nbr,readmitted,weight_is_bad,weight_lev__NA_,weight_logit_code,weight_prevalence_code
0,1,0,2278392,8222157,False,1.0,1.0,0.006737,0.993
1,2,1,64410,86047875,False,1.0,1.0,0.006737,0.993
2,2,2,500364,82442376,False,1.0,1.0,0.006737,0.993
3,3,3,35754,82637451,False,1.0,1.0,0.006737,0.993
4,4,4,55842,84259809,False,1.0,1.0,0.006737,0.993
...,...,...,...,...,...,...,...,...,...
995,3,995,8860284,94419315,False,1.0,1.0,0.006737,0.993
996,4,996,8860944,338247,False,1.0,1.0,0.006737,0.993
997,1,997,8864718,695439,False,1.0,1.0,0.006737,0.993
998,6,998,8866632,103586670,False,0.0,0.0,0.000000,0.001


Unnamed: 0,time_in_hospital,orig_index,encounter_id,patient_nbr,readmitted,weight_is_bad,weight_lev__NA_,weight_logit_code,weight_prevalence_code
0,1,0,2278392,8222157,False,1.0,1.0,0.006737,0.993
1,2,1,64410,86047875,False,1.0,1.0,0.006737,0.993
2,2,2,500364,82442376,False,1.0,1.0,0.006737,0.993
3,3,3,35754,82637451,False,1.0,1.0,0.006737,0.993
4,4,4,55842,84259809,False,1.0,1.0,0.006737,0.993
...,...,...,...,...,...,...,...,...,...
995,3,995,8860284,94419315,False,1.0,1.0,0.006737,0.993
996,4,996,8860944,338247,False,1.0,1.0,0.006737,0.993
997,1,997,8864718,695439,False,1.0,1.0,0.006737,0.993
998,6,998,8866632,103586670,False,0.0,0.0,0.000000,0.001


In [18]:
db_handle = data_algebra.SQLite.example_handle()

sql = db_handle.to_sql(ops)
print(sql)

-- data_algebra SQL https://github.com/WinVector/data_algebra
--  dialect: SQLiteModel
--       string quote: '
--   identifier quote: "
WITH
 "extend_0" AS (
  SELECT  -- .extend({ 'weight_is_bad': '(weight.is_bad()).if_else(1.0, 0.0)', 'weight_lev__NA_': "(weight.coalesce('_NA_') == '_NA_').if_else(1.0, 0.0)"})
   "orig_index" ,
   "encounter_id" ,
   "time_in_hospital" ,
   "patient_nbr" ,
   "readmitted" ,
   "weight" ,
   CASE WHEN is_bad("weight") THEN 1.0 WHEN NOT is_bad("weight") THEN 0.0 ELSE NULL END AS "weight_is_bad" ,
   CASE WHEN (COALESCE("weight", '_NA_') = '_NA_') THEN 1.0 WHEN NOT (COALESCE("weight", '_NA_') = '_NA_') THEN 0.0 ELSE NULL END AS "weight_lev__NA_"
  FROM
   "data"
 ) ,
 "extend_1" AS (
  SELECT  -- .extend({ 'weight': "weight.coalesce('_NA_')"})
   "orig_index" ,
   "encounter_id" ,
   "time_in_hospital" ,
   "weight_is_bad" ,
   "patient_nbr" ,
   "weight_lev__NA_" ,
   "readmitted" ,
   COALESCE("weight", '_NA_') AS "weight"
  FROM
   "extend_0"
 ) ,
 

In [19]:
db_handle.insert_table(data.loc[:, columns], table_name='data')
db_handle.insert_table(transform_as_data, table_name='transform_as_data')

db_handle.execute('CREATE TABLE res AS ' + sql)

lst_ids = str(tuple(list(range(5)) + list(range(n-5, n))))
q = f'SELECT * FROM res WHERE orig_index IN {lst_ids} ORDER BY orig_index LIMIT 10'
print(q)
db_handle.read_query(q)

SELECT * FROM res WHERE orig_index IN (0, 1, 2, 3, 4, 995, 996, 997, 998, 999) ORDER BY orig_index LIMIT 10


Unnamed: 0,orig_index,weight_logit_code,encounter_id,weight_is_bad,patient_nbr,weight_prevalence_code,weight_lev__NA_,readmitted,time_in_hospital
0,0,0.006737,2278392,1.0,8222157,0.993,1.0,0,1
1,1,0.006737,64410,1.0,86047875,0.993,1.0,0,2
2,2,0.006737,500364,1.0,82442376,0.993,1.0,0,2
3,3,0.006737,35754,1.0,82637451,0.993,1.0,0,3
4,4,0.006737,55842,1.0,84259809,0.993,1.0,0,4
5,995,0.006737,8860284,1.0,94419315,0.993,1.0,0,3
6,996,0.006737,8860944,1.0,338247,0.993,1.0,0,4
7,997,0.006737,8864718,1.0,695439,0.993,1.0,0,1
8,998,0.0,8866632,0.0,103586670,0.001,0.0,0,6
9,999,0.006737,8867106,1.0,4988970,0.993,1.0,0,9


SELECT * FROM res WHERE orig_index IN (0, 1, 2, 3, 4, 995, 996, 997, 998, 999) ORDER BY orig_index LIMIT 10


Unnamed: 0,orig_index,weight_logit_code,encounter_id,weight_is_bad,patient_nbr,weight_prevalence_code,weight_lev__NA_,readmitted,time_in_hospital
0,0,0.006737,2278392,1.0,8222157,0.993,1.0,0,1
1,1,0.006737,64410,1.0,86047875,0.993,1.0,0,2
2,2,0.006737,500364,1.0,82442376,0.993,1.0,0,2
3,3,0.006737,35754,1.0,82637451,0.993,1.0,0,3
4,4,0.006737,55842,1.0,84259809,0.993,1.0,0,4
5,995,0.006737,8860284,1.0,94419315,0.993,1.0,0,3
6,996,0.006737,8860944,1.0,338247,0.993,1.0,0,4
7,997,0.006737,8864718,1.0,695439,0.993,1.0,0,1
8,998,0.0,8866632,0.0,103586670,0.001,0.0,0,6
9,999,0.006737,8867106,1.0,4988970,0.993,1.0,0,9


In [20]:
db_handle.close()