In [None]:
import numpy as np
from xml.etree import ElementTree as ET
import pandas as pd
from datetime import datetime


In [None]:
## Filter out the good examples of Time duration prediction
trial_path = "/Users/leo/Documents/code/clinical-trial/original/Downloads/"


In [None]:
input_file = f"{trial_path}/trials/all_xml.txt"
with open(input_file, 'r') as fin:
    lines = fin.readlines()

input_file_lst = [i.strip() for i in lines]

LOGGER.log_with_depth(input_file_lst[:10])

In [None]:
iqvia_outcomes = pd.read_csv("/Users/leo/Documents/code/clinical-trial/Clinical-Trials-Time-Prediction/data/IQVIA/trial_outcomes_v1.csv")

LOGGER.log_with_depth(iqvia_outcomes)

nctid2outcome = { row[0]: row[1] for idx, row in iqvia_outcomes.iterrows()}

In [None]:
def parse_date(date_str):
    try:
        output = datetime.strptime(date_str, "%B %d, %Y")
    except:
        try:
            output = datetime.strptime(date_str, "%B %Y")
        except Exception as e:
            LOGGER.log_with_depth(e)
            raise e
    return output

def calculate_duration(start_date, completion_date):
    # Unit: days
    if start_date and completion_date:
        start_date = parse_date(start_date)
        completion_date = parse_date(completion_date)
        duration = (completion_date - start_date).days
    else:
        duration = -1

    return duration

def xml_file_2_tuple(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    nctid = root.find('id_info').find('nct_id').text	### nctid: 'NCT00000102'
    study_type = root.find('study_type').text 
    if study_type != 'Interventional':
        return ("non-Interventional",) 

    interventions = [i for i in root.findall('intervention')]
    drug_interventions = [i.find('intervention_name').text for i in interventions \
                                                        if i.find('intervention_type').text=='Drug']
                                                        # or i.find('intervention_type').text=='Biological']
    if len(drug_interventions)==0:
        return ("Biological",)

    try:
        status = root.find('overall_status').text 
    except:
        status = ''

    try:
        why_stop = root.find('why_stopped').text
    except:
        why_stop = ''

    try:
        phase = root.find('phase').text 
        # LOGGER.log_with_depth("phase\n\t\t", phase)
    except:
        phase = ''
    conditions = [i.text for i in root.findall('condition')]

    try:
        criteria = root.find('eligibility').find('criteria').find('textblock').text 
    except:
        criteria = ''

    conditions = [i.lower() for i in conditions]
    drugs = [i.lower() for i in drug_interventions]

    try:
        start_date = root.find('start_date').text
    except:
        start_date = ''

    try:
        completion_date = root.find('primary_completion_date').text
    except:
        try:
            completion_date = root.find('completion_date').text 
        except:
            completion_date = ''
    
    if start_date and completion_date:
        duration = calculate_duration(start_date, completion_date)
    else:
        duration = -1
    
    outcome = nctid2outcome[nctid] if nctid in nctid2outcome else -1

    return nctid, status, why_stop, phase, conditions, criteria, drugs, duration, outcome


In [None]:
for i in range (10000):
    random_idx = np.random.randint(0, len(input_file_lst))

    result = xml_file_2_tuple(f"{trial_path}/{input_file_lst[random_idx]}")

    if len(result) == 1:
        continue

    nctid, status, why_stop, phase, conditions, criteria, drugs, duration, outcome = result

    if outcome != "Terminated, Poor enrollment":
        continue

    LOGGER.log_with_depth(f"conditions: {conditions}")
    LOGGER.log_with_depth(f"drugs: {drugs}")
    LOGGER.log_with_depth(f"criteria: {criteria}")
    LOGGER.log_with_depth(f"duration: {duration}")
    LOGGER.log_with_depth(f"outcome: {outcome}")
    LOGGER.log_with_depth(f"nctid: {nctid}")
    LOGGER.log_with_depth(f"phase: {phase}")
    LOGGER.log_with_depth(f"status: {status}")
    LOGGER.log_with_depth(f"why_stop: {why_stop}")
    LOGGER.log_with_depth("=====================================\n")
