# Multimodal Clustering

In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8
   }
}

In [ ]:
azure_storage_domain = ''
batch_num = ''
batch_root = ''
blob_account_name = ''
documents_contents_tbl_name = ''
enriched_images_tbl_name = ''
minted_tables_output_path = ''
model_name = '' # This is the sentence-transformers model from Hugging Face that we will use from the Hugging Face repo - actual name on Hugging Face - "sentence-transformers/all-mpnet-base-v2"
# This notebook depends on the main environment.yml file is updated with this notebook's necessary packages

In [ ]:
import pyodbc
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
display_dataframes = False

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'documents_contents_tbl_name': documents_contents_tbl_name,
    'enriched_images_tbl_name': enriched_images_tbl_name,
    'batch_root': batch_root,
    'batch_num': batch_num,
    'model_name': model_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
import json
import os
import ntpath
import numpy as np
from types import SimpleNamespace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, util
import ray

from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, pandas_udf, lit, col
from pyspark.sql.types import StringType, StructField, StructType, FloatType, IntegerType

# Added to support possible BERTopic PickleError
# https://github.com/MaartenGr/BERTopic/issues/517
import pynndescent
pynndescent.rp_trees.FlatTree.__module__  = "pynndescent.rp_trees"

#batch_path = f'abfss://{manifest_container}@{blob_account_name}.dfs.{azure_storage_domain}'
#batch_folder = os.path.dirname(manifest_file_path)
#batch_root = f'{batch_path}{batch_folder}' 

with tracer.span(name=f"Load config: {mssparkutils.runtime.context['notebookname']}"):
    # Initialise session, create (if necessary) and read batch config
    sc = spark.sparkContext
    spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

    def copy_global_config(config_path: str, global_config_path: str):
        """
        This method makes sure that a config is availabile in the batch root.
        If a config file isn't already there, it is copied over form global_config_path.
        If there is no config under global_config_path, this function will crash (indicating an error in pipeline set up.)
        """
        logger.info("Loading global config")
        try:
            mssparkutils.fs.cp(global_config_path, config_path)    
        except Py4JJavaError as e:
            logger.exception(e)
            raise e

    def read_batch_config(batch_root: str, global_config_path: str):
        """
        We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
        all back together again
        """
        # Change our file system from 'synapse' to 'input'
        sc._jsc.hadoopConfiguration().set("fs.defaultFS", f'abfss://input@{blob_account_name}.dfs.{azure_storage_domain}')

        fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
        config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

        # If we don't have a batch config, copy the global one.
        if fs.exists(config_path) != True:
            copy_global_config(f'{batch_root}/config.json', global_config_path)

        # Open our file directly rather than through spark
        input_stream = fs.open(config_path)  # FSDataInputStream

        config_string = sc._jvm.java.io.BufferedReader(
            sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
            ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

        # # Load it into json    
        return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

    # NOTE: this path should be in sync with Terraform configuration which uploads this file
    config = read_batch_config(batch_root, global_config_path=f'abfss://configuration@{blob_account_name}.dfs.{azure_storage_domain}/config.global.json')

    # Set log level
    if config.log_level == "INFO":
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)
        config.log_level = "ERROR"   

# Model parameters 

In [ ]:
batch_size = int(config.multimodal_clustering.batch_size) # Model dependent - Needs to be added to Synapse global config
num_cpus = int(config.multimodal_clustering.num_cpus) # Resource dependent - Needs to be added to Synapse global config
bert_topic_diversity = config.multimodal_clustering.diversity # BERTopic parameter - Needs to be added to Synapse global config
nr_topics= config.multimodal_clustering.nr_topics # (UNUSED) BERTopic Topic reduction parameter - Needs to be added to Synapse global config

In [ ]:
with tracer.span(name='Read the documents contents table'):
    df = spark.read.parquet(minted_tables_output_path + documents_contents_tbl_name) \
        .select('file_path', col('summarized_text_xsum')) \
        .withColumn('file_type', lit('text')) \
        .withColumnRenamed('summarized_text_xsum', 'summary_text')

    # Convert to Pandas DataFrame
    df_txt_pd = df.toPandas()

    df_txt_pd['summary_text'] = df_txt_pd['summary_text'].apply(lambda x: x[1:-1]) # Removes and trailing chars, specifically '[]'
    df_txt_pd['summary_text'] = df_txt_pd['summary_text'].str.strip() # Removes all leading and trailing     
    filtered_df = df_txt_pd[df_txt_pd['summary_text'].str.len() >= 5] # Filtering out rows where the length of string in summary_text is less than n

