# Text Final

In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8
   }
}

In [ ]:
enriched_text_tbl_name = ''
summarized_text_xsum_tbl_name = ''
summarized_text_dailymail_tbl_name = ''
clustered_text_tbl_name = ''
clustered_multimodal_tbl_name = ''
documents_contents_tbl_name = ''
text_prep_errors_tbl_name = ''
batch_root = ''
batch_num = ''
batch_description = ''
input_container=''
output_container=''
blob_account_name = ''
azure_storage_domain = ''
minted_tables_output_path = ''
manifest_file_name = ''

In [ ]:
import pyodbc
import pandas as pd
from pyspark.sql.functions import current_timestamp
# Dedicated and serverless SQL config
dedicated_database = "dedicated"
database = 'minted'   
driver = '{ODBC Driver 17 for SQL Server}'

# Secrets
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'documents_contents_tbl_name': documents_contents_tbl_name,
    'enriched_text_tbl_name': enriched_text_tbl_name,
    'clustered_text_tbl_name': clustered_text_tbl_name,
    'batch_description': batch_description,
    'batch_root': batch_root,
    'batch_num': batch_num,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
# Column names 
file_path_col = "file_path"
file_name_col = "file_name"
file_type_col = "file_type"
text_content_col = "text_content"
original_lang_col = "original_lang"
text_content_target_lang_col = f"text_content_target_lang"
extraction_error_col = "extraction_error"
language_detection_error_col = "language_detection_error"
translation_error_col = "translation_error"

key_phrases_col = "key_phrases"
named_entities_col = "named_entities"
pii_col = "pii"
pii_redacted_text_col = "pii_redacted_text"
text_analysis_error_col = "text_analysis_error"
summarized_text_xsum_col = "summarized_text_xsum"
summarization_xsum_error_col = "summarization_xsum_error"
summarized_text_dailymail_col = "summarized_text_dailymail"
summarization_dailymail_error_col = "summarization_dailymail_error"
processed_text_col = "processed_text"
cluster_col = "cluster"
x_col = "X"
y_col = "Y"
errors_col = "errors"

# For multimodal clustering results
summarized_text_xsum_col_1 = "summarized_text_xsum_1"
topic_name_col = "topic_name"
topic_name_col_1 = "topic_name_1"
cluster_col_1 = "cluster_1"
x_col_1 = "X_1"
y_col_1 = "Y_1"

key_col = file_path_col

clustering_cols = [file_path_col, processed_text_col, cluster_col, x_col, y_col, original_lang_col]

# For multimodal clustering results
multimodal_clustering_cols = [file_path_col, summarized_text_xsum_col_1, topic_name_col_1, cluster_col_1, x_col_1, y_col_1]

output_cols = [ 
    file_name_col, 
    file_type_col,
    text_content_col,
    original_lang_col, 
    text_content_target_lang_col
]

error_cols = [
    extraction_error_col,
    language_detection_error_col,
    translation_error_col,
    text_analysis_error_col
]

In [ ]:
import json
import os
import csv
from types import SimpleNamespace

from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql.functions import col
import pyspark.sql.functions as F

# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

