In [ ]:
%%configure -f
{
"conf": {
     "spark.dynamicAllocation.disableIfMinMaxNotSpecified.enabled": true,
     "spark.dynamicAllocation.enabled": true,
     "spark.dynamicAllocation.minExecutors": 2,
     "spark.dynamicAllocation.maxExecutors": 8
   }
}

In [ ]:
batch_num = ""
batch_root = ""
file_system = ""
media_contents_tbl_name = ""
batch_description = ""
batch_file_count = 0
azure_storage_domain = ''
minted_tables_output_path = ""

In [ ]:
from azure.identity import ClientSecretCredential
from azure.eventgrid import EventGridPublisherClient, EventGridEvent
from azure.mgmt.keyvault import KeyVaultManagementClient
from types import SimpleNamespace
from datetime import datetime
import json
import os
import requests as req
import time
import hashlib
import random
import sys
import urllib.parse
from pyspark.sql import SparkSession
from pyspark import SparkContext
from pyspark.sql.functions import col, to_json, from_json, lit, explode, concat, udf, json_tuple, current_timestamp
from pyspark.sql.types import StringType, MapType, BooleanType, StructType, StructField, StringType
from requests.structures import CaseInsensitiveDict
from azure.identity import ClientSecretCredential

In [ ]:
# Load keys, set defaults
instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")
subscription_id = mssparkutils.credentials.getSecretWithLS("keyvault", "SubscriptionId")
resource_group_name = mssparkutils.credentials.getSecretWithLS("keyvault", "ResourceGroupName")
subscription_id = mssparkutils.credentials.getSecretWithLS("keyvault", "SubscriptionId")
tenant_id = mssparkutils.credentials.getSecretWithLS("keyvault", "TenantID")
client_id = mssparkutils.credentials.getSecretWithLS("keyvault", "ADAppRegClientId")
client_secret = mssparkutils.credentials.getSecretWithLS("keyvault", "ADAppRegClientSecret")
storage_account_name = mssparkutils.credentials.getSecretWithLS("keyvault", "StorageAccountName")
storage_account_key = mssparkutils.credentials.getSecretWithLS("keyvault", "StorageAccountKey")
vi_account_name = mssparkutils.credentials.getSecretWithLS("keyvault", "VideoIndexerAccountName")
apiUrl = "<<TF_VAR_azure_avam_api_domain>>" #api's are documented here... https://api-portal.videoindexer.ai/
enrichment_output_path_root = f'abfss://{output_container}@{blob_account_name}.dfs.{azure_storage_domain}/{batch_num}'

azure_resource_manager = "<<TF_VAR_azure_arm_management_api>>";
credential = ClientSecretCredential(tenant_id, client_id, client_secret)

callback_params = [("batch_num", batch_num), ("file_system", file_system),
    ("batch_root", batch_root), ("output_container", output_container),
    ("batch_description", batch_description)] 
callbackurl = "".join(["<<VI_CALLBACK_URL>>&", 
    urllib.parse.urlencode(callback_params)])

In [ ]:
# Initiate logging
import logging
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'batch_num': batch_num,
    'file_system': file_system,
    'media_contents_tbl_name': media_contents_tbl_name,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
  
logger.info(f"{mssparkutils.runtime.context['notebookname']}: INITIALISED", extra=run_time_parameters)

In [ ]:
# Initialise session and config
sc = spark.sparkContext
spark = SparkSession.builder.appName(f"TextProcessing {mssparkutils.runtime.context}").getOrCreate()

def read_batch_config(batch_root: str):
    """
    We read the config file using the Java File System API as we do not need to let multiple nodes read individual lines and join it
    all back together again
    """
    # Change our file system from 'synapse' to 'input'
    sc._jsc.hadoopConfiguration().set("fs.defaultFS", file_system)

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    config_path = sc._jvm.org.apache.hadoop.fs.Path(f'{batch_root}/config.json')

    # If we don't have a batch config, copy the global one.
    if fs.exists(config_path) != True:
        logger.error(f'{config_path} not found.')

    # Open our file directly rather than through spark
    input_stream = fs.open(config_path)  # FSDataInputStream

    config_string = sc._jvm.java.io.BufferedReader(
        sc._jvm.java.io.InputStreamReader(input_stream, sc._jvm.java.nio.charset.StandardCharsets.UTF_8)
        ).lines().collect(sc._jvm.java.util.stream.Collectors.joining("\n"))

    # Load it into json    
    return json.loads(''.join(config_string), object_hook=lambda dictionary: SimpleNamespace(**dictionary))