In [ ]:
# Create UDFs to extract caption text and clean it up
get_caption = udf(lambda x: x['description']['captions'][0]['text'] if (x != None) and len(x['description']) > 0 and len(x['description']['captions'] ) > 0 else '', StringType())
cap_caption = udf(lambda x: x.capitalize(), StringType())
add_period_caption = udf(lambda x: x+'.' , StringType())

with tracer.span(name='Read the enriched images table'):
    df_img = spark.read.parquet(minted_tables_output_path + enriched_images_tbl_name) \
        .select('path','analysis_results', 'read_results') \
        .withColumn('file_type', lit('images')) \
        .withColumnRenamed('path', 'file_path') \
        .withColumn('summary_text', get_caption(col('analysis_results')))

    # Images that have an empty caption
    df_img_empty = df_img.where(df_img.summary_text == '') \
        .drop(col('analysis_results')) \
        .drop(col('read_results'))

    # Convert to Pandas DataFrame
    df_img_empty_pd = df_img_empty.toPandas()

    df_img = df_img.where(df_img.summary_text != '') \
        .withColumn('summary_text', cap_caption(col('summary_text'))) \
        .withColumn('summary_text', add_period_caption(col('summary_text'))) \
        .drop(col('analysis_results')) \
        .drop(col('read_results'))

    # Convert to Pandas DataFrame
    df_img_pd = df_img.toPandas()

In [ ]:
if display_dataframes:
    print(df_img_pd.head())

In [ ]:
import pandas as pd

with tracer.span(name='Concatenate the text contents and enriched images captions'):
    df_combined_pd = pd.concat([df_txt_pd, df_img_pd]).reset_index(drop=True)