config = json.loads(''.join(sc.textFile(f'{batch_root}/config.json').collect()), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

# Set log level
if config.log_level == "INFO":
    logger.setLevel(logging.INFO)
else:
    logger.setLevel(logging.ERROR)
    config.log_level = "ERROR"


In [ ]:
with tracer.span(name=f'Read the dataframe from the given table {documents_contents_tbl_name}'):

    documents_contents_df = spark.read.parquet(f'{minted_tables_output_path}{documents_contents_tbl_name}')
    enriched_text_df = spark.read.parquet(f'{minted_tables_output_path}{enriched_text_tbl_name}')
    clustered_text_df = spark.read.parquet(f'{minted_tables_output_path}{clustered_text_tbl_name}')
    text_prep_errors_df = spark.read.parquet(f'{minted_tables_output_path}{text_prep_errors_tbl_name}')
    
    # For multimodal clustering results
    clustered_multimodal_df = (spark.read.parquet(f'{minted_tables_output_path}{clustered_multimodal_tbl_name}')
        .withColumnRenamed(summarized_text_xsum_col, summarized_text_xsum_col_1)
        .withColumnRenamed(topic_name_col, topic_name_col_1)
        .withColumnRenamed(cluster_col, cluster_col_1)
        .withColumnRenamed(x_col, x_col_1)
        .withColumnRenamed(y_col, y_col_1)
    )
    clustered_multimodal_df = (clustered_multimodal_df
        .where(clustered_multimodal_df.type_of_file == 'text')
    )	
    clustered_multimodal_df = (clustered_multimodal_df
        .drop(clustered_multimodal_df.type_of_file)
    )

    docs_df = (documents_contents_df
        .join(enriched_text_df, [documents_contents_df[file_path_col] == enriched_text_df[file_path_col]],'left_outer')
        .join(clustered_text_df, [documents_contents_df[file_path_col] == clustered_text_df[file_path_col]],'left_outer')
        .join(text_prep_errors_df, [documents_contents_df[file_path_col] == text_prep_errors_df[file_path_col]],'left_outer')
        .join(clustered_multimodal_df, [documents_contents_df[file_path_col] == clustered_multimodal_df[file_path_col]],'left_outer') # For multimodal clustering results
        .select(
            documents_contents_df.file_path,
            documents_contents_df.file_name,
            documents_contents_df.file_type,
            documents_contents_df.text_content,
            documents_contents_df.original_lang,
            documents_contents_df.text_content_target_lang,
            enriched_text_df[key_phrases_col],
            enriched_text_df[named_entities_col],
            enriched_text_df[pii_col],
            enriched_text_df[pii_redacted_text_col],
            enriched_text_df[text_analysis_error_col],
            clustered_text_df[processed_text_col],
            clustered_text_df[cluster_col],
            clustered_text_df[x_col],
            clustered_text_df[y_col],
            text_prep_errors_df[extraction_error_col],
            text_prep_errors_df[language_detection_error_col],
            text_prep_errors_df[translation_error_col],
            clustered_multimodal_df[summarized_text_xsum_col_1], # For multimodal clustering results
            clustered_multimodal_df[topic_name_col_1], # For multimodal clustering results
            clustered_multimodal_df[cluster_col_1], # For multimodal clustering results
            clustered_multimodal_df[x_col_1], # For multimodal clustering results
            clustered_multimodal_df[y_col_1] # For multimodal clustering results
        )
        .withColumn("batch_num", F.lit(batch_num))
    )
    # Re-group errors to simplify processing on the UI
    docs_df = docs_df.withColumn(errors_col, F.array(*[ 
        F.struct(F.when((docs_df[column].isNull() | (docs_df[column] == "")), F.lit("")).otherwise(docs_df[column]).alias("message"),
        F.lit(column.replace("_error", "")).alias("stage")) 
        for column in error_cols]))
    docs_df = docs_df.drop(*error_cols)

with tracer.span(name='Persist processed text as SQL table'):

    # The batch_num_document_enrichment_errors dataframe is used as a simplified query in the powerBI query
    batch_num_document_enrichment_errors = docs_df.select("errors.stage", "errors.message")
    batch_num_document_enrichment_errors = batch_num_document_enrichment_errors.selectExpr('inline(arrays_zip(stage,message))') \
        .groupBy("stage", "message").count() \
        .withColumnRenamed("stage", "errors.stage") \
        .withColumnRenamed("message", "errors.message") \
        .withColumnRenamed("count", "Count")
    batch_num_document_enrichment_errors_tbl_name = f'{batch_num}_document_enrichment_errors'
    batch_num_document_enrichment_errors.write.mode("overwrite").parquet(f'{minted_tables_output_path}{batch_num_document_enrichment_errors_tbl_name}')
    
    # Create the external table for batch_num_document_enrichment_errors to be used in powerBI
    sql_command_document_enrichment_errors = f'''
        IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{batch_num_document_enrichment_errors_tbl_name}') 
        CREATE EXTERNAL TABLE [{batch_num_document_enrichment_errors_tbl_name}] (
            [errors.stage] nvarchar(1000), 
            [errors.message] nvarchar(1000),
            [Count] bigint, 
        )
        WITH (
            LOCATION = 'minted_tables/{batch_num_document_enrichment_errors_tbl_name}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
        )
    '''
    # Table for batch_num processed_text
    processed_text_tbl_name = f'{batch_num}_processed_text'
    docs_df.write.mode("overwrite").parquet(f'{minted_tables_output_path}{processed_text_tbl_name}')
    
    docs_df_sql_command = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{processed_text_tbl_name}') 
        CREATE EXTERNAL TABLE [{processed_text_tbl_name}] 
        (
            [file_path] nvarchar(4000), 
            [file_name] nvarchar(4000), 
            [file_type] nvarchar(4000), 
            [text_content] nvarchar(4000), 
            [original_lang] nvarchar(4000),
            [text_content_target_lang] nvarchar(4000),
            [key_phrases] varchar(MAX),
            [named_entities] varchar(MAX),
            [pii] varchar(MAX),
            [pii_redacted_text] nvarchar(4000),
            [processed_text] nvarchar(4000),
            [cluster] bigint,
            [X] float,
            [Y] float,
            [summarized_text_xsum_1] nvarchar(4000),
            [topic_name_1] nvarchar(4000),
            [cluster_1] bigint,
            [X_1] float,
            [Y_1] float,
            [batch_num] nvarchar(4000),
            [errors] varchar(MAX)
        ) WITH (
                LOCATION = 'minted_tables/{processed_text_tbl_name}/**', 
                DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
                FILE_FORMAT = [SynapseParquetFormat]
                )"""

    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(docs_df_sql_command)
        cursor.execute(sql_command_document_enrichment_errors)

with tracer.span(name='Persist processed text as json'):
    output_path = f'abfss://{output_container}@{blob_account_name}.dfs.{azure_storage_domain}/{batch_num}'

    docs_df = docs_df \
        .withColumn("json", F.to_json(F.struct(col("*"))))

    out_lst = docs_df.collect()

    for row in out_lst:
        json_path = f'{output_path}/text_processing_json/{row.file_name}.output.json'
        mssparkutils.fs.put(json_path, row.json, overwrite=True)

with tracer.span(name='Persist clustering results as CSV and a table'):
    # clustered_df = docs_df.select(*clustering_cols)
    # clustered_df = clustered_df \
    #     .withColumn("batch_num", F.lit(batch_num))

    # clustered_df_1 = docs_df.select(*clustering_cols)
    clustered_df_1 = docs_df.select(col('file_path'), \
        col('processed_text'), \
        col('cluster'), \
        col('X'), \
        col('Y'), \
		col('original_lang'), \
		col('batch_num'))
	
    # clustered_df_2 = docs_df.select(*multimodal_clustering_cols)
    clustered_df_2 = docs_df.select(col('file_path'), \
        col(summarized_text_xsum_col_1).alias(summarized_text_xsum_col), \
		col(topic_name_col_1).alias(topic_name_col), \
        col(cluster_col_1).alias(cluster_col), \
        col(x_col_1).alias(x_col), \
        col(y_col_1).alias(y_col),
        col('original_lang'), \
		col('batch_num'))

    # Save the clustering results type 1 as a single
    clustered_text_report_tbl_name_1 = f'{batch_num}_clustered_text_report_1'
    clustered_df_1.write.mode("overwrite").parquet(f"{minted_tables_output_path}{clustered_text_report_tbl_name_1}")
    clustered_lst_1 = clustered_df_1.collect()
	
    clustered_df_sql_command_1 = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{clustered_text_report_tbl_name_1}') 
    CREATE EXTERNAL TABLE [{clustered_text_report_tbl_name_1}] 
    (
        [file_path] nvarchar(4000), 
        [processed_text] nvarchar(max), 
        [cluster] bigint,
        [X] float,
        [Y] float,
        [original_lang] nvarchar(4000),
        [batch_num] nvarchar(4000)
    ) WITH (
            LOCATION = 'minted_tables/{clustered_text_report_tbl_name_1}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
            )"""

    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(clustered_df_sql_command_1)

    # Output as CSV
    # Saving to a 'local' file first as we can't save directly to the outut container via open()
    with open('clusters.csv', 'w') as f:
        write = csv.writer(f)
        write.writerow(clustered_df_1.columns)
        write.writerows(clustered_lst_1)

    with open('clusters.csv', 'r') as f:
        contents = f.read() 
        mssparkutils.fs.put(f'{output_path}/text_processing_clustering/clusters.csv', contents, overwrite=True)

    # copy the manifest + config to output
    mssparkutils.fs.cp(f'{batch_root}/{manifest_file_name}', f'{output_path}/manifest.txt')
    mssparkutils.fs.cp(f'{batch_root}/config.json', output_path)
	
	# Save the clustering results type 2 as a single
    clustered_text_report_tbl_name_2 = f'{batch_num}_clustered_text_report_2'
    clustered_df_2.write.mode("overwrite").parquet(f"{minted_tables_output_path}{clustered_text_report_tbl_name_2}")
    clustered_lst_2 = clustered_df_2.collect()
	
    clustered_df_sql_command_2 = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{clustered_text_report_tbl_name_2}') 
    CREATE EXTERNAL TABLE [{clustered_text_report_tbl_name_2}] 
    (
        [file_path] nvarchar(4000), 
        [summarized_text_xsum] nvarchar(max), 
		[topic_name] nvarchar(max), 
        [cluster] bigint,
        [X] float,
        [Y] float,
        [original_lang] nvarchar(4000),
        [batch_num] nvarchar(4000)
    ) WITH (
            LOCATION = 'minted_tables/{clustered_text_report_tbl_name_2}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
            )"""

    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(clustered_df_sql_command_2)

    # Output as CSV
    # Saving to a 'local' file first as we can't save directly to the outut container via open()
    with open('clusters.csv', 'w') as f:
        write = csv.writer(f)
        write.writerow(clustered_df_2.columns)
        write.writerows(clustered_lst_2)

    with open('clusters.csv', 'r') as f:
        contents = f.read() 
        mssparkutils.fs.put(f'{output_path}/text_processing_clustering/clusters.csv', contents, overwrite=True)

    # copy the manifest + config to output
    mssparkutils.fs.cp(f'{batch_root}/{manifest_file_name}', f'{output_path}/manifest.txt')
    mssparkutils.fs.cp(f'{batch_root}/config.json', output_path)

    # # Save the clustering results as a single
    # clustered_text_report_tbl_name = f'{batch_num}_clustered_text_report'
    # clustered_df.write.mode("overwrite").parquet(f"{minted_tables_output_path}{clustered_text_report_tbl_name}")
    # clustered_lst = clustered_df.collect()
    # clustered_df_sql_command = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{clustered_text_report_tbl_name}') 
    # CREATE EXTERNAL TABLE [{clustered_text_report_tbl_name}] 
    # (
    #     [file_path] nvarchar(4000), 
    #     [processed_text] nvarchar(max), 
    #     [cluster] bigint,
    #     [X] float,
    #     [Y] float,
    #     [original_lang] nvarchar(4000),
    #     [batch_num] nvarchar(4000)
    # ) WITH (
    #         LOCATION = 'minted_tables/{clustered_text_report_tbl_name}/**', 
    #         DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
    #         FILE_FORMAT = [SynapseParquetFormat]
    #         )"""

    # with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
    #   with conn.cursor() as cursor:
    #     cursor.execute(clustered_df_sql_command)

    # # Output as CSV
    # # Saving to a 'local' file first as we can't save directly to the outut container via open()
    # with open('clusters.csv', 'w') as f:
    #     write = csv.writer(f)
    #     write.writerow(clustered_df.columns)
    #     write.writerows(clustered_lst)

    # with open('clusters.csv', 'r') as f:
    #     contents = f.read() 
    #     mssparkutils.fs.put(f'{output_path}/text_processing_clustering/clusters.csv', contents, overwrite=True)

    # # copy the manifest + config to output
    # mssparkutils.fs.cp(f'{batch_root}/{manifest_file_name}', output_path)
    # mssparkutils.fs.cp(f'{batch_root}/config.json', output_path)


In [ ]:
from pyspark.sql.functions import explode, count, countDistinct, udf, row_number
from pyspark.sql.types import FloatType, StringType, IntegerType, StructType, StructField
import math
from pyspark.sql.window import Window

max_num_terms_to_retain = 5

with tracer.span(name=f'Calculate TF-IDF from the given table {processed_text_tbl_name}'):
    small_docs_df = docs_df.select(col('cluster'), col('key_phrases')).na.drop('any')
    small_docs_df = small_docs_df.withColumn('term', explode('key_phrases')).drop('key_phrases')

    # Generate a table of term frequencies within each cluster
    terms_with_tf = small_docs_df.groupBy('cluster', 'term').agg(count('cluster').alias('tf'))

    # Calculate the number of clusters in which each term appears, and then take
    # the (slightly modified) inverse of that to generate the "IDF" table
    num_clusters = small_docs_df.select(countDistinct('cluster')).collect()[0][0]
    calc_idf_udf = udf(lambda x: math.log((num_clusters + 1) / (x + 1)), FloatType())
    terms_with_idf = small_docs_df.groupBy('term').agg(countDistinct('cluster').alias('df')) \
        .withColumn('idf', calc_idf_udf(col('df'))).drop('df')

    # Multiply term frequency by inverse document frequency to get TF-IDF
    tf_idf_df = terms_with_tf.join(terms_with_idf, on='term', how='left') \
        .withColumn('tf_idf', col('tf') * col('idf')).drop('tf', 'idf')

with tracer.span(name='Concatenate the top TF-IDF terms for each cluster to form a label'):
    # Retain the top terms for each cluster
    windower = Window.partitionBy('cluster').orderBy(col('tf_idf').desc())
    pd_df = tf_idf_df.withColumn('row', row_number().over(windower)) \
        .filter(col('row') <= max_num_terms_to_retain).drop('tf_idf').toPandas()
    

    # Concatenate the top terms for each cluster as a comma-separated string
    pd_df = pd_df.groupby('cluster').apply(lambda x: ', '.join(x.sort_values(by='row', ascending=True)['term'])).reset_index()
    if pd_df.empty:
        logger.info("DataFrame is empty")
        pd_df = pd.DataFrame([], columns=['cluster', 'label'])
    else:
        pd_df.columns = ['cluster', 'label']

with tracer.span(name='Persist cluster labels as table'):
    cluster_labels_df_schema = StructType([
        StructField('cluster', IntegerType(), True),
        StructField('label', StringType(), True)])
    cluster_labels_df = spark.createDataFrame(pd_df, schema=cluster_labels_df_schema)

    cluster_labels_tbl_name = f'{batch_num}_cluster_labels'
    cluster_labels_df.write.mode("overwrite").parquet(f"{minted_tables_output_path}{cluster_labels_tbl_name}")

    cluster_labels_sql_command = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{cluster_labels_tbl_name}') 
    CREATE EXTERNAL TABLE [{cluster_labels_tbl_name}] 
    (
        [cluster] bigint,
        [label] nvarchar(4000)
    ) WITH (
            LOCATION = 'minted_tables/{cluster_labels_tbl_name}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
            )"""

    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
        with conn.cursor() as cursor:
            cursor.execute(cluster_labels_sql_command)

# Generate a table for the top words for each cluster

In [ ]:
# clustered_df includes the main table with processed text
# cluster_labels_df includes the sluster labels

# Join the clustered_df and cluster_labels_df by cluster_col
top_wrods_df = (clustered_df_1
.join(cluster_labels_df, [clustered_df_1[cluster_col] == cluster_labels_df[cluster_col]],'left_outer')
.select(
    clustered_df_1.processed_text,
    clustered_df_1.cluster,
    cluster_labels_df.label
    )
)

# Drop undefined or null rows
top_wrods_df = top_wrods_df.na.drop()

# Group by cluster and concatenate all the processed_text
top_wrods_df = top_wrods_df.groupBy("cluster") \
  .agg(
      F.concat_ws(' ', F.collect_list("processed_text")).alias("words")
   )
# Expand the concatenated processed_text to individual words per row
top_wrods_df = top_wrods_df.withColumn('words',explode(F.split('words',' ')))

# Group by cluster and row
top_wrods_df = top_wrods_df.groupBy("cluster","words").count()

# Rank the words by count and order by descending 
my_window = Window.partitionBy("cluster").orderBy(F.col("count").desc())

# Select top 50 words per cluster
top_wrods_df=top_wrods_df.withColumn("row_rank_by_cluster",F.row_number().over(my_window))\
.filter(col('row_rank_by_cluster') <= 50).drop("row_rank_by_cluster")

cluster_top_words_tbl_name = f'{batch_num}_cluster_top_words'
top_wrods_df.write.mode("overwrite").parquet(f"{minted_tables_output_path}{cluster_top_words_tbl_name}")

cluster_top_words_sql_command = f"""IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{cluster_top_words_tbl_name}') 
CREATE EXTERNAL TABLE [{cluster_top_words_tbl_name}] 
(
    [cluster] bigint,
    [words] nvarchar(4000),
    [count] bigint
) WITH (
        LOCATION = 'minted_tables/{cluster_top_words_tbl_name}/**', 
        DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
        FILE_FORMAT = [SynapseParquetFormat]
        )"""

with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
    with conn.cursor() as cursor:
        cursor.execute(cluster_top_words_sql_command)

In [ ]:
from time import sleep
import datetime
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ? 
        ORDER BY [num_stages_complete] DESC;
    """
    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/9] {status_text}'
            x = datetime.datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?, update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

update_status_table('Text Analysis Complete', minted_tables_output_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd)

In [ ]:
# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'processed_text_tbl_name': processed_text_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])