In [ ]:
azure_storage_domain = ''
input_container = ''
blob_account_name = ''
image_file_path = ''
output_container = ''
key_vault_name = ''
ship_bb_image_high_res = ''
input_image_low_res = ''
config_path = ''
output_path = ''

In [ ]:
# Initiate logging
import logging
import base64
from opencensus.ext.azure.log_exporter import AzureLogHandler
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import config_integration
from opencensus.trace.samplers import AlwaysOnSampler
from opencensus.trace.tracer import Tracer

config_integration.trace_integrations(['logging'])

instrumentation_connection_string = mssparkutils.credentials.getSecretWithLS("keyvault", "AppInsightsConnectionString")

logger = logging.getLogger(__name__)
logger.addHandler(AzureLogHandler(connection_string=instrumentation_connection_string))
logger.setLevel(logging.INFO)

tracer = Tracer(
    exporter=AzureExporter(
        connection_string=instrumentation_connection_string
    ),
    sampler=AlwaysOnSampler()
)

# NOTE: this path should be in sync with Terraform configuration which uploads this file
global_config_path = f'abfss://configuration@{blob_account_name}.dfs.{azure_storage_domain}/anomdet.config.global.json'

# Spool parameters
run_time_parameters = {'custom_dimensions': {
    'input_container': input_container,
    'image_file_path': image_file_path,
    'blob_account_name': blob_account_name,
    'global_config_path': global_config_path,
    'notebook_name': mssparkutils.runtime.context['notebookname']
} }
 
logger.info(f"INITIALISED: {mssparkutils.runtime.context['notebookname']}", extra=run_time_parameters)

In [ ]:
import os, io, sys, math
import json
import glob
import requests
import copy

from requests.exceptions import HTTPError
from pathlib import Path
from pyspark.sql import SparkSession
from py4j.protocol import Py4JJavaError
from PIL import Image, UnidentifiedImageError
from azure.storage.blob import  ResourceTypes, ContainerClient, BlobServiceClient, generate_account_sas, generate_container_sas, generate_blob_sas, AccountSasPermissions, ContainerSasPermissions, BlobSasPermissions
from azure.identity import ClientSecretCredential
from datetime import datetime, timedelta

In [ ]:
# Initialise paths
image_path = f'https://{blob_account_name}.blob.{azure_storage_domain}/{input_container}/'
image_path_abfss = f'abfss://{input_container}@{blob_account_name}.dfs.{azure_storage_domain}/'
image_folder = os.path.dirname(image_file_path)
image_root = f"{image_path}{image_folder}"
image_root_abfss = f'{image_path_abfss}{image_folder}'
image_full_path = f"{image_path}{image_file_path}"
output_dir = f'https://{blob_account_name}.blob.{azure_storage_domain}/{output_container}/{output_path}'
output_dir_abfss = f'abfss://{output_container}@{blob_account_name}.dfs.{azure_storage_domain}/{output_path}'
output_root = f'{output_dir}/{image_folder}'

In [ ]:
with tracer.span(name=f'Preparing config from global config and loading into memory'):
    # Initialise session, create (if necessary) and read config
    sc = spark.sparkContext
    spark = SparkSession.builder.appName(f"AnomalyDetection {mssparkutils.runtime.context}").getOrCreate()

    def prepare_config(image_root: str, global_config_path: str):
        """
        This method makes sure that a config is availabile in the batch root.
        If a config file isn't already there, it is copied over form global_config_path.
        If there is no config under global_config_path, this function will crash (indicating an error in pipeline set up.)
        """
        image_config_path = f'{image_root_abfss}/anomdet.config.json'
        try: 
            mssparkutils.fs.head(image_config_path)
        except Py4JJavaError as e:
            if 'java.io.FileNotFoundException' in str(e):
                # File doesn't exist, copying over the global config path
                mssparkutils.fs.cp(global_config_path, image_config_path)    
            else:
                raise e


    # prepare_config(image_root=image_root, global_config_path=global_config_path)

    config = json.loads(''.join(sc.textFile(f'{image_path_abfss}/{config_path}').collect()))


