In [ ]:
%%configure -f
{
    "conf": {
        "spark.jars.packages": "com.databricks:spark-xml_2.12:0.14.0",
        "spark.jars.repositories": "https://repo1.maven.org/maven2/"
   }
}

In [ ]:
azure_storage_domain = ''
input_container = ''
blob_account_name = ''
image_file_path = ''
output_container = ''
key_vault_name = ''
config_path = ''
kml_path = ''
output_path = ''
ais_image = ''
anomaly_image = ''
ship_bb_image_high_res = ''

In [ ]:
# Initiate logging
import logging
import base64
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# NOTE: this path should be in sync with Terraform configuration which uploads this file
global_config_path = f'abfss://configuration@{blob_account_name}.dfs.{azure_storage_domain}/anomdet.config.global.json'

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'input_container': input_container,
    'image_file_path': image_file_path,
    'blob_account_name': blob_account_name,
    'global_config_path': global_config_path,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
 
logger.info(f"INITIALISED: {mssparkutils.runtime.context['notebookname']}", extra=run_time_parameters)

In [ ]:
import os, io, sys, math
import json
import glob
import logging
import requests
import copy

from requests.exceptions import HTTPError
from pathlib import Path
from pyspark.sql import SparkSession
from py4j.protocol import Py4JJavaError
from PIL import Image, UnidentifiedImageError
from azure.storage.blob import BlobServiceClient, ContainerClient, generate_account_sas, generate_container_sas, generate_blob_sas, ResourceTypes
from azure.storage.blob import AccountSasPermissions, ContainerSasPermissions, BlobSasPermissions
from azure.identity import ClientSecretCredential
from datetime import datetime, timedelta

In [ ]:
# Initialise paths
image_path = f'https://{blob_account_name}.blob.{azure_storage_domain}/{input_container}/'
image_path_abfss = f'abfss://{input_container}@{blob_account_name}.dfs.{azure_storage_domain}/'
image_folder = os.path.dirname(image_file_path)
image_root = f"{image_path}{image_folder}"
image_root_abfss = f'{image_path_abfss}{image_folder}'
image_full_path = f"{image_path}{image_file_path}"
output_dir = f'https://{blob_account_name}.blob.{azure_storage_domain}/{output_container}/{output_path}'
output_dir_abfss = f'abfss://{output_container}@{blob_account_name}.dfs.{azure_storage_domain}/{output_path}'
output_root = f'{output_dir}/{image_folder}'

In [ ]:
with tracer.span(name=f'Preparing config from global config and loading into memory'):
    # Initialise session, create (if necessary) and read config
    sc = spark.sparkContext
    spark = SparkSession.builder.appName(f"AnomalyDetection {mssparkutils.runtime.context}").getOrCreate()

    def prepare_config(image_root: str, global_config_path: str):
        """
        This method makes sure that a config is availabile in the batch root.
        If a config file isn't already there, it is copied over form global_config_path.
        If there is no config under global_config_path, this function will crash (indicating an error in pipeline set up.)
        """
        image_config_path = f'{image_root_abfss}/anomdet.config.json'
        try: 
            mssparkutils.fs.head(image_config_path)
        except Py4JJavaError as e:
            if 'java.io.FileNotFoundException' in str(e):
                # File doesn't exist, copying over the global config path
                mssparkutils.fs.cp(global_config_path, image_config_path)    
            else:
                raise e


    prepare_config(image_root=image_root, global_config_path=global_config_path)

    config = json.loads(''.join(sc.textFile(f'{image_root_abfss}/{config_path}').collect()))


