In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8
   }
}

In [ ]:
processed_text_tbl_name = ''
batch_root = ''
batch_num = ''
batch_description = ''
input_container=''
output_container=''
blob_account_name = ''
rules_container = ''
file_system = ''
processed_images_tbl_name = ''
batch_file_count = 0
azure_storage_domain = ''
minted_tables_output_path = ''

In [ ]:
# set to true if you wish to output data to investigate processing
display_results = False

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer
#import datetime

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'processed_text_tbl_name': processed_text_tbl_name,
    'batch_description': batch_description,
    'batch_root': batch_root,
    'batch_num': batch_num,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)


In [ ]:
from azure.identity import ClientSecretCredential
from azure.eventgrid import EventGridPublisherClient, EventGridEvent
from azure.mgmt.keyvault import KeyVaultManagementClient
from types import SimpleNamespace
from datetime import datetime

import json
import rule_engine
import urllib.parse

from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql.functions import col, to_json, from_json, lit, explode, concat, udf, regexp_replace
from pyspark.sql.types import StringType, MapType, BooleanType

In [ ]:
# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

config = json.loads(''.join(sc.textFile(f'{batch_root}/config.json').collect()))

# Set log level
if config["log_level"] == "INFO":
    logger.setLevel(logging.INFO)
else:
    logger.setLevel(logging.ERROR)
    config["log_level"] = "ERROR"

    
log_level = config["log_level"]
rulesets_config = config["rule_sets"]
web_app_uri = rulesets_config["webapp_uri"]
subscriber_uri = rulesets_config["teams_webhook_uri"]
alert_email = rulesets_config["alert_email"]
text_rulesets_config = rulesets_config["text_rule_sets"]

#get value from keyvault to build Event Grid Topic event
subscription_id = TokenLibrary.getSecretWithLS("keyvault", 'SubscriptionId')
resource_group_name = TokenLibrary.getSecretWithLS("keyvault", 'ResourceGroupName')
event_grid_topic_name = TokenLibrary.getSecretWithLS("keyvault", 'EventGridTopicName')
event_grid_topic_endpoint = TokenLibrary.getSecretWithLS("keyvault", 'EventGridTopicEndpointUri')
tenant_id = TokenLibrary.getSecretWithLS("keyvault", 'TenantID')
client_id = TokenLibrary.getSecretWithLS("keyvault", 'ADAppRegClientId')
client_secret = TokenLibrary.getSecretWithLS("keyvault", 'ADAppRegClientSecret')
event_grid_topic = f'/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.EventGrid/topics/{event_grid_topic_name}'
credential = ClientSecretCredential(tenant_id, client_id, client_secret)
client = EventGridPublisherClient(event_grid_topic_endpoint, credential)

if display_results:
    print(config)

In [ ]:
import pyodbc
from pyspark.sql.functions import current_timestamp
# Dedicated and serverless SQL config
dedicated_database = 'dedicated'
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'

# secrets
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

## Evaluate rules for full text of documents

In [ ]:
# load the table translated text
with tracer.span(name=f'Read the dataframe from the given table {processed_text_tbl_name}'):
    docs_df = spark.read.parquet(f'{minted_tables_output_path}{processed_text_tbl_name}')
df_original_text = docs_df.select(docs_df.file_name.alias("ruleset_file_name"), docs_df.text_content_target_lang, docs_df.original_lang, docs_df.cluster)
df_original_text = df_original_text.withColumn("Explanations", lit("n/a"))
df_original_text = df_original_text.na.fill(value=-1)

if display_results:
    print(df_original_text.count())
    df_original_text.show()