with tracer.span(name=f"Load config: {mssparkutils.runtime.context['notebookname']}"):
    try:
        config = read_batch_config(batch_root)
    except Exception as e:
        logger.exception(e)
        raise e

    # Set log level
    if config.log_level == "INFO":
        logger.setLevel(logging.INFO)
    else:
        logger.setLevel(logging.ERROR)
        config.log_level = "ERROR"

In [ ]:
import pyodbc
# Dedicated and serverless SQL config
dedicated_database = 'dedicated'
database = 'minted'   
driver= '{ODBC Driver 17 for SQL Server}'

# secrets
sql_user_name = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLUserName")
sql_user_pwd = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseSQLPassword")
serverless_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseServerlessSQLEndpoint")
dedicated_sql_endpoint = mssparkutils.credentials.getSecretWithLS("keyvault", "SynapseDedicatedSQLEndpoint")

In [ ]:
from time import sleep
from datetime import datetime, timedelta
# Update Status Table
def get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd):
    query = f"""
        SELECT TOP (1) 
        [num_stages_complete], [description]
        FROM [dbo].[batch_status] 
        WHERE [batch_id] = ?
        ORDER BY [num_stages_complete] DESC;
    """
    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{dedicated_sql_endpoint};PORT=1433;DATABASE={dedicated_database};UID={sql_user_name};PWD={sql_user_pwd}',autocommit=True) as conn:
        with conn.cursor() as cursor:
            cursor.execute(query, batch_num)
            num_stages_complete, description = cursor.fetchone()
            return num_stages_complete, description

def update_status_table(status_text, minted_tables_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd):
    retries = 0 
    exc = ''
    while retries < 10:
        try:
            stages_complete, description = get_recent_status(batch_num, driver, dedicated_sql_endpoint, dedicated_database, sql_user_name, sql_user_pwd)
            stages_complete += 1
            status = f'[{stages_complete}/10] {status_text}'
            x = datetime.now()
            time_stamp = x.strftime("%Y-%m-%d %H:%M:%S")

            sql_command = f"UPDATE batch_status SET status = ?, update_time_stamp = ?, num_stages_complete = ? WHERE batch_id = ?"
            with pyodbc.connect('DRIVER='+driver+';SERVER=tcp:'+dedicated_sql_endpoint+';PORT=1433;DATABASE='+dedicated_database+';UID='+sql_user_name+';PWD='+ sql_user_pwd+'',autocommit=True) as conn:
                with conn.cursor() as cursor:
                    cursor.execute(sql_command, status, time_stamp, stages_complete, batch_num)
                    cursor.commit()
            return 
        except Exception as e:
            exc_str = str(e)
            exc = e 
            logger.warning(f'Failed to update status table: {exc_str}, retrying . . .')
            retries += 1
            sleep(3)

    raise exc

In [ ]:
#Load media dataframe
with tracer.span(name='Load media contents table'):
    df_media_contents = spark.read.parquet(f'{minted_tables_output_path}{media_contents_tbl_name}')

# Retrieve access tokens and submit media file for processing

In [ ]:
from datetime import datetime, timedelta
api_version = config.video_indexer_api_version

#Set initial token refresh time in past to ensure immediate refresh.
token_refresh_time = datetime.now() - timedelta(minutes=55)

#Get ARM bearer token used to query AVAM for account access token.
def get_arm_token():
    # Get ARM access token (bearer token)
    logger.info('Retrieving ARM token')
    token_context = "<<TF_VAR_azure_arm_management_api>>/.default"
    token = credential.get_token(token_context).token
    logger.info(f"ARM token retreived at {datetime.now()}")
    return token

