In [3]:
import wallaroo
from wallaroo.object import EntityNotFoundError
import pandas as pd
import os
#Used for the Wallaroo SDK version 2023.1
os.environ["ARROW_ENABLED"]="True"

# bigquery library
from google.cloud import bigquery
from google.oauth2 import service_account
import db_dtypes

# wl = wallaroo.Client()
# get the arguments

# wallarooPrefix = "product-uat-ee"
# wallarooSuffix = "wallaroocommunity.ninja"

wallarooPrefix = "doc-test"
wallarooSuffix = "wallaroocommunity.ninja"

wl = wallaroo.Client(api_endpoint=f"https://{wallarooPrefix}.api.{wallarooSuffix}", 
                    auth_endpoint=f"https://{wallarooPrefix}.keycloak.{wallarooSuffix}", 
                    auth_type="sso")

if wl.in_task():
    arguments = wl.task_args()
    print(wl.task_args())

    if "workspace_name" in arguments:
        workspace_name = arguments['workspace_name']
    else:
        workspace_name="bigquerystatsmodelworkspace"

    if "pipeline_name" in arguments:
        pipeline_name = arguments['pipeline_name']
    else:
        pipeline_name="bigquerystatsmodelpipeline"

    if "bigquery_connection_input_name" in arguments:
        bigquery_connection_input_name = arguments['bigquery_connection_input_name']
    else:
        bigquery_connection_input_name = "bigqueryforecastinputs"

    if "bigquery_connection_output_name" in arguments:
        bigquery_connection_output_name = arguments['bigquery_connection_output_name']
    else:
        bigquery_connection_output_name = "bigqueryforecastoutputs"
else:
    # we're not in the task, so use the default values
    workspace_name = 'bigquerystatsmodelworkspace'
    pipeline_name = 'bigquerystatsmodelpipeline02'

    bigquery_connection_input_name = "bigqueryforecastinputs"
    
    bigquery_connection_output_name = "bigqueryforecastoutputs"
    
# helper methods to retrieve workspaces and pipelines

def get_workspace(name):
    workspace = None
    for ws in wl.list_workspaces():
        if ws.name() == name:
            workspace= ws
    if(workspace == None):
        workspace = wl.create_workspace(name)
    return workspace

def get_pipeline(name):
    try:
        pipeline = wl.pipelines_by_name(name)[0]
    except EntityNotFoundError:
        pipeline = wl.build_pipeline(name)
    return pipeline

# set the workspace and pipeline
workspace = get_workspace(workspace_name)
wl.set_current_workspace(workspace)
print(wl.get_current_workspace())

pipeline = get_pipeline(pipeline_name)

# deploy the pipeline
print("\nDeploying pipeline.")
print(pipeline_name)
pipeline.deploy()

# get the connections
big_query_input_connection = wl.get_connection(name=bigquery_connection_input_name)
big_query_output_connection = wl.get_connection(name=bigquery_connection_output_name)

# Set the bigquery input and output credentials
bigquery_input_credentials = service_account.Credentials.from_service_account_info(
    big_query_input_connection.details())

bigquery_output_credentials = service_account.Credentials.from_service_account_info(
    big_query_output_connection.details())

# start the input and output clients
bigqueryinputclient = bigquery.Client(
    credentials=bigquery_input_credentials, 
    project=big_query_input_connection.details()['project_id']
)
bigqueryoutputclient = bigquery.Client(
    credentials=bigquery_output_credentials, 
    project=big_query_output_connection.details()['project_id']
)

inference_dataframe_input = bigqueryinputclient.query(
        f"""
        (select dteday, temp, holiday, workingday, windspeed
        FROM {big_query_input_connection.details()['dataset']}.{big_query_input_connection.details()['table']}
        ORDER BY dteday DESC LIMIT 7)
        ORDER BY dteday
        """
    ).to_dataframe().drop(columns=['dteday'])

# convert to a dict, show the first 7 rows
print(f"\n{inference_dataframe_input.to_dict()}")

# perform the inference and display the result
results = pipeline.infer(inference_dataframe_input.to_dict())
print(f"\n{results[0]['forecast']}")

# Get the output table, then upload the inference results
output_table = bigqueryoutputclient.get_table(f"{big_query_output_connection.details()['dataset']}.{big_query_output_connection.details()['table']}")

job = bigqueryoutputclient.query(
        f"""
        INSERT {big_query_output_connection.details()['dataset']}.{big_query_output_connection.details()['table']}
        VALUES
        (current_timestamp(), "{results[0]['forecast']}")
        """
    )

# Show the last 5 output inserts 
# Get the last insert to the output table to verify

task_inference_results = bigqueryoutputclient.query(
        f"""
        SELECT *
        FROM {big_query_output_connection.details()['dataset']}.{big_query_output_connection.details()['table']}
        ORDER BY date DESC
        LIMIT 5
        """
    ).to_dataframe()

print(f"\n{task_inference_results}")

# close the bigquery clients
bigqueryinputclient.close()
bigqueryoutputclient.close()

# deploy the pipeline
print("\nUndeploying pipeline.")
pipeline.undeploy()


{'name': 'bigquerystatsmodelworkspace', 'id': 6, 'archived': False, 'created_by': 'eafd452e-1b6a-4ca4-aac9-1c1da3ee8301', 'created_at': '2023-05-11T18:51:50.813231+00:00', 'models': [{'name': 'bigquerystatsmodelmodel', 'versions': 2, 'owner_id': '""', 'last_update_time': datetime.datetime(2023, 5, 11, 18, 57, 59, 121701, tzinfo=tzutc()), 'created_at': datetime.datetime(2023, 5, 11, 18, 51, 53, 252933, tzinfo=tzutc())}], 'pipelines': [{'name': 'bigquerystatsmodelpipeline02', 'create_time': datetime.datetime(2023, 5, 11, 19, 50, 20, 988291, tzinfo=tzutc()), 'definition': '[]'}, {'name': 'bigquerystatsmodelpipeline', 'create_time': datetime.datetime(2023, 5, 11, 18, 51, 52, 266084, tzinfo=tzutc()), 'definition': '[]'}]}

Deploying pipeline.
bigquerystatsmodelpipeline02

{'temp': {0: 0.291304, 1: 0.243333, 2: 0.254167, 3: 0.253333, 4: 0.253333, 5: 0.255833, 6: 0.215833}, 'holiday': {0: 1, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}, 'workingday': {0: 0, 1: 1, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1}, 'windsp