# 2.1 Text Enrichment


In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 4,
     "spark.jars.packages": "com.microsoft.azure:synapseml_2.12:0.10.0-19-c3a445c5-SNAPSHOT",
      "spark.jars.repositories": "https://mmlspark.azureedge.net/maven",
      "spark.jars.excludes": "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalactic:scalactic_2.12,org.scalatest:scalatest_2.12,com.fasterxml.jackson.core:jackson-databind",
      "spark.yarn.user.classpath.first": "true"
   }
}

In [ ]:
documents_contents_tbl_name = ''
batch_num = ''
file_system = ''
display_dataframes = False
minted_tables_output_path = ''

In [ ]:
import pyodbc
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")

In [ ]:
# Load keys, set defaults
instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")
text_analytics_keys = mssparkutils.credentials.getSecretWithLS("keyvault", 'TextAnalyticsKeys').split(',')

cog_svc_concurrency = 1
cog_svc_batch_size = 15 # The /analyze endpoint that TextAnalyze uses is documented to allow batches of up to 25 documents
cog_svc_intial_polling_delay = 15000 # Time (in ms) to wait before first poll for results
cog_svc_polling_delay = 10000 # Time (in ms) to wait between repeated polling for results
cog_svc_maximum_retry_count = 100 # Maximum number of retries. 60 => 60 * 10s + 15s = 615s ~= 10 mins allowed for a job to complete

# Column names 
file_path_col = 'file_path'
text_col = 'text_content_target_lang'
text_split_col = 'text_content_target_lang_split'
chunk_number_col = 'chunk_number'
text_analysis_col = 'text_analysis'
text_analysis_error_col = 'text_analysis_error'
text_analytics_key_col = 'text_analytics_key'
named_entities_col = 'named_entities'
pii_col = 'pii'
key_phrases_col = 'key_phrases'
pii_redacted_text_col = 'pii_redacted_text'

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'batch_num': batch_num,
    'documents_contents_tbl_name': documents_contents_tbl_name,
    'cog_svc_concurrency': cog_svc_concurrency,
    'cog_svc_batch_size': cog_svc_batch_size,
    'file_system': file_system,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
import json
import os
import random
import uuid
from types import SimpleNamespace

import pyspark.sql.functions as F
from pyspark.sql.functions import col
from pyspark.sql.types import StringType, StructType, StructField
from pyspark import SparkContext
from pyspark.sql import SparkSession

from synapse.ml.cognitive import *
from synapse.ml.featurize.text import PageSplitter
from synapse.ml.stages import FixedMiniBatchTransformer
from synapse.ml.stages import FlattenBatch

# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

def read_batch_config(batch_root: str):
    """
    We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
    all back together again
    """
    # Change our file system from 'synapse' to 'input'
    sc._jsc.hadoopConfiguration().set("fs.defaultFS", file_system)

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

    # If we don't have a batch config, copy the global one.
    if fs.exists(config_path) != True:
        logger.error(f'{config_path} not found.')

    # Open our file directly rather than through spark
    input_stream = fs.open(config_path)  # FSDataInputStream

    config_string = sc._jvm.java.io.BufferedReader(
        sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
        ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

    # Load it into json    
    return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

with tracer.span(name=f"Load config: {mssparkutils.runtime.context['notebookname']}"):
    try:
        config = read_batch_config(batch_root)
    except Exception as e:
        logger.exception(e)
        raise e

    # Set log level
    if config.log_level == "INFO":
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)
        config.log_level = "ERROR"

In [ ]:
with tracer.span(name='Load documents contents table'):
    df = spark.read.parquet(minted_tables_output_path + documents_contents_tbl_name).select('file_path',f'{text_col}')

In [ ]:
with tracer.span(name='Split large documents'):
    page_splitter = (PageSplitter()
        .setInputCol(text_col)
        .setMaximumPageLength(5000)
        .setMinimumPageLength(4500)
        .setOutputCol(text_split_col))

    df_split = page_splitter.transform(df)

    df_split = df_split.select("file_path", text_col, F.posexplode(text_split_col).alias(chunk_number_col, text_split_col))

    if display_dataframes:
        df_split.show()

In [ ]:
with tracer.span(name='Group rows into batches'):
      # This will reduce the number of API calls to Cognitive Services

      fmbt = (FixedMiniBatchTransformer()
            .setBatchSize(cog_svc_batch_size))

      df_batched = fmbt.transform(df_split)

      if display_dataframes:
            df_batched.show()

