# MLOps with Spark MLLib & Vertex AI Pipelines
In this notebook, we create a Vertex AI pipeline for MLOps with Spark MLLib powered by Dataproc Serverless Spark

### One time setup

In [50]:
"""
!pip3 install --user --upgrade google-cloud-aiplatform==1.11.0 kfp==1.8.11 google-cloud-pipeline-components==1.0.1 --quiet --no-warn-conflicts

# Automatically restart kernel after installs

if not os.getenv("IS_TESTING"):
    # Automatically restart kernel after installs
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)
"""

'\n!pip3 install --user --upgrade google-cloud-aiplatform==1.11.0 kfp==1.8.11 google-cloud-pipeline-components==1.0.1 --quiet --no-warn-conflicts\n\n# Automatically restart kernel after installs\n\nif not os.getenv("IS_TESTING"):\n    # Automatically restart kernel after installs\n    import IPython\n\n    app = IPython.Application.instance()\n    app.kernel.do_shutdown(True)\n'

### 1. Setup

In [51]:
import random
from pathlib import Path as path
from typing import NamedTuple
import os

from google.cloud import aiplatform as vertex_ai
from google_cloud_pipeline_components import aiplatform as vertex_ai_components
from kfp.v2 import compiler, dsl
from kfp.v2.dsl import (Artifact, ClassificationMetrics, Condition, Input,
                        Metrics, Output, component)

#### a. Project specifics

In [52]:
import os

PROJECT_ID = ""
PROJECT_NBR = ""
UNIQUE_ID = random.randint(1, 10000)
WITHOUT_TASK_CACHING = True
BYO_NETWORK = True

# Get your Google Cloud project ID from gcloud
if not os.getenv("IS_TESTING"):
    project_id_output = !gcloud config list --format 'value(core.project)' 2>/dev/null
    PROJECT_ID = project_id_output[0]
    print("Project ID: ", PROJECT_ID)
    
    
    project_nbr_output = !gcloud projects describe $PROJECT_ID --format='value(projectNumber)'
    PROJECT_NBR = project_nbr_output[0]
    print("Project Number: ", PROJECT_NBR)
    
umsa_output = !gcloud config list account --format "value(core.account)"
UMSA_FQN = umsa_output[0]
print("UMSA FQN: ", UMSA_FQN)
print("UNIQUE ID: ", UNIQUE_ID)

!gcloud config set project $PROJECT_ID

Project ID:  spark-s8s-mlops
Project Number:  505815944775
UMSA FQN:  s8s-lab-sa@spark-s8s-mlops.iam.gserviceaccount.com
UNIQUE ID:  9922
Updated property [core/project].


#### b. Local resources

In [53]:
APP_BASE_NM = "customer-churn-model"

In [54]:
LOCAL_SCRATCH_DIR = path(f"/home/jupyter/scratch/{APP_BASE_NM}/")

In [55]:
!mkdir -m 777 -p $LOCAL_SCRATCH_DIR

In [56]:
!ls -al $LOCAL_SCRATCH_DIR

total 88
drwxrwxrwx 2 jupyter jupyter  4096 Aug 10 05:18 .
drwxr-xr-x 3 jupyter jupyter  4096 Aug 10 05:13 ..
-rw-r--r-- 1 jupyter jupyter 40953 Aug 10 05:13 pipeline_3126.json
-rw-r--r-- 1 jupyter jupyter 40953 Aug 10 05:18 pipeline_3644.json


In [57]:
!cd $LOCAL_SCRATCH_DIR && pwd

/home/jupyter/scratch/customer-churn-model


In [58]:
!ls -al /home/jupyter

