In [1]:
# default_exp core

# XGB2SQL

> API details.

In [2]:
#hide
from nbdev.showdoc import *

In [3]:
#export
import json
from typing import List


def clean_multiline_str(prefix_nls: int = 0, suffix_nls: int = 0, spaces: int = 0):
    def wrap(multi_line_str_func: Callable[..., str]):
        def inner(*args, **kwargs):
            s = multi_line_str_func(*args, **kwargs)
            pnls = "\n" * prefix_nls
            snls = "\n" * suffix_nls
            ses = " " * spaces
            sl = [line for line in s.split("\n")]
            min_lstrip = min(
                [len(line) - len(line.lstrip()) for line in sl if line.strip()]
            )
            sl = [line[min_lstrip:] if len(
                line) >= min_lstrip else "" for line in sl]
            sl = [f"{ses}{line}" for line in sl]
            s = "\n".join(sl)
            s = s.strip("\n")
            s = f"{pnls}{s}{snls}"
            return s
        return inner
    return wrap


class XGBFmap:
    '''
    This exists because xgb is a kinda questionably coded library (with incredible math behind it). Pass through a list of integers that aren't indicators (1/0) to here or else the fmap will break.
    '''
    def __init__(self,
                 path: Optional[str] = None,
                 file_txt: str = "fmap.txt"
                 ) -> None:
        if path:
            self.PATH = Path(path) / file_txt
        else:
            self.PATH = Path.cwd() / file_txt

    def create_fmap(self,
                    X: pd.DataFrame,
                    non_indicator_ints: Optional[List[str]] = None
                    ) -> None:
        if not non_indicator_ints:
            non_indicator_ints = []

        fmaps_df = (
            pd.DataFrame(X.dtypes)
            .reset_index()
            .rename(columns={"index": "name", 0: "dtype"})
            .assign(
                typ=lambda d: d.apply(
                    lambda x:
                    "int" if x["name"] in non_indicator_ints
                    else "i" if str(x["dtype"]).startswith("int")
                    else "q",
                    axis=1),
                row=lambda d: d.index,
            )
            .loc[:, ["row", "name", "typ"]]
        )

        fmaps_df.to_csv(self.PATH, header=False, index=False, sep=" ")

    def delete_fmap(self) -> None:
        Path.unlink(self.PATH)

