In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8,
     "spark.sql.autoBroadcastJoinThreshold": -1
   }
}

In [ ]:
images_tbl_name = ""
batch_root = ""
batch_num = ""
file_system = ""
azure_storage_domain = ""
blob_account_name = ""
minted_tables_output_path = ""

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer
import datetime

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
  'batch_root': batch_root,
  'batch_num': batch_num,
  'file_system': file_system,
  'images_tbl_name': images_tbl_name,
  'notebook_name': mssparkutils.runtime.context['notebookname']
} }

logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
import json
import random
import io
from types import SimpleNamespace
from typing import List
from PIL import Image, ImageFile
from math import trunc

import pyspark.sql.functions as F
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType, StructType, StructField, BinaryType
from pyspark import SparkContext
from pyspark.sql import SparkSession

def read_batch_config(batch_root: str):
    """
    We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
    all back together again
    """
    # Change our file system from 'synapse' to 'input'
    sc._jsc.hadoopConfiguration().set("fs.defaultFS", file_system)

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

    # If we don't have a batch config, copy the global one.
    if fs.exists(config_path) != True:
        logger.error(f'{config_path} not found.')

    # Open our file directly rather than through spark
    input_stream = fs.open(config_path)  # FSDataInputStream

    config_string = sc._jvm.java.io.BufferedReader(
        sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
        ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

    # Load it into json    
    return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

with tracer.span(name='Initialise Spark session'):
    sc = spark.sparkContext
    spark = SparkSession.builder.appName(f"ImageProcessing {mssparkutils.runtime.context}").getOrCreate()

with tracer.span(name=f"Load config: {mssparkutils.runtime.context['notebookname']}"):
    try:
        config = read_batch_config(batch_root)
    except Exception as e:
        logger.exception(e)
        raise e

    # Set log level
    if config.log_level == "INFO":
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)
        config.log_level = "ERROR"

In [ ]:
import pyodbc
from pyspark.sql.functions import current_timestamp
# Dedicated and serverless SQL config
dedicated_database = 'dedicated'
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'

# secrets
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

In [ ]:
from time import sleep
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ?
        ORDER BY [num_stages_complete] DESC;
    """
    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):   
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/10] {status_text}'
            x = datetime.datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?,  update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

update_status_table('Image Prep Started', minted_tables_output_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd) 

In [ ]:
with tracer.span(name='Get image contents'):
    #Load image contents into table to be used by downstream notebooks. 
    images_df = spark.read.parquet(f'{minted_tables_output_path}{images_tbl_name}')
    images_jpg = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.jpg").load(f'{batch_root}')
    images_JPG = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.JPG").load(f'{batch_root}')
    images_jpeg = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.jpeg").load(f'{batch_root}')
    images_JPEG = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.JPEG").load(f'{batch_root}')
    images_png = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.png").load(f'{batch_root}')
    images_PNG = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.PNG").load(f'{batch_root}')
    images_gif = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.gif").load(f'{batch_root}')
    images_GIF = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.GIF").load(f'{batch_root}')
    images_bmp = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.bmp").load(f'{batch_root}')
    images_BMP = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.BMP").load(f'{batch_root}')
    images_tif = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.tif").load(f'{batch_root}')
    images_TIF = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.TIF").load(f'{batch_root}')
    images_content_df = images_jpg.union(images_JPG).union(images_jpeg).union(images_JPEG).union(images_png).union(images_PNG).union(images_gif) \
        .union(images_GIF).union(images_bmp).union(images_BMP).union(images_tif).union(images_TIF)
    images_content_df = images_content_df.join(images_df, images_df.file_path == images_content_df.path, 'inner').drop('file_path')

## De-duplicate Images
For MINTED 2.0 Accelerator, it assumed that the image data provided is already de-duplicated.

In [ ]:
with tracer.span(name='De-duplicate images'):
    #Insert image de-duplication logic here based you the types of image content being consumed. 
    temp = ''

## Resize Images
Azure Computer Vision Cognitive Services have a minimum size of 50x50 pixels. If the image is below this size the calls to Computer Vision will fail the pipeline. 

In [ ]:
def resize(image_data, image_type):
    try:
        if (image_type == "jpg"):
            save_type = "JPEG"
        else:
            save_type = image_type.upper()
        ImageFile.LOAD_TRUNCATED_IMAGES = True
        im = Image.open(io.BytesIO(image_data)).convert('RGB')
        imgwidth, imgheight = im.size
        oimgwidth, oimgheight = im.size
        print(f'original size = {oimgwidth} x {oimgheight}')
        if imgwidth < 50:
            variance = 50 / imgwidth
            imgwidth = imgwidth * variance
            imgheight = imgheight * variance
            
        if imgheight < 50:
            variance = 50 / imgheight
            imgwidth = imgwidth * variance
            imgheight = imgheight * variance
            
        if (imgwidth > oimgwidth) or (imgheight != oimgheight):
            print(f'final size = {imgwidth} x {imgheight}')    
            size=(trunc(imgwidth),trunc(imgheight))
            try:
                resized_im = im.resize(size)
                buf = io.BytesIO()
                resized_im.save(buf,save_type)
                return buf.getvalue()
            except:
                return image_data 
        else:
            return image_data
    except:
        logger.error('Failed to resize image.')
        return image_data

In [ ]:
with tracer.span(name='Resize images'):
    #Check size of image and if too small upscale to minimum size. 
    udf_resize = udf(resize, BinaryType())
    images_content_df = images_content_df.withColumn("content", udf_resize(images_content_df.content, images_content_df.file_type))


with tracer.span(name='Persist resized images to table'):
    image_contents_tbl_name = f"{batch_num}_image_contents"
    images_content_df.write.mode("overwrite").parquet(f'{minted_tables_output_path}{image_contents_tbl_name}')
    sql_command = f'''
        IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{image_contents_tbl_name}') 
        CREATE EXTERNAL TABLE [{image_contents_tbl_name}] (
            [path] nvarchar(1000), 
            [modificationTime] datetime2(7), 
            [length] bigint, 
            [content] varbinary(max), 
            [file_name] nvarchar(1000), 
            [file_type] nvarchar(1000)
        )
        WITH (
            LOCATION = 'minted_tables/{image_contents_tbl_name}/**', 
            DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], 
            FILE_FORMAT = [SynapseParquetFormat]
        )
    '''
    with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+serverless_sql_endpoint+';PORT=1433;DATABASE='+database+';UID='+sql_user_name+';PWD='+ sql_user_pwd) as conn:
        with conn.cursor() as cursor:
            cursor.execute(sql_command)


In [ ]:
# return values to be used by other notebooks
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'image_contents_tbl_name': image_contents_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])