# Tool Wear Detection: Batch Prediction

In this notebook, you will use the tool wear classification model, trained in [1_tool_wear_train.ipynb](./1_tool_wear_train.ipynb), to batch predict the CNC milling machine telemetries in Manufacturing Data Engine.

## Before you begin

### Set your project ID

If you don't know your project ID, you may be able to get your project ID using gcloud.

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

In [None]:
if PROJECT_ID == "" or PROJECT_ID is None or PROJECT_ID == "[your-project-id]":
    # Get your GCP project id from gcloud
    shell_output = ! gcloud config list --format 'value(core.project)' 2>/dev/null
    PROJECT_ID = shell_output[0]
    print("Project ID:", PROJECT_ID)

## Import libraries

In [None]:
from datetime import datetime

from google.cloud import aiplatform as vertex_ai
from google.cloud import bigquery
from google.cloud.aiplatform import Model
from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

## Initialize Vertex AI and BigQuery clients

In [None]:
vertex_ai.init(project=PROJECT_ID)

bq_client = bigquery.Client(project=PROJECT_ID)

## Explore MDE BigQuery tables

In MDE deployment with default configurations, MDE creates a BigQuery dataset called `sfp_data`. In it, there are 6 tables described below:

| Table Name           | Partition By        | Description                                                                      |
|----------------------|---------------------|----------------------------------------------------------------------------------|
| ComponentDataSeries  | eventTimestamp      | N/A; not implemented                                                             |
| ContinuousDataSeries | eventTimestampStart | Time series of consecutive states defined by event data, start time and end time |
| DiscreteDataSeries   | eventTimestamp      | Time series of event data associated with a single timestamp                     |
| NumericDataSeries    | eventTimestamp      | Time series of numerical event data associated with a single timestamp           |
| InsertErrors         | insertTimestamp     | Dead letter queue of messages that fail to be ingested into BigQuery             |
| OperationsDashboard  | eventTimestamp      | Operational logs of MDE data pipelines                                           |

For classifying CNC milling machine condition, you will only use the `DiscreteDataSeries` and `NumericDataSeries` tables since the dataset contains non-numerical and numerical telemetries.