# Get account level access token for Azure Video Analyzer for Media
def get_account_access_token(arm_token): 
    logger.info('Retrieving AVAM access token')
    request_url = f'{azure_resource_manager}/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.VideoIndexer/accounts/{vi_account_name}/generateAccessToken?api-version={api_version}'
    headers = CaseInsensitiveDict()
    headers["Accept"] = "application/json"
    headers["Authorization"] = "Bearer " + arm_token
    body = '{"permissionType":"Contributor","scope":"Account","projectId":null,"videoId":null}'
    body = json.loads(body)
    response = req.post(request_url, headers=headers, json=body)
    response = response.json()
    logger.info(f"AVAM access token retreived at {datetime.now()}")
    return response["accessToken"]

# Refresh access tokens
def refresh_access_tokens():
    global token_refresh_time
    
    logger.info('Refreshing ARM and AVAM tokens.')
    token_refresh_time = datetime.now() + timedelta(minutes=55)
    arm_token = get_arm_token()
    # Return token dictionary as both tokens are used elsewhere, ARM token is also needed to get AVAM access token.
    return {
        "arm" : arm_token, 
        "avam" : get_account_access_token(arm_token)}

#initial get of each token
tokens = refresh_access_tokens()

In [ ]:
# Get Account information
# these top level API's are documented and testable here... https://docs.microsoft.com/en-us/rest/api/videoindexer/accounts/list
request_url = f'{azure_resource_manager}/subscriptions/{subscription_id}/resourcegroups/{resource_group_name}/providers/Microsoft.VideoIndexer/accounts/{vi_account_name}/?api-version={api_version}'
headers = CaseInsensitiveDict()
headers["Accept"] = "application/json"
headers["Authorization"] = "Bearer " + tokens["arm"]
response = req.get(request_url, headers=headers)
response = response.json()
vi_account_id = response['properties']['accountId']
vi_account_location = response['location']

In [ ]:
# user defined function to generate media file url with SAS
from azure.storage.blob import BlobClient, generate_blob_sas, BlobSasPermissions

def get_blob_sas(account_name, account_key, container_name, blob_name):
    sas_blob = generate_blob_sas(account_name=account_name, 
                                container_name=container_name,
                                blob_name=blob_name,
                                account_key=account_key,
                                permission=BlobSasPermissions(read=True),
                                expiry=datetime.utcnow() + timedelta(hours=1))
    return sas_blob

In [ ]:
import time

## Configuration Variables #################################################################################################################################
extra_retry_buffer = 10     # this is the amount of time we can add to the Retry-After we get back to wait until we send another post, just to be safe
max_retry_loops = 5         # this is the max times we will retry until we break from the loop and move on
############################################################################################################################################################

# helper function to generate endcoded callback URL for Video Indexer service
def generate_vi_callback_url(json_path, media_path, file_name, callbackurl):
    # return the callback url with the enrichment json output path encoded as a param
    encoded_callback_params = urllib.parse.urlencode([("json_path", json_path), ("media_path", media_path), ("media_file_name", file_name)]) 
    return "".join([callbackurl, "&", encoded_callback_params])