In [ ]:
with tracer.span(name=f'Getting Credentials, creating BlobServiceClient and sas_token'):
    tenant_id = mssparkutils.credentials.getSecretWithLS('keyvault', 'TenantID')
    client_id = mssparkutils.credentials.getSecretWithLS('keyvault', 'ADAppRegClientId')
    client_secret = mssparkutils.credentials.getSecretWithLS('keyvault', 'ADAppRegClientSecret')
    storage_account_key = mssparkutils.credentials.getSecretWithLS('keyvault', 'StorageAccountKey')
    credential = ClientSecretCredential(tenant_id, client_id, client_secret)
    service = BlobServiceClient(account_url=f'https://{blob_account_name}.blob.{azure_storage_domain}/', credential=credential)
    sas_token = generate_account_sas(
        account_name=f'{blob_account_name}',
        account_key=f'{storage_account_key}',
        resource_types=ResourceTypes(service=False,container=True, object=True),
        permission=AccountSasPermissions(read=True, list=True, write=True, add=True, create=True, update=True),
        expiry=datetime.utcnow() + timedelta(hours=1)
    )

In [ ]:
with tracer.span(name=f'Preparing AIS to send off to GDAL server'):

    import json
    from pyspark.sql.functions import struct, col, split, to_timestamp, to_json, spark_partition_id, desc, count

    src = f"abfss://{output_container}@{blob_account_name}.dfs.{azure_storage_domain}/{kml_path}"

    try:
        df = spark.read \
            .option("rootTag", "Document") \
            .option("rowTag", "Placemark") \
            .format("com.databricks.spark.xml") \
            .load(src)

    except Exception as e:
        logger.info(f'Getting file from source: {src}: {e}')
    else:
        logger.info(f'created dataframe from source: {src}')

    long_list_of_boats = df.withColumn("UTC", to_timestamp(col('ExtendedData').Data[8].value, "yyyy-MM-dd HH:mm:ss")) \
                            .withColumn("MMSI", col('ExtendedData').Data[1].value)
    long_list_of_boats.createOrReplaceTempView("ship_profiles")
    
    query = "(SELECT * FROM ship_profiles WHERE UTC BETWEEN '2021-06-23 03:27:54' AND '2021-06-23 03:37:54')"
    boats_in_time = spark.sql(query)
    boats_in_time.createOrReplaceTempView("ship_profiles_in_time")
    mmsi_agg_time_slice = boats_in_time.groupby("MMSI").count()
    mmsi_agg_time_slice.createOrReplaceTempView("MMSI_grouped")
    time_sliced_boats_agg = boats_in_time.join(mmsi_agg_time_slice, "MMSI", "outer")
    time_sliced_boats_agg.printSchema()
    time_sliced_boats_agg.select(count('MMSI')).show()
    time_sliced_boats_agg.orderBy(desc("count")).show()
    ais_remote_dir_name = 'AIS_sliced_and_grouped'
    try:
        ret = time_sliced_boats_agg.coalesce(1).write.mode("Overwrite").json(f'{output_dir_abfss}/{ais_remote_dir_name}')
    except Exception as err:
        logger.error(f'Other error occurred: {err}')
    else:
        logger.info(f'Success. Response: {ret}')

# Perform matching and get json-formatted information on anomalies + successful pairings