In [ ]:
with tracer.span(name=f'Getting Credentials, creating BlobServiceClient and sas_token'):
    tenant_id = mssparkutils.credentials.getSecretWithLS('keyvault', 'TenantID')
    client_id = mssparkutils.credentials.getSecretWithLS('keyvault', 'ADAppRegClientId')
    client_secret = mssparkutils.credentials.getSecretWithLS('keyvault', 'ADAppRegClientSecret')
    storage_account_key = mssparkutils.credentials.getSecretWithLS('keyvault', 'StorageAccountKey')
    credential = ClientSecretCredential(tenant_id, client_id, client_secret)
    service = BlobServiceClient(account_url=f'https://{blob_account_name}.blob.{azure_storage_domain}/', credential=credential)
    sas_token = generate_account_sas(
        account_name=f'{blob_account_name}',
        account_key=f'{storage_account_key}',
        resource_types=ResourceTypes(service=False,container=True, object=True),
        permission=AccountSasPermissions(read=True, list=True, write=True, add=True, create=True, update=True),
        expiry=datetime.utcnow() + timedelta(hours=1)
    )

In [ ]:
with tracer.span(name=f'Process geoTIFF through GDAL.Transform'):
  
  def call_img_prep(gdal_endpoint,img_prep_meta,api_key):
    resp = ""
    try:
        headers = {
            # Request headers
            "Content-Type": "application/json",
            "Gdal-Subscription-Key": api_key,
            "KEY": api_key
        }
        body = img_prep_meta
        url = f"{gdal_endpoint}/img_prep/"
        resp = requests.post(url=url, json=body, headers=headers)
        #result_response = resp.json()
        #print(json.dumps(result_response, indent=4, sort_keys=True))
    except Exception as e:
        logger.error('Exception', e)
    return resp

  def call_info(gdal_endpoint,info_metadata,api_key):
    resp = ""
    try:
      headers = {
          # Request headers
          "Content-Type": "application/json",
          "Gdal-Subscription-Key": api_key,
          "KEY": api_key
      }
      body = info_metadata
      url = f"{gdal_endpoint}/img_info/"
      #print(url, body, headers)
      resp = requests.post(url=url, json=body, headers=headers)
      result_response = resp.json()
      print(json.dumps(result_response, indent=4, sort_keys=True))
    except Exception as e:
      logger.error('Exception', e)
    #return result_response
    return resp

  gdal_host_url = config['gdal_host']['app_url']
  out_img_blob_path = input_image_low_res

  in_blob_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                              container_name=input_container,
                              blob_name=image_file_path,
                              account_key=storage_account_key,
                              permission=BlobSasPermissions(read=True),
                              expiry=datetime.utcnow() + timedelta(hours=1))

  out_cont_sas_tkn = generate_container_sas(account_name=blob_account_name, 
                              container_name=output_container,
                              account_key=storage_account_key,
                              permission=ContainerSasPermissions(read=True, list=True, write=True, add=True, create=True, update=True),
                              expiry=datetime.utcnow() + timedelta(hours=1))

  translate_opts = config["translate_options"]
  tile_opts = config["tile_options"]
  in_img_meta = {
      "blob_acct": blob_account_name,
      "sas_token": in_blob_sas_tkn,
      "container": input_container,
      "blob_path": image_file_path,
  }
  out_cont_meta = {
      "blob_acct": blob_account_name,
      "sas_token": out_cont_sas_tkn,
      "container": output_container
  }
  img_prep_meta = {
      "in_image":in_img_meta,
      "out_container":out_cont_meta,
      "translate_options":translate_opts,
      "tile_options":tile_opts
  }
    
  #get img info
  info_config = { "format": "json"}
  gdal_info = { 
      "info_options": info_config,
      "in_img": in_img_meta
  }
  
  try:
    
    ########################################################################################
    #   get img info
    #
    #    * this is used to get the metadata of the input satellite image file
    #    * this is very useful for debugging future semantics involving the image metadata
    #    
    #   info_resp = call_info(gdal_host_url,gdal_info, config['gdal_host']['key'])
    #   logger.info(json.dumps(info_resp, indent=4, sort_keys=True))
    ########################################################################################

    #prepare image for sending off to inference api
    #api auth key needs to match API_KEY in ship_anomaly_detection/gdal_server.py
    img_prep_resp = call_img_prep(gdal_host_url, img_prep_meta,config['gdal_host']['key'])
    
    img_prep_resp.raise_for_status()
  except HTTPError as http_err:
    logger.error(f'HTTP error occurred: {http_err}')
  except Exception as err:
    logger.error(f'Other error occurred: {err}')
  else:
    logger.info(f'Success. Response: {img_prep_resp.status_code} - {img_prep_resp.text}')
    gdal_output = json.loads(img_prep_resp.text)

