# K fold cross validation with BQML
***Make a copy of this notebook***

Created By: steveswalker@

Created On: 08/01/2021

Updated On: 08/01/2021
### This notebook is hard-coded to perform a BOOSTED_TREE_CLASSIFIER.  It will create a query for each kfold, then build k number of models consecutively, asynchronously.  
Thanks, rthallam@ for the callback example

Improvement opps:
*   Add eval results to a permanent BQ table to allow for historical kfold tracking
*   Add different BQML model support
*   Consider Project Beatrix for scheduling/'callable' options


In [None]:
import os
import time

from google.cloud import bigquery

!pip install ipython-autotime
# times each cell runtime.
%load_ext autotime

PROJECT = "bigquery-test-project-166321"  # changeme
DS = "bqml"  # changeme
MODEL_NAME = "kfold_"
K = 5  # change if necessary

In [None]:
from google.colab import auth

auth.authenticate_user()
os.environ["GCLOUD_PROJECT"] = PROJECT
print("Authenticated and set project to {}".format(os.environ["GCLOUD_PROJECT"]))

In [None]:
client = bigquery.Client()
poll_jobs = set()

In [None]:
# set up array to hold queries
queries = []
X = 1
for i in range(0, K):
    # an easy way to debug is, instead of doing a queries.append here, just do a print(), that way you can see the final, formatted sql
    queries.append(
        "CREATE OR REPLACE MODEL `"
        + PROJECT
        + "."
        + DS
        + "."
        + MODEL_NAME
        + str(X)
        + "` OPTIONS(model_type='BOOSTED_TREE_CLASSIFIER', DATA_SPLIT_METHOD='CUSTOM', DATA_SPLIT_COL='bCustomSplit', input_label_cols=['income_bracket']) as with cte as (SELECT *, CONCAT('K',CAST(MOD(ABS(FARM_FINGERPRINT(to_json_STRING(x))),"
        + str(K)
        + ") +1 as STRING)) as ksplit_col FROM  `bigquery-public-data.ml_datasets.census_adult_income` x) SELECT age, workclass, functional_weight, education, education_num, marital_status,  relationship, race, sex, capital_gain, capital_loss, hours_per_week, native_country, income_bracket,ksplit_col, IF(ksplit_col = 'K"
        + str(X)
        + "',TRUE,FALSE) bCustomSplit from cte"
    )
    X = X + 1

In [None]:
def handle_status(query_status):
    status = "{}".format(query_status.state)
    if status == "RUNNING":
        print(
            "Job {} is currently in state {}".format(
                query_status.job_id, query_status.state
            )
        )
    elif status == "DONE":
        print(
            "Job {} is currently in state {}".format(
                query_status.job_id, query_status.state
            )
        )
    elif status == "SUCCESS":
        print(
            "Job {} is currently in state {}".format(
                query_status.job_id, query_status.state
            )
        )
    elif status == "FAILURE":
        print(
            "Job {} is currently in state {} with error: {}".format(
                query_status.job_id, query_status.state, query_status.error_result
            )
        )

In [None]:
def flag_completed_query(future):
    query_status = client.get_job(future.job_id, location=future.location)
    handle_status(query_status)
    poll_jobs.discard(future.job_id)

In [None]:
# loop through queries array and run each query, asynch, as a bq job
X = 1
for i in range(0, K):
    query_job = client.query(
        "CREATE OR REPLACE MODEL `"
        + PROJECT
        + "."
        + DS
        + "."
        + MODEL_NAME
        + str(X)
        + "` OPTIONS(model_type='BOOSTED_TREE_CLASSIFIER', DATA_SPLIT_METHOD='CUSTOM', DATA_SPLIT_COL='bCustomSplit', \
    input_label_cols=['income_bracket']) as with cte as (SELECT *, CONCAT('K',CAST(MOD(ABS(FARM_FINGERPRINT(to_json_STRING(x))),"
        + str(K)
        + ") +1 as STRING)) as ksplit_col FROM  `bigquery-public-data.ml_datasets.census_adult_income` x) \
    SELECT age, workclass, functional_weight, education, education_num, marital_status,  relationship, race, sex, capital_gain, capital_loss, hours_per_week, native_country, income_bracket,ksplit_col, IF(ksplit_col = 'K"
        + str(X)
        + "',TRUE,FALSE) bCustomSplit from cte"
    )

    poll_jobs.add(query_job.job_id)
    # add callbabck function from query jobs.  Callback will automagically notify our colab when it is done
    query_job.add_done_callback(flag_completed_query)
    query_status = client.get_job(query_job.job_id, location=query_job.location)
    handle_status(query_status)
    X = X + 1

while poll_jobs:
    print("waiting for queries to finish ... sleeping for 23s")
    time.sleep(23)

In [None]:
# we created the models with "CUSTOM" split, and used the kfold ID column as the split info
# which means the model build process already evaluated the model with the correct eval holdout data
# so, just run the ml.evaluate, with model name, no need to pass in holdout data set
X = 1
poll_jobs = set()
# create new table to hold eval data "Limit 0" returns just schema
query = (
    "create or replace table "
    + PROJECT
    + "."
    + DS
    + "."
    + "tmp_kfold as select * from ml.evaluate(model `"
    + PROJECT
    + "."
    + DS
    + "."
    + MODEL_NAME
    + str(X)
    + "` Limit 0)"
)
query_job = client.query(query)
while X < K + 1:
    query = (
        "INSERT INTO `"
        + PROJECT
        + "."
        + DS
        + "."
        + "tmp_kfold"
        + "` select * from ml.evaluate(model `"
        + PROJECT
        + "."
        + DS
        + "."
        + MODEL_NAME
        + str(X)
        + "`)"
    )
    query_job = client.query(query)
    poll_jobs.add(query_job.job_id)
    query_job.add_done_callback(flag_completed_query)
    query_status = client.get_job(query_job.job_id, location=query_job.location)
    X += 1

In [None]:
# table with all the kfold results
query = "SELECT * FROM `" + PROJECT + "." + DS + "." + "tmp_kfold" + "`"
df = client.query(query).to_dataframe()
df.head(K)

In [None]:
# avg the ROC AUC across all models, that is your kfold validation score
query = (
    "SELECT AVG(roc_auc) kfold_auc_roc from `"
    + PROJECT
    + "."
    + DS
    + "."
    + "tmp_kfold"
    + "`"
)
df = client.query(query).to_dataframe()
print("kFold Validation ROC AUC = " + str(df.loc[0, "kfold_auc_roc"]))