In [None]:
import requests
import google.auth
import json
from google.auth.transport.requests import Request as GoogleAuthRequest
from pprint import pprint

In [None]:
ENV = "dev"
TEASPOONS_SA_EMAIL = f"tsps-{ENV}@broad-dsde-{ENV}.iam.gserviceaccount.com"

def get_access_token():
    """Get access token."""
    credentials, _ = google.auth.default()
    
    credentials.refresh(GoogleAuthRequest())
    
    return credentials.token


def get_cromwell_url(workspace_id, token):
    """"Get url for cromwell reader."""
    
    uri = f"https://leonardo.dsde-{ENV}.broadinstitute.org/api/apps/v2/{workspace_id}?includeDeleted=false"
    
    headers = {"Authorization": "Bearer " + token,
               "accept": "application/json"}
    
    response = requests.get(uri, headers=headers)
    status_code = response.status_code
    
    if status_code != 200:
        return response.text

    
    for entries in json.loads(response.text): 
        # pprint(entries)
        if entries['appType'] == 'WORKFLOWS_APP' and entries['proxyUrls']['cromwell-reader'] is not None:
            cromwell_url = entries['proxyUrls']['cromwell-reader']
            break
    
    if cromwell_url is None: 
        print("Cromwell is missing in current workspace")
        return
    else:
        return cromwell_url

def get_cromwell_workflow(cromwell_url, workflow_id, token):

    uri = f"{cromwell_url}/api/workflows/v1/{workflow_id}/metadata?includeKey=end&includeKey=status&includeKey=backendStatus&includeKey=executionStatus&includeKey=subWorkflowId&includeKey=workflowName"

    headers = {"Authorization": "Bearer " + token,
               "accept": "application/json"}
    
    response = requests.get(uri, headers=headers)
    status_code = response.status_code
    
    if status_code != 200:
        return response.text

    return response.json()

def clean_task_name(task_name, workflow_name="ImputationBeagle"):
    return task_name.replace(workflow_name + ".", "")

def print_in_red_bold(text):
    print('\x1b[1;31m' + text + '\x1b[0m')

def print_in_grey(text):
    print('\x1b[0;37m' + text + '\x1b[0m')

def display_task_statuses(cromwell_url, workflow_id, token):
    response = get_cromwell_workflow(cromwell_url, workflow_id, token)
    print(f"status: {response['status']}")
    # pprint(response)

    for task, info in response['calls'].items():
        n_total = len(info)
        # note the following doesn't deal with retries smartly
        n_failed = len([w for w in info if w['executionStatus'] == 'Failed'])
        failure_msg = "" if n_failed == 0 else f"- {n_failed} failures"
        n_complete = len([w for w in info if 'end' in w])
        
        line_to_print = f"{clean_task_name(task)} : {n_complete}/{n_total} complete tasks {failure_msg}"
        if failure_msg:
            print_in_red_bold(line_to_print)
        elif n_complete == n_total:
            print_in_grey(line_to_print)
        else:
            print(line_to_print)
    
        if ("ScatterAt" in task and ((n_failed > 0) or (n_complete < n_total))):
            for w in info:
                shard = w['shardIndex']
                
                if 'subWorkflowId' in w:
                    subworkflow_id = w['subWorkflowId']
                    
                    subworkflow_response = get_cromwell_workflow(cromwell_url, subworkflow_id, token)
                    
                    subworkflow_status = subworkflow_response['status']
                    shard_status_msg = f"  Shard {shard} status: {subworkflow_status}"

                    if subworkflow_status == "Succeeded":
                        print_in_grey(shard_status_msg)
                    else:
                        print(shard_status_msg)
                        for subtask, subinfo in subworkflow_response['calls'].items():
                            n_total = len(subinfo)
                            # the following doesn't deal with retries smartly
                            n_succeeded = len([w for w in subinfo if w['executionStatus'] == 'Done'])
                            
                            failed = [f"shard {w['shardIndex']} failed with status {w['backendStatus']}" for w in subinfo if w['executionStatus'] == 'Failed']

                            n_failed = len(failed)
                            if n_failed > 0:
                                print_in_red_bold(f"     {clean_task_name(subtask)} : {n_succeeded}/{n_total} total shards succeeded, {n_failed} failed. error messages:")
                                for message in failed:
                                    print_in_red_bold(f"       {message}")
                            elif n_succeeded == n_total:
                                print_in_grey(f"     {clean_task_name(subtask)} : Succeeded")
                            else:
                                print(f"     {clean_task_name(subtask)} : In progress; {n_succeeded}/{n_total} shards succeeded so far")
                
                else: # no subWorkflowId 
                    print(f"  Shard {shard} has no workflowId yet")

In [None]:
workspace_id = "YOUR WORKSPACE ID HERE"

In [None]:
token = get_access_token()

In [None]:
cromwell_url = get_cromwell_url(workspace_id, token)

In [None]:
workflow_id = "YOUR WORKFLOW ID HERE" 

display_task_statuses(cromwell_url, workflow_id, token)