# helper function to submit a media file url with SAS to VI
def post_media_for_enrichment(media_path, file_name, file_type):  

    #Check to see if ARM and AVAM access tokens need to be refreshed, both expire after 1 hr.
    global tokens
    if(token_refresh_time < datetime.now()):
        tokens = refresh_access_tokens()

    # get container for the blob
    start_text = 'abfss://'
    start_index = media_path.find(start_text)+len(start_text)
    end_text = '@'
    end_index = media_path.find(end_text)
    media_cointainer = media_path[start_index:end_index]
    
    # get path for the blob
    start_text = f'dfs.{azure_storage_domain}/'
    start_index = media_path.find(start_text)+len(start_text)
    media_path_trunc = media_path[start_index:]
    # get blob link with sas
    sas_token = get_blob_sas(storage_account_name, storage_account_key, media_cointainer, media_path_trunc)
    media_url = 'https://'+storage_account_name+ f'.blob.{azure_storage_domain}/'+media_cointainer+'/'+media_path_trunc+'?'+sas_token
 
    # path to dump enrichments json
    json_path = f'{enrichment_output_path_root}/media_processing_json/{file_name}.output.json'

    # post media file for processing and provide a callback url to notify when completed
    params = CaseInsensitiveDict()
    params["accessToken"] = tokens["avam"]
    params["name"] = file_name[:79]      # truncated to 80 chars as this is the max supported
    params["description"] = json.dumps({
        "media_path": media_path,
        "file_name": file_name,
        "batch_number": batch_num,
        "batch_file_count": batch_file_count
    })
    params["privacy"] = "private"
    params["partition"] = "partition"
    params["videoUrl"] = media_url
    params["language"] = "auto"
    params["callbackUrl"] = generate_vi_callback_url(json_path, media_path, file_name, callbackurl)
    request_url = f'{apiUrl}/{vi_account_location}/Accounts/{vi_account_id}/Videos'

    logger.info(f"Posting to AVAM, Media URL: {media_url}")
    response = req.post(request_url, headers={'Accept': 'application/json'}, params=params)

    vi_response_code = response.status_code
    logger.info(f"Post response code from AVAM: {vi_response_code}, Media URL: {media_url}")

    # variable to keep track of how many retry loops we have attempted
    post_retry_counter = 1

    # checking for response code of 429.  If we get this code, it means the video indexer service is oversaturated
    # and we need to pause for the amount of time the service responses back with in the "Retry-After" header.
    # For more info, see this article: https://docs.microsoft.com/en-us/azure/azure-video-analyzer/video-analyzer-for-media-docs/considerations-when-use-at-scale#respect-throttling
    if vi_response_code == 429:
        logger.info('AVAM response code == 429, entering retry loop')
        vi_response_retry_after_header = int(response.headers['Retry-After'])
        while vi_response_code == 429:
            # sleep for the amount of time we get back in the Retry-After header, plus a configurable extra buffer time
            time.sleep(vi_response_retry_after_header + extra_retry_buffer)    

            # retry the post again
            logger.info(f'AVAM request post attempt {post_retry_counter}/{max_retry_loops}')
            response = req.post(request_url, headers=headers, params=params)
            # capture the new response status code
            vi_response_code = response.status_code
            
            # keep track of how many retry loops we have done
            post_retry_counter += 1

            # if we are past the max number of retry loops then break from the loop and move on
            if post_retry_counter > max_retry_loops:
                logger.error('AVAM request post has reached max retry attempts.')
                break

            pass
        
        pass

    logger.info(f"Media: {media_url}, successfully posted to AVAM.")

    response = response.json()
    return json.dumps(response)

    

# convert function into UDF for later usage in dataframe
# NOTE: not using currently as workaround for video indexer issue with oversaturation
udf_post_media_for_enrichment = udf(post_media_for_enrichment, StringType())

In [ ]:
############################################
##### Submit via withColumn()
############################################

'''with tracer.span(name='Starting cell withColumn thru each video and call the VI endpoint...'):
    # Submit all media files using the UDF and add fresponse to df
    df_media_contents_submitted = df_media_contents \
        .withColumn("vi_submission_detail", udf_post_media_for_enrichment(
            df_media_contents.path, df_media_contents.file_name, df_media_contents.file_type)) '''

In [ ]:
############################################
##### Submit via Looper
############################################