class XGB2SQL(object):
    def __init__(self,
                 xgb_model,
                 model_type: Literal["Classifier", "Regressor"],
                 model_name: str = "XGB_model",
                 data_source_name: str,
                 index_columns: Optional[List[str]] = None,
                 fmap_path: Optional[str] = None,
                 fmap_file_txt: str = "fmap.txt",
                 output_path: Optional[str] = None,
                 output_filename: str = "model.sql"
                 ) -> None:
    """
    Takes in an XGB model and converts it to a SQL query. 
    Look, I'm not saying you should use this, but I'm saying it now exists. An example use case would be training your data on a small sample set, running this on the results of that model, and then leveraging Redshift/BigQuery to cheaply and quickly generate billions of predictions.
    I imagine any sort of tree based model could be relatively easily converted to a SQL query using this.
    
    Parameters
    ----------
    xgb_model: xgboost
        https://xgboost.readthedocs.io/en/latest/tutorials/model.html
    model_type:
        The way branches are converted into predictions is different for the two types of models but the concept is the same
    model_name:
        Useful for versioning.
    data_source_name:
        The name of the SQL table to query from. Obviously this table must have the same columns as the model inputs or else it won't work.
    index_columns:
        Anything in the list will be passed through as a column in your final output. 
    """

        self.xgb_base_score = xgb_model.base_score
        self.xgb_booster = xgb_model.get_booster()

        self.model_type = model_type
        self.model_name = model_name

        self.data_source_name = data_source_name
        if index_columns:
            self.index_columns = index_columns
        else:
            self.index_columns = []
        self.index_string = ", ".join(self.index_columns)

        if fmap_path:
            self.FMAP_PATH = Path(fmap_path) / fmap_file_txt
        else:
            self.FMAP_PATH = Path.cwd() / fmap_file_txt

        if output_path:
            self.OUTPUT_PATH = Path(output_path) / output_filename
        else:
            self.OUTPUT_PATH = Path.cwd() / output_filename

    def _json_parse(self) -> str:
        # fmap must be read in from a file path
        ret = self.xgb_booster.get_dump(
            dump_format="json", fmap=self.FMAP_PATH.as_posix()
        )
        clean_string = ", ".join(ret).replace("\n", "")
        json_string = f"[{clean_string}]"

        return json.loads(json_string)

    @clean_multiline_str(prefix_nls=2, suffix_nls=1)
    def _sql_eval(self,
                  columns: List[str]
                  ) -> str:
        column_string = " + ".join(columns)

        if self.model_type == "Classifier":
            # Note: xgboost doesn't use base_score for predict_proba(X)[:,1]
            score_string = (
                f"1 / ( 1 + `EXP` ( - ({column_string}) ) )"
            )
        else:
            score_string = f"{column_string} + {self.xgb_base_score}"

        if self.index_columns:
            query = f"""
                SELECT
                  {self.index_string},
                  {score_string} AS score,
                  '{self.model_name}' AS model_name,
                  CURRENT_TIMESTAMP AS _created_at
                FROM booster_output
                """
        else:
            query = f"""
                SELECT
                  {score_string} AS score,
                  '{self.model_name}' AS model_name,
                  CURRENT_TIMESTAMP AS _created_at
                FROM booster_output
                """

        return query

    def _extract_clean_json(self, obj):
        """
        This must always be applied to individual estimators within the json, not the full json.
        """

        key_dict = {}
        info_dict = {}

        def _extract(obj, prev=None):

            if isinstance(obj, dict):
                if "leaf" in obj:
                    key_dict.update({obj["nodeid"]: obj["leaf"]})
                    info_dict.update({obj["nodeid"]: {"parent": prev}})
                elif "split_condition" not in obj:
                    info_dict.update(
                        {
                            obj["nodeid"]: {
                                "parent": prev,
                                "split_column": obj["split"],
                                "split_number": 1,
                                "if_equal_to": obj["yes"],
                                "if_not_equal_to": obj["no"],
                            }
                        }
                    )
                else:
                    info_dict.update(
                        {
                            obj["nodeid"]: {
                                "parent": prev,
                                "split_column": obj["split"],
                                "split_number": obj["split_condition"],
                                "if_less_than": obj["yes"],
                                "if_greater_than": obj["no"],
                                "if_null": obj["missing"],
                            }
                        }
                    )

                prev = obj["nodeid"]

                for k, v in obj.items():
                    if isinstance(v, list):
                        _extract(v, prev)

            elif isinstance(obj, list):
                for item in obj:
                    _extract(item, prev)

            return key_dict

        results = _extract(obj)
        return results, info_dict

    @clean_multiline_str(prefix_nls=1, spaces=18)
    def _recurse_case_whens(self,
                            leaf_id: str,
                            leaf_value: float,
                            splits
                            ) -> str:

        query_list: List[str] = []

        def _recurse(x) -> None:
            prev_node_id = x
            next_node_id = splits[prev_node_id]["parent"]

            if next_node_id is not None:
                next_node = splits[next_node_id]
                if "if_less_than" in next_node:
                    if (next_node["if_less_than"] == prev_node_id) & (
                        next_node["if_less_than"] == next_node["if_null"]
                    ):
                        text = f"(({next_node['split_column']} < {next_node['split_number']}) OR ({next_node['split_column']} IS NULL))"  # noqa: E501
                        query_list.insert(0, text)
                        _recurse(next_node_id)
                    elif next_node["if_less_than"] == prev_node_id:
                        text = f"({next_node['split_column']} < {next_node['split_number']})"
                        query_list.insert(0, text)
                        _recurse(next_node_id)
                    elif (next_node["if_greater_than"] == prev_node_id) & (
                        next_node["if_greater_than"] == next_node["if_null"]
                    ):
                        text = f"(({next_node['split_column']} >= {next_node['split_number']}) OR ({next_node['split_column']} IS NULL))"  # noqa: E501
                        query_list.insert(0, text)
                        _recurse(next_node_id)
                    elif next_node["if_greater_than"] == prev_node_id:
                        text = f"({next_node['split_column']} >= {next_node['split_number']})"
                        query_list.insert(0, text)
                        _recurse(next_node_id)
                elif "if_equal_to" in next_node:
                    if next_node["if_equal_to"] == prev_node_id:
                        text = f"({next_node['split_column']} = {next_node['split_number']})"
                        query_list.insert(0, text)
                        _recurse(next_node_id)
                    elif next_node["if_not_equal_to"] == prev_node_id:
                        text = f"({next_node['split_column']} <> {next_node['split_number']})"
                        query_list.insert(0, text)
                        _recurse(next_node_id)

        _recurse(leaf_id)

        if query_list:
            out_str = "WHEN " + \
                " AND ".join(query_list) + f" THEN {leaf_value}"
        else:
            out_str = "WHEN " + "TRUE" + f" THEN {leaf_value}"
        return out_str

    def run(self) -> None:
        # Get xgb in json format
        tree_json = self._json_parse()
        leaf_list = []
        columns = []

        # For each estimator
        for i in range(len(tree_json)):

            # Recurse to clean json and generate CASE WHENs
            leaves, splits = self._extract_clean_json(tree_json[i])
            estimator_list = []
            for leaf_id, leaf_value in leaves.items():
                estimator_list.append(
                    self._recurse_case_whens(leaf_id, leaf_value, splits)
                )
            column = f"column_{i}"
            columns.append(column)

            # Generate the full CASE WHEN
            estimator = f"""
                CASE {''.join(estimator_list)}
                END AS {column}"""
            estimator = clean_multiline_str(prefix_nls=0, suffix_nls=0, spaces=16)(
                lambda: estimator
            )()
            leaf_list.append(estimator)

        # Combine all estimators
        query_top = f"""
            WITH booster_output AS (
              SELECT
                {', '.join(self.index_columns + leaf_list)}
              FROM {self.data_source_name}
            )
            """
        query_top = clean_multiline_str(prefix_nls=0, suffix_nls=0, spaces=0)(
            lambda: query_top
        )()

        # Create the last SELECT clause
        query_bottom = self._sql_eval(columns)

        # Final query
        query = f"{query_top}{query_bottom}"

        with open(self.OUTPUT_PATH, "w") as f:
            f.write(query)


