In [13]:
import json
import os
import re

import pandas as pd
from tqdm import tqdm
import requests

In [14]:
CTG_API = "https://clinicaltrials.gov/api/v2/studies/"

In [15]:
ctg_folder = '/Users/xx/Documents/Repositories/biogen-capstone/raw_data/ctg-studies.json'

In [16]:
pre_files = [filename for filename in os.listdir()
             if re.match(r'^phase.+\.csv$', filename)]
pre_files

['phase_I_test.csv',
 'phase_I_valid.csv',
 'phase_III_train.csv',
 'phase_II_valid.csv',
 'phase_III_test.csv',
 'phase_II_train.csv',
 'phase_III_valid.csv',
 'phase_II_test.csv',
 'phase_I_train.csv']

In [17]:
id_col = 'nctid'

In [22]:
def get_protocol(id: str) -> dict[str, str]:
    filepath = os.path.join(ctg_folder, f'{id}.json')
    if os.path.exists(filepath):  # Load from local file
        with open(filepath, 'r') as f:
            data = json.load(f)
    else:  # Retrive via API
        response = requests.get(CTG_API + id)
        if response.status_code == 200:
            data = response.json()
        else:
            print(f"{response.status_code} response for {id}")

    data = data['protocolSection']
    protocol = {
        id_col: id,
        'measures': get_measures(data),
        'design': get_design(data),
    }
    return protocol


def get_measures(data: dict) -> str:
    try:
        measure_list = data['outcomesModule']['primaryOutcomes']
        measure_text_list = [f"Measure {i+1}:\n{'\n'.join(measure.values())}"
                             for i, measure in enumerate(measure_list)]
        measures = '\n\n'.join(measure_text_list)
    except:
        measures = ''
    return measures


def get_design(data: dict) -> str:
    design = ''
    try:
        design_info = data['designModule']['designInfo']
    except:
        return design

    try:
        design += f"Allocation: {design_info['allocation']}\n"
    except:
        pass

    try:
        design += f"Intervention: {design_info['interventionModel']}\n"
    except:
        pass

    try:
        design += f"Primary purpose: {design_info['primaryPurpose']}\n"
    except:
        pass

    try:
        design += f"""\
{design_info['maskingInfo']['masking']} masking: \
{' '.join(design_info['maskingInfo']['whoMasked'])}\n
"""
    except:
        pass
    return design.strip()


def get_description(data: dict) -> str:
    try:
        description = data['descriptionModule']['detailedDescription']
    except:
        description = ''
    return description


get_protocol('NCT01288573')

{'nctid': 'NCT01288573',
 'measures': 'Measure 1:\nProportion of patients achieving at least a doubling of peripheral blood CD34+ count during Stage 2\nUp to 5 days',
 'design': 'Allocation: RANDOMIZED\nIntervention: PARALLEL\nPrimary purpose: TREATMENT'}

In [None]:
for pre_file in pre_files:
    protocols = []
    df = pd.read_csv(pre_file).set_index(id_col)
    for nct_id in tqdm(df.index, desc=pre_file):
        protocols.append(get_protocol(nct_id))
    protocol_df = pd.DataFrame(protocols).set_index(id_col)
    re_df = df.join(protocol_df)
    re_df.to_csv(f're_{pre_file}')

phase_I_test.csv: 100%|██████████| 627/627 [00:00<00:00, 2235.79it/s]
phase_I_valid.csv: 100%|██████████| 117/117 [00:01<00:00, 65.57it/s]
phase_III_train.csv: 100%|██████████| 3094/3094 [01:23<00:00, 37.20it/s]
phase_II_valid.csv: 100%|██████████| 446/446 [00:08<00:00, 51.02it/s]
phase_III_test.csv: 100%|██████████| 1146/1146 [00:01<00:00, 1058.96it/s]
phase_II_train.csv: 100%|██████████| 4005/4005 [01:29<00:00, 44.53it/s] 
phase_III_valid.csv: 100%|██████████| 344/344 [00:08<00:00, 42.02it/s]
phase_II_test.csv: 100%|██████████| 1654/1654 [00:01<00:00, 1222.29it/s]
phase_I_train.csv: 100%|██████████| 1044/1044 [00:13<00:00, 74.61it/s] 
