# 2. Text Prep



In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8,
     "spark.jars.packages": "com.microsoft.azure:synapseml_2.12:0.10.0-19-c3a445c5-SNAPSHOT",
     "spark.jars.repositories": "https://mmlspark.azureedge.net/maven",
     "spark.jars.excludes": "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalactic:scalactic_2.12,org.scalatest:scalatest_2.12,com.fasterxml.jackson.core:jackson-databind",
     "spark.yarn.user.classpath.first": "true"
   }
}


In [ ]:
documents_tbl_name = ""
batch_root = ""
batch_num = ""
file_system = ""
document_cracking_timeout = 180
blob_account_name = ""
minted_tables_output_path = ""

In [ ]:
import pyodbc
from pyspark.sql.functions import current_timestamp
# Dedicated and serverless SQL config
dedicated_database = "dedicated"
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

In [ ]:
# Column names 
file_path_col = "file_path"
file_name_col = "file_name"
file_type_col = "file_type"
text_content_col = "text_content"
entropy_error_col = "entropy_error"
original_lang_col = "original_lang"
original_lang_translate_col = "original_lang_translate"
original_lang_prob_col = "original_lang_prob"
text_content_target_lang_col = f"text_content_target_lang"
extraction_error_col = "extraction_error"
language_detection_error_col = "language_detection_error"
translation_error_col = "translation_error"

key_col = file_path_col

output_cols = [ 
    file_name_col, 
    file_type_col,
    text_content_col,
    original_lang_col, 
    text_content_target_lang_col
]

error_cols = [
    extraction_error_col,
    entropy_error_col,
    language_detection_error_col,
    translation_error_col
]

# NOTE: these columns are dropped after the processing is finished
text_content_truncated_col = "text_content_truncated"
text_content_split_col = "text_content_split"
language_detector_key_col = "language_detector_key"
translation_key_col = "translation_key"

# Load secrets
translation_keys = mssparkutils.credentials.getSecretWithLS("keyvault", 'TranslationKeys').split(',')
language_detector_keys = mssparkutils.credentials.getSecretWithLS("keyvault", 'TextAnalyticsKeys').split(',')
instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

# Parameters for truncating and batching
truncated_text_len = 1000
# For Language Detection, we can send up to 1000 documents per request
# see https://docs.microsoft.com/en-us/azure/cognitive-services/language-service/language-detection/how-to/call-api#data-limits
detect_batch_size = 10

# For Translation, we can send up to 10000 characters per request
# see https://docs.microsoft.com/en-us/azure/cognitive-services/translator/request-limits
translate_document_min_length = 9900
translate_document_max_length = 10000

# Useful for debugging
display_dataframes = False

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer
import datetime

config_integration.trace_integrations(['logging'])

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
  'documents_tbl_name': documents_tbl_name,
  'batch_root': batch_root,
  'batch_num': batch_num,
  'file_system': file_system,
  'notebook_name': mssparkutils.runtime.context['notebookname']
} }

logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
from time import sleep
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ?
        ORDER BY [num_stages_complete] DESC;
    """

    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/10] {status_text}'
            x = datetime.datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?, update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

update_status_table('Text Prep Started', minted_tables_output_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd)

In [ ]:
import json
import random
from collections import defaultdict
import math
import re
from types import SimpleNamespace
from typing import List

import pyspark.sql.functions as F
from pyspark.sql.functions import col
from pyspark.sql.types import StringType, StructType, StructField
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.window import Window

from synapse.ml.stages import FlattenBatch, FixedMiniBatchTransformer, DynamicMiniBatchTransformer
from synapse.ml.cognitive import Translate, LanguageDetector
from synapse.ml.featurize.text import PageSplitter

def read_batch_config(batch_root: str):
    """
    We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
    all back together again
    """
    # Change our file system from 'synapse' to 'input'
    sc._jsc.hadoopConfiguration().set("fs.defaultFS", file_system)

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

    # If we don't have a batch config, copy the global one.
    if fs.exists(config_path) != True:
        logger.error(f'{config_path} not found.')

    # Open our file directly rather than through spark
    input_stream = fs.open(config_path)  # FSDataInputStream

    config_string = sc._jvm.java.io.BufferedReader(
        sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
        ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

    # Load it into json    
    return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

with tracer.span(name='Initialise Spark session'):
    sc = spark.sparkContext
    spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

with tracer.span(name=f"Load config: {mssparkutils.runtime.context['notebookname']}"):
    try:
        config = read_batch_config(batch_root)
    except Exception as e:
        logger.exception(e)
        raise e

    # Set log level
    if config.log_level == "INFO":
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)
        config.log_level = "ERROR"

In [ ]:
with tracer.span(name='Extract document contents'):
    # Extract the text from complex document types
    documents_cracked_view_name = f"{batch_num}_cracked_view"

    mssparkutils.notebook.run("2_Text_Extraction", document_cracking_timeout, {
        "file_system": file_system,
        "documents_tbl_name": f'{documents_tbl_name}',
        "documents_cracked_view_name": documents_cracked_view_name,
        "minted_tables_output_path": f'{minted_tables_output_path}'
    })

    documents_cracked_tbl_name = f"{batch_num}_documents_cracked"
    df = spark.sql(f"""
        SELECT {file_name_col}, {file_type_col}, {file_path_col}, {text_content_col}, {extraction_error_col} 
        FROM {documents_cracked_view_name}
    """)
    
    # Repartition df by file_path to optimise upcoming joins
    df = df.repartition(file_path_col)

    if display_dataframes:
        df.show()

## Attempting JSON parsing

In [ ]:
import json

def walk(current_node):
    ''' Traverse a parsed JSON, concatenating string values with newlines and
        dropping key names + non-string leaf node values '''
    my_strings = []
    if type(current_node) == list:
        for item in current_node:
            my_strings.append(walk(item))
    elif type(current_node) == dict:
        for item in current_node.values():
            my_strings.append(walk(item))
    elif type(current_node) == str:
        my_strings.append(current_node)
    else:
        return('')
    return('\n'.join([i for i in my_strings if len(i) > 0]))

def attempt_json_handling(input_string):
    ''' If a JSON file, traverse tree, concatenating string values with newlines '''
    try:
        # This will fail for non-json docs
        output_string = walk(json.loads(input_string))
        return(output_string)
    except:
        return(input_string)

udf_attempt_json_handling = F.udf(lambda z: attempt_json_handling(z), StringType())
df = df.withColumn(text_content_col, udf_attempt_json_handling(col(text_content_col)))

## Detecting text with unusually high or low entropy

Some languages have especially large character sets (Chinese, Japanese). We will explicitly prevent these documents from being filtered out later in the notebook.

In [ ]:
# Don't attempt to filter on entropy if the document is very short.
# (Short documents tend to have a larger range of entropy values.)
entropy_min_char_length = 100

# These values were arrived at by examining the perf2 and Abbottabad datasets.
entropy_min = 2.7
entropy_max = 3.5

# Simplify multiple contiguous whitespace characters so they don't skew the
# calculated entropy too much.
whitespace_simplifier=re.compile(r"\s+")

# These are languages where we know we need to tolerate high character entropy.
# This list is probably not exhaustive. We will use this set at the end of the
# notebook to retain documents where these languages were detected, even if
# they threw an entropy error.
high_entropy_languages = {
    'yue', # Cantonese (traditional)
    'lzh', # Chinese (literary)			
    'zh-Hans', # Chinese Simplified
    'zh-Hant', # Chinese Traditional
    'ja' # Japanese
}

def check_for_entropy_issues(input_string):
    if type(input_string) != str:
        return('Could not calculate entropy for text content of type {}'.format(
            type(input_string)))
    input_string = whitespace_simplifier.sub(' ', input_string)
    char_counts = defaultdict(lambda: 0)
    n_chars = len(input_string)

    if n_chars < entropy_min_char_length:
        return(None)
    
    # Calculate entropy
    for character in input_string:
        char_counts[character] += 1

    entropy = 0.0
    for val in char_counts.values():
        p = val / n_chars
        entropy -= p * math.log(p)

    # Return an error message about extreme entropy, if appropriate
    if (entropy < entropy_min):
        return('Char-level entropy too low ({:0.6f} < {:0.6f}) minimum'.format(
            entropy,
            entropy_min
        ))
    elif (entropy > entropy_max):
        return('Char-level entropy too high ({:0.6f} > {:0.6f}) maximum'.format(
            entropy,
            entropy_max
        ))
    else:
        return(None)

udf_entropy = F.udf(lambda z: check_for_entropy_issues(z), StringType())
df = df.withColumn(entropy_error_col, udf_entropy(col(text_content_col)))

## Detecting language in the source text, producing a column `original_lang`

In [ ]:
with tracer.span(name='Set Batch Size'):
    # For performance, we assume that the whole text will be in the same language as the first truncated_text_lencharacters.
    # This allows us to truncate the text and then take advantage of the batch API to reduce number of calls to Cognitive Services.

    # Truncate text
    df_detect = df.withColumn(text_content_truncated_col, F.substring(col(text_content_col), 0, truncated_text_len))

    # Group rows into batches
    fmbt = (FixedMiniBatchTransformer()
          .setBatchSize(detect_batch_size))

    df_detect = fmbt.transform(df_detect)

    if display_dataframes:
        df_detect.show()

In [ ]:

def rand_key(keys: List[str]) :
    return random.sample(keys, 1)[0]

with tracer.span(name='Distribute translation keys across rows'):
    udf_language_detector_key = F.udf(lambda: rand_key(language_detector_keys), StringType())
    df_detect = df_detect.withColumn(language_detector_key_col, udf_language_detector_key())

In [ ]:
with tracer.span(name='Detect Language'):
    # Language detection
    # NOTE: Language detection is also performed as part of Translate API,
    # however, we don't run Translate on text in target language, as it is quite expensive and about 3x time slower than Detect.
    detect = (
        LanguageDetector()
        .setSubscriptionKeyCol(language_detector_key_col)
        .setLocation(config.location)
        .setTextCol(text_content_truncated_col)
        .setOutputCol(original_lang_col)
        .setErrorCol(language_detection_error_col)
        <<SYNAPSE_ML_LANG_DETECT_ENDPOINT_CMD>>
    )

    df_detect_results_batched = detect.transform(df_detect)

    if display_dataframes:
        df_detect_results_batched.show()

In [ ]:
error_response_schema = StructType(
    [StructField("error", StructType(
        [StructField("code", StringType()), StructField("message", StringType())]
    ))]
)

with tracer.span(name='Join Detect results back with source dataframe'):
    # Flatten the columns to separate individual files again
    flattener = FlattenBatch()
    df_detect_results = flattener.transform(df_detect_results_batched)
    # Unwrap the result column
    df_detect_results = df_detect_results\
        .withColumn(original_lang_col, df_detect_results[original_lang_col].detectedLanguage.iso6391Name)\
        .withColumn(language_detection_error_col, F.from_json(
            df_detect_results[language_detection_error_col]["response"], error_response_schema)["error"]["message"])

    if display_dataframes:
        df_detect_results.show()

## Translate any text that isn't already in the target language

In [ ]:
import uuid
import requests
import pandas as pd
import time

class TranslationException(Exception):
    pass

class TextTranslation:

    def __init__(
        self,
        translation_location,
        translation_keys,
        translation_endpoint,
        target_lang='en',
        original_lang_column='original_lang',
        input_text_column='text_content',
        output_column='text_content_target_lang',
        error_column='translation_error',
        min_page_len=49000,
        max_page_len=50000,
        boundary_regex='\\s',

    ):
        self.translation_endpoint = translation_endpoint
        self.translation_url = f'{self.translation_endpoint}translate'
        self.translation_location = translation_location
        self.translation_keys = translation_keys
        self.target_lang = target_lang
        self.original_lang_column = original_lang_column
        self.input_text_column = input_text_column
        self.output_column = output_column
        self.error_column = error_column
        self.min_page_len = min_page_len
        self.max_page_len = max_page_len
        self.boundary_regex = boundary_regex
    
    def transform_df(self, df):
        return pd.DataFrame.from_records([self.transform(row) for _idx, row in df.iterrows()])
        

    def transform(self, row):
        return self.translate_row(row)

    def __call__(self, df_iter):
        for df in df_iter:
            yield self.transform_df(df)
        
    def page_splitter(self, input_text):
        if len(input_text) < self.max_page_len:
            return [input_text]

        page_start_idx = 0
        left_idx = self.min_page_len
        right_idx = self.max_page_len
        pages = []
        while right_idx < len(input_text):
            # look between min_len and max_len for a boundary
            string_slice = input_text[left_idx:right_idx]
            boundary_indices = [match.end() for match in re.finditer(self.boundary_regex, string_slice)]
            # if we don't find one we take the max_len of string
            if len(boundary_indices) == 0:
                pages.append(input_text[page_start_idx:right_idx])
                page_start_idx = right_idx
            else:
                boundary_idx = boundary_indices[-1] 
                page_end_idx = left_idx + boundary_idx
                pages.append(input_text[page_start_idx:page_end_idx])
                page_start_idx = page_end_idx

            left_idx = page_start_idx + self.min_page_len
            right_idx = page_start_idx + self.max_page_len

        # handle the end of the text
        if page_start_idx < len(input_text): 
            pages.append(input_text[page_start_idx:])

        return pages
    
    @property
    def translation_key(self):
        return random.sample(self.translation_keys, 1)[0]        

    def translate_page(self, page):
        params = {
            'api-version': '3.0', 
            'to': [self.target_lang]
        }

        headers = {
            'Ocp-Apim-Subscription-Key': self.translation_key, 
            'Ocp-Apim-Subscription-Region': self.translation_location, 
            'Content-type': 'application/json', 
            'X-ClientTraceId': str(uuid.uuid4())
        }

        body = [{'text': page}]
        wait_time = 3
        retries = 0
        translation_exception_str = ''
        while retries < 10:        
            req = requests.post(self.translation_url, params=params, headers=headers, json=body)
            req_data = req.json()
            if req.ok:
                detected_language = req_data[0]['detectedLanguage']
                page_original_lang = detected_language['language']
                page_original_lang_score = detected_language['score']
                translated_page = req_data[0]['translations'][0]['text']
                return page_original_lang, page_original_lang_score, translated_page
            # if req.status_code == 429:
            #     time.sleep(3)

            # else:
            translation_exception_str = json.dumps(req_data)
            time.sleep(wait_time)
            wait_time += 5
            retries += 1

        raise TranslationException(translation_exception_str)

    def translate_text(self, input_text):
        pages = self.page_splitter(input_text)
        original_lang = ''
        original_lang_score = 0 
        translated_text = ''
        for page in pages:
            page_original_lang, page_original_lang_score, translated_page = self.translate_page(page)
            translated_text = translated_text + translated_page
            if page_original_lang_score > original_lang_score:
                original_lang = page_original_lang
        return original_lang, translated_text        


    def translate_row(self, row):
        input_text = row[self.input_text_column]
        row[self.output_column] = None
        row[self.error_column] = None
        
        if input_text is None:
            return row

        if not row[self.original_lang_column] == self.target_lang:
            try:
                original_lang, text_content_target_lang = self.translate_text(input_text)
                row[self.output_column] = text_content_target_lang
                row[self.original_lang_column] = original_lang
            except TranslationException as e:
                row[self.error_column] = str(e)
        else:
            row[self.output_column] = input_text


        return row 


output_schema = StructType([
    StructField(file_name_col, StringType()),
    StructField(file_type_col, StringType()),
    StructField(file_path_col, StringType()),
    StructField(text_content_col, StringType()),
    StructField(extraction_error_col, StringType()),
    StructField(entropy_error_col, StringType()),
    StructField(text_content_truncated_col, StringType()),
    StructField(language_detector_key_col, StringType()),
    StructField(language_detection_error_col, StringType()),
    StructField(original_lang_col, StringType()),
    StructField(text_content_target_lang_col, StringType()),
    StructField(translation_error_col, StringType())
])

text_translator = TextTranslation(
    config.location,
    translation_keys,
    config.translation_endpoint,
    target_lang=config.prep.target_language,
    original_lang_column=original_lang_col,
    input_text_column=text_content_col,
    output_column=text_content_target_lang_col,
    error_column=translation_error_col,
    )

df_translate = df_detect_results.mapInPandas(text_translator, output_schema)


## Persist resulting dataframe and error dataframes as Synapse tables

In [ ]:
df_results = df_translate
df_results.cache() # cache as we will split out success/error tables from here

with tracer.span(name='Persist document contents as table'):
    # Remove null rows as downstream notebooks won't be able to work with them.
    # Also drop rows with entropy errors since they would appear as nonsense
    # entries in document clusters, etc. -- unless the detected original language
    # is one where we know to expect high entropy levels.
    df_output = df_results\
                    .where(
                        (df_results[entropy_error_col].isNull()) |
                        (df_results[original_lang_col].isin(high_entropy_languages))
                    ).select(file_path_col, *output_cols)\
                    .na.drop(subset=[text_content_target_lang_col])
    df_error = df_results\
                    .where(
                        (df_results[text_content_target_lang_col] == "") |
                        (df_results[text_content_target_lang_col].isNull()) |
                        ((df_results[entropy_error_col].isNotNull()) &
                        (~df_results[original_lang_col].isin(high_entropy_languages)))
                        )\
                    .select(file_path_col, *error_cols)

    documents_contents_tbl_name = f"{batch_num}_documents_contents"
    df_output.write.mode("overwrite").parquet(f'{minted_tables_output_path}{documents_contents_tbl_name}')
    df_output.printSchema()

    text_prep_errors_tbl_name = f"{batch_num}_text_prep_errors"
    df_error.write.mode("overwrite").parquet(f'{minted_tables_output_path}{text_prep_errors_tbl_name}')
    df_error.printSchema()

    # create remote sql tables over the parquet file
    df_output_sql_command = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{documents_contents_tbl_name}') CREATE EXTERNAL TABLE [{documents_contents_tbl_name}] ([file_name] nvarchar(4000), [file_type] nvarchar(4000), [file_path] nvarchar(4000)) WITH (LOCATION = 'minted_tables/{documents_contents_tbl_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    df_error = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{text_prep_errors_tbl_name}') CREATE EXTERNAL TABLE [{text_prep_errors_tbl_name}] ([file_name] nvarchar(4000), [extraction_error] nvarchar(4000), [entropy_error] nvarchar(4000), [language_detection_error] nvarchar(4000), [translation_error] nvarchar(4000)) WITH (LOCATION = 'minted_tables/{text_prep_errors_tbl_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"

    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(df_output_sql_command)
        cursor.execute(df_error)

output = {'custom_dimensions': {
    'batch_num': batch_num,
    'documents_contents_tbl_name': documents_contents_tbl_name,
    'text_prep_errors_tbl_name': text_prep_errors_tbl_name,
    'file_system': file_system,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])