In this section, you will use the [BigQuery Python client](https://cloud.google.com/bigquery/docs/reference/libraries#client-libraries-install-python) to explore the basics of MDE BigQuery tables. 

In [None]:
# Explore the BigQuery tables in MDE
query = f"""
SELECT
    *
FROM
    `{PROJECT_ID}.sfp_data.INFORMATION_SCHEMA.TABLES`;
"""

results = bq_client.query(query)

results.to_dataframe()

In [None]:
# Query for 3 numeric messages from today
query = f"""
SELECT
    *
FROM
    `{PROJECT_ID}.sfp_data.NumericDataSeries`
WHERE
  DATE(eventTimestamp) = CURRENT_DATE()
ORDER BY
  eventTimestamp DESC
LIMIT 3;
"""

results = bq_client.query(query)

results.to_dataframe()

In [None]:
# Query for 3 discrete messages from today
query = f"""
SELECT
    *
FROM
    `{PROJECT_ID}.sfp_data.DiscreteDataSeries`
WHERE
  DATE(eventTimestamp) = CURRENT_DATE()
ORDER BY
  eventTimestamp DESC
LIMIT 3;
"""

results = bq_client.query(query)

results.to_dataframe()

> Note: `DiscreteDataSeries` and `NumericDataSeries` contains many similar fields (like `tagName`, `eventTimestamp`, `payload`, and others). This is by design since the data series share many properties in common. Some of the fields are described below:

| Field | Description |
|-------|-------------|
| tagName  | Unique label for a single time series in the Cloud |
| edgeTagName | Time series label used by edge appliance |
| eventTimestamp | Timestamp when event is captured |
| payload | Message body |
| payloadQualifier | Dynamic descriptors for time series (e.g. unit of metrics) |
| metadata | Slowly changing contextual data enriched by data pipelines |

In [None]:
# Get all CNC milling machine numeric tags
query = f"""
SELECT
  DISTINCT tagName 
FROM 
  `{PROJECT_ID}.sfp_data.NumericDataSeries`
WHERE
  DATE(eventTimestamp) = CURRENT_DATE()
  AND STARTS_WITH(tagName, 'cncmilling_')
ORDER BY tagName;
"""

results = bq_client.query(query)

results.to_dataframe()

In [None]:
# Get all CNC milling machine discrete tags
query = f"""
SELECT
  DISTINCT tagName 
FROM 
  `{PROJECT_ID}.sfp_data.DiscreteDataSeries`
WHERE
  DATE(eventTimestamp) = CURRENT_DATE()
  AND STARTS_WITH(tagName, 'cncmilling_')
ORDER BY tagName;
"""

results = bq_client.query(query)

results.to_dataframe()

> Observe the mapping of tag names to the feature names in the [CNC milling machine dataset](https://www.kaggle.com/datasets/shasun/tool-wear-detection-in-cnc-mill). For example, `cncmilling_m1_current_feedrate` is mapped to `M1_CURRENT_FEEDRATE` and `cncmilling_z1_outputvoltage` is mapped to `Z1_OutputVoltage`.

## Create prediction features

From the data exploration of MDE BigQuery, you have the following observations:

1. Each prediction feature is stored as seperate time series tags in rows.
1. Each tag can be stored in `DiscreteDataSeries` or `NumericalDataSeries` tables.
1. The tag names are different from the feature names.
1. The tag names follow a naming convention and can map to feature names.
1. Some features are stored as metadata for tags (e.g. `material`).

In order to use the MDE data for tool wear classification, you will pivot the `DiscreteDataSeries` and `NumericalDataSeries` tables with required tags. Also, you will extract features from the tag metadata. You will assemble the pivoted tables and extracted metadata into a BigQuery view for batch prediction.

In this section, you will:

1. Create a BigQuery dataset for storing ML-related artifacts.
1. Create a BigQuery table to store mapping between the tag names and feature names.
1. Create a BigQuery stored procedure to dynamically generate a BigQuery view that contains unpredicted features.
1. Call the BigQuery stored procedure to generate prediction feature view.

In [None]:
query = f"""
CREATE SCHEMA IF NOT EXISTS `{PROJECT_ID}.ml`
  OPTIONS (
    description = 'Dataset for storing machine learning artifacts',
    location = 'us'
  );
"""

results = bq_client.query(query)

results.to_dataframe()

In [None]:
query = f"""
CREATE OR REPLACE TABLE `{PROJECT_ID}.ml.tool_wear_tags` (
  tagName STRING,
  featureName STRING
) OPTIONS (
    description = 'Mapping between tags and feature for tool wear classification'
);
"""

results = bq_client.query(query)

results.to_dataframe()

The `tool_wear_tags` table is initialized with the tag-to-feature mapping.

In [None]:
query = f"""
INSERT `{PROJECT_ID}.ml.tool_wear_tags` (tagName, featureName)
VALUES 
('cncmilling_clamp_pressure', 'clamp_pressure')
, ('cncmilling_feedrate', 'feedrate')
, ('cncmilling_m1_current_feedrate', 'M1_CURRENT_FEEDRATE')
, ('cncmilling_m1_current_program_number', 'M1_CURRENT_PROGRAM_NUMBER')
, ('cncmilling_m1_sequence_number', 'M1_sequence_number')
, ('cncmilling_machining_process', 'Machining_Process')
, ('cncmilling_s1_accelerationdiff', 'S1_AccelerationDiff')
, ('cncmilling_s1_actualacceleration', 'S1_ActualAcceleration')
, ('cncmilling_s1_actualposition', 'S1_ActualPosition')
, ('cncmilling_s1_actualvelocity', 'S1_ActualVelocity')
, ('cncmilling_s1_commandacceleration', 'S1_CommandAcceleration')
, ('cncmilling_s1_commandposition', 'S1_CommandPosition')
, ('cncmilling_s1_commandvelocity', 'S1_CommandVelocity')
, ('cncmilling_s1_currentfeedback', 'S1_CurrentFeedback')
, ('cncmilling_s1_dcbusvoltage', 'S1_DCBusVoltage')
, ('cncmilling_s1_outputcurrent', 'S1_OutputCurrent')
, ('cncmilling_s1_outputpower', 'S1_OutputPower')
, ('cncmilling_s1_outputvoltage', 'S1_OutputVoltage')
, ('cncmilling_s1_positiondiff', 'S1_PositionDiff')
, ('cncmilling_s1_systeminertia', 'S1_SystemInertia')
, ('cncmilling_s1_velocitydiff', 'S1_VelocityDiff')
, ('cncmilling_x1_accelerationdiff', 'X1_AccelerationDiff')
, ('cncmilling_x1_actualacceleration', 'X1_ActualAcceleration')
, ('cncmilling_x1_actualposition', 'X1_ActualPosition')
, ('cncmilling_x1_actualvelocity', 'X1_ActualVelocity')
, ('cncmilling_x1_commandacceleration', 'X1_CommandAcceleration')
, ('cncmilling_x1_commandposition', 'X1_CommandPosition')
, ('cncmilling_x1_commandvelocity', 'X1_CommandVelocity')
, ('cncmilling_x1_currentfeedback', 'X1_CurrentFeedback')
, ('cncmilling_x1_dcbusvoltage', 'X1_DCBusVoltage')
, ('cncmilling_x1_outputcurrent', 'X1_OutputCurrent')
, ('cncmilling_x1_outputpower', 'X1_OutputPower')
, ('cncmilling_x1_outputvoltage', 'X1_OutputVoltage')
, ('cncmilling_x1_positiondiff', 'X1_PositionDiff')
, ('cncmilling_x1_velocitydiff', 'X1_VelocityDiff')
, ('cncmilling_y1_accelerationdiff', 'Y1_AccelerationDiff')
, ('cncmilling_y1_actualacceleration', 'Y1_ActualAcceleration')
, ('cncmilling_y1_actualposition', 'Y1_ActualPosition')
, ('cncmilling_y1_actualvelocity', 'Y1_ActualVelocity')
, ('cncmilling_y1_commandacceleration', 'Y1_CommandAcceleration')
, ('cncmilling_y1_commandposition', 'Y1_CommandPosition')
, ('cncmilling_y1_commandvelocity', 'Y1_CommandVelocity')
, ('cncmilling_y1_currentfeedback', 'Y1_CurrentFeedback')
, ('cncmilling_y1_dcbusvoltage', 'Y1_DCBusVoltage')
, ('cncmilling_y1_outputcurrent', 'Y1_OutputCurrent')
, ('cncmilling_y1_outputpower', 'Y1_OutputPower')
, ('cncmilling_y1_outputvoltage', 'Y1_OutputVoltage')
, ('cncmilling_y1_positiondiff', 'Y1_PositionDiff')
, ('cncmilling_y1_velocitydiff', 'Y1_VelocityDiff')
, ('cncmilling_z1_accelerationdiff', 'Z1_AccelerationDiff')
, ('cncmilling_z1_actualacceleration', 'Z1_ActualAcceleration')
, ('cncmilling_z1_actualposition', 'Z1_ActualPosition')
, ('cncmilling_z1_actualvelocity', 'Z1_ActualVelocity')
, ('cncmilling_z1_commandacceleration', 'Z1_CommandAcceleration')
, ('cncmilling_z1_commandposition', 'Z1_CommandPosition')
, ('cncmilling_z1_commandvelocity', 'Z1_CommandVelocity')
, ('cncmilling_z1_currentfeedback', 'Z1_CurrentFeedback')
, ('cncmilling_z1_dcbusvoltage', 'Z1_DCBusVoltage')
, ('cncmilling_z1_outputcurrent', 'Z1_OutputCurrent')
, ('cncmilling_z1_outputvoltage', 'Z1_OutputVoltage')
, ('cncmilling_z1_positiondiff', 'Z1_PositionDiff')
, ('cncmilling_z1_velocitydiff', 'Z1_VelocityDiff')
, ('cncmilling_tool_condition', 'tool_condition')
, ('cncmilling_material', 'material');
"""

results = bq_client.query(query)

results.to_dataframe()

In [None]:
query = f"""
SELECT
  *
FROM
  `{PROJECT_ID}.ml.tool_wear_tags`
ORDER BY
  tagName 
LIMIT 5;
"""

results = bq_client.query(query)

results.to_dataframe()

In [None]:
query = f'''
CREATE OR REPLACE PROCEDURE `{PROJECT_ID}.ml.create_prediction_features_view`(view_name STRING)
BEGIN
    DECLARE select_expr STRING;
    DECLARE pivot_expr STRING;
    DECLARE query STRING;
    DECLARE last_prediction_timestamp DEFAULT (
      SELECT
        IFNULL(
          MAX(eventTimestamp),
          TIMESTAMP_SECONDS(0)
        ) AS last_prediction_timestamp
      FROM
        `sfp_data.DiscreteDataSeries`
      WHERE
        DATE(eventTimestamp) >= DATE_FROM_UNIX_DATE(0)
        AND tagName = 'tool_wear_predictions'
    );
    
    -- Dynamically create select expressions
    SET select_expr = (
      WITH SelectExpr AS (
        SELECT 
          CASE
            WHEN EXISTS (SELECT DISTINCT tagName FROM `sfp_data.DiscreteDataSeries` d WHERE d.eventTimestamp > last_prediction_timestamp AND d.tagName = t.tagName)
              THEN FORMAT(', JSON_EXTRACT_SCALAR(payload_%s, "$") AS %s', t.tagName, t.featureName)
            WHEN EXISTS (SELECT DISTINCT tagName FROM `sfp_data.NumericDataSeries` n WHERE n.eventTimestamp > last_prediction_timestamp AND n.tagName = t.tagName)
              THEN FORMAT(', JSON_EXTRACT_SCALAR(payload_%s, "$.value") AS %s', t.tagName, t.featureName)
            ELSE ''
          END AS expr
        FROM
          `ml.tool_wear_tags` t
      )
      SELECT 
        STRING_AGG(SelectExpr.expr, ' ')
      FROM 
        SelectExpr
    );
    
    -- Dynamically create pivot expression
    SET pivot_expr = (
      SELECT
        STRING_AGG(DISTINCT CONCAT("'", tagName, "'"))
      FROM
        `ml.tool_wear_tags`
    );
    
    -- Format query using dynamic select and pivot expressions
    SET query = (
      SELECT FORMAT(r"""
        CREATE OR REPLACE VIEW %s AS
        WITH metadata AS (
            SELECT
              d.eventTimestamp
              , CASE
                  WHEN (REGEXP_CONTAINS(mkv.schemaIdentifier, r'^(cncmill/\d+/local/cncmill/\d+)$') AND mkv.key = 'material')
                    THEN 'cncmilling_material'
                  ELSE NULL
                END AS tagName
              , TO_JSON_STRING(mkv.value) AS payload
            FROM
              `sfp_data.DiscreteDataSeries` d, d.metadataKV mkv
            WHERE
              TIMESTAMP_TRUNC(d.eventTimestamp, SECOND) > "%t"
              AND d.tagName = 'cncmilling_tool_condition'
              AND ARRAY_LENGTH(metadataKV) != 0
        )
        SELECT 
          eventTimestamp
          %s
          , JSON_EXTRACT_SCALAR(payload_cncmilling_material, "$") AS material
        FROM 
        (
          SELECT
            TIMESTAMP_TRUNC(n.eventTimestamp, SECOND) AS eventTimestamp
            , n.tagName
            , TO_JSON_STRING(n.payload) AS payload
          FROM
            `sfp_data.NumericDataSeries` n
          INNER JOIN
            `ml.tool_wear_tags` ntag
          ON
            n.tagName = ntag.tagName
          WHERE
            TIMESTAMP_TRUNC(n.eventTimestamp, SECOND) > "%t"
          
          UNION ALL
          
          SELECT
            TIMESTAMP_TRUNC(d.eventTimestamp, SECOND) AS eventTimestamp
            , d.tagName
            , TO_JSON_STRING(d.payload) AS payload
          FROM
            `sfp_data.DiscreteDataSeries` d
          INNER JOIN
            `ml.tool_wear_tags` dtag
          ON
            d.tagName = dtag.tagName
          WHERE
            TIMESTAMP_TRUNC(d.eventTimestamp, SECOND) > "%t"
          
          UNION ALL
          
          SELECT
            TIMESTAMP_TRUNC(metadata.eventTimestamp, SECOND) AS eventTimestamp
            , metadata.tagName
            , TO_JSON_STRING(metadata.payload) AS payload
          FROM
            metadata
          INNER JOIN
            `ml.tool_wear_tags` mtag
          ON
            metadata.tagName = mtag.tagName
        )
        PIVOT
        (
          ANY_VALUE(payload) as payload
          FOR tagName
          IN (
            %s
          )
        )
    """,
        view_name,
        last_prediction_timestamp, 
        select_expr, 
        last_prediction_timestamp, 
        last_prediction_timestamp, 
        pivot_expr)
    );

    -- Execute query
    EXECUTE IMMEDIATE query;
END
'''

results = bq_client.query(query)

results.to_dataframe()

In [None]:
current_time_string = datetime.now().strftime("%Y%m%d_%H%M%S")
view_name = f"ml.prediction_features_{current_time_string}"
bq_params = {"view_name": view_name}

view_name

In [None]:
query = f"""CALL `{PROJECT_ID}.ml.create_prediction_features_view`('{view_name}');"""

results = bq_client.query(query)

results.to_dataframe()

You can see the view query by querying the `ml.INFORMATION_SCHEMA.VIEWS` table.

In [None]:
query = f"""
SELECT
  * 
FROM 
  `{PROJECT_ID}.ml.INFORMATION_SCHEMA.VIEWS`
WHERE
  table_name = '{view_name.split('.')[1]}';
"""

results = bq_client.query(query)
df = results.to_dataframe()

print(df["view_definition"].values[0])

In [None]:
query = f"""
SELECT *
FROM `{PROJECT_ID}.{view_name}`
ORDER BY eventTimestamp DESC
LIMIT 5;
"""

results = bq_client.query(query)

results.to_dataframe()

> Note: the prediction features are arranged in the expected format where features are represented as columns and each row contains a collection of features. 

## Trigger Batch Prediction

After the prediction features are curated, you will trigger a [batch prediction](https://cloud.google.com/vertex-ai/docs/tabular-data/classification-regression/get-batch-predictions) using the previously trained tool wear classification AutoML model.

In this section, you will create a batch prediction job to predicts from BigQuery view and output prediction results in MDE Cloud Storage bucket for batch ingestion. 

### Initialize variables

In [None]:
LOCATION = "us-central1"
API_ENDPOINT = f"{LOCATION}-aiplatform.googleapis.com"
CLIENT_OPTIONS = {"api_endpoint": API_ENDPOINT}
PARENT = f"projects/{PROJECT_ID}/locations/{LOCATION}"

BQ_ML_DATASET = "ml"
NAME_PREFIX = "tool_wear"
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

### Get last updated Vertex AI model

In [None]:
MODEL_NAME = "[your-model-name]"  # Format: 'projects/{project}/locations/{location}/models/{model_id}'

if MODEL_NAME == "" or MODEL_NAME is None or MODEL_NAME == "[your-model-name]":
    # Get latest updated model with prefix
    models = Model.list(
        order_by="updateTime desc",
    )

    MODEL_NAME = None
    for m in models:
        if m.display_name.startswith(NAME_PREFIX):
            MODEL_NAME = m.resource_name
            print(f"Vertex AI model found: {m.display_name}")
            break

    if MODEL_NAME is None:
        print(
            f'Vertex AI model with prefix "{NAME_PREFIX}" not found.'
            f"Please search using a new prefix or locate the model resource name manually."
        )

### Get last updated Vertex AI dataset

In [None]:
TRAINING_DATASET_NAME = "[your-dataset-name]"  # Format: 'projects/{project}/locations/{location}/datasets/{dataset_id}'

if (
    TRAINING_DATASET_NAME == ""
    or TRAINING_DATASET_NAME is None
    or TRAINING_DATASET_NAME == "[your-dataset-name]"
):
    # Get latest updated dataset with prefix
    client = vertex_ai.gapic.DatasetServiceClient(client_options=CLIENT_OPTIONS)
    datasets = client.list_datasets(
        request={
            "parent": PARENT,
            "order_by": "update_time desc",
        }
    )

    TRAINING_DATASET_NAME = None
    for d in datasets:
        if d.display_name.startswith(NAME_PREFIX):
            TRAINING_DATASET_NAME = d.name
            print(f"Vertex AI dataset found: {d.display_name}")
            break

    if TRAINING_DATASET_NAME is None:
        print(
            f'Vertex AI dataset with prefix "{NAME_PREFIX}" not found.'
            f"Please search using a new prefix or locate the dataset resource name manually."
        )

### [Create batch prediction job](https://cloud.google.com/vertex-ai/docs/samples/aiplatform-create-batch-prediction-job-sample)

To see all parameters for configuring batch prediction, please see the [Vertex AI Python SDK Reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1beta1.types.BatchPredictionJob).

In [None]:
DISPLAY_NAME = f"tool_wear_{TIMESTAMP}"
model_parameters_dict = {}
MODEL_PARAMETERS = json_format.ParseDict(model_parameters_dict, Value())
INSTANCE_FORMAT = "bigquery"  # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models#Model.FIELDS.supported_input_storage_formats
BQ_SOURCE_URI = (
    f"bq://{PROJECT_ID}.{view_name}"  # Format: bq://projectId.bqDatasetId.bqTableId
)
PREDICTIONS_FORMAT = "jsonl"  # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.models#Model.FIELDS.supported_output_storage_formats
GCS_OUTPUT_URI_PREFIX = f"gs://{PROJECT_ID}-batch-ingestion/tool_wear_predictions/"  # Format: gs://bucket_name/folder_name
MACHINE_TYPE = "n1-standard-2"
MIN_NODES = 1
MAX_NODES = 3

# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = JobServiceClient(client_options=CLIENT_OPTIONS)

batch_prediction_job = {
    "display_name": DISPLAY_NAME,
    "model": MODEL_NAME,
    "model_parameters": MODEL_PARAMETERS,
    "input_config": {
        "instances_format": INSTANCE_FORMAT,
        "bigquery_source": {"input_uri": BQ_SOURCE_URI},
    },
    "output_config": {
        "predictions_format": PREDICTIONS_FORMAT,
        "gcs_destination": {"output_uri_prefix": GCS_OUTPUT_URI_PREFIX},
    },
    "dedicated_resources": {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
        },
        "starting_replica_count": MIN_NODES,
        "max_replica_count": MAX_NODES,
    },
    "generate_explanation": True,
    "explanation_spec": {
        "parameters": {"sampled_shapley_attribution": {"path_count": 3}}
    },
}


response = client.create_batch_prediction_job(
    parent=PARENT, batch_prediction_job=batch_prediction_job
)

print("response:", response.name)

## Explore prediction result

You have to wait for the batch prediction job to complete before you can query for the prediction result. 

In [None]:
query = f"""
SELECT
  *
FROM
  `{PROJECT_ID}.sfp_data.DiscreteDataSeries`
WHERE
  DATE(eventTimestamp) > DATE_SUB(CURRENT_DATE(),INTERVAL 1 DAY)
  AND tagName = 'tool_wear_predictions'
ORDER BY
  eventTimestamp DESC
LIMIT 5;
"""

results = bq_client.query(query)

results.to_dataframe()