In [None]:
import requests
import google.auth
import json
import time
from datetime import datetime, timedelta
from google.auth.transport.requests import Request as GoogleAuthRequest
from pprint import pprint

In [None]:
# cloud = "AZURE" 
cloud = "GOOGLE"

# for GOOGLE PROD workspace
# workspace_name = "Imputation_pipeline_testing"
# workspace_project = "morgan-fieldeng"

# for GOOGLE DEV workspace
workspace_name = "tsps_gcp_scratch_space_mma"
workspace_project = "general-dev-billing-account"

# for AZURE workspace
workspace_id = "a4d49543-65cb-43ec-a9c5-631e71fa77b9" # new
# workspace_id = "c8a8a57d-f9aa-4a66-8a1a-b0af1b3c10c8" # old: tsps_dev_bp_02_01_2024_v1/imputation-pipeline-testing-03-07-2024


In [None]:
!gcloud auth login --update-adc

In [None]:
ENV = "dev"
TEASPOONS_SA_EMAIL = f"tsps-{ENV}@broad-dsde-{ENV}.iam.gserviceaccount.com"

def get_access_token():
    """Get access token."""
    credentials, _ = google.auth.default()
    
    credentials.refresh(GoogleAuthRequest())
    
    return credentials.token


def get_workflows_url(cloud, workspace_id, workspace_project, workspace_name, token):
    if cloud == "AZURE":
        return get_cromwell_url_azure(workspace_id, token)
    elif cloud == "GOOGLE":
        return f"https://rawls.dsde-{ENV}.broadinstitute.org/api/workspaces/{workspace_project}/{workspace_name}"


def get_cromwell_url_azure(workspace_id, token):
    """"Get url for cromwell reader."""
    
    uri = f"https://leonardo.dsde-{ENV}.broadinstitute.org/api/apps/v2/{workspace_id}?includeDeleted=false"
    
    headers = {"Authorization": "Bearer " + token,
               "accept": "application/json"}
    
    response = requests.get(uri, headers=headers)
    status_code = response.status_code
    
    if status_code != 200:
        return response.text

    
    for entries in json.loads(response.text): 
        # pprint(entries)
        if entries['appType'] == 'WORKFLOWS_APP' and entries['proxyUrls']['cromwell-reader'] is not None:
            cromwell_url = entries['proxyUrls']['cromwell-reader']
            break
    
    if cromwell_url is None: 
        print("Cromwell is missing in current workspace")
        return
    else:
        return cromwell_url


def get_cromwell_workflow(cloud, workflows_url, submission_id, workflow_id, token):
    if cloud == "AZURE":
        uri = f"{workflows_url}/api/workflows/v1/{workflow_id}/metadata?includeKey=attempt&includeKey=start&includeKey=end&includeKey=status&includeKey=backendStatus&includeKey=executionStatus&includeKey=subWorkflowId&includeKey=workflowName"
        return get_cromwell_workflow_azure(uri, token)
    elif cloud == "GOOGLE":
        uri = f"{workflows_url}/submissions/{submission_id}/workflows/{workflow_id}?includeKey=attempt&includeKey=backendStatus&includeKey=status&includeKey=start&includeKey=end&includeKey=executionStatus&includeKey=shardIndex&includeKey=subWorkflowId&includeKey=workflowName&expandSubWorkflows=false"
        return get_cromwell_workflow_google(uri, token)


def get_cromwell_workflow_azure(uri, token):
    headers = {"Authorization": "Bearer " + token,
               "accept": "application/json"}
    
    response = requests.get(uri, headers=headers)
    status_code = response.status_code
    
    if status_code != 200:
        print("error fetching cromwell workflow metadata")
        print(response.text)

    return response.json()


def get_cromwell_workflow_google(uri, token, attempt=1):
    headers = {"Authorization": "Bearer " + token,
               "accept": "application/json"}
    
    response = requests.get(uri, headers=headers)
    status_code = response.status_code
    
    if status_code != 200:
        if status_code == 404 and attempt <= 3:
            attempt += 1
            # sometimes we get a transient 404, we should retry twice
            print(f"retrying call after {attempt} sec, attempt {attempt}")
            time.sleep(attempt)
            return get_cromwell_workflow_google(uri, token, attempt=attempt)
        print("error fetching cromwell workflow metadata")
        print(response.text)

    return response.json()