total 84
drwxr-xr-x 12 jupyter jupyter  4096 Aug 10 05:35 .
drwxr-xr-x  3 root    root     4096 Aug 10 02:37 ..
drwxr-xr-x  5 jupyter jupyter  4096 Aug 10 05:11 .cache
drwxr-xr-x  4 jupyter jupyter  4096 Aug 10 02:37 .config
drwxr-xr-x  2 jupyter jupyter  4096 Aug 10 02:37 .docker
drwxr-xr-x  2 jupyter jupyter  4096 Aug 10 05:12 .ipynb_checkpoints
drwxr-xr-x  5 jupyter jupyter  4096 Aug 10 02:47 .ipython
drwxr-xr-x  3 jupyter jupyter  4096 Aug 10 02:46 .jupyter
drwxr-xr-x  5 jupyter jupyter  4096 Aug 10 05:12 .local
-rw-r--r--  1 jupyter jupyter 34653 Aug 10 05:35 customer_churn_training_pipeline.ipynb
drwxr-xr-x  3 jupyter jupyter  4096 Aug 10 05:13 scratch
drwxr-xr-x  3 jupyter jupyter  4096 Aug 10 02:37 src
drwxr-xr-x  4 jupyter jupyter  4096 Aug 10 02:37 tutorials


#### d. The pre-created resources

In [59]:
CODE_BUCKET = f"gs://s8s_code_bucket-{PROJECT_NBR}"
DATA_BUCKET = f"gs://s8s_data_bucket-{PROJECT_NBR}"
MODEL_BUCKET = f"gs://s8s_model_bucket-{PROJECT_NBR}"
SCRATCH_BUCKET = f"s8s-spark-bucket-{PROJECT_NBR}"
BQ_DS_NM = f"{PROJECT_ID}.customer_churn_ds"
LOCATION = "us-central1"
VPC_NM = f"s8s-vpc-{PROJECT_NBR}"
SUBNET_RESOURCE_URI = f"projects/{PROJECT_ID}/regions/{LOCATION}/subnetworks/spark-snet"
PERSISTENT_SPARK_HISTORY_SERVER_RESOURCE_URI = f"projects/{PROJECT_ID}/regions/{LOCATION}/clusters/s8s-sphs-{PROJECT_NBR}"
GCR_REPO_NM = f"s8s-spark-{PROJECT_NBR}"
DOCKER_IMAGE_TAG = "1.0.0"
DOCKER_IMAGE_NM = "customer_churn_image"
DOCKER_IMAGE_FQN = f"gcr.io/{PROJECT_ID}/{DOCKER_IMAGE_NM}:{DOCKER_IMAGE_TAG}"

#### e. Pipeline entity specific

In [60]:
PIPELINE_ID = UNIQUE_ID
PIPELINE_NM = f"{APP_BASE_NM}-pipeline"
PIPELINE_PACKAGE_SRC_LOCAL_PATH = f"{LOCAL_SCRATCH_DIR}/pipeline_{PIPELINE_ID}.json"
PIPELINE_ROOT_GCS_URI = f"{MODEL_BUCKET}/{APP_BASE_NM}/pipelines"

print('PIPELINE_ID =',PIPELINE_ID)
print('PIPELINE_NM =',PIPELINE_NM)
print('PIPELINE_PACKAGE_SRC_LOCAL_PATH =',PIPELINE_PACKAGE_SRC_LOCAL_PATH)
print('PIPELINE_ROOT_GCS_URI =',PIPELINE_ROOT_GCS_URI)

PIPELINE_ID = 9922
PIPELINE_NM = customer-churn-model-pipeline
PIPELINE_PACKAGE_SRC_LOCAL_PATH = /home/jupyter/scratch/customer-churn-model/pipeline_9922.json
PIPELINE_ROOT_GCS_URI = gs://s8s_model_bucket-505815944775/customer-churn-model/pipelines


#### d. Pipeline stage agnostic

In [61]:
PY_SCRIPTS_FQP = f"{CODE_BUCKET}/pyspark"
PYSPARK_COMMON_UTILS_SCRIPT_FQP = [f"{PY_SCRIPTS_FQP}/common_utils.py"]

print('PY_SCRIPTS_FQP =',PY_SCRIPTS_FQP)
print('PYSPARK_COMMON_UTILS_SCRIPT_FQP =',PYSPARK_COMMON_UTILS_SCRIPT_FQP)

PY_SCRIPTS_FQP = gs://s8s_code_bucket-505815944775/pyspark
PYSPARK_COMMON_UTILS_SCRIPT_FQP = ['gs://s8s_code_bucket-505815944775/pyspark/common_utils.py']


#### d. Data preprocessing stage specific

