# Text Final

In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8
   }
}

In [ ]:
summarized_text_xsum_tbl_name = ''
summarized_text_dailymail_tbl_name = ''
documents_contents_tbl_name = ''
batch_root = ''
batch_num = ''
batch_description = ''
input_container=''
output_container=''
blob_account_name = ''
rule_set_eval_timeout = 180
azure_storage_domain = ''
minted_tables_output_path = ''

In [ ]:
# Column names 
file_path_col = "file_path"
summarized_text_xsum_col = "summarized_text_xsum"
summarized_text_dailymail_col = "summarized_text_dailymail"

summarization_xsum_error_col = "summarization_xsum_error"
summarization_dailymail_error_col = "summarization_dailymail_error"
errors_col = "errors"

key_col = file_path_col

error_cols = [
    summarization_xsum_error_col,
    summarization_dailymail_error_col
]

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer
import datetime

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'documents_contents_tbl_name': documents_contents_tbl_name,
    'summarized_text_xsum_tbl_name': summarized_text_xsum_tbl_name,
    'summarized_text_dailymail_tbl_name': summarized_text_dailymail_tbl_name,
    'batch_description': batch_description,
    'batch_root': batch_root,
    'batch_num': batch_num,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
import json
import os
import csv
from types import SimpleNamespace

from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql.functions import col
import pyspark.sql.functions as F

# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

config = json.loads(''.join(sc.textFile(f'{batch_root}/config.json').collect()), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

# Set log level
if config.log_level == "INFO":
    logger.setLevel(logging.INFO)
else:
    logger.setLevel(logging.ERROR)
    config.log_level = "ERROR"


In [ ]:
import pyodbc
from pyspark.sql.functions import current_timestamp
# Dedicated and serverless SQL config
dedicated_database = 'dedicated'
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'

# secrets
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

In [ ]:
with tracer.span(name=f'Read and join all dataframes'):
    
    # Load Dataframes
    document_contents_df = spark.read.parquet(f'{minted_tables_output_path}{documents_contents_tbl_name}')
    xsum_df = spark.read.parquet(f'{minted_tables_output_path}{summarized_text_xsum_tbl_name}')
    dailymail_df = spark.read.parquet(f'{minted_tables_output_path}{summarized_text_dailymail_tbl_name}')

    # Join Dataframes
    docs_df = ( document_contents_df
        .join(xsum_df, file_path_col, 'left_outer')
        .join(dailymail_df, file_path_col, 'left_outer')
        .select(
            document_contents_df['file_path'],
            document_contents_df['file_name'],
            document_contents_df['file_type'],
            document_contents_df['text_content'],
            document_contents_df['original_lang'],
            document_contents_df['text_content_target_lang'],
            xsum_df[summarized_text_xsum_col],
            xsum_df[summarization_xsum_error_col],
            dailymail_df[summarized_text_dailymail_col],
            dailymail_df[summarization_dailymail_error_col]
        )
        .withColumn('batch_num', F.lit(batch_num))
    )
# Re-group errors to simplify processing on the UI
docs_df = docs_df.withColumn(errors_col, F.array(*[ 
    F.struct(F.when((docs_df[column].isNull() | (docs_df[column] == "")), F.lit("")).otherwise(docs_df[column]).alias("message"),
    F.lit(column.replace("_error", "")).alias("stage")) 
    for column in error_cols]))
docs_df = docs_df.drop(*error_cols)

with tracer.span(name='Persist processed text as SQL tables'):
    processed_summarised_text_tbl_name = f'{batch_num}_summarised_processed_text'
    docs_df.write.mode("overwrite").parquet(f'{minted_tables_output_path}{processed_summarised_text_tbl_name}')
    sql_command = f'''
        IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{processed_summarised_text_tbl_name}') 
        CREATE EXTERNAL TABLE [{processed_summarised_text_tbl_name}] (
            [file_path] nvarchar(4000), 
            [file_name] nvarchar(4000), 
            [file_type] nvarchar(4000), 
            [text_content] nvarchar(max),
            [original_lang] nvarchar(4000),
            [text_content_target_lang] nvarchar(max),
            [summarized_text_xsum] nvarchar(max),
            [summarized_text_dailymail] nvarchar(max),
            [batch_num] nvarchar(4000),
            [errors] varchar(max)
        )
        WITH (
            LOCATION = 'minted_tables/{processed_summarised_text_tbl_name}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
        )
    '''
    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
        with conn.cursor() as cursor:
            cursor.execute(sql_command)


with tracer.span(name='Persist processed text as json'):
    output_path = f'abfss://{output_container}@{blob_account_name}.dfs.{azure_storage_domain}/{batch_num}'

    docs_df = docs_df \
        .withColumn("json", F.to_json(F.struct(col("*"))))

    out_lst = docs_df.collect()

    for row in out_lst:
        p = f'{output_path}/text_summarisation_processing_json/{row.file_name}.output.json'
        mssparkutils.fs.put(p, row.json, overwrite=True)


In [ ]:
from time import sleep
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ?
        ORDER BY [num_stages_complete] DESC;
    """
    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/10] {status_text}'
            x = datetime.datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?, update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

update_status_table('Text Summarization Complete', minted_tables_output_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd)

In [ ]:
# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'processed_text_tbl_name': processed_summarised_text_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])