In [None]:
import json
import os
import time
from google.cloud import storage
from google.cloud.storage.blob import Blob
import pandas as pd
import optuna
import firecloud.api as fapi
import pickle
from io import StringIO

# Google Cloud and Terra setup
PROJECT_ID = '1091079109155' 
BUCKET_NAME = 'fc-0c540a8a-11ec-4cf7-b7ff-39de06b1bca3'
billing_project = 'broad-firecloud-dsde-methods'
workspace = 'malaria-filtering-optimization-staging_monica'
cnamespace = 'malaria-filtering-optimization-staging'
configname = 'FilterAndEvaluate'

storage_client = storage.Client(project=PROJECT_ID)


def get_hyperparameters_path_from_config():
    """Retrieve the hyperparameters JSON path from the workspace configuration."""
    config = fapi.get_workspace_config(billing_project, workspace, cnamespace, configname).json()
    hyperparameters_json_path = config['inputs']['FilterAndEvaluate.JointVcfFiltering.hyperparameters_json']
    return hyperparameters_json_path.strip('\"')  # Remove extra quotes

def update_and_upload_hyperparameters(trial, config_path):
    # Parse the bucket name and the path inside the bucket from the config path
    bucket_name, base_path = config_path.replace("gs://", "").split("/", 1)
    directory_path, json_filename = os.path.split(base_path)
    
    # Extract the base model name from the file name without the extension
    # Assume the base model name does not contain underscores followed by 'trial'
    base_model_name = json_filename.split('_trial_')[0]
    base_model_name, _ = os.path.splitext(base_model_name)
    
    # Create a trial-specific JSON filename
    trial_specific_filename = f"{base_model_name}_trial_{trial.number}.json"
    trial_specific_path = os.path.join(directory_path, trial_specific_filename)

    # Initialize Google Cloud Storage bucket and blob
    bucket = storage_client.bucket(bucket_name)
    
    # Check if the original hyperparameters JSON file exists and load it
    original_blob = bucket.blob(base_path)
    if not original_blob.exists():
        raise FileNotFoundError(f"The hyperparameters file does not exist at the path '{config_path}'")
    config = json.loads(original_blob.download_as_text())

    # Update config with new trial data
    new_updated_values = {}
    search_space = config.get('search_space', {})
    for param, specs in search_space.items():
        if specs['type'] == 'int':
            new_updated_values[param] = trial.suggest_int(param, specs['low'], specs['high'])
        elif specs['type'] == 'float':
            new_updated_values[param] = trial.suggest_float(param, specs['low'], specs['high'])
        elif specs['type'] == 'categorical':
            new_updated_values[param] = trial.suggest_categorical(param, specs['options'])

    config['updated_values'] = new_updated_values

    # Save updated JSON in the directory specified by the original config path
    updated_blob = bucket.blob(trial_specific_path)
    updated_blob.upload_from_string(json.dumps(config), content_type='application/json')

    return f"gs://{bucket_name}/{trial_specific_path}"



def submit_trial_workflow(trial, updated_hyperparameters_path):
    # Retrieve current workspace configuration
    workspace_config = fapi.get_workspace_config(billing_project, workspace, cnamespace, configname).json()

    # Update the hyperparameters JSON path in the workspace configuration
    workspace_config['inputs']['FilterAndEvaluate.JointVcfFiltering.hyperparameters_json'] = f'"{updated_hyperparameters_path}"'

    # Update the workspace configuration with the new hyperparameters JSON path
    update_response = fapi.update_workspace_config(billing_project, workspace, cnamespace, configname,
        body=workspace_config)
    if update_response.status_code != 200:
        raise Exception(f"Failed to update workspace config: {update_response.text}")

    # Submit the workflow with the updated configuration
    submission_response = fapi.create_submission(billing_project, workspace, cnamespace, configname)
    if submission_response.status_code == 201:
        submission_id = submission_response.json()['submissionId']
        return submission_id
    else:
        raise Exception(f"Failed to submit workflow: {submission_response.text}")


def wait_for_workflow_completion(submission_id):
    while True:
        submission_status = fapi.get_submission(billing_project, workspace, submission_id).json()
        if submission_status['status'] in ['Done', 'Aborted']:
            workflow_id = submission_status['workflows'][0]['workflowId']
            return workflow_id
        time.sleep(60)

def fetch_process_and_extract_scores(billing_project, workspace, submission_id, workflow_id):
    # Fetch workflow outputs to get the metrics TSV file path
    outputs_response = fapi.get_workflow_outputs(billing_project, workspace, submission_id, workflow_id)
    if outputs_response.status_code != 200:
        print("Failed to get workflow outputs")
        return None

    try:
        # Extract the metrics file path from the response
        metrics_file_path = outputs_response.json()['tasks']['FilterAndEvaluate']['outputs']['FilterAndEvaluate.metrics_pkl']
        bucket_name, blob_path = metrics_file_path.replace('gs://', '').split('/', 1)
    except KeyError:
        print("Metrics file path not found in workflow outputs.")
        return None

    # Download and load the pickle file
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(blob_path)
    score_data = pickle.loads(blob.download_as_bytes())

    # Process the loaded score data to extract the required scores
    try:
        snp_score_roc_auc = score_data['SCORE']['snp']['ROC_AUC']
        print("Extracted SNP SCORE ROC_AUC:", snp_score_roc_auc)
        return snp_score_roc_auc  # Returning ROC_AUC for SNP for hyperparameter tuning
    except KeyError:
        print("Required score data not found in the pickle file.")
        return None