In [ ]:
with tracer.span(name=f'Send AIS and BB data to GDAL server to do anomaly detection'):
    # Define a helper function for calling the endpoint
    def call_detect_anomalies(gdal_endpoint,anom_det_meta,api_key):
        resp = ""
        try:
            headers = {
                # Request headers
                "Content-Type": "application/json",
                "Gdal-Subscription-Key": api_key,
                "KEY": api_key
            }
            body = anom_det_meta
            url = f"{gdal_endpoint}/anomaly_detection/"
            resp = requests.post(url=url, json=body, headers=headers)
        except Exception as e:
            logger.error('Exception', e)
        return resp

    # Get SAS keys for:
    # - the input image
    try:
        img_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                                    container_name=input_container,
                                    blob_name=image_file_path,
                                    account_key=storage_account_key,
                                    permission=BlobSasPermissions(read=True),
                                    expiry=datetime.utcnow() + timedelta(hours=1))
    except Exception as e:
        logger.info(f"unable to create remote bb json blob sas: {e} ")
    else:
        logger.info(f"created blob sas for file: https://{blob_account_name}.blob.{azure_storage_domain}/{input_container}/{image_file_path}")

    # - the bounding box list
    bb_json_name = 'bb_' + os.path.splitext(image_file_path)[0]+'.json'
    bb_remote_json_path = f'{output_path}/{bb_json_name}'
    try:
        bb_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                                    container_name=output_container,
                                    blob_name=bb_remote_json_path,
                                    account_key=storage_account_key,
                                    permission=BlobSasPermissions(read=True),
                                    expiry=datetime.utcnow() + timedelta(hours=1))
    except Exception as e:
        logger.info(f"unable to create remote bb json blob sas: {e} ")
    else:
        logger.info(f"created blob sas for file: https://{blob_account_name}.blob.{azure_storage_domain}/{output_container}/{bb_remote_json_path}")

    # - the AIS point list
    #   use container client to find ais data and then copy it to the anomaly_detection path
    service_client = BlobServiceClient(account_url=f'https://{blob_account_name}.blob.{azure_storage_domain}/', credential=sas_token)    
    container_client = service_client.get_container_client(output_container)
    blobs_list = container_client.list_blobs(name_starts_with=f"{output_path}/{ais_remote_dir_name}")
    ais_remote_path = ""
    for blob in blobs_list:
        if os.path.splitext(blob.name)[1] == '.json':
            logger.info(f"found a json file: {blob.name}")
            ais_remote_path = blob.name

    ais_remote_path = '/' + ais_remote_path

    try:
        ais_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                                    container_name=output_container,
                                    blob_name=ais_remote_path,
                                    account_key=storage_account_key,
                                    permission=BlobSasPermissions(read=True),
                                    expiry=datetime.utcnow() + timedelta(hours=1))
    except Exception as e:
        logger.info(f"unable to create ais blob sas: {e} ")
    else:
        logger.info(f"created blob sas for file: https://{blob_account_name}.blob.{azure_storage_domain}/{output_container}/{ais_remote_path}")

    # - the output json file
    out_json_remote_path = f'{output_path}/anomaly_detection_output.json'
    try:
        out_json_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                                    container_name=output_container,
                                    blob_name=out_json_remote_path,
                                    account_key=storage_account_key,
                                    permission=BlobSasPermissions(
                                        write=True),
                                    expiry=datetime.utcnow() + timedelta(hours=1))
    except Exception as e:
        logger.info(f"unable to create anomaly detection result json blob sas: {e} ")
    else:
        logger.info(f"created blob sas for file: https://{blob_account_name}.blob.{azure_storage_domain}/{output_container}/{out_json_remote_path}")

    # Prepare the request body
    anomaly_detection_metadata = {
        "img_meta": {
            "blob_acct": blob_account_name,
            "sas_token": img_sas_tkn,
            "container": input_container,
            "blob_path": image_file_path
        },
        "ais_meta": {
            "blob_acct": blob_account_name,
            "sas_token": ais_sas_tkn,
            "container": output_container,
            "blob_path": ais_remote_path            
        },
        "bb_meta": {
            "blob_acct": blob_account_name,
            "sas_token": bb_sas_tkn,
            "container": output_container,
            "blob_path": bb_remote_json_path
        },
        "anom_det_meta": {
            "translate_options": config['translate_options'],
            "visualization_config": config['visualization_options'],
            "matching_config": config['matching_options']
        },
        "out_json_meta": {
            "blob_acct": blob_account_name,
            "sas_token": out_json_sas_tkn,
            "container": output_container,
            "blob_path": out_json_remote_path
        }
    }

    print(json.dumps(anomaly_detection_metadata, indent=4))

    # Call the endpoint and 
    try:
        #api auth key needs to match API_KEY in ship_anomaly_detection/gdal_server.py
        anom_det_resp = call_detect_anomalies(
            config['gdal_host']['app_url'],
            anomaly_detection_metadata,
            config['gdal_host']['key'])
        anom_det_resp.raise_for_status()
    except HTTPError as http_err:
        logger.error(f'HTTP error occurred: {http_err}')
        gdal_output = None
    except Exception as err:
        logger.error(f'Other error occurred: {err}')
        gdal_output = None
    else:
        logger.info(f'Success. Response: {anom_det_resp.status_code}')
        gdal_output = json.loads(anom_det_resp.text)