In [12]:
import xgboost as xgb
from xgb2sql import XGBFmap, XGB2SQL

In [11]:
import pandas as pd
from sklearn.datasets import load_breast_cancer

X, y = load_breast_cancer(return_X_y=True)
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)



In [19]:
m = xgb.XGBClassifier(n_estimators=5)
m.fit(X_train, y_train)

XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
              colsample_bynode=1, colsample_bytree=1, gamma=0,
              learning_rate=0.1, max_delta_step=0, max_depth=3,
              min_child_weight=1, missing=None, n_estimators=5, n_jobs=1,
              nthread=None, objective='binary:logistic', random_state=0,
              reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
              silent=None, subsample=1, verbosity=1)

In [20]:
xgbFmap = XGBFmap()
xgbFmap.create_fmap(X)
tree = XGB2SQL(xgb_model=m, model_type='Classifier',data_source_name='breast_cancer')
xgbFmap.delete_fmap()

In [21]:
print(tree.run())

WITH booster_output AS (
	SELECT
		CASE
			WHEN ((f7 < 0.0489199981) OR (f7 IS NULL))
			AND ((f20 < 16.8250008) OR (f20 IS NULL))
			AND ((f10 < 0.591250002) OR (f10 IS NULL))
		THEN 0.191869915
			WHEN ((f7 < 0.0489199981) OR (f7 IS NULL))
			AND ((f20 < 16.8250008) OR (f20 IS NULL))
			AND (f10 >= 0.591250002)
		THEN 0
			WHEN ((f7 < 0.0489199981) OR (f7 IS NULL))
			AND (f20 >= 16.8250008)
			AND ((f1 < 18.9599991) OR (f1 IS NULL))
		THEN 0.120000005
			WHEN ((f7 < 0.0489199981) OR (f7 IS NULL))
			AND (f20 >= 16.8250008)
			AND (f1 >= 18.9599991)
		THEN -0.13333334
			WHEN (f7 >= 0.0489199981)
			AND ((f23 < 785.799988) OR (f23 IS NULL))
			AND ((f21 < 23.7399998) OR (f21 IS NULL))
		THEN 0.155555561
			WHEN (f7 >= 0.0489199981)
			AND ((f23 < 785.799988) OR (f23 IS NULL))
			AND (f21 >= 23.7399998)
		THEN -0.100000001
			WHEN (f7 >= 0.0489199981)
			AND (f23 >= 785.799988)
			AND ((f1 < 14.3000002) OR (f1 IS NULL))
		THEN 0
			WHEN (f7 >= 0.0489199981)
			AND (f23 >= 785.799988)


In [22]:
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted index.ipynb.
