In [0]:
from azure.storage.fileshare import ShareServiceClient
import os
import tqdm
from pathlib import Path
from pyspark.sql import functions as F
import asyncio
from delta.tables import *
import time


In [0]:
MAX_CONCURRENT_COPIES = 30 # when it was 10, it took avg 9s/iter; when 20, it took avg 12s/iter; when 30, it took avg 18s/iter
RETRY_INTERVAL_HOURS = 12
MAX_ITERATIONS = 1000
MAX_TRIES = 10

In [0]:
start_time = spark.sql("SELECT CURRENT_TIMESTAMP() AS start_time").collect()[0]["start_time"]
print(start_time)

In [0]:
def getSrcFileClient(src_path):
    acc_name = dbutils.secrets.get(scope = "adc_store", key = "pacs_intfileshare_accname")
    acc_key = dbutils.secrets.get(scope = "adc_store", key = "pacs_intfileshare_acckey")

    # Connection string
    connection_string = f"DefaultEndpointsProtocol=https;AccountName={acc_name};AccountKey={acc_key};EndpointSuffix=core.windows.net"

    # File share name
    share_name = "intfileshare"

    # Get a share client via connection string
    share_client = ShareServiceClient.from_connection_string(connection_string).get_share_client(share_name)

    file_client = share_client.get_file_client(src_path)

    return file_client

In [0]:

async def copySrcFileToDst(row):
    src_path = os.path.join(row["src_root"], row["src_proj_dir"], row["src_subdirs"], row["src_filename"])

    try:
        file_client = getSrcFileClient(src_path)
        file_bytes = file_client.download_file().readall()

        # Make parent directory if not exist
        Path(row["dst_filepath"]).parent.absolute().mkdir(parents=True, exist_ok=True)

        # Write to Databricks
        with open(row["dst_filepath"], "wb") as f:
            f.write(file_bytes)
    except Exception as e:
        return (row["item_id"], 'failed')
    
    return (row["item_id"], 'done')


In [0]:
def update_table(results):
    result_df = spark.createDataFrame(data=results, schema=["item_id", "status"])

    dt = DeltaTable.forName(spark, "1_inland.sectra.pacs_file_copy")
    merge_ret = dt.alias("t")\
        .merge(result_df.alias("s"), "t.item_id == s.item_id")\
        .whenMatchedUpdate(set ={
            "copy_status": "s.status",
            "last_copy_run_at": "CURRENT_TIMESTAMP()",
            "num_copy_tries": "t.num_copy_tries+1",
            "process_status": "CASE WHEN s.status = 'done' THEN 'pending' ELSE t.process_status END"
        }).execute()

async def async_update_table(results):
    result_df = spark.createDataFrame(data=results, schema=["item_id", "status"])

    dt = DeltaTable.forName(spark, "1_inland.sectra.pacs_file_copy")
    merge_ret = dt.alias("t")\
        .merge(result_df.alias("s"), "t.item_id == s.item_id")\
        .whenMatchedUpdate(set ={
            "copy_status": "s.status",
            "last_copy_run_at": "CURRENT_TIMESTAMP()",
            "num_copy_tries": "t.num_copy_tries+1",
            "process_status": "CASE WHEN s.status = 'done' THEN 'pending' ELSE t.process_status END"
        }).execute()

In [0]:
est_num_iter = spark.sql(f"""
    SELECT CAST(COUNT(*)/{MAX_CONCURRENT_COPIES} AS INT) + 1 AS estimation
    FROM 1_inland.sectra.pacs_file_copy
    WHERE 
        active_ind = 1 
        AND LOWER(copy_status) != 'done'
        AND num_copy_tries < {MAX_TRIES}
        AND (
            TIMEDIFF(HOUR, last_copy_run_at, CURRENT_TIMESTAMP()) > {RETRY_INTERVAL_HOURS}
            OR last_copy_run_at IS NULL)
""").collect()[0]["estimation"]


In [0]:

bg_update_tasks = set()


for _ in tqdm.tqdm(range(min(MAX_ITERATIONS, est_num_iter))):
    df = spark.sql(f"""
        SELECT *
        FROM 1_inland.sectra.pacs_file_copy
        WHERE 
            active_ind = 1 
            AND LOWER(copy_status) != 'done'
            AND num_copy_tries < {MAX_TRIES}
            AND (
                TIMEDIFF(HOUR, last_copy_run_at, CURRENT_TIMESTAMP()) > {RETRY_INTERVAL_HOURS}
                OR last_copy_run_at IS NULL)
        LIMIT {MAX_CONCURRENT_COPIES}     
    """)

    if df.count() > 0:
        pass
    else:
        print("\nNo pending job. Exiting loop.\n")
        break

    coros = [copySrcFileToDst(row) for row in df.collect()]
    results = await asyncio.gather(*coros)

    # nonasychonous update
    update_table(results)
    
    
    # asynchronous update: ~10% faster
    #task = asyncio.create_task(async_update_table(results))
    #bg_update_tasks.add(task)
    #task.add_done_callback(bg_update_tasks.discard)

while len(bg_update_tasks) > 0:
    print(f"\nwaiting for table update jobs (n={len(bg_update_tasks)})")
    time.sleep(10)
    

In [0]:
done_jobs = spark.sql(f"""
    SELECT COUNT(*) AS job_count
    FROM 1_inland.sectra.pacs_file_copy
    WHERE 
    copy_status = 'done'
    AND last_copy_run_at > '{start_time}'
""").collect()[0]["job_count"]
print("Number of successful jobs:", done_jobs)

In [0]:
failed_jobs = spark.sql(f"""
    SELECT COUNT(*) AS job_count
    FROM 1_inland.sectra.pacs_file_copy
    WHERE 
    copy_status = 'failed'
    AND last_copy_run_at > '{start_time}'
""").collect()[0]["job_count"]
print("Number of failed jobs:", failed_jobs)