In [ ]:
with tracer.span(name='Evaluate fulltext conditions for documents'):
    # Load the rulesets to search for and pass to flashtext (https://flashtext.readthedocs.io/en/latest/)
    from flashtext import KeywordProcessor

    # enumerate all the text_rulesets and add full text conditions to processor
    i = 0
    for ruleset in text_rulesets_config:
        ruleset_name = ruleset["rule_set_name"].replace(" ", "_")

        # add condition as keywords for each fulltext condition
        keyword_processor = KeywordProcessor()
        for fulltext_condition in ruleset["fulltext_conditions"]:
            condition_json = "{"+fulltext_condition["condition"]+"}"
            condition_dict = json.loads(condition_json) 
            keyword_processor.add_keywords_from_dict(condition_dict)

        # UDF to call flashtext search for a text value
        def fulltext_search(text):
            fulltext_result = json.dumps(keyword_processor.extract_keywords(text, span_info=True))
            return fulltext_result
        udf_fulltext_search = udf(fulltext_search, StringType())            

        # perform a single search for all fulltext conditions in the ruleset to save processing time
        df_original_text_fulltext_output = df_original_text.withColumn("fulltext_result", udf_fulltext_search(df_original_text.text_content_target_lang)) 
   
        # add rows to output (all_flagged_df)
        all_flagged_text_df_fulltext = df_original_text_fulltext_output.filter(df_original_text_fulltext_output.fulltext_result != "[]")        
        all_flagged_text_df_fulltext = all_flagged_text_df_fulltext.withColumn("ruleset_name", lit(ruleset["rule_set_name"])) 
        all_flagged_text_df_fulltext = all_flagged_text_df_fulltext.drop("text_content_target_lang")
        
        # tag each row as being an text type file and folder which contains their enrichments
        all_flagged_text_df_fulltext = all_flagged_text_df_fulltext.withColumn("file_type", lit('text'))
        all_flagged_text_df_fulltext = all_flagged_text_df_fulltext.withColumn("file_enrichment_folder", lit('/text_processing_json/'))
        
        #append results to all_flagged_df_fulltext which builds for all rulesets
        if i==0:
            all_flagged_text_df_final = all_flagged_text_df_fulltext
        else:
            all_flagged_text_df_final = all_flagged_text_df_final.unionAll(all_flagged_text_df_fulltext)
        i=i+1

if display_results:
    all_flagged_text_df_final.groupBy("ruleset_name").count().show()
    all_flagged_text_df_final.show()

# Evaluate rules for enrichments on images

In [ ]:
# load the table with image enrichments
with tracer.span(name=f'Read the dataframe from the given table {processed_images_tbl_name}'):
    images_df = spark.read.parquet(f'{minted_tables_output_path}{processed_images_tbl_name}')
df_images_processed = images_df.select(images_df.file_name.alias("ruleset_file_name"), images_df.analysis_results, images_df.cluster, images_df.Explanations)
#from pyspark.sql.functions import *
df_images_processed = df_images_processed.withColumn('Explanations', regexp_replace('Explanations', "'", ""))
df_images_processed = df_images_processed.withColumn("original_lang", lit('n/a'))

if display_results:
    print(df_images_processed.count())
    df_images_processed.show()

