In [1]:
PROJECT_ID      = ""
PROJECT_REGION  = ""

GCS_BUCKET_NAME = ""

VERTEX_DATASET_NAME    = ""
VERTEX_MODEL_NAME      = ""
VERTEX_PREDICTION_NAME = ""

BQ_DATASET_NAME  = ""
BQ_TRAIN_TABLE   = ""
BQ_PREDICT_TABLE = ""

In [61]:
PROJECT_ID = "wb-ai-acltr-tbs-3-pr-a62583"
GCS_BUCKET_NAME = "bkt_b2b_wf_prediction"
PROJECT_REGION = "northamerica-northeast1"

VERTEX_DATASET_NAME = "b2b_wf_prediction_panorama"
VERTEX_MODEL_NAME = "b2b_wf_prediction_panorama"
VERTEX_PREDICTION_NAME = "b2b_wf_prediction_batch"

BQ_DATASET_NAME = "b2b_wf_prediction"
BQ_TRAIN_TABLE = "vw_wf_experiment_historical"
BQ_PREDICT_TABLE = "bq_wf_temp_predictions"

TRAIN_TEST_DATA_SPLIT = "DATE('2024-07-01')"

In [62]:
import google.cloud.aiplatform as aiplatform
from google.cloud import bigquery
import datetime

TRAINING_DATASET_BQ_PATH   = f"bq://{PROJECT_ID}.{BQ_DATASET_NAME}.{BQ_TRAIN_TABLE}"
PREDICTION_DATASET_BQ_PATH = f"bq://{PROJECT_ID}.{BQ_DATASET_NAME}.{BQ_PREDICT_TABLE}"
PREDICTION_OUTPUT_PREFIX   = f"bq://{PROJECT_ID}.{BQ_DATASET_NAME}"
BUCKET_URI = f"gs://{PROJECT_ID}_{GCS_BUCKET_NAME}"

In [63]:
aiplatform.init(
    project=PROJECT_ID, 
    staging_bucket=BUCKET_URI,
    location=PROJECT_REGION
)

In [64]:
client = bigquery.Client(
    project=PROJECT_ID, 
    location=PROJECT_REGION
)



In [65]:
from dataclasses import dataclass

@dataclass(frozen=True)
class Experiment:
    name: str
    experiment_columns: list[str]
    group_by_columns: list[str]
    objective: str
    forecast_horizon: int
    context_window: int
    data_granularity_unit: str
    

In [66]:
daily_forecast_experiment = Experiment(
    name="panorama_daily_forecast",
    experiment_columns=[
        "Appointment_Day",
        "District",
        "Region_Type",
        "Product",
        "Product_Grp",
        "Technology",
        "Work_Order_Action",
        "Work_Order_Action_Grp",
        "Work_Force",
        "SWT"
    ],
    group_by_columns=[],
    objective="minimize-rmse",
    forecast_horizon=184,
    context_window=552,
    data_granularity_unit='day'
)

## Select experiment

In [67]:
running_experiment = daily_forecast_experiment

## Create train data view

In [68]:
def create_series_identifier(columns):
    coalesce_parts = [f"COALESCE({column}, 'None')" for column in columns]
    separator = "' '"
    return f"CONCAT({f', {separator}, '.join(coalesce_parts)}) as Series_Identifier"

In [69]:
query_columns = ','.join(running_experiment.experiment_columns)


groupy_by = ''
if len(running_experiment.group_by_columns):
  groupy_by = "GROUP BY " + ','.join(running_experiment.group_by_columns)



time_column                   = "Appointment_Day"
time_series_identifier_column = "Series_Identifier"
target_column                 = "SWT"

ATTRIBUTE_COLUMNS = running_experiment.experiment_columns.copy()
ATTRIBUTE_COLUMNS.remove(time_column)
ATTRIBUTE_COLUMNS.remove(target_column)