In [62]:
DATA_PREPROCESSING_BATCH_PREFIX = "preprocessing"
DATA_PREPROCESSING_BATCH_INSTANCE_ID = f"{APP_BASE_NM}-{DATA_PREPROCESSING_BATCH_PREFIX}-{UNIQUE_ID}"
DATA_PREPROCESSING_MAIN_PY_SCRIPT = f"{PY_SCRIPTS_FQP}/preprocessing.py"

DATA_PROCESSING_SINK = f"{BQ_DS_NM}.training_data"
DATA_PROCESSING_BQ_SINK_URI = f"bq://{DATA_PROCESSING_SINK}"

DATA_PREPROCESSING_ARGS = [f"--pipelineID={UNIQUE_ID}", \
        f"--projectID={PROJECT_ID}", \
        f"--projectNbr={PROJECT_NBR}", 
        f"--displayPrintStatements={True}"]

print('DATA_PREPROCESSING_BATCH_INSTANCE_ID =',DATA_PREPROCESSING_BATCH_INSTANCE_ID)
print('DATA_PREPROCESSING_MAIN_PY_SCRIPT =',DATA_PREPROCESSING_MAIN_PY_SCRIPT)
print('DATA_PROCESSING_SINK =',DATA_PROCESSING_SINK)
print('DATA_PROCESSING_BQ_SINK_URI =',DATA_PROCESSING_BQ_SINK_URI)
print('DATA_PREPROCESSING_ARGS =',DATA_PREPROCESSING_ARGS)

DATA_PREPROCESSING_BATCH_INSTANCE_ID = customer-churn-model-preprocessing-9922
DATA_PREPROCESSING_MAIN_PY_SCRIPT = gs://s8s_code_bucket-505815944775/pyspark/preprocessing.py
DATA_PROCESSING_SINK = spark-s8s-mlops.customer_churn_ds.training_data
DATA_PROCESSING_BQ_SINK_URI = bq://spark-s8s-mlops.customer_churn_ds.training_data
DATA_PREPROCESSING_ARGS = ['--pipelineID=9922', '--projectID=spark-s8s-mlops', '--projectNbr=505815944775', '--displayPrintStatements=True']


#### e. Dataset registration specific

In [63]:
MANAGED_DATASET_NM = f"{APP_BASE_NM}-{UNIQUE_ID}"

#### f. Model specific

In [64]:
MODEL_TRAINING_BATCH_PREFIX = "training"
MODEL_TRAINING_BATCH_INSTANCE_ID = f"{APP_BASE_NM}-{MODEL_TRAINING_BATCH_PREFIX}-{UNIQUE_ID}"
MODEL_TRAINING_MAIN_PY_SCRIPT = f"{PY_SCRIPTS_FQP}/model_training.py"
MODEL_TRAINING_ARGS = [f"--pipelineID={UNIQUE_ID}", \
        f"--projectID={PROJECT_ID}", \
        f"--projectNbr={PROJECT_NBR}", 
        f"--displayPrintStatements={True}"]

MODEL_METRICS_BUCKET_FQP = f"gs://s8s_metrics_bucket-{PROJECT_NBR}/{APP_BASE_NM}/{MODEL_TRAINING_BATCH_PREFIX}/{UNIQUE_ID}/full/metrics.json"

print('MODEL_TRAINING_BATCH_INSTANCE_ID =',MODEL_TRAINING_BATCH_INSTANCE_ID)
print('MODEL_TRAINING_MAIN_PY_SCRIPT =',MODEL_TRAINING_MAIN_PY_SCRIPT)
print('MODEL_TRAINING_ARGS =',MODEL_TRAINING_ARGS)
print('MODEL_METRICS_BUCKET_FQP =',MODEL_METRICS_BUCKET_FQP)

MODEL_TRAINING_BATCH_INSTANCE_ID = customer-churn-model-training-9922
MODEL_TRAINING_MAIN_PY_SCRIPT = gs://s8s_code_bucket-505815944775/pyspark/model_training.py
MODEL_TRAINING_ARGS = ['--pipelineID=9922', '--projectID=spark-s8s-mlops', '--projectNbr=505815944775', '--displayPrintStatements=True']
MODEL_METRICS_BUCKET_FQP = gs://s8s_metrics_bucket-505815944775/customer-churn-model/training/9922/full/metrics.json


#### g. Hyperparameter tuning specific

