In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8
   }
}

In [ ]:
# This is a parameters cell where we define the batch_file details as params to be passed in by the pipeline
key_vault_name = ''
batch_num = ''
batch_root = ''
file_system = ''
output_container = ''
media_path = ''
enriched_media_tbl_name = ''
rules_container = ''
batch_description = ''
azure_storage_domain = ''
blob_account_name = ''

In [ ]:
# set to true if you wish to output data to investigate processing
display_results = False

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])
instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
media_contents_tbl_name = f'{batch_num}_submitted_media'
run_time_parameters = {'custom_dimensions': {
    'batch_num': batch_num,
    'file_system': file_system,
    'media_contents_tbl_name': media_contents_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)


In [ ]:
import pyodbc
from pyspark.sql.functions import current_timestamp
# Dedicated and serverless SQL config
dedicated_database = 'dedicated'
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'
minted_tables_path = f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/minted_tables/'

# secrets
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

In [ ]:
from azure.identity import ClientSecretCredential
from azure.eventgrid import EventGridPublisherClient, EventGridEvent
from azure.mgmt.keyvault import KeyVaultManagementClient
from types import SimpleNamespace
from datetime import datetime
import json
import rule_engine
import urllib.parse
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql.functions import col, to_json, from_json, lit, explode, concat, udf, regexp_replace, json_tuple
from pyspark.sql.types import StringType, MapType, BooleanType


In [ ]:
# Load keys, set defaults
with tracer.span(name='load values from key vault'):
    instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")
    subscription_id = mssparkutils.credentials.getSecretWithLS("keyvault", "SubscriptionId")
    resource_group_name = mssparkutils.credentials.getSecretWithLS("keyvault", "ResourceGroupName")
    tenant_id = mssparkutils.credentials.getSecretWithLS("keyvault", "TenantID")
    client_id = mssparkutils.credentials.getSecretWithLS("keyvault", "ADAppRegClientId")
    client_secret = mssparkutils.credentials.getSecretWithLS("keyvault", "ADAppRegClientSecret")
    storage_account_name = mssparkutils.credentials.getSecretWithLS("keyvault", "StorageAccountName")
    storage_account_key = mssparkutils.credentials.getSecretWithLS("keyvault", "StorageAccountKey")

    ruleset_eval_tbl_name = f'{batch_num}_media_ruleset_eval'

    azure_resource_manager = "<<TF_VAR_azure_arm_management_api>>";
    credential = ClientSecretCredential(tenant_id, client_id, client_secret)

In [ ]:
# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"Media Processing {mssparkutils.runtime.context}").getOrCreate()
config = json.loads(''.join(sc.textFile(f'{batch_root}/config.json').collect()))

# Set log level
if config["log_level"] == "INFO":
    logger.setLevel(logging.INFO)
else:
    logger.setLevel(logging.ERROR)
    config["log_level"] = "ERROR"

    
log_level = config["log_level"]
rulesets_config = config["rule_sets"]
web_app_uri = rulesets_config["webapp_uri"]
subscriber_uri = rulesets_config["teams_webhook_uri"]
alert_email = rulesets_config["alert_email"]
text_rulesets_config = rulesets_config["text_rule_sets"]

# get value from keyvault to build Event Grid Topic event
event_grid_topic_name = mssparkutils.credentials.getSecretWithLS('keyvault', 'EventGridTopicName')
event_grid_topic_endpoint = mssparkutils.credentials.getSecretWithLS('keyvault', 'EventGridTopicEndpointUri')
event_grid_topic = f'/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.EventGrid/topics/{event_grid_topic_name}'
credential = ClientSecretCredential(tenant_id, client_id, client_secret)
client = EventGridPublisherClient(event_grid_topic_endpoint, credential)

if display_results:
    print(config)

## Evaluate rules for full text of media if this is the last expected callback

In [ ]:
# get count of how many media files we have in the batch - we only process alerts and rules once we have processed all media files
submitted_media_tbl_name = f'{batch_num}_submitted_media'
df_enriched_media_sql = f"SELECT COUNT(*) FROM [dbo].[{submitted_media_tbl_name}] WHERE state = 'Processing'"
media_file_count = 0
with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{serverless_sql_endpoint};PORT=1433;DATABASE={database};UID={sql_user_name};PWD={sql_user_pwd}') as conn:
        with conn.cursor() as cursor:
            cursor.execute(df_enriched_media_sql)
            media_file_count, *_ = cursor.fetchone()
print(media_file_count)

In [ ]:
# load the table of media enrichments from the previous notebook. We will run a search against these enrichments
with tracer.span(name=f'Read the dataframe from the given table {enriched_media_tbl_name}'):
    media_df = spark.read.parquet(f'{minted_tables_path}{enriched_media_tbl_name}')
    media_df = media_df.select(col("media_path"),col("media_file_name"),col("video_id"),col("enrichments"),col("original_lang"),json_tuple(col("enrichments"),"state")) \
        .toDF("media_path","media_file_name","video_id","enrichments", "original_lang", "state")
    processed_count = media_df.filter(media_df.state == "Processed").count()
    failed_count = media_df.filter(media_df.state == "Failed").count()

    if display_results:
        media_df.show()
        print("processed files =", processed_count)
        print("failed files =", failed_count)

    # abort notebook if all files are not completed - we haven't received all the callbacks
    if media_file_count > processed_count + failed_count:
        output = {
            'custom_dimensions': {
                'batch_num': batch_num,
                'ruleset_eval_tbl_name': ruleset_eval_tbl_name,
                'notebook_name': mssparkutils.runtime.context['notebookname']
            } 
        }
        # exit notebook if this isn't the last expected callback
        logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT ", extra=output)
        mssparkutils.notebook.exit("exited as the full set of media files is not complete - we only raise a single alert")

if display_results:
    media_df.show()

In [ ]:
with tracer.span(name='Evaluate fulltext conditions for documents'):

    # Only process rules/alerts if the full media batch is completed - we exitied the notebook of the last cell if that is the case
    # Load the rulesets to search for and pass to flashtext (https://flashtext.readthedocs.io/en/latest/)
    from flashtext import KeywordProcessor

    # enumerate all the text_rulesets and add full text conditions to processor
    i = 0
    for ruleset in text_rulesets_config:
        ruleset_name = ruleset["rule_set_name"].replace(" ", "_")

        # add condition as keywords for each fulltext condition
        keyword_processor = KeywordProcessor()
        for fulltext_condition in ruleset["fulltext_conditions"]:
            condition_json = "{"+fulltext_condition["condition"]+"}"
            condition_dict = json.loads(condition_json) 
            keyword_processor.add_keywords_from_dict(condition_dict)

        # UDF to call flashtext search for a text value
        def fulltext_search(text):
            fulltext_result = json.dumps(keyword_processor.extract_keywords(text, span_info=True))
            return fulltext_result
        udf_fulltext_search = udf(fulltext_search, StringType())            

        # perform a single search for all fulltext conditions in the ruleset to save processing time
        media_df_fulltext_output = media_df.withColumn("fulltext_result", udf_fulltext_search(media_df.enrichments))  

        # add rows to output (all_flagged_df)
        all_flagged_media = media_df_fulltext_output.filter(media_df_fulltext_output.fulltext_result != "[]")        
        all_flagged_media = all_flagged_media.withColumn("ruleset_name", lit(ruleset["rule_set_name"]))   

if display_results:
    all_flagged_media.show()   

# Get count of file in the original manifest


In [ ]:
# load the table with original manifest
manifest_df_name = f'{batch_num}_manifest'    
with tracer.span(name=f'Read the dataframe from the given table {manifest_df_name}'):
    batch_df = spark.read.parquet(f'{minted_tables_path}{manifest_df_name}')
    batch_file_count = batch_df.count()


## Raise rule based alert

In [ ]:
with tracer.span(name='Persist as json and queue event grid topic'):
    output_path = f'abfss://{output_container}@{storage_account_name}.dfs.{azure_storage_domain}/{batch_num}'
    rules_output_path = f'abfss://{rules_container}@{storage_account_name}.dfs.{azure_storage_domain}/{batch_num}'
    now = datetime.now().strftime("%Y-%m-%dT%H:%M:%S%Z")
    files_count = media_file_count
    media_enrichment_folder = "/media_processing_json/"
    
    # iterate over each ruleset
    for ruleset in text_rulesets_config:
        filtered_files = all_flagged_media.filter(col("ruleset_name") == ruleset["rule_set_name"])
        filtered_files = filtered_files.withColumn("file_enrichment_uri", concat(lit(output_path), lit(media_enrichment_folder), col("media_file_name"), lit(".output.json")))          
        flagged_files_count = filtered_files.count()
        print(flagged_files_count)

        # Only process an event if the count of matching files is > 0
        if flagged_files_count > 0:

            # Collect the data to Python List
            file_list_json = []
            filtered_files_list = filtered_files.collect()
            for filtered_file in filtered_files_list:
                file_p = f'{batch_root}/{filtered_file.media_file_name}'
                enrichment_p = filtered_file.file_enrichment_uri
                search_results = filtered_file.fulltext_result
                file_type = "media"
                original_lang = filtered_file.original_lang
                cluster = "n/a"
                Explanations = "n/a"
                
                # finalise file detail outut json
                file_list_json.append({
                    "file_uri": file_p,
                    "file_type_class": file_type,
                    "file_enrichment_uri": enrichment_p,
                    "original_lang": original_lang,
                    "cluster": cluster,
                    "Explanations": Explanations,
                    "fulltext_search_detail": search_results
                })

            # build output json encoded string
            output_json = json.dumps({
                "batch_id": batch_num,
                "batch_description": batch_description,
                "rule_set_config": ruleset,
                "eventDate": now,
                "eventDetails": file_list_json
            })
            
            # write rules json output to the storage container ready for downstream use
            p = f'{rules_output_path}/ruleset_events/{ruleset["rule_set_name"]}.media.ruleset.output.json' #path to rules output file
            mssparkutils.fs.put(p, output_json, overwrite=True)     

            # generate the Event Grid schema and send to Event Grid Topic and ultimately to a summary alert
            webapp_alert_page_path = f'{batch_num}/ruleset_events/{ruleset["rule_set_name"]}.media.ruleset.output.json'

            webapp_alert_page_parameter = json.dumps({
                "Path": webapp_alert_page_path,
                "Filter": {
                    "original_lang": "",
                    "cluster": -1, 
                    "Explanations": ""
                }
            })
            
            ruleset_event_data_obj = {
                "batch_id": batch_num,
                "batch_description": batch_description,
                "rule_set_config": ruleset,
                "eventDate": now,
                "eventMetrics": {
                    "rule_events_count": flagged_files_count,
                    "files_processed_count": batch_file_count,
                    "event_detail_uri": f"https://{web_app_uri}/alert/{urllib.parse.quote(webapp_alert_page_parameter)}"
                },
                "teams_webhook_endpoint": subscriber_uri,
                "alert_email": alert_email   
            }
            
            print(ruleset_event_data_obj) 

            try:
                #queue event grid message
                event = EventGridEvent(data= ruleset_event_data_obj, subject="MINTED/rulesetmediaevent", event_type="MINTED.ruleTriggered", data_version="1.0", topic=event_grid_topic)
                client.send(event)
            except Exception as e:
                logger.exception(e)
                raise e                  


In [ ]:
# persist output to SQL
with tracer.span(name='Persist ruleset evaluations as table'):
    all_flagged_media.write.mode("overwrite").parquet(f'{minted_tables_path}{ruleset_eval_tbl_name}')

In [ ]:
from time import sleep
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ?
        ORDER BY [num_stages_complete] DESC;
    """
    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/10] {status_text}'
            x = datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?, update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

update_status_table('Media Processing Complete', minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd)

## Complete notebook outputs

In [ ]:
# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'ruleset_eval_tbl_name': ruleset_eval_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])