# Inference

# Define global variables and imports

In [None]:
import json
import os
import shutil
import time

from pathlib import Path
from src.serving import export
from src import feature_utils

from google.cloud import aiplatform as vertex_ai

In [None]:
PROJECT_ID = 'jk-mlops-dev' # Change to your project.
REGION = 'us-central1'  # Change to your region.
STAGING_BUCKET = 'jk-merlin-dev' # Change to your bucket.

# Build docker image

In [None]:
! docker build -t gcr.io/$PROJECT_ID/hugectr-training:$IMAGE_VERSION -f ../src/Dockerfile.hugectr ../src/

# Define variables for Inference

In [None]:
LOCAL_WORKSPACE = '/home/jupyter/staging'
MODEL_ARTIFACTS_REPOSITORY = f'gs://{STAGING_BUCKET}/models'

In [None]:
MODEL_NAME = 'deepfm'
MODEL_VERSION = 'v01'
MODEL_DISPLAY_NAME = f'criteo-hugectr-{MODEL_NAME}-{MODEL_VERSION}'
MODEL_DESCRIPTION = 'HugeCTR DeepFM model'
ENDPOINT_DISPLAY_NAME = f'hugectr-{MODEL_NAME}-{MODEL_VERSION}'

In [None]:
WORKFLOW_MODEL_PATH = "gs://criteo-datasets/criteo_processed_parquet/workflow" # Change to GCS path of the nvt workflow.
HUGECTR_MODEL_PATH = "gs://merlin-models/hugectr_deepfm_21.09" # Change to GCS path of the hugectr trained model.

In [None]:
IMAGE_NAME = 'triton-deploy-hugectr'
IMAGE_URI = f"gcr.io/{PROJECT_ID}/{IMAGE_NAME}"
DOCKERFILE = 'src/Dockerfile.triton'

In [None]:
WORKFLOW_MODEL_PATH = "gs://criteo-datasets/criteo_processed_parquet/workflow" # Change to GCS path of the nvt workflow.
HUGECTR_MODEL_PATH = "gs://merlin-models/hugectr_deepfm_21.09" # Change to GCS path of the hugectr trained model.

# Export Triton Ensamble model

In [None]:
if os.path.isdir(LOCAL_WORKSPACE):
    shutil.rmtree(LOCAL_WORKSPACE)
os.makedirs(LOCAL_WORKSPACE)

!gsutil -m cp -r {WORKFLOW_MODEL_PATH} {LOCAL_WORKSPACE}
!gsutil -m cp -r {HUGECTR_MODEL_PATH} {LOCAL_WORKSPACE}

In [None]:
NUM_SLOTS = 26
MAX_NNZ = 2
EMBEDDING_VECTOR_SIZE = 11
MAX_BATCH_SIZE = 64

continuous_columns = feature_utils.continuous_columns()
categorical_columns = feature_utils.categorical_columns()
label_columns = feature_utils.label_columns()
num_outputs = len(label_columns)

local_workflow_path = Path(LOCAL_WORKSPACE) / Path(WORKFLOW_MODEL_PATH).parts[-1]
local_saved_model_path = Path(LOCAL_WORKSPACE) / Path(HUGECTR_MODEL_PATH).parts[-1]
local_ensemble_path = Path(LOCAL_WORKSPACE) / f'triton-ensemble-{time.strftime("%Y%m%d%H%M%S")}'
model_repository_path = '/models'

In [None]:
export.export_ensemble(
    model_name=MODEL_NAME,
    workflow_path=local_workflow_path,
    saved_model_path=local_saved_model_path,
    output_path=local_ensemble_path,
    categorical_columns=categorical_columns,
    continuous_columns=continuous_columns,
    label_columns=label_columns,
    num_slots=NUM_SLOTS,
    max_nnz=MAX_NNZ,
    num_outputs=num_outputs,
    embedding_vector_size=EMBEDDING_VECTOR_SIZE,
    max_batch_size=MAX_BATCH_SIZE,
    model_repository_path=model_repository_path
    )