In [ ]:
with tracer.span(name=f'Run Custom Vision labeling'):
    
    def get_base64_encoded_image(image_path):
        """
        Converts an image to base64
        :param image_path: the filename and path
        :return:
        """
        with open(image_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode('utf-8')

    
    def predict(img_data):
        result_response = ""
        try:
            headers = {
                # Request headers
                "Content-Type": "application/json",
                "Ocp-Apim-Subscription-Key": API_KEY,
                "KEY": API_KEY
            }
            body = {
                "values": [
                    {
                        "recordId": "0",
                        "data": {
                            "images": {
                                "data": img_data
                            }
                        }
                    }
                ]
            }
            url = f"{WEB_APP_URL}/api/extraction"
            print(url, headers)

            resp = requests.post(url=url, json=body, headers=headers)
            print(resp)
            result_response = resp.json()
            
            #return result_response
        except Exception as e:
            logger.error('Exception', e)
            raise Exception(e)

        return result_response


    def download(client, source, dest):
        '''
        Download a file or directory to a path on the local filesystem
        '''
        if not dest:
            raise Exception('A destination must be provided')

        blobs = ls_files(client, source, recursive=True)
        if blobs:
        # if source is a directory, dest must also be a directory
            if not source == '' and not source.endswith('/'):
                source += '/'
            if not dest.endswith('/'):
                dest += '/'
        # append the directory name from source to the destination
            dest += os.path.basename(os.path.normpath(source)) + '/'

            blobs = [source + blob for blob in blobs]
            for blob in blobs:
                blob_dest = dest + os.path.relpath(blob, source)
                download_file(client, blob, blob_dest)
        else:
            download_file(client, source, dest)

    def download_file(client, source, dest):
        '''
        Download a single file to a path on the local filesystem
        '''
        # dest is a directory if ending with '/' or '.', otherwise it's a file
        if dest.endswith('.'):
            dest += '/'
        blob_dest = dest + os.path.basename(source) if dest.endswith('/') else dest

        print(f'Downloading {source} to {blob_dest}')
        os.makedirs(os.path.dirname(blob_dest), exist_ok=True)
        bc = client.get_blob_client(blob=source)
        with open(blob_dest, 'wb') as file:
            data = bc.download_blob()
            file.write(data.readall())

    def ls_files(client, path, recursive=False):
        '''
        List files under a path, optionally recursively
        '''
        if not path == '' and not path.endswith('/'):
            path += '/'

        blob_iter = client.list_blobs(name_starts_with=path)
        files = []
        for blob in blob_iter:
            relative_path = os.path.relpath(blob.name, path)
            if recursive or not '/' in relative_path:
                files.append(relative_path)
        return files

    def ls_dirs(client, path, recursive=False):
        '''
        List directories under a path, optionally recursively
        '''
        if not path == '' and not path.endswith('/'):
            path += '/'

        blob_iter = client.list_blobs(name_starts_with=path)
        dirs = []
        for blob in blob_iter:
            relative_dir = os.path.dirname(os.path.relpath(blob.name, path))
            if relative_dir and (recursive or not '/' in relative_dir) and not relative_dir in dirs:
                dirs.append(relative_dir)

        return dirs

    blob_service_client = BlobServiceClient(account_url=f'https://{blob_account_name}.blob.{azure_storage_domain}/', credential=storage_account_key)
    container_client = blob_service_client.get_container_client(output_container)
    remote_container_files = ls_files(container_client, f'{output_path}/tiles', recursive=True)
    remote_tile_paths = []
    for possible_tile_path in remote_container_files:
        if os.path.splitext(possible_tile_path)[1][-3:] == config["tile_options"]["out_type"]:
            remote_tile_paths.append(possible_tile_path)

    API_KEY = config['computer_vision']['key'] # This is a secret key on your app service - only requests with this key will be allowed
    WEB_APP_URL = config['computer_vision']['app_url'] # # "http://0.0.0.0:6000" #

    chips_and_chip_bounding_boxes = {}
    for chip_file in remote_tile_paths:
        try:
            download_file(container_client, f'{output_path}/tiles/{chip_file}', f'tiles/{chip_file}')
        except Exception as err:
            logger.error(f'Other error occurred: {err}')
        else:
            logger.info(f'Success. Downloaded: {chip_file}')

        try:
            chips_and_chip_bounding_boxes[chip_file] = predict(get_base64_encoded_image(f'tiles/{chip_file}'))
        except Exception as err:
            logger.error(f'Other error occurred: {err}')
        else:
            logger.info(f'Success. Response: {chips_and_chip_bounding_boxes[chip_file]}')

In [ ]:
with tracer.span(name=f'get vision model to work with aoai geolocate'):
    class NormalAoaiGeolocateInput:
        """
        The Target defines the domain-specific interface used by the client code.
        """
        def request(self) -> str:
            return "Target: The default target's behavior."

    class VisionModelReturn:
        """
        The Adaptee contains some useful behavior, but its interface is incompatible
        with the existing client code. The Adaptee needs some adaptation before the
        client code can use it.
        """
        def __init__(self, chip_name, data) -> None:
            self.chip_name = chip_name
            self.data = data

        def specific_request(self) -> str:
            return self.data

    class Adapter(NormalAoaiGeolocateInput):
        """
        The Adapter makes the Adaptee's interface compatible with the Target's
        interface via composition.
        """
        def __init__(self, visionmodelreturn: VisionModelReturn) -> None:
            self.visionmodelreturn = visionmodelreturn
            #print('Input \n%s' %  json.dumps(self.visionmodelreturn.data, indent = 4, sort_keys=True))
        
        @staticmethod
        def get_full_image_cs(img_tile_size, tile_overlap, chip_annot, offset):
            full_img_annot = {}
            full_img_annot['bottomX'] = math.floor(float(chip_annot['bottomX']) * float(img_tile_size)) + ((int(offset['x']) - 1) * (int(img_tile_size) - int(tile_overlap)))
            full_img_annot['topX'] = math.floor(float(chip_annot['topX']) * float(img_tile_size)) + ((int(offset['x']) - 1) * (int(img_tile_size) - int(tile_overlap)))
            full_img_annot['bottomY'] = math.floor(float(chip_annot['bottomY']) * float(img_tile_size)) + ((int(offset['y']) - 1) * (int(img_tile_size) - int(tile_overlap)))
            full_img_annot['topY'] = math.floor(float(chip_annot['topY']) * float(img_tile_size)) + ((int(offset['y']) - 1) * (int(img_tile_size) - int(tile_overlap)))
            #print(full_img_annot)
            return full_img_annot

        def request(self) -> str:
            """
            multiply box values coordinates that are in percentage image cs by size of image chip to get box values coordinates that are in pixels image cs 
            add chip offset from origin to each 
            """
            tile_size = config["tile_options"]["tile_size"]
            #get the chip offset from origin that's written in the chip filename
            chip_offset = {}
            chip_offset['y'], chip_offset['x'] = (self.visionmodelreturn.chip_name.split('_'))[-2:]
            chip_offset['x'] = os.path.splitext(chip_offset['x'])[0]
            print(f"tile size: {tile_size}  chip offsets: {int(chip_offset['x'])}, {int(chip_offset['y'])}")
            self.adapted_returns = copy.deepcopy(self.visionmodelreturn.data)

            for idx, detection in enumerate(self.visionmodelreturn.data['values'][0]['data']['boxes']):
                print('Detection %d before: %s' %  (idx, json.dumps(self.visionmodelreturn.data['values'][0]['data']['boxes'][idx], indent = 4, sort_keys=True)))
                full_image_cs_annot = Adapter.get_full_image_cs(img_tile_size=tile_size, tile_overlap=config['tile_options']['overlap'], chip_annot=detection['box'], offset=chip_offset)
                self.adapted_returns['values'][0]['data']['boxes'][idx]['box']['bottomX'] = full_image_cs_annot['bottomX']
                self.adapted_returns['values'][0]['data']['boxes'][idx]['box']['topX'] = full_image_cs_annot['topX']
                self.adapted_returns['values'][0]['data']['boxes'][idx]['box']['bottomY'] = full_image_cs_annot['bottomY']
                self.adapted_returns['values'][0]['data']['boxes'][idx]['box']['topY'] = full_image_cs_annot['topY']
                print('Detection %d after : %s' %  (idx, json.dumps(self.adapted_returns['values'][0]['data']['boxes'][idx], indent = 4, sort_keys=True)))

            return self.adapted_returns

    def download_file(client, source, dest):
        '''
        Download a single file to a path on the local filesystem
        '''
        # dest is a directory if ending with '/' or '.', otherwise it's a file
        if dest.endswith('.'):
            dest += '/'
        blob_dest = dest + os.path.basename(source) if dest.endswith('/') else dest

        os.makedirs(os.path.dirname(blob_dest), exist_ok=True)
        bc = client.get_blob_client(blob=source)
        with open(blob_dest, 'wb') as file:
            data = bc.download_blob()
            file.write(data.readall())
    
    def upload_file(client, source, dest):
        '''
        Upload a single file to a path inside the container
        '''
        with open(source, 'rb') as data:
            client.upload_blob(name=dest, data=data, overwrite=True)

    #init access and files
    blob_service_client = BlobServiceClient(account_url=f'https://{blob_account_name}.blob.{azure_storage_domain}/', credential=storage_account_key)
    container_client = blob_service_client.get_container_client(output_container)

    #using converted_bounding_boxes for big picture image 
    converted_bounding_boxes = {}
    #using big_list_of_bbs to write to blob and use for image cs to lat long cs
    big_list_of_bbs = []

    #for each chip that we ran ship inference
    for chip_name, data in chips_and_chip_bounding_boxes.items():
        vision_return = VisionModelReturn(chip_name,data)
        converted_bounding_boxes[chip_name] = Adapter(vision_return)
        converted_bounding_boxes[chip_name].request()
        logger.debug(f"chip:%s \ndata:\n%s" % (chip_name,json.dumps(converted_bounding_boxes[chip_name].adapted_returns['values'][0]['data']['boxes'], indent = 4, sort_keys=True)))

        big_list_of_bbs += converted_bounding_boxes[chip_name].adapted_returns['values'][0]['data']['boxes']
        #print(f"box data:\n%s" % (json.dumps(big_list_of_bbs, indent = 4, sort_keys=True)))

    #create a file to write full image cs bounding boxes as json
    bb_local_json_name = 'bb_' + os.path.splitext(os.path.split(image_file_path)[1])[0]+'.json'
    bb_remote_json_path = f'{output_path}/{bb_local_json_name}'
    with open(bb_local_json_name, 'w+',encoding="utf-8") as file:
        json.dump(big_list_of_bbs, file)
    try:
        upload_file(container_client,bb_local_json_name, bb_remote_json_path)
    except Exception as e:
        logger.error(f'Other error occurred: {e}')
    else:
         logger.info(f'Success. Upload to blob: {bb_remote_json_path}')


In [ ]:
with tracer.span(name=f'get bounding boxes in format that easily compares to AIS'):

    def call_bb_translate(gdal_endpoint,bb_translate_metadata,api_key):
        resp = ""
        try:
            headers = {
                # Request headers
                "Content-Type": "application/json",
                "Gdal-Subscription-Key": api_key,
                "KEY": api_key
            }
            body = bb_translate_metadata
            url = f"{gdal_endpoint}/bb_translate/"
            #print(url, body, headers)
            resp = requests.post(url=url, json=body, headers=headers)
            #result_response = resp.json()
            #print(json.dumps(result_response, indent=4, sort_keys=True))
        except Exception as e:
            logger.error('Exception', e)
        return resp

    #init access and files
    blob_service_client = BlobServiceClient(account_url=f'https://{blob_account_name}.blob.{azure_storage_domain}/', credential=storage_account_key)
    container_client = blob_service_client.get_container_client(output_container)
    gdal_host_url = config['gdal_host']['app_url']
    
    bb_local_json_name = 'bb_' + os.path.splitext(image_file_path)[0]+'.json'
    bb_remote_json_path = os.path.splitext(image_file_path)[0] + f'/{bb_local_json_name}'
    bb_local_json_path = f'full_img/{bb_local_json_name}'
    proj_bb_local_json_name = 'proj_' + bb_local_json_name
    proj_bb_remote_json_path = os.path.splitext(image_file_path)[0] + f'/{proj_bb_local_json_name}'

    in_img_blob_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                                container_name=input_container,
                                blob_name=image_file_path,
                                account_key=storage_account_key,
                                permission=BlobSasPermissions(read=True),
                                expiry=datetime.utcnow() + timedelta(hours=1))

    in_bb_blob_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                              container_name=output_container,
                              blob_name=bb_remote_json_path,
                              account_key=storage_account_key,
                              permission=BlobSasPermissions(read=True),
                              expiry=datetime.utcnow() + timedelta(hours=1))
    out_bb_blob_sas_tkn = generate_blob_sas(account_name=blob_account_name, 
                                container_name=output_container,
                                blob_name=proj_bb_remote_json_path,
                                account_key=storage_account_key,
                                permission=BlobSasPermissions(write=True),
                                expiry=datetime.utcnow() + timedelta(hours=1))

    in_img_metadata = {
        "blob_acct": blob_account_name,
        "sas_token": in_blob_sas_tkn,
        "container": input_container,
        "blob_path": image_file_path
    }
    in_bb_metadata = {
        "blob_acct": blob_account_name,
        "sas_token": in_bb_blob_sas_tkn,
        "container": output_container,
        "blob_path": bb_remote_json_path
    }
    out_bb_metadata = {
        "blob_acct": blob_account_name,
        "sas_token": out_bb_blob_sas_tkn,
        "container": output_container,
        "blob_path": proj_bb_remote_json_path
    }

    gdal_bb_translate = {
        "in_img":in_img_metadata,
        "in_bbs":in_bb_metadata,
        "out_bbs":out_bb_metadata
    }

    try:
        #api auth key needs to match API_KEY in ship_anomaly_detection/gdal_server.py
        #translate bbs
        bb_translate_resp = call_bb_translate(gdal_host_url, gdal_bb_translate, config['gdal_host']['key'])
        
        bb_translate_resp.raise_for_status()
    except HTTPError as http_err:
        logger.error(f'HTTP error occurred: {http_err}')
    except Exception as err:
        logger.error(f'Other error occurred: {err}')
    else:
        logger.info(f'Success. Response: {bb_translate_resp.status_code} - {bb_translate_resp.text}')
        gdal_output = json.loads(bb_translate_resp.text)
    