In [ ]:
with tracer.span(name='Evaluate fulltext conditions for image enrichments'):
    # Load the rulesets to search for and pass to flashtext (https://flashtext.readthedocs.io/en/latest/)
    from flashtext import KeywordProcessor


    # enumerate all the text_rulesets and add full text conditions to processor
    i = 0
    for ruleset in text_rulesets_config:
        ruleset_name = ruleset["rule_set_name"].replace(" ", "_")
 
        # add condition as keywords for each fulltext 
        keyword_processor = KeywordProcessor()        
        for fulltext_condition in ruleset["fulltext_conditions"]:
            condition_json = "{"+fulltext_condition["condition"]+"}"
            condition_dict = json.loads(condition_json) 
            keyword_processor.add_keywords_from_dict(condition_dict)

        # UDF to call flashtext search for a text value
        def fulltext_search(text):
            fulltext_result = json.dumps(keyword_processor.extract_keywords(text, span_info=True))
            return fulltext_result
        udf_fulltext_search = udf(fulltext_search, StringType())                        

        # perform a single search for all fulltext conditions in the ruleset to save processing time
        df_images_processed_fulltext_output = df_images_processed.withColumn("fulltext_result", udf_fulltext_search(df_images_processed.analysis_results.cast(StringType()))) 

        # add rows to output (all_flagged_df)
        all_flagged_images_df_fulltext = df_images_processed_fulltext_output.filter(df_images_processed_fulltext_output.fulltext_result != "[]")        
        all_flagged_images_df_fulltext = all_flagged_images_df_fulltext.withColumn("ruleset_name", lit(ruleset["rule_set_name"])) 
        all_flagged_images_df_fulltext = all_flagged_images_df_fulltext.drop("analysis_results")

        # tag each row as being an image type file
        all_flagged_images_df_fulltext = all_flagged_images_df_fulltext.withColumn("file_type", lit('image'))
        all_flagged_images_df_fulltext = all_flagged_images_df_fulltext.withColumn("file_enrichment_folder", lit('/image_processing_json/'))

        #append results to all_flagged_df_fulltext which builds for all rulesets
        if i==0:
            all_flagged_images_df_final = all_flagged_images_df_fulltext
        else:
            all_flagged_images_df_final = all_flagged_images_df_final.unionAll(all_flagged_images_df_fulltext)
        i=i+1
        print(all_flagged_images_df_final.count())


if display_results:
    all_flagged_images_df_final.groupBy("ruleset_name").count().show()
    all_flagged_images_df_final.show(1)
    all_flagged_images_df_final.show()


# Raise rule based alert

In [ ]:
# merge results from image and documents, but reorder columns to match first
images_df = all_flagged_images_df_final.select("ruleset_file_name", "ruleset_name", "fulltext_result", "file_type", "file_enrichment_folder", "original_lang", "cluster", "Explanations")
text_df = all_flagged_text_df_final.select("ruleset_file_name", "ruleset_name", "fulltext_result", "file_type", "file_enrichment_folder", "original_lang", "cluster", "Explanations")
all_flagged_combined_df_final = images_df.unionAll(text_df)
all_flagged_combined_df_final.orderBy(['ruleset_file_name'], ascending = [True])

if display_results:
    all_flagged_combined_df_final.groupBy("ruleset_name").count().show()
    all_flagged_combined_df_final.show()

In [ ]:
with tracer.span(name='Persist ruleset evaluations as table'):
    ruleset_eval_tbl_name = f'{batch_num}_ruleset_eval'
    all_flagged_combined_df_final.write.mode("overwrite").parquet(f'{minted_tables_output_path}{ruleset_eval_tbl_name}')
    sql_command = f'''
        IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{ruleset_eval_tbl_name}') 
        CREATE EXTERNAL TABLE [{ruleset_eval_tbl_name}] (
            [ruleset_file_name] nvarchar(4000),
            [ruleset_name] nvarchar(4000),
            [fulltext_result] nvarchar(max),
            [file_type] nvarchar(4000),
            [file_enrichment_folder] nvarchar(1000),
            [original_lang] nvarchar(4000),
            [cluster] bigint, 
            [Explanations] nvarchar(max)
        )
        WITH (
            LOCATION = 'minted_tables/{ruleset_eval_tbl_name}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
        )
    '''
    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
        with conn.cursor() as cursor:
            cursor.execute(sql_command)