COLUMN_SPECS = {
    time_column:             "timestamp",
    target_column:           "numeric"
}

for category in ATTRIBUTE_COLUMNS:
    COLUMN_SPECS[category] = "categorical"

In [70]:
experiment_data_cte = f"""
WITH historical_table AS (
  SELECT 
    {query_columns}
  FROM `b2b_wf_prediction.vw_wf_historical`
  WHERE Appointment_Day < {TRAIN_TEST_DATA_SPLIT}
  {groupy_by}
)"""


experiment_train_data_query = f"""
CREATE OR REPLACE VIEW `b2b_wf_prediction.{BQ_TRAIN_TABLE}` AS 
{experiment_data_cte}
SELECT 
  {create_series_identifier(ATTRIBUTE_COLUMNS)},
  {query_columns}
FROM historical_table
"""

VERTEX_DATASET_NAME += f"_{running_experiment.name}"

In [72]:
client.query_and_wait(experiment_train_data_query)

<google.cloud.bigquery.table.RowIterator at 0x7fffa7d853f0>

In [74]:
dataset_list = aiplatform.TimeSeriesDataset.list(
    filter=f"display_name={VERTEX_DATASET_NAME}"
)

if len(dataset_list) == 0:
    print("... creating new dataset ... ")
    dataset = aiplatform.TimeSeriesDataset.create(
        display_name=VERTEX_DATASET_NAME,
        bq_source=[TRAINING_DATASET_BQ_PATH],
    )
else:
    print("... using existent dataset ... ")
    dataset = dataset_list[0]

... using existent dataset ... 


In [73]:
model_list = aiplatform.Model.list(
    filter=f"display_name={VERTEX_MODEL_NAME}"
)

if len(model_list) == 0:
    print("... training a new model ... ")
    parent_model = None
else:
    print("... using existent model ... ")
    model = model_list[0]
    print(model)
    parent_model = model.resource_name

... training a new model ... 


In [12]:
training_job = aiplatform.AutoMLForecastingTrainingJob(
    display_name=VERTEX_MODEL_NAME,
    optimization_objective=running_experiment.objective,
    column_specs=COLUMN_SPECS,
)

In [None]:
model = training_job.run(
    dataset=dataset,
    target_column=target_column,
    time_column=time_column,
    time_series_identifier_column=time_series_identifier_column,
    available_at_forecast_columns=[time_column],
    unavailable_at_forecast_columns=[target_column],
    time_series_attribute_columns=ATTRIBUTE_COLUMNS,
    forecast_horizon=running_experiment.forecast_horizon,
    context_window=running_experiment.context_window,
    data_granularity_unit=running_experiment.data_granularity_unit,
    data_granularity_count=1,
    weight_column=None,
    budget_milli_node_hours=1000,
    parent_model = parent_model,
    model_display_name=VERTEX_MODEL_NAME,
    is_default_version = True,
    model_version_description = f"Model generated on {datetime.date.today().isoformat()}",
    predefined_split_column_name=None,
)

In [None]:
# TODO: Create the temp table for the predictions

In [None]:
batch_prediction_job = model.batch_predict(
    job_display_name=VERTEX_PREDICTION_NAME,
    bigquery_source=PREDICTION_DATASET_BQ_PATH,
    instances_format="bigquery",
    bigquery_destination_prefix=PREDICTION_OUTPUT_PREFIX,
    predictions_format="bigquery",
    generate_explanation=True,
    sync=True,
)

In [None]:
# TODO: Clean up the temp predictions table

In [None]:


batch_table  = batch_prediction_job.output_info.bigquery_output_table

# TODO: manually fill the fields based on the fields on the dataset table
query_job = client.query(
    f"""
        QUERY
    """,
)

# TODO: delete the batch table

query_job.result(timeout=3600)

In [None]:
# TODO: calculate the metrics

In [None]:
# TODO: save the metrics on the bq_wf_evaluation