In [65]:
# Condition
AUPR_THRESHOLD = 0.5
AUPR_HYPERTUNE_CONDITION = "[AUPR_HYPERTUNE]"

HYPERPARAMETER_TUNING_BATCH_PREFIX = "hyperparameter-tuning"
HYPERPARAMETER_TUNING_BATCH_INSTANCE_ID = f"{APP_BASE_NM}-{HYPERPARAMETER_TUNING_BATCH_PREFIX}-{UNIQUE_ID}"
HYPERPARAMETER_TUNING_ARGS = [f"--pipelineID={UNIQUE_ID}", \
        f"--projectID={PROJECT_ID}", \
        f"--projectNbr={PROJECT_NBR}", 
        f"--displayPrintStatements={True}"]

HYPERPARAMETER_TUNING_MAIN_PY_SCRIPT = f"{PY_SCRIPTS_FQP}/hyperparameter_tuning.py"
HYPERPARAMETER_TUNING_BUCKET_FQP = f"gs://s8s_metrics_bucket-{PROJECT_NBR}/{APP_BASE_NM}/{HYPERPARAMETER_TUNING_BATCH_PREFIX}/{UNIQUE_ID}/full/metrics.json"


print('HYPERPARAMETER_TUNING_BATCH_INSTANCE_ID =',HYPERPARAMETER_TUNING_BATCH_INSTANCE_ID)
print('HYPERPARAMETER_TUNING_MAIN_PY_SCRIPT =',HYPERPARAMETER_TUNING_MAIN_PY_SCRIPT)
print('HYPERPARAMETER_TUNING_ARGS =',HYPERPARAMETER_TUNING_ARGS)

HYPERPARAMETER_TUNING_BATCH_INSTANCE_ID = customer-churn-model-hyperparameter-tuning-9922
HYPERPARAMETER_TUNING_MAIN_PY_SCRIPT = gs://s8s_code_bucket-505815944775/pyspark/hyperparameter_tuning.py
HYPERPARAMETER_TUNING_ARGS = ['--pipelineID=9922', '--projectID=spark-s8s-mlops', '--projectNbr=505815944775', '--displayPrintStatements=True']


### 2. Initialize Vertex AI SDK for Python

In [66]:
vertex_ai.init(project=PROJECT_ID, location=LOCATION, staging_bucket=SCRATCH_BUCKET)

### 3. Define custom components

In [67]:
@component(
    base_image="python:3.8",
    packages_to_install=["numpy==1.21.2", "pandas==1.3.3", "scikit-learn==0.24.2"],
)
def fnEvaluateModel(
    metricsUri: str,
    metrics: Output[Metrics],
    plots: Output[ClassificationMetrics],
) -> NamedTuple("Outputs", [("threshold_metric", float)]):

    import json
    import numpy as np
    from sklearn.metrics import confusion_matrix, roc_curve

    # Variables
    metricsGCSMountPath = metricsUri.replace("gs://", "/gcs/")
    labels = ["yes", "no"]

    # Helpers
    def fnCalculateROC(metrics, true, score):
        y_true_np = np.array(metrics[true])
        y_score_np = np.array(metrics[score])
        fpr, tpr, thresholds = roc_curve(
            y_true=y_true_np, y_score=y_score_np, pos_label=True
        )
        return fpr, tpr, thresholds

    def fnCalculateConfusionMatrix(metrics, true, prediction):
        y_true_np = np.array(metrics[true])
        y_pred_np = np.array(metrics[prediction])
        c_matrix = confusion_matrix(y_true_np, y_pred_np)
        return c_matrix

    # Main
    with open(metricsGCSMountPath, mode="r") as json_file:
        metricsDictionary = json.load(json_file)

    area_roc = metricsDictionary["test_area_roc"]
    area_prc = metricsDictionary["test_area_prc"]
    acc = metricsDictionary["test_accuracy"]
    f1 = metricsDictionary["test_f1"]
    prec = metricsDictionary["test_precision"]
    rec = metricsDictionary["test_recall"]

    metrics.log_metric("Test_areaUnderROC", area_roc)
    metrics.log_metric("Test_areaUnderPRC", area_prc)
    metrics.log_metric("Test_Accuracy", acc)
    metrics.log_metric("Test_f1-score", f1)
    metrics.log_metric("Test_Precision", prec)
    metrics.log_metric("Test_Recall", rec)

    fpr, tpr, thresholds = fnCalculateROC(metricsDictionary, "true", "score")
    c_matrix = fnCalculateConfusionMatrix(metricsDictionary, "true", "prediction")
    plots.log_roc_curve(fpr.tolist(), tpr.tolist(), thresholds.tolist())
    plots.log_confusion_matrix(labels, c_matrix.tolist())

    componentOutputsTuple = NamedTuple(
        "Outputs",
        [
            ("threshold_metric", float),
        ],
    )
    return componentOutputsTuple(area_prc)