with tracer.span(name='Starting cell to loop thru each video and call the VI endpoint...'):
    import time

    # Create empty Dataframe with proper shape to use if no media files are in manifest.txt
    empty_RDD = spark.sparkContext.emptyRDD() 
    columns = StructType([
        StructField("path", StringType()),
        StructField("file_name", StringType()),
        StructField("file_type", StringType()),
        StructField("vi_submission_detail", StringType()),
    ])
    empty_df = spark.createDataFrame(data = empty_RDD, schema = columns)
    ## Configuration Variables #################################################################################################################################
    ## variables to configure the amount of time to sleep per loop through each video and
    ## the modulus for how many to run in a row before a longer pause if needed
    sleep_modulus = 5                   # the amount of videos per batch to run through before sleeping for a longer time period
    sleep_time_per_normal_loop = 0      # the amount of time to sleep per video in general
    sleep_time_per_modulus_loop = 0     # the amount of time to sleep after a given batch of videos have been processed
    ############################################################################################################################################################

    # convert the dataframe with each video into a simple list to make easier to loop through
    # in an old fashion for loop, as the azure credential object is non-picklable
    media_contents_list = df_media_contents.collect()
    submitted_list = []
    submitted_list_cols = ["path", "file_name", "file_type", "vi_submission_detail"]
    totalrows = df_media_contents.count()
    counter = 1

    # main loop to process all videos to video indexer service
    for row in media_contents_list:

        path = row['path']
        file_name = row['file_name']
        file_type = row['file_type']

        # call helper function to submit video to the video indexer service
        response = post_media_for_enrichment(path, file_name, file_type)
        # append file info and response results of the video indexer service call
        # to a new list to later convert back into a dataframe and save into a table
        # for later review
        submitted_list.append([path, file_name, file_type, str(response)])

        # use this section only if needing to add sleeps per video or per batch of videos
        if counter % sleep_modulus == 0:
            time.sleep(sleep_time_per_modulus_loop)
        else:
            time.sleep(sleep_time_per_normal_loop)
        counter += 1  

        pass

    # convert our new list of processing resuslts into a dataframe for later saving into a table
    df_media_contents_submitted = empty_df
    if len(media_contents_list) > 0: 
        df_media_contents_submitted = spark.createDataFrame(submitted_list, schema=submitted_list_cols)
    else:
        update_status_table('Media Processing Complete', minted_tables_output_path, batch_num, driver, dedicated_sql_endpoint, sql_user_name, sql_user_pwd)

In [ ]:
# Extract any anomaly responses, anything apart from "State: Processing", to a new column to identify these files easily
# https://sparkbyexamples.com/pyspark/pyspark-json-functions-with-examples/
df_media_contents_submitted = df_media_contents_submitted
df_media_contents_submitted = df_media_contents_submitted.select(col("path"),col("file_name"),col("file_type"),col("vi_submission_detail"),json_tuple(col("vi_submission_detail"),"state")) \
    .toDF("path","file_name","file_type","vi_submission_detail", "state")

In [ ]:
with tracer.span(name='Persist submitted media details as table'):
    submitted_media_tbl_name = f'{batch_num}_submitted_media'
    df_media_contents_submitted \
        .write.mode("overwrite").parquet(f'{minted_tables_output_path}{submitted_media_tbl_name}')
    ext_table_command = (
        f"IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{submitted_media_tbl_name}') "
        f"CREATE EXTERNAL TABLE [{submitted_media_tbl_name}] ("
            "[path] nvarchar(4000), "
            "[file_name] nvarchar(4000), "
            "[file_type] nvarchar(4000), "
            "[vi_submission_detail] nvarchar(4000),"
            "[state] nvarchar(4000)"
        ") "
        f"WITH (LOCATION = 'minted_tables/{submitted_media_tbl_name}/**', DATA_SOURCE = [synapse_<<STORAGE_ACCOUNT_NAME>>_dfs_core_windows_net], FILE_FORMAT = [SynapseParquetFormat])"
    )
    with pyodbc.connect(f'DRIVER={driver};SERVER=tcp:{serverless_sql_endpoint};PORT=1433;DATABASE={database};UID={sql_user_name};PWD={sql_user_pwd}') as conn:
        with conn.cursor() as cursor:
            cursor.execute(ext_table_command)
            
# return name of new table
output = {'custom_dimensions': {
    'batch_num': batch_num,
    'submitted_media_tbl_name': submitted_media_tbl_name,
    'file_system': file_system,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(output['custom_dimensions'])