# One-time Prerequisites:
Install azure blob storage and google auth packages

In [None]:
# %pip install azure-storage-blob==12.20.0
# %pip install google-auth
# %pip install tqdm

In [None]:
from azure.storage.blob import BlobClient

import google.auth
from google.auth.transport.requests import Request as GoogleAuthRequest

import json
import os
import requests
import time
import uuid

In [None]:
TSPS_URL = "https://tsps.dsde-dev.broadinstitute.org"
IMPUTATION_BEAGLE_VERSION = "0.0.1"

def get_access_token():
    """Get access token."""

    # scopes = ["https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.email"]
    # credentials = GoogleCredentials.get_application_default()
    # credentials = credentials.create_scoped(scopes)

    # return credentials.get_access_token().access_token

    credentials, _ = google.auth.default()
    
    credentials.refresh(GoogleAuthRequest())
    
    return credentials.token

def prepare_imputation_pipeline(multi_sample_vcf_path, output_basename, token):
    request_body = {
        "jobId": f"{uuid.uuid4()}",
        "pipelineVersion": "string",
        "pipelineInputs": {
            "multiSampleVcf": multi_sample_vcf_path,
            "outputBasename": output_basename
        }
    }

    uri = f"{TSPS_URL}/api/pipelineruns/v1/imputation_beagle/prepare"
    headers = {
        "Authorization": f"Bearer {token}",
        "accept": "application/json",
        "Content-Type": "application/json"
    }

    response = requests.post(uri, json=request_body, headers=headers)
    status_code = response.status_code

    if status_code != 200:
        raise Exception(response.text)

    response = json.loads(response.text)
    job_id = response['jobId']

    print(f"Successfully prepared imputation pipeline run with job_id {job_id}")

    return job_id, response['fileInputUploadUrls']

# run imputation beagle pipeline
def start_imputation_pipeline(job_id, description, token):
    request_body = {
        "description": description,
        "jobControl": {
            "id": job_id
        }
    }

    uri = f"{TSPS_URL}/api/pipelineruns/v1/imputation_beagle/start"
    headers = {
        "Authorization": f"Bearer {token}",
        "accept": "application/json",
        "Content-Type": "application/json"
    }

    response = requests.post(uri, json=request_body, headers=headers)
    status_code = response.status_code

    if status_code != 202:
        raise Exception(response.text)

    print(f"Successfully started imputation pipeline run for job_id {job_id}")
    return


# poll for imputation beagle job; if successful, return the pipelineOutput object (dict)
def check_imputation_job_status(job_id, token):
    uri = f"{TSPS_URL}/api/pipelineruns/v1/imputation_beagle/result/{job_id}"
    headers = {
        "Authorization": f"Bearer {token}",
        "accept": "application/json",
        "Content-Type": "application/json"
    }

    response = requests.get(uri, headers=headers)
    status_code = response.status_code
    response = json.loads(response.text)

    if status_code == 200:
        # job is completed, test for status
        if response['jobReport']['status'] == 'SUCCEEDED':
            print(f"pipeline has succeeded: {response}")
            # return the pipeline output dictionary
            return response['jobReport']['status'], response['pipelineOutput']
        else:
            return response['jobReport']['status'], response['errorReport']
    elif status_code == 202:
        print("tsps pipeline still running")
        return response['jobReport']['status'], None
        
    else:
        raise Exception(f'pipeline failed with a {status_code} status code. has response {response.text}')


def sizeof_fmt(num, suffix="B"):
    for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"):
        if abs(num) < 1024.0:
            return f"{num:3.1f}{unit}{suffix}"
        num /= 1024.0
    return f"{num:.1f}Yi{suffix}"
    

def upload_file_with_azcopy(local_file_path, write_sas_url):
    blob_client = BlobClient.from_blob_url(write_sas_url)
    file_size_bytes = os.path.getsize(local_file_path)
    file_size_human_readable = sizeof_fmt(file_size_bytes)

    print(f"uploading file from {local_file_path}, file size: {file_size_human_readable} \n")
    
    start = time.time()

    def upload_progress_report(response):
        current = response.context['upload_stream_current']  #There's also a 'download_stream_current'
        total = response.context['data_stream_total']

        if current is None:
            current = total
            
        percent_done = round(100*current/total, 1)
        duration_m = round((time.time() - start) / 60, 1)
        print(f"uploaded {sizeof_fmt(current)} out of {sizeof_fmt(total)} total ({percent_done}%, {duration_m} min elapsed) \t\t\t", end='\r')
    
    # upload the file
    with open(file=local_file_path, mode="rb") as blob_file:
        blob_client.upload_blob(blob_file, max_concurrency=8, raw_response_hook=upload_progress_report)


def download_with_azcopy(read_sas_url, local_file_path=None):
    blob_client = BlobClient.from_blob_url(read_sas_url)

    if local_file_path == None:
        # extract the file name from the sas url
        local_file_path = read_sas_url.split("?")[0].split("/")[-1] 

    start = time.time()
    
    def download_progress_report(response):
        current = response.context['download_stream_current']
        total = response.context['data_stream_total']
        if current is not None:
            percent_done = round(100*current/total, 1)
            duration_m = round((time.time() - start) / 60, 1)
            print(f"downloaded {sizeof_fmt(current)} out of {sizeof_fmt(total)} total ({percent_done}%, {duration_m} min elapsed) \t\t\t", end='\r')
    
    print(f"downloading file to {local_file_path} \n")

    # download the file
    with open(file=local_file_path, mode="wb") as blob_file:
        download_stream = blob_client.download_blob(max_concurrency=8, raw_response_hook=download_progress_report)
        blob_file.write(download_stream.readall())


## Prepare your imputation run

In [None]:
local_input_files = {
    "multiSampleVcf": "palantir_merged_input_samples.liftedover.vcf.gz"
}
output_basename = "palantir_merged_samples"
description = "notebook run to dev 5 from indiana - added maxRetries"

In [None]:
multi_sample_vcf_path = local_input_files["multiSampleVcf"]

job_id, file_input_upload_urls = prepare_imputation_pipeline(multi_sample_vcf_path, output_basename, get_access_token())

## Choose one of the following two methods to upload your data:

In [None]:
upload_method = "notebook"     # upload through jupyter notebook
# upload_method = "command line" # upload manually via your computer's command line

In [None]:
for input_file_key, sas_info in file_input_upload_urls.items():

    if upload_method == "notebook":
        upload_file_with_azcopy(local_input_files[input_file_key], sas_info["sasUrl"])
    
    elif upload_method == "command line":
        print(f"command to upload {input_file_key}:\n")
        print(sas_info['azcopyCommand'] + "\n")

## Start your prepared imputation run

In [None]:
start_imputation_pipeline(job_id, description, get_access_token())

## Check pipeline run status

In [None]:
response, output = check_imputation_job_status(job_id, get_access_token())
print(response)
print(output)

## Once pipeline run has status SUCCEEDED, retrieve your outputs

In [None]:
download_method = "notebook"
# download_method = "command line"

In [None]:
for output_file_key, read_sas_url in output.items():

    if download_method == "notebook":
        download_with_azcopy(read_sas_url)

    elif download_method == "command line":
        print(f"command to download {output_file_key}:\n")
        print(f"azcopy copy {read_sas_url} . \n")