In [ ]:
with tracer.span(name='Distribute cognitive service keys across rows'):
    def rand_key() :
        index = random.randint(0, len(text_analytics_keys)-1)
        return text_analytics_keys[index]
    udf_rand_key = F.udf(rand_key, StringType())

    df_batched = df_batched.withColumn("text_analytics_key", udf_rand_key())

In [ ]:
with tracer.span(name='Define TextAnalyze transform'):
    # Text Analysis (key phrase extraction, named entity recognition, PII recognition + redaction)
    text_analyze = (TextAnalyze()
        .setTextCol(text_split_col)
        .setLocation(config.location)
        .setSubscriptionKeyCol(text_analytics_key_col)
        .setOutputCol(text_analysis_col)
        .setErrorCol(text_analysis_error_col)
        .setConcurrency(cog_svc_concurrency)
        .setInitialPollingDelay(cog_svc_intial_polling_delay)
        .setPollingDelay(cog_svc_polling_delay)
        .setMaxPollingRetries(cog_svc_maximum_retry_count)
        .setSuppressMaxRetriesExceededException(True)
        .setLanguage(config.prep.target_language)
        .setEntityRecognitionTasks([{"parameters": { "model-version": "latest"}}])
        .setKeyPhraseExtractionTasks([{"parameters": { "model-version": "latest"}}])
        .setEntityRecognitionPiiTasks([{"parameters": { "model-version": "latest"}}])
        <<SYNAPSE_ML_TEXT_ANALYZE_ENDPOINT_CMD>>
        )

    df_results_batched = text_analyze.transform(df_batched)

    if display_dataframes:
        df_results_batched.show()

In [ ]:
flattener = FlattenBatch()

df_results = flattener.transform(df_results_batched)

if display_dataframes:
    df_results.show()

In [ ]:
# split out text analysis results
df_results = df_results.select(
        file_path_col,
        f"{text_analysis_col}.*",
        text_analysis_error_col,
        chunk_number_col
    )
    
if display_dataframes:
    df_results.show()

In [ ]:
# collate split rows back to single row per document
error_response_schema = StructType(
    [StructField("error", StructType(
        [StructField("code", StringType()), StructField("message", StringType())]
    ))]
)

sorted_text_col = "sorted_text"
df_output = df_results.select(
                file_path_col,
                # we only have a single task for each task type, so unpack it
                col("entityRecognition")[0]["result"]["entities"].alias(named_entities_col),
                col("entityRecognitionPii")[0]["result"]["entities"].alias(pii_col),
                col("keyPhraseExtraction")[0]["result"]["keyPhrases"].alias(key_phrases_col),
                col("entityRecognitionPii")[0]["result"]["redactedText"].alias(pii_redacted_text_col),
                 # set up as an array for the grouping step
                F.from_json(df_results[text_analysis_error_col]["response"], error_response_schema)["error"]["message"].alias(text_analysis_error_col),
                chunk_number_col
            )\
            .groupby(file_path_col)\
            .agg(
                F.flatten(F.collect_list("named_entities")).alias(named_entities_col),
                F.flatten(F.collect_list("pii")).alias(pii_col),
                F.flatten(F.collect_list("key_phrases")).alias(key_phrases_col),
                F.sort_array(F.collect_list(F.struct(chunk_number_col, pii_redacted_text_col))).alias(sorted_text_col),
                F.max(col(text_analysis_error_col)).alias(text_analysis_error_col)
            )\
            .withColumn(
                pii_redacted_text_col,
                F.concat_ws("", col(f"{sorted_text_col}.{pii_redacted_text_col}"))
            )\
            .drop(sorted_text_col)

if display_dataframes:
    df_output.show()

In [ ]:
with tracer.span(name='Persist enriched text as table'):
    enriched_text_tbl_name = f'{batch_num}_enriched_text'
    df_output.show()
    df_output.printSchema()
    print(enriched_text_tbl_name)
    df_output.write.mode("overwrite").parquet(f'{minted_tables_output_path}{enriched_text_tbl_name}')

    df_output_sql_command = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{enriched_text_tbl_name}') CREATE EXTERNAL TABLE [{enriched_text_tbl_name}] ([file_path] nvarchar(4000), [named_entities] varchar(MAX), [pii] varchar(MAX), [key_phrases] varchar(MAX), [text_analysis_error] varchar(MAX), [pii_redacted_text] varchar(MAX)) WITH (LOCATION = 'minted_tables/{enriched_text_tbl_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    
    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(df_output_sql_command)

# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'enriched_text_tbl_name': enriched_text_tbl_name,
    'file_system': file_system,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])