In [0]:
from azure.storage.fileshare import ShareServiceClient
import os
import tqdm
from pathlib import Path
from pyspark.sql import functions as F
import asyncio
from delta.tables import *
import time
import pydicom
from functools import lru_cache


In [0]:
MAX_CONCURRENT_COPIES = 20 # when it was 10, it took avg 9s/iter; when 20, it took avg 12s/iter; when 30, it took avg 18s/iter
RETRY_INTERVAL_HOURS = 12
MAX_ITERATIONS = 1000
MAX_TRIES = 10

In [0]:
start_time = spark.sql("SELECT CURRENT_TIMESTAMP() AS start_time").collect()[0]["start_time"]
print(start_time)

In [0]:
def getSrcFileClient(src_path):
    acc_name = dbutils.secrets.get(scope = "adc_store", key = "pacs_intfileshare_accname")
    acc_key = dbutils.secrets.get(scope = "adc_store", key = "pacs_intfileshare_acckey")

    # Connection string
    connection_string = f"DefaultEndpointsProtocol=https;AccountName={acc_name};AccountKey={acc_key};EndpointSuffix=core.windows.net"

    # File share name
    share_name = "intfileshare"

    # Get a share client via connection string
    share_client = ShareServiceClient.from_connection_string(connection_string).get_share_client(share_name)

    file_client = share_client.get_file_client(src_path)

    return file_client

In [0]:
@lru_cache(maxsize=128, typed=False)
def retrievePersonId(accession_nbr):
    if accession_nbr is None:
        return 'unknown'
    
    try:
        person_id = spark.sql(f"""
            SELECT MAX(MillPersonId) AS MillPersonId
            FROM 4_prod.pacs.all_pacs_ref_nbr
            WHERE refnbr = '{accession_nbr}'
        """).first()["MillPersonId"]
    except:
        person_id = 'unknown'
    
    if person_id is None:
        person_id = 'unknown'
    
    return str(person_id)

    

def updateDstDicomPatientId(dst_filepath):
    dcm = pydicom.dcmread(dst_filepath)

    if str(dst_filepath.split("/")[-1]).upper() == "DICOMDIR":
        for x in dcm.DirectoryRecordSequence:
            try:
                nbr = x["AccessionNumber"].value
                break
            except:
                continue

        pid = retrievePersonId(nbr)

        for x in dcm.DirectoryRecordSequence:
            try:
                x["PatientID"].value = pid
            except:
                continue

    elif str(dst_filepath.split("/")[-1].split(".")[-1]).lower() == "dcm":
        try:
            nbr = dcm["AccessionNumber"].value
        except:
            nbr = None

        dcm["PatientID"].value = retrievePersonId(nbr)

    else:
        try:
            nbr = dcm["AccessionNumber"].value
        except:
            nbr = None

        dcm["PatientID"].value = retrievePersonId(nbr)


    dcm.save_as(dst_filepath)




In [0]:

async def copySrcFileToDst(row):
    src_path = os.path.join(row["src_root"], row["src_proj_dir"], row["src_subdirs"], row["src_filename"])

    # Check if the item is already copied
    if row["copy_status"] == "done":
        copy_status = "done"
    else:
        try:
            file_client = getSrcFileClient(src_path)
            file_bytes = file_client.download_file().readall()

            # Make parent directory if not exist
            Path(row["dst_filepath"]).parent.absolute().mkdir(parents=True, exist_ok=True)

            # Write to Databricks
            with open(row["dst_filepath"], "wb") as f:
                f.write(file_bytes)
            
            copy_status = "done"

        except Exception as e:
            copy_status = "failed"


    # Delete source file after copying
    if row["src_delete_status"] == "done":
        delete_status = "done"
    elif copy_status != "done":
        delete_status = "pending"
    elif copy_status == "done" and not file_client.exists():
        delete_status = "done"
    else:
        try:
            # use the file client to delete it
            #file_client.delete_file()
            delete_status = "done"
        except:
            delete_status = "failed"

    delete_status = "pending"

    # Update Pateint ID in dst file
    if row["process_status"] == "done":
        process_status = "done"
    else:
        try:
            updateDstDicomPatientId(row["dst_filepath"])
            process_status = "done"
        except:
            process_status = "failed"

    

    
    return (row["item_id"], copy_status, process_status, delete_status)


In [0]:
def update_table(results):
    result_df = spark.createDataFrame(data=results, schema=["item_id", "copy_status", "process_status", "delete_status"])

    dt = DeltaTable.forName(spark, "1_inland.sectra.pacs_file_copy")
    merge_ret = dt.alias("t")\
        .merge(result_df.alias("s"), "t.item_id == s.item_id")\
        .whenMatchedUpdate(set ={
            "copy_status": "s.copy_status",
            "last_copy_run_at": "CURRENT_TIMESTAMP()",
            "num_copy_tries": "t.num_copy_tries+1",
            "process_status": "s.process_status",
            "last_process_run_at": "CURRENT_TIMESTAMP()",
            "num_process_tries": "t.num_process_tries+1",
            "src_delete_status": "s.delete_status",
            "last_delete_run_at": "CURRENT_TIMESTAMP()",
            "num_delete_tries": "t.num_delete_tries+1"
        }).execute()

