# Text Summarization

In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 4
   }
}

In [ ]:
documents_contents_tbl_name = ''
batch_root = ''
batch_num = ''
model_name = ''  # This is the text summarisation model we will use: currently either "google/pegasus-xsum" || "google/pegasus-cnn_dailymail"
output_col_name = ''
error_col_name = ''
output_tbl_name = ''
file_system = ''
blob_account_name = ''
azure_storage_domain = ''
minted_tables_output_path = ''

In [ ]:
import pyodbc
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
display_dataframes = False

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'documents_contents_tbl_name': documents_contents_tbl_name,
    'batch_root': batch_root,
    'batch_num': batch_num,
    'model_name': model_name,
    'output_col_name': output_col_name,
    'error_col_name': error_col_name,
    'output_tbl_name': output_tbl_name,
    'file_system': file_system,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
import json
import os
from types import SimpleNamespace

from pyspark import SparkContext
from pyspark.sql import SparkSession
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, StructField, StructType

# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

def read_batch_config(batch_root: str):
    """
    We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
    all back together again
    """
    # Change our file system from 'synapse' to 'input'
    sc._jsc.hadoopConfiguration().set("fs.defaultFS", file_system)

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

    # If we don't have a batch config, copy the global one.
    if fs.exists(config_path) != True:
        logger.error(f'{config_path} not found.')

    # Open our file directly rather than through spark
    input_stream = fs.open(config_path)  # FSDataInputStream

    config_string = sc._jvm.java.io.BufferedReader(
        sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
        ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

    # Load it into json    
    return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

with tracer.span(name=f"Load config: {mssparkutils.runtime.context['notebookname']}"):
    try:
        config = read_batch_config(batch_root)
    except Exception as e:
        logger.exception(e)
        raise e

    # Set log level
    if config.log_level == "INFO":
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)
        config.log_level = "ERROR"

In [ ]:
max_length = int(config.summarization.max_length) # Max length for summary
num_beams = int(config.summarization.num_beams) # Number of beams to use for beam search - read from global config and passed in
skip_special_tokens = bool(config.summarization.skip_special_tokens) # HF Skip special tokens - read from global config and passed in
clean_up_tokenization_spaces = bool(config.summarization.clean_up_tokenization_spaces)  # Use HF to clean tokenization - read from global config and passed in

summarisation_config = {'custom_dimensions': {
    'batch_num': batch_num,
    'max_length': max_length,
    'num_beams': num_beams,
    'skip_special_tokens': skip_special_tokens,
    'clean_up_tokenization_spaces': clean_up_tokenization_spaces,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

logger.info(f"{mssparkutils.runtime.context['notebookname']}: RUN_CONFIG", extra=summarisation_config)

class Models:

    def __init__(self, summarizer_model, tokenizer=None, model=None):
        self.tokenizer = tokenizer
        self.model = model

    def load_summarisation_model(self, summarizer_model):
        self.tokenizer = AutoTokenizer.from_pretrained(summarizer_model) 
        self.model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_model)


def summarize(text) -> (str, str):
    try:
        # Creates a summary from the input text
        token_inputs = summarizer_model.tokenizer([text], max_length=max_length, return_tensors='pt', truncation=True)
        summary_ids = summarizer_model.model.generate(token_inputs['input_ids'], num_beams=num_beams,
                                                early_stopping=True) # Not using early stopping
        summary = [summarizer_model.tokenizer.decode(g, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, truncation=True)
                for g in summary_ids]
        return summary, ""
    except Exception as e:
        logger.exception(e)
        return "", str(e)


In [ ]:
with tracer.span(name='Read the documents contents table'):
    df = spark.read.parquet(minted_tables_output_path + documents_contents_tbl_name).select('file_path','text_content_target_lang').repartition(320)

with tracer.span(name=f'Mount {model_name} summarisation model'):
    try:
        # Read the first character of the Model config.json to see if it's there. Otherwise download
        mssparkutils.fs.head(f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/models/{model_name}/config.json', 1)
        mount_point = f'/mnt'
        jobId = mssparkutils.env.getJobId()
        linkedStorageName = f'{mssparkutils.env.getWorkspaceName()}-WorkspaceDefaultStorage'

        # mssparkutils.fs.unmount(mount_point)
        mssparkutils.fs.mount( 
            f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/', 
            mount_point, 
            {'linkedService':f'{linkedStorageName}'} 
        )

        # Please note the differences with the synfs protocol
        # https://docs.microsoft.com/en-us/azure/synapse-analytics/spark/synapse-file-mount-api#how-to-access-files-under-mount-point-via-local-file-system-api
        model_location = f'/synfs/{jobId}{mount_point}/models/{model_name}/'
        logger.info(f'Using {model_name} model from {model_location}.')
        model_name = model_location
    except Exception as e:
        logger.info(f'Using {model_name} model from HuggingFace.')

with tracer.span(name=f'Download and instantiate model {model_name}'):
    summarizer_model = Models(summarizer_model=None)
    summarizer_model.load_summarisation_model(model_name)

with tracer.span(name=f'Run summarisation {model_name}'):
    udf_summarize = udf(summarize, StructType([StructField("summary_text", StringType()), StructField("error", StringType())]))
    summarized_result = udf_summarize(df.text_content_target_lang)

    df = (df
        .withColumn(output_col_name, summarized_result.summary_text)
        .withColumn(error_col_name, summarized_result.error)
        .drop(df.text_content_target_lang)
    )

if display_dataframes:
    df.show()

In [ ]:
with tracer.span(name=f'Persist summarisation as table'):
    summarized_text_tbl_name = f'{batch_num}_{output_tbl_name}'
    
    df.write.mode("overwrite").parquet(f'{minted_tables_output_path}{summarized_text_tbl_name}')

    df_sql_command = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{summarized_text_tbl_name}') 
    CREATE EXTERNAL TABLE [{summarized_text_tbl_name}] 
    (
        [file_path] nvarchar(4000), 
        [{output_col_name}] nvarchar(4000), 
        [{error_col_name}] nvarchar(4000)
    ) WITH (
            LOCATION = 'minted_tables/{summarized_text_tbl_name}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
            )"""

    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
        with conn.cursor() as cursor:
            cursor.execute(df_sql_command)

# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'summarized_text_tbl_name': summarized_text_tbl_name,
    'file_system': file_system,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])