# Generate three overlays
## Helper function

In [ ]:
def call_make_overlay(gdal_host_url, make_overlay_metadata, key):
    try:
        resp = requests.post(
            url=f"{gdal_host_url}/make_overlay/",
            json=make_overlay_metadata,
            headers={
                'Content-Type': 'application/json',
                'Gdal-Subscription-Key': key,
                'KEY': key})
        resp.raise_for_status()
        resp_json = json.loads(resp.text)
        return(resp_json)
    except Exception as e:
        logger.error(f"Exception during call to make_overlay endpoint: {e}")
        return(None)

## AIS points only

In [ ]:
# Get a SAS token for the output image
out_ais_img_remote_path = ais_image
out_ais_img_sas_tkn = generate_blob_sas(
    account_name=blob_account_name,
    container_name=output_container,
    blob_name=out_ais_img_remote_path,
    account_key=storage_account_key,
    permission=BlobSasPermissions(write=True),
    expiry=datetime.utcnow() + timedelta(hours=1))

# Get all AIS points only
features = {
    'unpaired_ais_points': gdal_output['anomalies']['unpaired_ais_points'] + \
        [i[1] for i in gdal_output['anomalies']['bbox_ais_pairs']]
}

make_overlay_metadata = {
    'in_img_meta': {
        "blob_acct": blob_account_name,
        "sas_token": img_sas_tkn,
        "container": input_container,
        "blob_path": image_file_path
    },
    'out_img_meta': {
        'blob_acct': blob_account_name,
        'sas_token': out_ais_img_sas_tkn,
        'container': output_container,
        'blob_path': out_ais_img_remote_path
    },
    'translate_options': config['translate_options'],
    'visualization_config': config['visualization_options'],
    'features_json_str': json.dumps(features)
}

resp_json = call_make_overlay(
    config['gdal_host']['app_url'],
    make_overlay_metadata,
    config['gdal_host']['key'])

## All Bounding Boxes

In [ ]:
# Get a SAS token for the output image
out_bb_img_remote_path = ship_bb_image_high_res
out_bb_img_sas_tkn = generate_blob_sas(
    account_name=blob_account_name,
    container_name=output_container,
    blob_name=out_bb_img_remote_path,
    account_key=storage_account_key,
    permission=BlobSasPermissions(write=True),
    expiry=datetime.utcnow() + timedelta(hours=1))

# Get the bounding boxes only. Note that a confidence score threshold has
# been applied (it's defined in the config file) so this is a subset of the
# OD model's results.
features = {
    'unpaired_bboxes': gdal_output['anomalies']['unpaired_bboxes'] + [i[0] for i in gdal_output['anomalies']['bbox_ais_pairs']]
}

# The user can select a different color (by default, yellow) to show all
# bounding boxes. Make sure it's being used:
bbox_visualization_config = config['visualization_options'].copy()
bbox_visualization_config['anom_bb_color'] = bbox_visualization_config['all_bb_color']

make_overlay_metadata = {
    'in_img_meta': {
        "blob_acct": blob_account_name,
        "sas_token": img_sas_tkn,
        "container": input_container,
        "blob_path": image_file_path
    },
    'out_img_meta': {
        'blob_acct': blob_account_name,
        'sas_token': out_bb_img_sas_tkn,
        'container': output_container,
        'blob_path': out_bb_img_remote_path
    },
    'translate_options': config['translate_options'],
    'visualization_config': bbox_visualization_config,
    'features_json_str': json.dumps(features)
}

resp_json = call_make_overlay(
    config['gdal_host']['app_url'],
    make_overlay_metadata,
    config['gdal_host']['key'])

## Anomaly Bounding Boxes