### 4. Define Vertex AI Pipeline

In [68]:
@dsl.pipeline(
    name=PIPELINE_NM, 
    description="A SparkMLlib MLOps Vertex pipeline")
def fnSparkMlopsPipeline(
    project_id: str = PROJECT_ID,
    location: str = LOCATION,
    service_account: str = UMSA_FQN,
    subnetwork_uri: str = SUBNET_RESOURCE_URI,
    spark_phs_nm: str = PERSISTENT_SPARK_HISTORY_SERVER_RESOURCE_URI,
    container_image: str = DOCKER_IMAGE_FQN,
    common_utils_py_fqn: list = PYSPARK_COMMON_UTILS_SCRIPT_FQP,
    data_preprocessing_pyspark_batch_id: str = DATA_PREPROCESSING_BATCH_INSTANCE_ID,
    data_preprocessing_main_py_fqn: str = DATA_PREPROCESSING_MAIN_PY_SCRIPT,
    data_preprocessing_args: list = DATA_PREPROCESSING_ARGS,
    managed_dataset_display_nm: str = MANAGED_DATASET_NM,
    managed_dataset_src_uri: str = DATA_PROCESSING_BQ_SINK_URI,
    model_training_pyspark_batch_id: str = MODEL_TRAINING_BATCH_INSTANCE_ID,
    model_training_main_py_fqn: str = MODEL_TRAINING_MAIN_PY_SCRIPT,
    model_training_metrics_fqp: str = MODEL_METRICS_BUCKET_FQP,
    model_training_args: list = MODEL_TRAINING_ARGS,
    threshold: float = AUPR_THRESHOLD,
    hyperparameter_tuning_pyspark_batch_id: str = HYPERPARAMETER_TUNING_BATCH_INSTANCE_ID,
    hyperparameter_tuning_main_py_fqn: str = HYPERPARAMETER_TUNING_MAIN_PY_SCRIPT,
    hyperparameter_tuning_args: list = HYPERPARAMETER_TUNING_ARGS,
    hyperparameter_tuning_metrics_fqp: str = MODEL_METRICS_BUCKET_FQP,
):
    from google_cloud_pipeline_components.experimental.dataproc import \
        DataprocPySparkBatchOp

    # Step 1. PRE-PROCESS DATA in PREP FOR MODEL TRAINING
    # ....................................................................
    preprocessingStep = DataprocPySparkBatchOp(
        project = project_id,
        location = location,
        container_image = container_image,
        subnetwork_uri = subnetwork_uri,
        spark_history_dataproc_cluster = spark_phs_nm,
        service_account = service_account,     
        batch_id = data_preprocessing_pyspark_batch_id,
        main_python_file_uri = data_preprocessing_main_py_fqn,
        python_file_uris = common_utils_py_fqn,
        args = data_preprocessing_args
    ).set_display_name("Preprocessing")
    
    
    # Step 2. REGISTER PRE-PROCESSED DATA AS MANAGED DATASET
    # ....................................................................
    createManagedDatasetStep = vertex_ai_components.TabularDatasetCreateOp(
        display_name= managed_dataset_display_nm,
        bq_source=managed_dataset_src_uri,
        project=project_id,
        location=location,
    ).after(preprocessingStep).set_display_name("Dataset registration")
    
    # Step 3. TRAIN MODEL
    # .................................................................... 
    trainSparkMLModelStep = DataprocPySparkBatchOp(
        project = project_id,
        location = location,
        container_image = container_image,
        subnetwork_uri = subnetwork_uri,
        spark_history_dataproc_cluster = spark_phs_nm,
        service_account = service_account,     
        batch_id = model_training_pyspark_batch_id,
        main_python_file_uri = model_training_main_py_fqn,
        python_file_uris = common_utils_py_fqn,
        args = model_training_args
    ).after(preprocessingStep).set_display_name("Model training")
    
    # Step 4. EVALUATE MODEL
    # .................................................................... 
    evaluateModelStep = fnEvaluateModel(model_training_metrics_fqp).after(trainSparkMLModelStep).set_display_name("Evaluate model")
    
    # Step 5. CONDITIONAL HYPERPARAMETER TUNING
    # .................................................................... 
    with Condition(
        evaluateModelStep.outputs["threshold_metric"] >= threshold,
        name="AUPR Threshold Exceeded",
    ):
        # HYPERPARAMETER TUNING
        hyperparameterTuningStep = DataprocPySparkBatchOp(
        project = project_id,
        location = location,
        container_image = container_image,
        subnetwork_uri = subnetwork_uri,
        spark_history_dataproc_cluster = spark_phs_nm,
        service_account = service_account,     
        batch_id = hyperparameter_tuning_pyspark_batch_id,
        main_python_file_uri = hyperparameter_tuning_main_py_fqn,
        python_file_uris = common_utils_py_fqn,
        args = hyperparameter_tuning_args
        ).after(evaluateModelStep).set_display_name("Hyperparameter tuning")

