In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np
import duckdb
import ast

%load_ext autoreload
%autoreload 2

%reload_ext sql

%config SqlMagic.autopandas = True
%config SqlMagic.feedback = False
%config SqlMagic.displaycon = False

conn = duckdb.connect()
%sql conn --alias duckdb

In [None]:
# Download TrialBench dataset from HuggingFace
data = {}
dataset = load_dataset("ML2Healthcare/ClinicalTrialDataset")
dataset = dataset["train"].to_dict()
for task, phase, type_, table in zip(
    dataset["task"], dataset["phase"], dataset["type"], dataset["data"]
):
    table = pd.DataFrame.from_dict(eval(table, {"nan": np.nan}))
    table_name = f"{task}_{phase}_{type_}"
    data[table_name] = table

for k in data.keys():
    print(k)

In [None]:
# Merge all samples from TrialBench
phases = np.arange(1, 5).astype(str)
splits = ["train", "test"]
df_list = []

for phase in phases:
    for split in splits:
        df_x = data[
            "trial-approval-prediction-Phase"
            + phase
            + "-"
            + split
            + "_x_Phase"
            + phase
            + "_"
            + split
            + "_x"
        ]
        df_x["phase"] = phase
        df_x["split"] = split
        df_y = data[
            "trial-approval-prediction-Phase"
            + phase
            + "-"
            + split
            + "_y_Phase"
            + phase
            + "_"
            + split
            + "_y"
        ]
        df_train = df_x.merge(df_y, how="inner", left_index=True, right_index=True)
        df_list.append(df_train)

raw_df = pd.concat(df_list)
raw_df.to_parquet("raw_trialbench.parquet")

In [None]:
raw_df = pd.read_parquet("raw_trialbench.parquet")
cols_to_keep = [
    "brief_summary/textblock",
    "brief_title",
    "condition",
    "condition_browse/mesh_term",
    "icdcode",
    "location/facility/address/city",
    "sponsors/lead_sponsor/agency_class",
    "study_design_info/primary_purpose",
    "phase",
    "smiless",
    "outcome",
    "split",
]

trialbench_df = (
    raw_df[cols_to_keep]
    .reset_index()
    .rename(
        columns={
            "index": "nctid",
            "condition_browse/mesh_term": "mesh_term",
            "location/facility/address/city": "location",
            "sponsors/lead_sponsor/agency_class": "lead_sponsor",
            "study_design_info/primary_purpose": "primary_purpose",
        }
    )
)

In [None]:
def clean(x):
    try:
        if x is None:
            return None
        elif x[0] == "[":
            return ast.literal_eval(x)
        else:
            return [
                x,
            ]
    except ValueError as e:
        print("Got ", x)
        print(e)
        return None


def clean_icd(x):
    try:
        cleaned = clean(x)
        if cleaned is None:
            return None
        return clean(cleaned[0])
    except ValueError as e:
        print("Got ", x)
        print(e)
        return None


trialbench_df["condition"] = trialbench_df["condition"].apply(clean)
trialbench_df["mesh_term"] = trialbench_df["mesh_term"].apply(clean)
trialbench_df["location"] = trialbench_df["location"].apply(clean)
trialbench_df["smiless"] = trialbench_df["smiless"].apply(clean)
trialbench_df["icdcode"] = trialbench_df["icdcode"].apply(clean_icd)

In [None]:
%%sql 
trialbench_df << 
with unique_start_dates as (
    select 
        distinct nctid, start_date, briefSummary, detailedDescription 
    from "ctg-projected-processed.parquet"
)
select 
    t.*, c.start_date, 
    cast(year(c.start_date) as varchar) as start_year, 
    c.briefSummary as brief_summary, 
    c.detailedDescription as detailed_description
from trialbench_df t
left join unique_start_dates c on (c.nctid = t.nctid)

In [None]:
trialbench_df.to_parquet("trialbench.parquet")