async def async_update_table(results):
    result_df = spark.createDataFrame(data=results, schema=["item_id", "status"])

    dt = DeltaTable.forName(spark, "1_inland.sectra.pacs_file_copy")
    merge_ret = dt.alias("t")\
        .merge(result_df.alias("s"), "t.item_id == s.item_id")\
        .whenMatchedUpdate(set ={
            "copy_status": "s.status",
            "last_copy_run_at": "CURRENT_TIMESTAMP()",
            "num_copy_tries": "t.num_copy_tries+1",
            "process_status": "CASE WHEN s.status = 'done' THEN 'pending' ELSE t.process_status END"
        }).execute()

In [0]:
est_num_iter = spark.sql(f"""
    SELECT CAST(COUNT(*)/{MAX_CONCURRENT_COPIES} AS INT) + 1 AS estimation
    FROM 1_inland.sectra.pacs_file_copy
    WHERE 
        active_ind = 1 
        AND LOWER(copy_status) != 'done'
        AND num_copy_tries < {MAX_TRIES}
        AND (
            TIMEDIFF(HOUR, last_copy_run_at, CURRENT_TIMESTAMP()) > {RETRY_INTERVAL_HOURS}
            OR last_copy_run_at IS NULL)
""").collect()[0]["estimation"]


In [0]:

bg_update_tasks = set()


for _ in tqdm.tqdm(range(min(MAX_ITERATIONS, est_num_iter))):
    df = spark.sql(f"""
        SELECT *
        FROM 1_inland.sectra.pacs_file_copy
        WHERE 
            active_ind = 1
            AND (
                    (
                        LOWER(copy_status) != 'done'
                        AND num_copy_tries < {MAX_TRIES}
                        AND (TIMEDIFF(HOUR, last_copy_run_at, CURRENT_TIMESTAMP()) > {RETRY_INTERVAL_HOURS} OR last_copy_run_at IS NULL)
                    ) OR (
                        LOWER(copy_status) = 'done'
                        AND LOWER(process_status) != 'done'
                        AND num_process_tries < {MAX_TRIES}
                        AND (TIMEDIFF(HOUR, last_process_run_at, CURRENT_TIMESTAMP()) > {RETRY_INTERVAL_HOURS} OR last_process_run_at IS NULL)
                    ) OR (
                        LOWER(copy_status) = 'done'
                        AND LOWER(src_delete_status) != 'done'
                        AND num_delete_tries < {MAX_TRIES}
                        AND (TIMEDIFF(HOUR, last_delete_run_at, CURRENT_TIMESTAMP()) > {RETRY_INTERVAL_HOURS} OR last_delete_run_at IS NULL)
                    )
                )
            
        LIMIT {MAX_CONCURRENT_COPIES}     
    """)

    if df.count() > 0:
        pass
    else:
        print("\nNo pending job. Exiting loop.\n")
        break

    coros = [copySrcFileToDst(row) for row in df.collect()]
    results = await asyncio.gather(*coros)

    # nonasychonous update
    update_table(results)
    
    
    # asynchronous update: ~10% faster
    #task = asyncio.create_task(async_update_table(results))
    #bg_update_tasks.add(task)
    #task.add_done_callback(bg_update_tasks.discard)

while len(bg_update_tasks) > 0:
    print(f"\nwaiting for table update jobs (n={len(bg_update_tasks)})")
    time.sleep(10)
    

In [0]:
done_jobs = spark.sql(f"""
    SELECT COUNT(*) AS job_count
    FROM 1_inland.sectra.pacs_file_copy
    WHERE 
    (copy_status = 'done'AND last_copy_run_at > '{start_time}')
    OR (process_status = 'done' AND last_process_run_at > '{start_time}')
""").collect()[0]["job_count"]
print("Number of successful jobs:", done_jobs)

In [0]:
failed_jobs = spark.sql(f"""
    SELECT COUNT(*) AS job_count
    FROM 1_inland.sectra.pacs_file_copy
    WHERE 
    (copy_status = 'failed' AND last_copy_run_at > '{start_time}')
    OR (process_status = 'failed' AND last_process_run_at > '{start_time}')
""").collect()[0]["job_count"]
print("Number of failed jobs:", failed_jobs)

In [0]:
failed_jobs = spark.sql(f"""
    SELECT *
    FROM 1_inland.sectra.pacs_file_copy
    WHERE 
    (copy_status = 'failed' AND last_copy_run_at > '{start_time}')
    OR (process_status = 'failed' AND last_process_run_at > '{start_time}')
    LIMIT 1000
""")
display(failed_jobs)