In [69]:
@dsl.pipeline(
    name=PIPELINE_NM, 
    description="A SparkMLlib MLOps Vertex pipeline")
def fnSparkMlopsPipelineWithoutCaching(
    project_id: str = PROJECT_ID,
    location: str = LOCATION,
    service_account: str = UMSA_FQN,
    subnetwork_uri: str = SUBNET_RESOURCE_URI,
    spark_phs_nm: str = PERSISTENT_SPARK_HISTORY_SERVER_RESOURCE_URI,
    container_image: str = DOCKER_IMAGE_FQN,
    common_utils_py_fqn: list = PYSPARK_COMMON_UTILS_SCRIPT_FQP,
    data_preprocessing_pyspark_batch_id: str = DATA_PREPROCESSING_BATCH_INSTANCE_ID,
    data_preprocessing_main_py_fqn: str = DATA_PREPROCESSING_MAIN_PY_SCRIPT,
    data_preprocessing_args: list = DATA_PREPROCESSING_ARGS,
    managed_dataset_display_nm: str = MANAGED_DATASET_NM,
    managed_dataset_src_uri: str = DATA_PROCESSING_BQ_SINK_URI,
    model_training_pyspark_batch_id: str = MODEL_TRAINING_BATCH_INSTANCE_ID,
    model_training_main_py_fqn: str = MODEL_TRAINING_MAIN_PY_SCRIPT,
    model_training_metrics_fqp: str = MODEL_METRICS_BUCKET_FQP,
    model_training_args: list = MODEL_TRAINING_ARGS,
    threshold: float = AUPR_THRESHOLD,
    hyperparameter_tuning_pyspark_batch_id: str = HYPERPARAMETER_TUNING_BATCH_INSTANCE_ID,
    hyperparameter_tuning_main_py_fqn: str = HYPERPARAMETER_TUNING_MAIN_PY_SCRIPT,
    hyperparameter_tuning_args: list = HYPERPARAMETER_TUNING_ARGS,
    hyperparameter_tuning_metrics_fqp: str = MODEL_METRICS_BUCKET_FQP,
):
    from google_cloud_pipeline_components.experimental.dataproc import \
        DataprocPySparkBatchOp

    # Step 1. PRE-PROCESS DATA in PREP FOR MODEL TRAINING
    # ....................................................................
    preprocessingStep = DataprocPySparkBatchOp(
        project = project_id,
        location = location,
        container_image = container_image,
        subnetwork_uri = subnetwork_uri,
        spark_history_dataproc_cluster = spark_phs_nm,
        service_account = service_account,     
        batch_id = data_preprocessing_pyspark_batch_id,
        main_python_file_uri = data_preprocessing_main_py_fqn,
        python_file_uris = common_utils_py_fqn,
        args = data_preprocessing_args
    ).set_caching_options(False).set_display_name("Preprocessing")
    
    
    # Step 2. REGISTER PRE-PROCESSED DATA AS MANAGED DATASET
    # ....................................................................
    createManagedDatasetStep = vertex_ai_components.TabularDatasetCreateOp(
        display_name= managed_dataset_display_nm,
        bq_source=managed_dataset_src_uri,
        project=project_id,
        location=location,
    ).after(preprocessingStep).set_caching_options(False).set_display_name("Dataset registration")
    
    # Step 3. TRAIN MODEL
    # .................................................................... 
    trainSparkMLModelStep = DataprocPySparkBatchOp(
        project = project_id,
        location = location,
        container_image = container_image,
        subnetwork_uri = subnetwork_uri,
        spark_history_dataproc_cluster = spark_phs_nm,
        service_account = service_account,     
        batch_id = model_training_pyspark_batch_id,
        main_python_file_uri = model_training_main_py_fqn,
        python_file_uris = common_utils_py_fqn,
        args = model_training_args
    ).set_caching_options(False).after(preprocessingStep).set_display_name("Model training")
    
    # Step 4. EVALUATE MODEL
    # .................................................................... 
    evaluateModelStep = fnEvaluateModel(model_training_metrics_fqp).after(trainSparkMLModelStep).set_caching_options(False).set_display_name("Evaluate model")
    
    # Step 5. CONDITIONAL HYPERPARAMETER TUNING
    # .................................................................... 
    with Condition(
        evaluateModelStep.outputs["threshold_metric"] >= threshold,
        name="AUPR Threshold Exceeded",
    ):
        # HYPERPARAMETER TUNING
        hyperparameterTuningStep = DataprocPySparkBatchOp(
        project = project_id,
        location = location,
        container_image = container_image,
        subnetwork_uri = subnetwork_uri,
        spark_history_dataproc_cluster = spark_phs_nm,
        service_account = service_account,     
        batch_id = hyperparameter_tuning_pyspark_batch_id,
        main_python_file_uri = hyperparameter_tuning_main_py_fqn,
        python_file_uris = common_utils_py_fqn,
        args = hyperparameter_tuning_args
        ).after(evaluateModelStep).set_caching_options(False).set_display_name("Hyperparameter tuning")