def clean_task_name(task_name, workflow_name="ImputationBeagle"):
    return task_name.replace(workflow_name + ".", "")

def print_in_red_bold(text):
    print('\x1b[1;31m' + text + '\x1b[0m')

def print_in_grey(text):
    print('\x1b[0;37m' + text + '\x1b[0m')

def get_retries(task_info):
    return [w['attempt'] for w in task_info if int(w['attempt']) > 1]

def get_task_durations(task_info, successful_only=True):
    if successful_only:
        return [[w['start'], w['end']] for w in task_info if ('end' in w and w['executionStatus'] == 'Done')]
    else:
        return [[w['start'], w['end']] for w in task_info if 'end' in w]

def display_task_statuses(cloud, workflows_url, submission_id, workflow_id, token):
    response = get_cromwell_workflow(cloud, workflows_url, submission_id, workflow_id, token)

    retries = {} # dict of "taskname": [list of attempt numbers]
    durations = {} # dict of "taskname": [list of durations]
    
    # pprint(response)
    
    print(f"status: {response['status']}")

    for task, info in response['calls'].items():
        n_total = len(info)
        n_failed = len([w for w in info if w['executionStatus'] == 'Failed'])
        failure_msg = "" if n_failed == 0 else f"- {n_failed} failures"
        n_succeeded = len([w for w in info if w['executionStatus'] == 'Done'])

        task_name = clean_task_name(task)
        
        # log retries
        task_retries = get_retries(info)
        n_retries = len(task_retries)
        if task_name in retries:
            retries[task_name].extend(task_retries)
        else:
            retries[task_name] = task_retries
        # retries[task_name] = n_retries if task_name not in retries else n_retries + retries[task_name]

        # log durations
        task_durations = get_task_durations(info)
        if task_name in durations:
            durations[task_name].extend(task_durations)
        else:
            durations[task_name] = task_durations

        # subtract retries from total number of tasks to get effective expected total
        line_to_print = f"{clean_task_name(task)} : {n_succeeded}/{n_total-n_retries} complete tasks {failure_msg}"
        if failure_msg:
            print_in_red_bold(line_to_print)
        elif n_succeeded == n_total-n_retries:
            print_in_grey(line_to_print)
        else:
            print(line_to_print)
    
        if ("ScatterAt" in task): # and ((n_failed > 0) or (n_succeeded < n_total))):
            for w in info:
                shard = w['shardIndex']
                
                if 'subWorkflowId' in w:
                    subworkflow_id = w['subWorkflowId']
                    
                    subworkflow_response = get_cromwell_workflow(cloud, workflows_url, submission_id, subworkflow_id, token)
                    
                    subworkflow_status = subworkflow_response['status']
                    shard_status_msg = f"  Shard {shard} status: {subworkflow_status}"

                    for subtask, subinfo in subworkflow_response['calls'].items():
                        subtask_name = clean_task_name(subtask)
                        
                        # log retries
                        subtask_retries = get_retries(subinfo)
                        if subtask_name in retries:
                            retries[subtask_name].extend(subtask_retries)
                        else:
                            retries[subtask_name] = subtask_retries
                        # n_retries = len(get_retries(subinfo))
                        # retries[subtask_name] = n_retries if subtask_name not in retries else n_retries + retries[subtask_name]

                        # log durations
                        subtask_durations = get_task_durations(subinfo)
                        if subtask_name in durations:
                            durations[subtask_name].extend(subtask_durations)
                        else:
                            durations[subtask_name] = subtask_durations

                    if subworkflow_status == "Succeeded":
                        print_in_grey(shard_status_msg)
                        
                    else:
                        print(shard_status_msg)
                        for subtask, subinfo in subworkflow_response['calls'].items():

                            n_total = len(subinfo)
                            
                            n_succeeded = len([w for w in subinfo if w['executionStatus'] == 'Done'])

                            failed = []
                            for w in subinfo:
                                if w['executionStatus'] == 'Failed':
                                    if 'backendStatus' in w:
                                        failed += [f"shard {w['shardIndex']} failed with status {w['backendStatus']}"]
                                    else:
                                        failed += [f"shard {w['shardIndex']} failed with unknown backendStatus"]
                            # failed = [f"shard {w['shardIndex']} failed with status {w['backendStatus']}" for w in subinfo if w['executionStatus'] == 'Failed']

                            n_failed = len(failed)
                            n_retries = len([w for w in subinfo if int(w['attempt']) > 1])
                            if n_failed > 0:
                                print_in_red_bold(f"     {clean_task_name(subtask)} : {n_succeeded}/{n_total-n_retries} total shards succeeded, {n_failed} failed. error messages:")
                                for message in failed:
                                    print_in_red_bold(f"       {message}")
                            elif n_succeeded == n_total:
                                print_in_grey(f"     {clean_task_name(subtask)} : Succeeded")
                            else:
                                print(f"     {clean_task_name(subtask)} : In progress; {n_succeeded}/{n_total-n_retries} shards succeeded so far")
                
                else: # no subWorkflowId 
                    print(f"  Shard {shard} has no workflowId yet")

    return retries, durations