with tracer.span(name='Persist as json and queue event grid topic'):
    output_path = f'abfss://{output_container}@{blob_account_name}.dfs.{azure_storage_domain}/{batch_num}'
    rules_output_path = f'abfss://{rules_container}@{blob_account_name}.dfs.{azure_storage_domain}/{batch_num}'
    now = datetime.now().strftime("%Y-%m-%dT%H:%M:%S%Z")
    
    # iterate over each ruleset
    for ruleset in text_rulesets_config:
        filtered_files = all_flagged_combined_df_final.filter(col("ruleset_name") == ruleset["rule_set_name"])
        filtered_files = filtered_files.withColumn("file_enrichment_uri", concat(lit(output_path), col("file_enrichment_folder"), col("ruleset_file_name"), lit(".output.json")))          
        flagged_files_count = filtered_files.count()
 
        # Collect the data to Python List
        file_list_json = ""
        filtered_files_list = filtered_files.collect()
        for filtered_file in filtered_files_list:
            file_p = f'{batch_root}/{filtered_file.ruleset_file_name}'
            enrichment_p = filtered_file.file_enrichment_uri
            search_results = filtered_file.fulltext_result
            file_type = filtered_file.file_type
            original_lang = filtered_file.original_lang
            cluster = filtered_file.cluster
            Explanations = filtered_file.Explanations
            
            # finalise file detail outut json
            file_list_json = file_list_json + f'{{"file_uri": "{file_p}", "file_type_class": "{file_type}", "file_enrichment_uri": \
                "{enrichment_p}", "original_lang": "{original_lang}", "cluster": {cluster}, "Explanations": "{Explanations}", \
                "fulltext_search_detail": {search_results}}},'
            file_list_json = file_list_json.replace("'", '"') 

        file_list_json = file_list_json[:-1]
        # build output json
        output_json = f'{{"batch_id": "{batch_num}",' \
            f'"batch_description": "{batch_description}",' \
            f'"rule_set_config": {json.dumps(ruleset)},' \
            f'"eventDate": "{now}",' \
            f'"eventDetails": [' \
            f'{file_list_json}' \
            f']}}'

        if display_results:
            print(output_json)

        # write rules json output to the storage container ready for downstream use
        p = f'{rules_output_path}/ruleset_events/{ruleset["rule_set_name"]}.ruleset.output.json' #path to rules output file
        mssparkutils.fs.put(p, output_json, overwrite=True)

        #generate the Event Grid schema and send to Event Grid Topic and ultimately to a summary alert
        webapp_alert_page_path = f'{batch_num}/ruleset_events/{ruleset["rule_set_name"]}.ruleset.output.json' 
        webapp_alert_page_parameter = '{"Path": "' + webapp_alert_page_path + '", "Filter": {"original_lang": "", "cluster": -1, "Explanations": ""}}'   
        ruleset_event_data = f'{{"batch_id": "{batch_num}",' \
            f'"batch_description": "{batch_description}",' \
            f'"rule_set_config": {json.dumps(ruleset)},' \
            f'"eventDate": "{now}",' \
            f'"eventMetrics": {{' \
            f'"rule_events_count": "{flagged_files_count}",' \
            f'"files_processed_count": "{batch_file_count}",' \
            f'"event_detail_uri": "https://{web_app_uri}/alert/{urllib.parse.quote(webapp_alert_page_parameter)}"' \
            f'}},' \
            f'"teams_webhook_endpoint": "{subscriber_uri}",' \
            f'"alert_email": "{alert_email}"' \
            f'}}'
        ruleset_event_data_obj = json.loads(ruleset_event_data)

        try:
            #queue event grid message
            event = EventGridEvent(data= ruleset_event_data_obj, subject="MINTED/rulesetevent", event_type="MINTED.ruleTriggered", data_version="1.0", topic=event_grid_topic)
            client.send(event)
        except Exception as e:
            logger.exception(e)
            raise e

In [ ]:
from time import sleep
#from datetime import datetime, timedelta
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ? 
        ORDER BY [num_stages_complete] DESC;
    """

    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/10] {status_text}'
            x = datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?, update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

update_status_table('Text Ruleset Evaluation Complete', minted_tables_output_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd)

## Complete notebook outputs

In [ ]:
# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'ruleset_eval_tbl_name': ruleset_eval_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])