# 1. Dataset Prep


In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8
   }
}

In [ ]:
# This is a parameters cell where we define the batch_file details as params to be passed in by the pipeline
manifest_file_path=''
manifest_container=''
blob_account_name = ''
global_config_location = ''
azure_storage_domain = ''

In [ ]:
import os
import json
from types import SimpleNamespace

from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import regexp_extract, regexp_replace, concat, lit, col, current_timestamp
from pyspark.sql.types import StringType, TimestampType, StructType, StructField, IntegerType
from py4j.protocol import Py4JJavaError
import re
import pyodbc
from pathlib import Path

# Initialise paths and batch root
batch_path = f'abfss://{manifest_container}@{blob_account_name}.dfs.{azure_storage_domain}'
batch_folder = os.path.dirname(manifest_file_path)
batch_root = f'{batch_path}{batch_folder}'
manifest_full_path = f'{batch_path}{manifest_file_path}'
manifest_file_name = Path(manifest_file_path).name
# Dedicated and serverless SQL config
dedicated_database = 'dedicated'
database = 'minted'
driver= '{ODBC Driver 17 for SQL Server}'
output_path = f'abfss://synapse@{blob_account_name}.dfs.{azure_storage_domain}/minted_tables/'

# Load secrets
instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer
import datetime

config_integration.trace_integrations(['logging'])

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
  'batch_path': batch_path,
  'batch_folder': batch_folder,
  'batch_root': batch_root,
  'manifest_full_path': manifest_full_path,
  'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
with tracer.span(name=f"Load config: {mssparkutils.runtime.context['notebookname']}"):
    # Initialise session, create (if necessary) and read batch config
    sc = spark.sparkContext
    spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

    def copy_global_config(config_path: str, global_config_path: str):
        """
        This method makes sure that a config is availabile in the batch root.
        If a config file isn't already there, it is copied over form global_config_path.
        If there is no config under global_config_path, this function will crash (indicating an error in pipeline set up.)
        """
        logger.info("Loading global config")
        try:
            mssparkutils.fs.cp(global_config_path, config_path)    
        except Py4JJavaError as e:
            logger.exception(e)
            raise e

    def read_batch_config(batch_root: str, global_config_path: str):
        """
        We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
        all back together again
        """
        # Change our file system from 'synapse' to 'input'
        sc._jsc.hadoopConfiguration().set("fs.defaultFS", batch_path)

        fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
        config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

        # If we don't have a batch config, copy the global one.
        if fs.exists(config_path) != True:
            copy_global_config(f'{batch_root}/config.json', global_config_path)

        # Open our file directly rather than through spark
        input_stream = fs.open(config_path)  # FSDataInputStream

        config_string = sc._jvm.java.io.BufferedReader(
            sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
            ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

        # # Load it into json    
        return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

    # NOTE: this path should be in sync with Terraform configuration which uploads this file
    config = read_batch_config(batch_root, global_config_path=f'abfss://configuration@{blob_account_name}.dfs.{azure_storage_domain}/config.global.json')

    # Set log level
    if config.log_level == "INFO":
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)
        config.log_level = "ERROR"    

In [ ]:
from time import sleep
def update_status_table(batch_num, description):
    retries = 0
    exc = ''
    while retries < 10:
        try:
            x = datetime.datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")
            
            sql_command = (
                f"IF NOT EXISTS (SELECT * FROM batch_status WHERE batch_id = ?) INSERT INTO batch_status (batch_id, date_submitted, description, status, update_time_stamp, num_stages_complete) VALUES (?, ?, ?, '[1/10] Pipeline Started', ?, 1)"
                 "ELSE UPDATE batch_status SET description = ?, status = '[1/10] Pipeline Started', update_time_stamp = ?, num_stages_complete = 1"
            )
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, batch_num, batch_num, time_stamp, description, time_stamp, description, time_stamp)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