In [ ]:
# Get a SAS token for the output image
out_bb_img_remote_path = anomaly_image
out_bb_img_sas_tkn = generate_blob_sas(
    account_name=blob_account_name,
    container_name=output_container,
    blob_name=out_bb_img_remote_path,
    account_key=storage_account_key,
    permission=BlobSasPermissions(write=True),
    expiry=datetime.utcnow() + timedelta(hours=1))

# Get the bounding boxes only. Note that a confidence score threshold has
# been applied (it's defined in the config file) so this is a subset of the
# OD model's results.
features = {
    'unpaired_bboxes': gdal_output['anomalies']['unpaired_bboxes'] 
}

# The user can select a different color (by default, yellow) to show all
# bounding boxes. Make sure it's being used:
bbox_visualization_config = config['visualization_options'].copy()
bbox_visualization_config['anom_bb_color'] = bbox_visualization_config['all_bb_color']

make_overlay_metadata = {
    'in_img_meta': {
        "blob_acct": blob_account_name,
        "sas_token": img_sas_tkn,
        "container": input_container,
        "blob_path": image_file_path
    },
    'out_img_meta': {
        'blob_acct': blob_account_name,
        'sas_token': out_bb_img_sas_tkn,
        'container': output_container,
        'blob_path': out_bb_img_remote_path
    },
    'translate_options': config['translate_options'],
    'visualization_config': bbox_visualization_config,
    'features_json_str': json.dumps(features)
}

resp_json = call_make_overlay(
    config['gdal_host']['app_url'],
    make_overlay_metadata,
    config['gdal_host']['key'])

## Bounding boxes and AIS points with pairs indicated

In [ ]:
# Get a SAS token for the output image
out_anom_img_remote_path = f'{output_path}/bb_ais_pairs.png'
out_anom_img_sas_tkn = generate_blob_sas(
    account_name=blob_account_name,
    container_name=output_container,
    blob_name=out_anom_img_remote_path,
    account_key=storage_account_key,
    permission=BlobSasPermissions(write=True),
    expiry=datetime.utcnow() + timedelta(hours=1))

# Use all features
features = {
    'unpaired_ais_points': gdal_output['anomalies']['unpaired_ais_points'],
    'unpaired_bboxes': gdal_output['anomalies']['unpaired_bboxes'],
    'bbox_ais_pairs': gdal_output['anomalies']['bbox_ais_pairs']
}

make_overlay_metadata = {
    'in_img_meta': {
        "blob_acct": blob_account_name,
        "sas_token": img_sas_tkn,
        "container": input_container,
        "blob_path": image_file_path
    },
    'out_img_meta': {
        'blob_acct': blob_account_name,
        'sas_token': out_anom_img_sas_tkn,
        'container': output_container,
        'blob_path': out_anom_img_remote_path
    },
    'translate_options': config['translate_options'],
    'visualization_config': config['visualization_options'],
    'features_json_str': json.dumps(features)
}

resp_json = call_make_overlay(
    config['gdal_host']['app_url'],
    make_overlay_metadata,
    config['gdal_host']['key'])

# Wrap up

In [ ]:
# Extract datetime and location information needed for the cell's output
my_datetime, location = '', []
for line in gdal_output['img_metadata'].split('\n'):
    if 'TIFFTAG_DATETIME=' in line:
        my_datetime = line.split('TIFFTAG_DATETIME=')[1].strip()
    if ('Upper Left' in line) or ('Lower Left' in line) or \
        ('Upper Right' in line) or ('Lower Right' in line):
        location.append(line.strip())
location = '\n'.join(location)

output = {
    'discrepancy': {
        'anomaly_found': True if len(gdal_output['anomalies']['unpaired_bboxes']) > 0 else False,
        'anomaly_location': location,
        'anomaly_time': my_datetime,
        'notebook_name': mssparkutils.runtime.context['notebookname']}}

# Return the object to the pipeline
logger.info(f"{mssparkutils.runtime.context['notebookname']}: OUTPUT", extra=output)
mssparkutils.notebook.exit(json.dumps(output['discrepancy']))