def save_trial_outputs(trial_number, submission_id, workflow_id):
    outputs_response = fapi.get_workflow_outputs(billing_project, workspace, submission_id, workflow_id)
    if outputs_response.status_code != 200:
        raise Exception("Failed to get workflow outputs")

    # The base path where trial-specific outputs will be saved
    base_path = f"trial_{trial_number}"

    # Retrieve the outputs from the response
    outputs = outputs_response.json()['tasks']['FilterAndEvaluate']['outputs']

    # Access the bucket
    bucket = storage_client.bucket(BUCKET_NAME)

    # Process each output file
    for output_type, file_paths in outputs.items():
        if not isinstance(file_paths, list):
            file_paths = [file_paths]
        for file_path in file_paths:
            # Extract the filename and create a trial-specific path
            _, file_name = os.path.split(file_path.replace('gs://', ''))
            trial_specific_path = os.path.join(base_path, file_name)
            
            # Copy the file to the trial-specific path
            source_blob = bucket.blob(file_path.replace(f"gs://{BUCKET_NAME}/", ""))
            new_blob = bucket.copy_blob(source_blob, bucket, trial_specific_path)
            #print(f"Copied {file_path} to {new_blob.public_url}")

def objective(trial):
    try:
        config_path = get_hyperparameters_path_from_config()
        updated_hyperparameters_path = update_and_upload_hyperparameters(trial, config_path)
        
        submission_id = submit_trial_workflow(trial, updated_hyperparameters_path)
        workflow_id = wait_for_workflow_completion(submission_id)
        
        # Store the submission ID in trial's user attributes and return the score
        trial.set_user_attr("submission_id", submission_id)
        
        score = fetch_process_and_extract_scores(billing_project, workspace, submission_id, workflow_id)
        return score if score is not None else float('-inf')
        
    except Exception as e:
        print(f"Error during trial {trial.number}: {str(e)}")
        return float('-inf')

def run_optimization():
    try:
        study = optuna.create_study(direction='maximize')
        study.optimize(objective, n_trials=2)  # Adjust the number of trials as necessary
        best_trial = max((trial for trial in study.trials if trial.value != float('-inf')),
                         key=lambda t: t.value, default=None)
        
        if best_trial:
            print(f"Best successful trial: {best_trial.number}")
            print(f"Best value: {best_trial.value}")
            for key, value in best_trial.params.items():
                print(f"{key}: {value}")
                
            # Retrieve the hyperparameters path used in the best trial
            best_hyperparameters_path = update_and_upload_hyperparameters(best_trial, get_hyperparameters_path_from_config())
            
            # Execute the test phase only once with these best hyperparameters and specific test settings
            test_scores = test_phase(best_hyperparameters_path)

            if test_scores:
                print("Test scores have been successfully extracted and processed.")
            else:
                print("Unable to extract or process test scores.")
        else:
            print("No successful trials were completed.")
    except Exception as e:
        print(f"An error occurred during the optimization or test phase: {e}")


def test_phase(updated_hyperparameters_path):
    workspace_config = fapi.get_workspace_config(billing_project, workspace, cnamespace, configname).json()
    
    # Update the workspace configuration specifically for testing
    workspace_config['inputs']['FilterAndEvaluate.JointVcfFiltering.hyperparameters_json'] = f'"{updated_hyperparameters_path}"'
    workspace_config['inputs']['FilterAndEvaluate.JointVcfFiltering.score_extra_args'] = "\"--ignore-all-filters -L Pf3D7_03_v3 --resource-matching-strategy START_POSITION_AND_MINIMAL_REPRESENTATION\""
    
    update_response = fapi.update_workspace_config(billing_project, workspace, cnamespace, configname, workspace_config)
    if update_response.status_code != 200:
        raise Exception(f"Failed to update workspace config for testing: {update_response.text}")
    
    submission_response = fapi.create_submission(billing_project, workspace, cnamespace, configname)
    if submission_response.status_code != 201:
        raise Exception(f"Failed to submit test workflow: {submission_response.text}")
    
    submission_id = submission_response.json()['submissionId']
    print(f"Test workflow submitted successfully. Submission ID: {submission_id}")
    
    workflow_id = wait_for_workflow_completion(submission_id)
    print(f"Test workflow completed. Workflow ID: {workflow_id}")
    
    test_scores = fetch_process_and_extract_scores(billing_project, workspace, submission_id, workflow_id)
    if test_scores:
        print(f"Extracted test phase scores: {test_scores}")
    else:
        print("Failed to extract scores from the test phase.")
    return test_scores


if __name__ == '__main__':
    run_optimization()


[I 2024-05-07 05:43:27,045] A new study created in memory with name: no-name-bc859036-861f-4e7e-a6d7-972d441005ba
[I 2024-05-07 05:56:58,154] Trial 0 finished with value: 0.9798681333060042 and parameters: {'n_estimators': 73, 'max_features': 4, 'contamination': 0.0764158410101489}. Best is trial 0 with value: 0.9798681333060042.


Extracted SNP SCORE ROC_AUC: 0.9798681333060042