In [ ]:
try:
  with tracer.span(name='Convert manifest to tables'):
    # read the manifest rows into a dataframe
    docs_columns = ["file_path", "file_extension"]
    df = spark.read.text(manifest_full_path)
    headers = df.head(2)
    batch_num = headers[0].asDict()['value']
    description = headers[1].asDict()['value']

    '''
    Sanitise batch_num to letters, numbers and _ 
    Also converts to lower case, because power BI is case sensitive and synapse/azure storage forces 
    everything
    '''
    batch_num = "".join(re.findall('[a-z]|[A-Z]|[0-9]|_', batch_num)).lower()

    # separate the file extension into a new column
    df = df.withColumnRenamed("value", "file_name") \
      .withColumn("file_type", regexp_replace(regexp_extract("file_name", "\.[0-9a-zA-Z]+$", 0), "\.", "")) \
      .withColumn("file_path", concat(lit(batch_root + "/"), col("file_name")))

    #save the contents of the manifest file for use in Power BI summary pages
    batch_df = df.withColumn('batch_num', lit(f'{batch_num}')).withColumn('batch_desc',lit(f'{description}')).filter((col('file_type').isNotNull()) & (col('file_type') != ''))
    batch_df_name = f'{batch_num}_manifest'

    batch_status_schema = StructType([
        StructField("batch_id", StringType(), True),
        StructField("date_submitted", TimestampType(), True),
        StructField("description", StringType(), True),
        StructField("status", StringType(), True),
        StructField("update_time_stamp", TimestampType(), True),
        StructField("num_stages_complete", IntegerType(), True)])

    batch_status_df = spark.createDataFrame([], StructType([]))
    print("Path: " + f'{output_path}batch_status/')
    try:
        batch_status_df = spark.read.format("parquet").schema(batch_status_schema).load(f'{output_path}batch_status/')
    except:
        emptyRDD = spark.sparkContext.emptyRDD()
        batch_status_df = spark.createDataFrame(emptyRDD,batch_status_schema)
        batch_status_df.write.parquet(f'{output_path}batch_status')
    
    update_status_table(batch_num, description)

    # select rows into new dataframes, per file type
    doc_file_types = ["txt", "TXT", "docx", "DOCX", "doc", "DOC", "pdf", "PDF", "pptx", "PPTX", "ppt", "PPT", "html", "HTML", "htm", "HTM", "json", "JSON"]
    docs_df = df.where(df.file_type.isin(doc_file_types))
    docs_df_name = f'{batch_num}_documents'

    img_file_types = ["jpg", "JPG", "jpeg", "JPEG", "png", "PNG", "gif", "GIF", "bmp", "BMP", "tif", "TIF"]
    img_df = df.where(df.file_type.isin(img_file_types))
    img_df_name = f'{batch_num}_images'

    media_file_types = ["avi", "AVI", "mp4", "MP4", "mp3", "MP3", "mpg", "MPG", "wmv", "WMV", "wav", "WAV", "mov", "MOV"]
    media_df = df.where(df.file_type.isin(media_file_types))
    media_df_name = f'{batch_num}_media'

    # get count of files in manifest (pre-processed)
    batch_file_count = batch_df.count()
    media_file_count = media_df.count()
    image_file_count = img_df.count()
    text_file_count = docs_df.count()

    # persist new dataframes as tables
    batch_df.write.mode("overwrite").parquet(f'{output_path}{batch_df_name}')
    docs_df.write.mode("overwrite").parquet(f'{output_path}{docs_df_name}')
    img_df.write.mode("overwrite").parquet(f'{output_path}{img_df_name}')
    media_df.write.mode("overwrite").parquet(f'{output_path}{media_df_name}')
  
    # create remote sql tables over the parquet files
    batch_df_sql_command = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{batch_df_name}') CREATE EXTERNAL TABLE [{batch_df_name}] ([file_name] nvarchar(4000), [file_type] nvarchar(4000), [file_path] nvarchar(4000), [batch_num] nvarchar(4000), [batch_desc] nvarchar(4000)) WITH (LOCATION = 'minted_tables/{batch_df_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    docs_df_sql_command = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{docs_df_name}') CREATE EXTERNAL TABLE [{docs_df_name}] ([file_name] nvarchar(4000), [file_type] nvarchar(4000), [file_path] nvarchar(4000)) WITH (LOCATION = 'minted_tables/{docs_df_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    img_df_sql_command = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{img_df_name}') CREATE EXTERNAL TABLE [{img_df_name}] ([file_name] nvarchar(4000), [file_type] nvarchar(4000), [file_path] nvarchar(4000)) WITH (LOCATION = 'minted_tables/{img_df_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    media_df_sql_command = f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{media_df_name}') CREATE EXTERNAL TABLE [{media_df_name}] ([file_name] nvarchar(4000), [file_type] nvarchar(4000), [file_path] nvarchar(4000)) WITH (LOCATION = 'minted_tables/{media_df_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    
    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
      with conn.cursor() as cursor:
        cursor.execute(batch_df_sql_command)
        cursor.execute(docs_df_sql_command)
        cursor.execute(img_df_sql_command)
        cursor.execute(media_df_sql_command)

    # return json block with names of tables
    output = {'custom_dimensions': {
        'batch_tbl_name': batch_df_name,
        'documents_tbl_name': docs_df_name,
        'images_tbl_name': img_df_name,
        'media_tbl_name': media_df_name,
        'batch_num': batch_num,
        'batch_description': description,
        'batch_root': batch_root,
        'manifest_file_name': manifest_file_name,
        'file_system': batch_path,
        'notebook_name': mssparkutils.runtime.context['notebookname'],
        'batch_file_count': batch_file_count,
        'media_file_count': media_file_count,
        'image_file_count': image_file_count,
        'text_file_count': text_file_count,
        'blob_account_name': blob_account_name,
        'minted_tables_output_path': output_path,
        'instrumentation_connection_string': instrumentation_connection_string
    } }

except Exception as e:
  logger.exception(e)
  raise e

# return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])