In [ ]:
with tracer.span(name=f'Mount {model_name} sentence-transformers model'):
    try:
        # Read the first character of the Model config.json to see if it's there. Otherwise download
        mssparkutils.fs.head(f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/models/{model_name}/config.json', 1)
        mount_point = f'/mnt'
        jobId = mssparkutils.env.getJobId()
        linkedStorageName = f'{mssparkutils.env.getWorkspaceName()}-WorkspaceDefaultStorage'

        # mssparkutils.fs.unmount(mount_point)
        mssparkutils.fs.mount( 
            f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/', 
            mount_point, 
            {'linkedService':f'{linkedStorageName}'} 
        )

        # Please note the differences with the synfs protocol
        # https://docs.microsoft.com/en-us/azure/synapse-analytics/spark/synapse-file-mount-api#how-to-access-files-under-mount-point-via-local-file-system-api
        model_location = f'/synfs/{jobId}{mount_point}/models/{model_name}/'
        print(model_location)
        transformers_cache_path = f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/models/{model_name}/.cache/'
        print(transformers_cache_path)
        logger.info(f'Using {model_name} model from {model_location}.')
        model_path = model_location
    except Exception as e:
        transformers_cache_path = f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/models/{model_name}/.cache/'
        print(transformers_cache_path)
        model_path = f'{model_name}'
        print(model_path)
        logger.info(f'Using {model_path} model from HuggingFace.')

In [ ]:
# Get list of unique text summaries
text_summs_list = df_combined_pd['summary_text'].unique().tolist()

In [ ]:
with tracer.span(name=f'Use Ray and run sentence-transformers {model_name} to get embeddings'):
    import os
    from sentence_transformers import SentenceTransformer, util
    import ray

    # Ray Actor (class) implementation
    @ray.remote
    class st_mpnet_text_model:
        
        # Want to avoid CPU contention.
        os.environ["MKL_NUM_THREAD"] = "1" 
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        os.environ['TRANSFORMERS_CACHE'] = transformers_cache_path
        
        def __init__(self, model_dir: str):
            from sentence_transformers import SentenceTransformer, util
            import torch
            
            self.model = self.load_model(model_dir)
            
            # Quantized the model to speed up inference
            # Resource: https://pytorch.org/docs/stable/quantization.html
            # self.model = torch.quantization.quantize_dynamic(self.model, {torch.nn.Linear}, dtype=torch.qint8)
            
        def load_model(self, model_dir: str):
            model = SentenceTransformer(model_dir)
            return model

        def predict(self, txt_batch):
            text_emb = self.model.encode(txt_batch)
            text_emb_tbl = pd.DataFrame(text_emb)
            text_emb_tbl['summarized_text_xsum'] = txt_batch
            return text_emb_tbl
    

In [ ]:
# Need an array to store the task_id
task_ids = []

# Actor index
a_id = 0

print('NUMBER OF CPUS', num_cpus)
ray.init(num_cpus=num_cpus)

model_actors = [st_mpnet_text_model.remote(model_path) for _ in range(num_cpus)]

for i in range(0, len(text_summs_list), batch_size):
    # Get the batch input
    txt_batch = text_summs_list[i:i+batch_size]
    
    # Get which actor to use for this batch, round robin
    a_id = (a_id + 1) % num_cpus
    task_id = model_actors[a_id].predict.remote(txt_batch)
    
    # Add task id to the list
    task_ids.append(task_id)

In [ ]:
print('Number of Task IDs: ', len(task_ids))

In [ ]:
import datetime
main_start_time = datetime.datetime.now()

ready_results = []

for i in range(len(task_ids)):
    
    start_time = datetime.datetime.now()
    
    # It is best to set a timeout if using this in production setting so that tasks that are taking too long don't continue to run and impact the production environment and resources.
    ready, not_ready = ray.wait(task_ids, timeout=None)
    ready_results.append(ray.get(ready))
    
    # print('Iteration: ', i) 
    # print('Ready Count: ', len(ready))
    # print('Not Ready Count: ', len(not_ready))
    # print('Iteration Time: ', (datetime.datetime.now()-start_time))
    # print("Total Inference Time (Batches): ", (datetime.datetime.now()-main_start_time))
    
    task_ids = not_ready
    if not task_ids:
        break
        
print('Total inference time (batches): ', datetime.datetime.now()-main_start_time)


In [ ]:
with tracer.span(name=f'Shutdown Ray'):
    # Shutdown Ray
    ray.shutdown()

In [ ]:
with tracer.span(name=f'Return results from Ray into a pandas DataFrame'):
    all_embs_tbl = pd.concat([y for x in ready_results for y in x])

In [ ]:
with tracer.span(name=f'Run BERTopic to perform topic modeling'):
    import numpy as np
    from bertopic import BERTopic
    # from sentence_transformers import SentenceTransformer, util
    # from umap import UMAP
    # from hdbscan import HDBSCAN
    from sklearn.feature_extraction.text import CountVectorizer

    # Prepare embeddings
    # sentence_model = SentenceTransformer(model_path)
    # embeddings = sentence_model.encode(mytext, show_progress_bar=False)
    embeddings = np.array(all_embs_tbl.iloc[:, :768])

    # Get documents
    #docs = text_summs_list
    docs = all_embs_tbl['summarized_text_xsum'].tolist()

    # Load vectorizer model
    vectorizer_model = CountVectorizer(stop_words='english')

    # Train BERTopic
    # topic_model = BERTopic(embedding_model=model_path, vectorizer_model=vectorizer_model, diversity=0.2)
    # topic_model = BERTopic(vectorizer_model=vectorizer_model, diversity=0.2, nr_topics="auto") # Use for topic reduction
    topic_model = BERTopic(vectorizer_model=vectorizer_model, diversity=bert_topic_diversity) # pre-computed embeddings; default is NOT "auto" but None
    fitted_topic_model = topic_model.fit(docs, embeddings)

    # Run the visualization with the original embeddings to visualize documents by topic
    f1 = fitted_topic_model.visualize_documents(docs, embeddings=embeddings)

    # Create topic model and calculate topics per class
    #topics_main, probs_main = topic_model.fit_transform(docs, embeddings)

In [ ]:
with tracer.span(name=f'Get the x and y values from the fitted topic model'):
    all_outputs = []

    for x in f1.data:
        summarized_text_xsum = x['hovertext']
        Topic_Name_Main = np.array([x['name']] * len(x['hovertext'])) 
        x_val = x['x']
        y_val = x['y']
        
        all_outputs.extend([{
            'summarized_text_xsum': summarized_text_xsum,
            'Topic_Name_Main': Topic_Name_Main,
            'X':x_val,
            'Y':y_val} for summarized_text_xsum, Topic_Name_Main, x_val, y_val in zip(
                summarized_text_xsum, Topic_Name_Main, x_val, y_val)])

In [ ]:
with tracer.span(name=f'Create a table with the text, topic, x and y values'):
    res_data = pd.json_normalize(all_outputs)
    res_data = res_data.fillna('')
    res_data_final = res_data[res_data['summarized_text_xsum'] != '']
    res_data_final = res_data_final.reset_index(drop=True)

In [ ]:
with tracer.span(name=f'Merge embeddings table with the table of x and y values'):
    result_df = df_combined_pd.merge(res_data_final, left_on='summary_text', right_on='summarized_text_xsum',  how='left')
    result_df = result_df[['file_path','file_type','summarized_text_xsum','Topic_Name_Main','X','Y']]
    result_df['Topic_Name_Main'] = result_df['Topic_Name_Main'].apply(lambda x: str(x))

In [ ]:
# Helper function to avoid the Type Error if there are NaNs
def cluster_util_func(topic_name):
    try:
        if topic_name != "other":
            return int(topic_name.split("_")[0])
        else:
            return int(-1)
    except Exception: # Quick fix to avoid the Type Error if there are NaNs
        return int(-1)

def cluster_name__util_func(topic_name):
    try:
        if topic_name != "other":
            return "_".join(x.split("_")[1:])
        else:
            return 'other'
    except Exception: # Quick fix to avoid the Type Error if there are NaNs
        return 'other'

# Replace Topic_Name_Main = 'nan' with 'other'
result_df['Topic_Name_Main'] = result_df['Topic_Name_Main'].apply(lambda x: 'other' if x == 'nan' else x ) 

# A small table with two columns, the first being an integer ID for each cluster, 
# and the second being the auto-generated label for the cluster. 
result_df['cluster'] = result_df['Topic_Name_Main'].apply(lambda x: cluster_util_func(x) )
result_df['cluster_name'] = result_df['Topic_Name_Main'].apply(lambda x: cluster_name__util_func(x))

small_tbl = result_df[['cluster', 'cluster_name']]
small_tbl = small_tbl.drop_duplicates()
small_tbl = small_tbl.sort_values(by='cluster').reset_index(drop=True)

if display_dataframes:
    print(small_tbl.head())

In [ ]:
# Prepare empty entries for images that had no available caption
df_img_empty_pd['summarized_text_xsum'] = ''
df_img_empty_pd['Topic_Name_Main'] = 'other'
df_img_empty_pd['X'] = None # was: np.nan
df_img_empty_pd['Y'] = None # was: np.nan
df_img_empty_pd['cluster'] = -1
df_img_empty_pd['cluster_name'] = 'other'
df_img_empty_pd = df_img_empty_pd.drop(['summary_text'], axis=1)

if display_dataframes:
    print(df_img_empty_pd.head())

In [ ]:
if display_dataframes:
    print(len(result_df))
    print(len(df_img_empty_pd))

In [ ]:
# A table of all documents and images that could be clustered (e.g. had non-empty captions/summaries), 
# including their filenames, cluster assignments, summaries, thumbnails (for images), 2D "coordinates," 
# and any other info that might be surfaced on a tab with multimodal clustering results
output_df = pd.concat([result_df, df_img_empty_pd]).reset_index(drop=True)
output_df['cluster'] = output_df['cluster'].astype('Int64')
output_df.rename(columns={'Topic_Name_Main': 'topic_name', 'file_type': 'type_of_file'}, inplace=True)

output_df = output_df[[
    'file_path',
    'type_of_file',
    'summarized_text_xsum',
    'topic_name',
    'cluster',
    'X',
    'Y']]

if display_dataframes:
    print(len(output_df))
    print(output_df.head())

In [ ]:
# from pyspark.sql.functions import regexp_replace, length, col
clustered_text_lookup_tbl_name = f'{batch_num}_clustered_multimodal_lookup'
clustered_multimodal_tbl_name = f'{batch_num}_clustered_multimodal'

# Change back to spark data frame
df_spark_small = spark.createDataFrame(small_tbl)
df_spark = spark.createDataFrame(output_df)

if display_dataframes:
    df_spark_small.show()

In [ ]:
if display_dataframes:
    df_spark.show()

In [ ]:
with tracer.span(name='Persist to clustered multimodal lookup table'):
    df_spark_small.write.mode("overwrite").parquet(f'{minted_tables_output_path}{clustered_text_lookup_tbl_name}')
    if display_dataframes:
      df_spark_small.show()
    df_spark_small.printSchema()

    df_output_sql_command_lk = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{clustered_text_lookup_tbl_name}') CREATE EXTERNAL TABLE [{clustered_text_lookup_tbl_name}] ([cluster] bigint, [cluster_name] nvarchar(4000)) WITH (LOCATION = 'minted_tables/{clustered_text_lookup_tbl_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    
    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(df_output_sql_command_lk)

In [ ]:
with tracer.span(name='Persist to clustered multimodal table'):
    #Drop Nan values before writing out
    df_spark = df_spark.na.drop()
    df_spark.write.mode("overwrite").parquet(f'{minted_tables_output_path}{clustered_multimodal_tbl_name}')
    if display_dataframes:
      df_spark.show()
    df_spark.printSchema()

    df_output_sql_command = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{clustered_multimodal_tbl_name}') CREATE EXTERNAL TABLE [{clustered_multimodal_tbl_name}] ([file_path] nvarchar(4000), [type_of_file] nvarchar(4000), [summarized_text_xsum] nvarchar(4000), [topic_name] nvarchar(4000), [cluster] bigint, [X] float, [Y] float ) WITH (LOCATION = 'minted_tables/{clustered_multimodal_tbl_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    
    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(df_output_sql_command)

In [ ]:
# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'clustered_multimodal_tbl_name': clustered_multimodal_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])


# Raise clustering complete event

In [ ]:
# # Prepare the event contents
# with tracer.span(name='preparing contents to send to event grid'):   
#     from datetime import datetime
#     now = datetime.now().strftime("%Y-%m-%dT%H:%M:%S%Z")    
#     web_app_uri = config.rule_sets.webapp_uri
#     subscriber_uri = config.rule_sets.teams_webhook_uri
#     alert_email = config.rule_sets.alert_email    
#     df_cluster_count = df_spark.groupBy("cluster").count()
#     df_cluster_count = df_cluster_count.orderBy('cluster', ascending=True)
#     cluster_json_list = df_cluster_count.toJSON().collect()
#     num_of_clusters = df_cluster_count.distinct().count ()
#     cluster_output = ''
#     for x in range(len(cluster_json_list)): 
#         cluster_output = cluster_output + ', ' + cluster_json_list[x]   
#     cluster_output = cluster_output[2:]
#     cluster_output_str = ''.join(cluster_output)

#     # generate the Event Grid json 
#     event_data = f'{{"batch_id": "{batch_num}",' \
#         f'"batch_description": "{batch_description}",' \
#         f'"eventDate": "{now}",' \
#         f'"eventMetrics": {{' \
#         f'  "event_type": "text",' \
#         f'  "files_processed_count": "{text_file_count}",' \
#         f'  "event_detail_uri": "https://{web_app_uri}/reports",' \
#         f'  "num_of_clusters": {num_of_clusters},' \
#         f'  "clusters": [' \
#         f'      {cluster_output_str}' \
#         f'  ]' \
#         f'}},' \
#         f'"teams_webhook_endpoint": "{subscriber_uri}",' \
#         f'"alert_email": "{alert_email}"' \
#         f'}}'

#     event_data_obj = json.loads(event_data)

In [ ]:
# # Raise the event
# with tracer.span(name='sending message to event grid'):    
#     from azure.identity import ClientSecretCredential
#     from azure.eventgrid import EventGridPublisherClient, EventGridEvent    

#     # Get value from keyvault to build Event Grid Topic event
#     subscription_id = TokenLibrary.getSecretWithLS("keyvault", 'SubscriptionId')
#     resource_group_name = TokenLibrary.getSecretWithLS("keyvault", 'ResourceGroupName')
#     event_grid_topic_name = TokenLibrary.getSecretWithLS("keyvault", 'EventGridTopicName')
#     event_grid_topic_endpoint = TokenLibrary.getSecretWithLS("keyvault", 'EventGridTopicEndpointUri')
#     tenant_id = TokenLibrary.getSecretWithLS("keyvault", 'TenantID')
#     client_id = TokenLibrary.getSecretWithLS("keyvault", 'ADAppRegClientId')
#     client_secret = TokenLibrary.getSecretWithLS("keyvault", 'ADAppRegClientSecret')
#     event_grid_topic = f'/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.EventGrid/topics/{event_grid_topic_name}'
#     credential = ClientSecretCredential(tenant_id, client_id, client_secret)
#     client = EventGridPublisherClient(event_grid_topic_endpoint, credential)

#     try:
#         # queue event grid message
#         event = EventGridEvent(data=event_data_obj, subject="MINTED/ClusterAlert", event_type="MINTED.ruleTriggered", data_version="1.0", topic=event_grid_topic)
#         client.send(event)
#         print("done")
#     except Exception as e:
#         logger.exception(e)
#         raise e


# # Return the object to the pipeline
# logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
# mssparkutils.notebook.exit(output['custom_dimensions'])