def get_duration(timestamp_pair):
    format_string = "%Y-%m-%dT%H:%M:%S.%fZ"
    return datetime.strptime(timestamp_pair[1], format_string) - datetime.strptime(timestamp_pair[0], format_string)

def avg_td(timedeltas):
    if len(timedeltas) > 0:
        return sum(timedeltas, timedelta(0)) / len(timedeltas)
    else:
        return None


In [None]:
print(f"running on {ENV} {cloud}")

In [None]:
workflows_url = get_workflows_url(cloud, workspace_id, workspace_project, workspace_name, get_access_token())

In [None]:
token = get_access_token()

In [None]:
# AZURE DEV

# workflow_id = "aac49125-c06d-4e0f-9012-14e85b3f5aa5" # july 11 500 samples
# workflow_id = "93ac95f9-1edc-456c-b989-8578ef0355bd" # july 12 1000 samples failed
# workflow_id = "f0f8ada6-08bc-4faf-b749-3e660fb2ea76" # july 14 1000 samples
# workflow_id = "4e233132-173a-45f1-a8b4-751cea4914bb"

# submission_id = None

# GOOGLE PROD

# GOOGLE DEV
# submission_id = "c313cd58-eca1-4a22-8d95-0f1df44debca" # 42 samples
# workflow_id = "8f7250e9-e6a0-4563-9922-33939740b152"

# submission_id = "355548fd-71e9-45b3-b213-4e482e0dfedf"
# workflow_id = "0d297916-5cda-4a32-ac5d-79ad35b23bbc" # 500 samples
# workflow_id = "c80cbb40-30bc-408b-80ba-1f9c6cff873b" # 1000 samples
# workflow_id = "b72cbf26-7454-4c37-a64c-df0d5ccbe022" # 5000 samples

submission_id = "9ae58e2e-6559-4eef-a0c2-fedc4dd434bc"
workflow_id = "523fdcc8-2410-463f-806f-a7a6387484db"

retries, durations = display_task_statuses(cloud, workflows_url, submission_id, workflow_id, get_access_token())

In [None]:
n_retries = sum([len(rt) for rt in retries.values()])
print(f"Total retries executed: {n_retries}. Details:\n")
for task, retry_list in retries.items():
    print(f"{task}: {len(retry_list)}")
# pprint(retries)

In [None]:
print("Average completed task times (HH:MM:SS):\n")
for task, duration_list in durations.items():
    avg_duration = avg_td([get_duration(p) for p in duration_list])
    if avg_duration and "Scatter" not in task:
        print(f"{str(avg_duration).split('.')[0]} (n={len(duration_list)}) -> {task}")
    

In [None]:
# break down retries further for a task
task_name = 'CountVariantsInChunksBeagle'
task_retries = retries[task_name]
for attempt in range(2,5):
    n_attempts = sum([1 for rt in task_retries if rt == attempt])
    print(f"{attempt} attempts: {n_attempts}")