### 4. Compile the Vertex AI Pipeline into a JSON

In [70]:
if WITHOUT_TASK_CACHING:
    compiler.Compiler().compile(pipeline_func=fnSparkMlopsPipelineWithoutCaching, package_path=PIPELINE_PACKAGE_SRC_LOCAL_PATH)
    print("Executing fnSparkMlopsPipelineWithoutCaching")
else:
    compiler.Compiler().compile(pipeline_func=fnSparkMlopsPipeline, package_path=PIPELINE_PACKAGE_SRC_LOCAL_PATH)
    print("Executing fnSparkMlopsPipeline")

Executing fnSparkMlopsPipelineWithoutCaching


### 5. Submit the Pipeline for execution via Vertex AI SDK

In [71]:
pipeline = vertex_ai.PipelineJob(
    display_name=PIPELINE_NM,
    template_path=PIPELINE_PACKAGE_SRC_LOCAL_PATH,
    pipeline_root=PIPELINE_ROOT_GCS_URI,
    enable_caching=False
)

In [72]:
if BYO_NETWORK:
    pipeline.submit(service_account=UMSA_FQN, network=f"projects/{PROJECT_NBR}/global/networks/{VPC_NM}")
else:
    pipeline.submit(service_account=UMSA_FQN)

INFO:google.cloud.aiplatform.pipeline_jobs:Creating PipelineJob
INFO:google.cloud.aiplatform.pipeline_jobs:PipelineJob created. Resource name: projects/505815944775/locations/us-central1/pipelineJobs/customer-churn-model-pipeline-20220810053536
INFO:google.cloud.aiplatform.pipeline_jobs:To use this PipelineJob in another session:
INFO:google.cloud.aiplatform.pipeline_jobs:pipeline_job = aiplatform.PipelineJob.get('projects/505815944775/locations/us-central1/pipelineJobs/customer-churn-model-pipeline-20220810053536')
INFO:google.cloud.aiplatform.pipeline_jobs:View Pipeline Job:
https://console.cloud.google.com/vertex-ai/locations/us-central1/pipelines/runs/customer-churn-model-pipeline-20220810053536?project=505815944775
