# Installing Necessary Libraries

Cell 1

In [None]:
!pip install openai

Collecting openai
  Downloading openai-1.51.0-py3-none-any.whl.metadata (24 kB)
Collecting httpx<1,>=0.23.0 (from openai)
  Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)
Collecting jiter<1,>=0.4.0 (from openai)
  Downloading jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)
Collecting httpcore==1.* (from httpx<1,>=0.23.0->openai)
  Downloading httpcore-1.0.6-py3-none-any.whl.metadata (21 kB)
Collecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->openai)
  Downloading h11-0.14.0-py3-none-any.whl.metadata (8.2 kB)
Downloading openai-1.51.0-py3-none-any.whl (383 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m383.5/383.5 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading httpx-0.27.2-py3-none-any.whl (76 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading httpcore-1.0.6-py3-none-any.whl (78 kB)
[2K   [90m━

## Generating queries.json

Cell 2

In [None]:
import os
from openai import OpenAI
import xml.etree.ElementTree as ET
import json

# Set up your OpenAI API key

client = OpenAI(
    # This is the default and can be omitted
    api_key='Your API Keys',
)

# Function to parse the XML file and extract relevant patient information
def extract_patient_info(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    # Extract patient ID (using your XML structure as a guide)
    namespace = {'hl7': 'urn:hl7-org:v3'}
    patient_id = root.find('.//hl7:id', namespace).attrib['extension']

    # Extract patient details to summarize
    patient_info = ET.tostring(root, encoding='unicode')

    return patient_id, patient_info

# Function to summarize patient information using OpenAI API
def summarize_patient_info(patient_info):

    response = client.chat.completions.create(
    model="gpt-4o",  # or "gpt-4" if you have access
    messages=[
        {"role": "system", "content": "You are a concise medical assistant."},
        {"role": "user", "content": f"Summarize the following patient's medical information in one sentence:\n\n{patient_info}"}
    ],
    max_tokens=150
    )
    summary = response.choices[0].message.content
    # print(type(summary))
    print(summary)
    return summary

# Function to process all XML files in a directory
def process_patient_directory(directory):
    patient_summaries = []

    # for filename in os.listdir(directory):
    if directory.endswith(".xml"):
        # xml_file = os.path.join(directory, filename)

        # Extract patient info from XML
        patient_id, patient_info = extract_patient_info(directory)

        # Submit the extracted data to OpenAI and get the summary
        summary = summarize_patient_info(patient_info)

        # Add the summary to the dictionary
        patient_summaries.append({
            "_id": patient_id,
            "text": summary
        })

    return patient_summaries

Cell 3

In [None]:
# Specify the directory where all XML files are stored
# directory_path = 'D:\Turmerik_3\synthea_1m_fhir_3_0_May_24\output_1\CCDA'
directory_path = '/content/Abbott509_Chase285_5.xml'

    # Process the directory and get the patient summaries
patient_summaries = process_patient_directory(directory_path)

    # Output the final dictionary with all patient summaries
print(patient_summaries)

    # Optional: Save the patient summaries to a JSON file
import json
with open('queries.json', 'w') as f:
  json.dump(patient_summaries, f, indent=4)

Chase285 Abbott509, a White male born on February 2, 2012, with no known allergies, has a medical history of otitis media, streptococcal sore throat, viral sinusitis, and acute bronchitis, had multiple vaccinations including Hep B, DTaP, and influenza, and underwent respiratory therapy as part of his plan of care.
[{'_id': '5925ebf53425de8bbf004f49', 'text': 'Chase285 Abbott509, a White male born on February 2, 2012, with no known allergies, has a medical history of otitis media, streptococcal sore throat, viral sinusitis, and acute bronchitis, had multiple vaccinations including Hep B, DTaP, and influenza, and underwent respiratory therapy as part of his plan of care.'}]


# TrialGPT Retrieval

Cell 4

In [None]:
"""
generate the search keywords for each patient
"""

from openai import OpenAI
import json

client = OpenAI(
    # This is the default and can be omitted
		# api_key = "Your API Key"
		api_key = 'Your API Keys'
)


def get_keyword_generation_messages(note):
	system = 'You are a helpful assistant and your task is to help search relevant clinical trials for a given patient description. Please first summarize the main medical problems of the patient. Then generate up to 32 key conditions for searching relevant clinical trials for this patient. The key condition list should be ranked by priority. Please output only a JSON dict formatted as Dict{{"summary": Str(summary), "conditions": List[Str(condition)]}}.'

	prompt =  f"Here is the patient description: \n{note}\n\nJSON output:"

	messages = [
		{"role": "system", "content": system},
		{"role": "user", "content": prompt}
	]

	return messages

Cell 5

In [None]:
outputs = {}
ret_trials = {}
model = 'gpt-4o'

with open(f"/content/queries.jsonl", "r") as f:
  for line in f.readlines():
    entry = json.loads(line)
    messages = get_keyword_generation_messages(entry["text"])

    response = client.chat.completions.create(
			model=model,
			messages=messages,
      temperature=0,
		)

    output = response.choices[0].message.content
    output = output.strip("`").strip("json")

    ret_trials[entry["_id"]] = {}
    ret_trials[entry["_id"]]["raw"] = entry["text"]
    ret_trials[entry["_id"]]["gpt-4-turbo"] = json.loads(output)

    outputs[entry["_id"]] = json.loads(output)

    with open(f"retrieval_keywords_{model}.json", "w") as f:
      json.dump(outputs, f, indent=4)

    with open(f"id2queries.json", "w") as f:
      json.dump(ret_trials, f, indent=4)

Cell 6

In [None]:
!pip install beir

Collecting beir
  Downloading beir-2.0.0.tar.gz (53 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.6/53.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sentence-transformers (from beir)
  Downloading sentence_transformers-3.1.1-py3-none-any.whl.metadata (10 kB)
Collecting pytrec_eval (from beir)
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting faiss_cpu (from beir)
  Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Collecting elasticsearch==7.9.1 (from beir)
  Downloading elasticsearch-7.9.1-py2.py3-none-any.whl.metadata (8.0 kB)
Collecting datasets (from beir)
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from dat

Cell 7

In [None]:
!pip install rank-bm25

Collecting rank-bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank-bm25
Successfully installed rank-bm25-0.2.2


Cell 8

In [None]:
from beir.datasets.data_loader import GenericDataLoader
import faiss
import json
from nltk import word_tokenize
import numpy as np
import os
from rank_bm25 import BM25Okapi
import sys
import tqdm
import torch
from transformers import AutoTokenizer, AutoModel

Cell 9

In [None]:
"""
Conduct the first stage retrieval by the hybrid retriever
"""

def get_bm25_corpus_index(corpus):
  corpus_path = os.path.join(f"bm25_corpus_{corpus}.json")

	# if already cached then load, otherwise build
  if os.path.exists(corpus_path):
    corpus_data = json.load(open(corpus_path))
    tokenized_corpus = corpus_data["tokenized_corpus"]
    corpus_nctids = corpus_data["corpus_nctids"]

  else:
    tokenized_corpus = []
    corpus_nctids = []

    with open(f"corpus.jsonl", "r") as f:
      for line in f.readlines():
        entry = json.loads(line)
        corpus_nctids.append(entry["_id"])

        # weighting: 3 * title, 2 * condition, 1 * text
        tokens = word_tokenize(entry["title"].lower()) * 3
        for disease in entry["metadata"]["diseases_list"]:
          tokens += word_tokenize(disease.lower()) * 2
        tokens += word_tokenize(entry["text"].lower())

        tokenized_corpus.append(tokens)

    corpus_data = {
			"tokenized_corpus": tokenized_corpus,
			"corpus_nctids": corpus_nctids,
		}

    with open(corpus_path, "w") as f:
      json.dump(corpus_data, f, indent=4)

  bm25 = BM25Okapi(tokenized_corpus)

  return bm25, corpus_nctids


def get_medcpt_corpus_index(corpus):
  corpus_path = f"{corpus}_embeds.npy"
  nctids_path = f"{corpus}_nctids.json"

  if os.path.exists(corpus_path):
    embeds = np.load(corpus_path)
    corpus_nctids = json.load(open(nctids_path))

  else:
    embeds = []
    corpus_nctids = []

    model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder").to("cuda")
    tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder")

    with open(f"corpus.jsonl", "r") as f:
      print("Encoding the corpus")
      for line in tqdm.tqdm(f.readlines()):
        entry = json.loads(line)
        corpus_nctids.append(entry["_id"])

        title = entry["title"]
        text = entry["text"]

        with torch.no_grad():
          # tokenize the articles
          encoded = tokenizer(
              [[title, text]],
              truncation=True,
              padding=True,
              return_tensors='pt',
              max_length=512,
          ).to("cuda")

          embed = model(**encoded).last_hidden_state[:, 0, :]

          embeds.append(embed[0].cpu().numpy())

    embeds = np.array(embeds)

    np.save(corpus_path, embeds)
    with open(nctids_path, "w") as f:
      json.dump(corpus_nctids, f, indent=4)

  index = faiss.IndexFlatIP(768)
  index.add(embeds)

  return index, corpus_nctids

Cell 10

In [None]:
corpus = "Synthetic_Mass"
q_type = "gpt-4-turbo"

import nltk
nltk.download('punkt')

# different k for fusion
k = 20

# bm25 weight
bm25_wt = 1

# medcpt weight
medcpt_wt = 1

# how many to rank
N = 2000

id2queries = json.load(open(f"id2queries.json"))

trial_info = json.load(open(f"trial_info.json"))

# loading the indices
bm25, bm25_nctids = get_bm25_corpus_index(corpus)
medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus)

# loading the query encoder for MedCPT
model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder")

# then conduct the searches, saving top 1k
output_path = f"qid2nctids_results_{q_type}_{corpus}_k{k}_bm25wt{bm25_wt}_medcptwt{medcpt_wt}_N{N}.json"

qid2nctids = {}
recalls = []

retrieved_trials_final = []

with open(f"/content/queries.jsonl", "r") as f:
  for line in tqdm.tqdm(f.readlines()):
    entry = json.loads(line)
    query = entry["text"]
    qid = entry["_id"]
    print(qid)

  if "turbo" in q_type:
    conditions = id2queries[qid][q_type]["conditions"]

  if len(conditions) == 0:
    nctid2score = {}

  else:
    # a list of nctid lists for the bm25 retriever
    bm25_condition_top_nctids = []

    for condition in conditions:
      tokens = word_tokenize(condition.lower())
      top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=N)
      bm25_condition_top_nctids.append(top_nctids)

    # doing MedCPT retrieval
    with torch.no_grad():
      encoded = tokenizer(
          conditions,
          truncation=True,
          padding=True,
          return_tensors='pt',
          max_length=256,
      ).to("cuda")

      # encode the queries (use the [CLS] last hidden states as the representations)
      embeds = model(**encoded).last_hidden_state[:, 0, :].cpu().numpy()

      # search the Faiss index
      scores, inds = medcpt.search(embeds, k=N)

    medcpt_condition_top_nctids = []
    for ind_list in inds:
      top_nctids = [medcpt_nctids[ind] for ind in ind_list]
      medcpt_condition_top_nctids.append(top_nctids)

    nctid2score = {}

    for condition_idx, (bm25_top_nctids, medcpt_top_nctids) in enumerate(zip(bm25_condition_top_nctids, medcpt_condition_top_nctids)):
      if bm25_wt > 0:
        for rank, nctid in enumerate(bm25_top_nctids):
          if nctid not in nctid2score:
            nctid2score[nctid] = 0

          nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1))

      if medcpt_wt > 0:
        for rank, nctid in enumerate(medcpt_top_nctids):
          if nctid not in nctid2score:
            nctid2score[nctid] = 0

          nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1))

  nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1])
  top_nctids = [nctid for nctid, _ in nctid2score[:N]]
  qid2nctids[qid] = top_nctids

  print(qid2nctids[qid])

  retrieved_trials = {}
  retrieved_trials["patient_id"] = qid
  retrieved_trials["patient"] = query
  retrieved_trials["trials"] = []
  for trial in qid2nctids[qid]:
    retrieved_trials["trials"].append(trial_info[trial])

  retrieved_trials_final.append(retrieved_trials)

with open(output_path, "w") as f:
  json.dump(qid2nctids, f, indent=4)

with open("retrieved_trials.json", "w") as f:
  json.dump(retrieved_trials_final, f, indent=4)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Encoding the corpus


100%|██████████| 5/5 [00:00<00:00, 29.52it/s]
100%|██████████| 2/2 [00:00<00:00, 9467.95it/s]

sigir-20141
sigir-20142
['NCT00665366', 'NCT00188279', 'NCT02110251', 'NCT02073188', 'NCT02490241']





## Trial Info

Cell 11

In [None]:
import json

# Load the JSONL file
input_file_path = '/content/corpus.jsonl'
output_file_path = '/content/trial_info.json'

# Prepare a dictionary for the output data
trial_info = {}

# Open and process the input file
with open(input_file_path, 'r') as f:
    for line in f:
        # Load each line as a dictionary
        entry = json.loads(line)

        # Extract the _id and metadata
        trial_id = entry["_id"]
        metadata = entry["metadata"]

        # Add the NCTID field in the metadata
        metadata["NCTID"] = trial_id

        # Add this information to the trial_info dictionary
        trial_info[trial_id] = metadata

# Save the trial_info dictionary to a new JSON file
with open(output_file_path, 'w') as out_file:
    json.dump(trial_info, out_file, indent=4)

print(f"trial_info.json has been created at: {output_file_path}")

trial_info.json has been created at: /content/trial_info.json


# TrialGPT Matching

Cell 12

In [None]:
import nltk
nltk.download('punkt')

import json
from nltk.tokenize import sent_tokenize
import time
import os

from openai import OpenAI

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Cell 13

In [None]:
"""
TrialGPT-Matching main functions.
"""

client = OpenAI(
    # This is the default and can be omitted
    api_key='Your API Keys',
)

def parse_criteria(criteria):
	output = ""
	criteria = criteria.split("\n\n")

	idx = 0
	for criterion in criteria:
		criterion = criterion.strip()

		if "inclusion criteria" in criterion.lower() or "exclusion criteria" in criterion.lower():
			continue

		if len(criterion) < 5:
			continue

		output += f"{idx}. {criterion}\n"
		idx += 1

	return output


def print_trial(
	trial_info: dict,
	inc_exc: str,
) -> str:
	"""Given a dict of trial information, returns a string of trial."""

	trial = f"Title: {trial_info['brief_title']}\n"
	trial += f"Target diseases: {', '.join(trial_info['diseases_list'])}\n"
	trial += f"Interventions: {', '.join(trial_info['drugs_list'])}\n"
	trial += f"Summary: {trial_info['brief_summary']}\n"

	if inc_exc == "inclusion":
		trial += "Inclusion criteria:\n %s\n" % parse_criteria(trial_info['inclusion_criteria'])
	elif inc_exc == "exclusion":
		trial += "Exclusion criteria:\n %s\n" % parse_criteria(trial_info['exclusion_criteria'])

	return trial


def get_matching_prompt(
	trial_info: dict,
	inc_exc: str,
	patient: str,
) -> str:
	"""Output the prompt."""
	prompt = f"You are a helpful assistant for clinical trial recruitment. Your task is to compare a given patient note and the {inc_exc} criteria of a clinical trial to determine the patient's eligibility at the criterion level.\n"

	if inc_exc == "inclusion":
		prompt += "The factors that allow someone to participate in a clinical study are called inclusion criteria. They are based on characteristics such as age, gender, the type and stage of a disease, previous treatment history, and other medical conditions.\n"

	elif inc_exc == "exclusion":
		prompt += "The factors that disqualify someone from participating are called exclusion criteria. They are based on characteristics such as age, gender, the type and stage of a disease, previous treatment history, and other medical conditions.\n"

	prompt += f"You should check the {inc_exc} criteria one-by-one, and output the following three elements for each criterion:\n"
	prompt += f"\tElement 1. For each {inc_exc} criterion, briefly generate your reasoning process: First, judge whether the criterion is not applicable (not very common), where the patient does not meet the premise of the criterion. Then, check if the patient note contains direct evidence. If so, judge whether the patient meets or does not meet the criterion. If there is no direct evidence, try to infer from existing evidence, and answer one question: If the criterion is true, is it possible that a good patient note will miss such information? If impossible, then you can assume that the criterion is not true. Otherwise, there is not enough information.\n"
	prompt += f"\tElement 2. If there is relevant information, you must generate a list of relevant sentence IDs in the patient note. If there is no relevant information, you must annotate an empty list.\n"
	prompt += f"\tElement 3. Classify the patient eligibility for this specific {inc_exc} criterion: "

	if inc_exc == "inclusion":
		prompt += 'the label must be chosen from {"not applicable", "not enough information", "included", "not included"}. "not applicable" should only be used for criteria that are not applicable to the patient. "not enough information" should be used where the patient note does not contain sufficient information for making the classification. Try to use as less "not enough information" as possible because if the note does not mention a medically important fact, you can assume that the fact is not true for the patient. "included" denotes that the patient meets the inclusion criterion, while "not included" means the reverse.\n'
	elif inc_exc == "exclusion":
		prompt += 'the label must be chosen from {"not applicable", "not enough information", "excluded", "not excluded"}. "not applicable" should only be used for criteria that are not applicable to the patient. "not enough information" should be used where the patient note does not contain sufficient information for making the classification. Try to use as less "not enough information" as possible because if the note does not mention a medically important fact, you can assume that the fact is not true for the patient. "excluded" denotes that the patient meets the exclusion criterion and should be excluded in the trial, while "not excluded" means the reverse.\n'

	prompt += "You should output only a JSON dict exactly formatted as: dict{str(criterion_number): list[str(element_1_brief_reasoning), list[int(element_2_sentence_id)], str(element_3_eligibility_label)]}."

	user_prompt = f"Here is the patient note, each sentence is led by a sentence_id:\n{patient}\n\n"
	user_prompt += f"Here is the clinical trial:\n{print_trial(trial_info, inc_exc)}\n\n"
	user_prompt += f"Plain JSON output:"

	return prompt, user_prompt


def trialgpt_matching(trial: dict, patient: str, model: str):
	results = {}

	# doing inclusions and exclusions in separate prompts
	for inc_exc in ["inclusion", "exclusion"]:
		system_prompt, user_prompt = get_matching_prompt(trial, inc_exc, patient)

		messages = [
			{"role": "system", "content": system_prompt},
			{"role": "user", "content": user_prompt},
		]

		response = client.chat.completions.create(
			model="gpt-4o",
			messages=messages,
			temperature=0,
		)

		message = response.choices[0].message.content.strip()
		message = message.strip("`").strip("json")

		try:
			results[inc_exc] = json.loads(message)
		except:
			results[inc_exc] = message

	return results

Cell 14

In [None]:
"""
Running the TrialGPT matching for three cohorts (sigir, TREC 2021, TREC 2022).
"""

import json
from nltk.tokenize import sent_tokenize
import os
import sys

corpus = "Synthetic_Mass"
model = "gpt-4-turbo"

dataset = json.load(open(f"retrieved_trials.json"))

output_path = f"matching_results_{corpus}_{model}.json"

# Dict{Str(patient_id): Dict{Str(label): Dict{Str(trial_id): Str(output)}}}
if os.path.exists(output_path):
	output = json.load(open(output_path))
else:
	output = {}

for instance in dataset:
  # Dict{'patient': Str(patient), '0': Str(NCTID), ...}
  patient_id = instance["patient_id"]
  patient = instance["patient"]
  sents = sent_tokenize(patient)
  sents.append("The patient will provide informed consent, and will comply with the trial protocol without any practical issues.")
  sents = [f"{idx}. {sent}" for idx, sent in enumerate(sents)]
  patient = "\n".join(sents)

  # initialize the patient id in the output
  if patient_id not in output:
    output[patient_id] = {"trials": {}}

	# for label in ["2", "1", "0"]:
	# 	if label not in instance: continue

  for trial in instance["trials"]:
    trial_id = trial["NCTID"]

    # already calculated and cached
    if trial_id in output[patient_id]["trials"]:
      continue

    # in case anything goes wrong (e.g., API calling errors)
    try:
      results = trialgpt_matching(trial, patient, model)
      output[patient_id]["trials"][trial_id] = results

      with open(output_path, "w") as f:
        json.dump(output, f, indent=4)

    except Exception as e:
      print(e)
      continue

# TrialGPT Ranking

Cell 15

In [None]:
import nltk
nltk.download('punkt')

import json
from nltk.tokenize import sent_tokenize
import time
import os

from openai import OpenAI


client = OpenAI(
    # This is the default and can be omitted
    api_key='Your API Keys',
)

def convert_criteria_pred_to_string(prediction: dict,trial_info: dict,) -> str:
  """Given the TrialGPT prediction, output the linear string of the criteria."""
  output = ""

  for inc_exc in ["inclusion", "exclusion"]:
    # first get the idx2criterion dict
    idx2criterion = {}
    criteria = trial_info[inc_exc + "_criteria"].split("\n\n")

    idx = 0
    for criterion in criteria:
      criterion = criterion.strip()

      if "inclusion criteria" in criterion.lower() or "exclusion criteria" in criterion.lower():
        continue

      if len(criterion) < 5:
        continue

      idx2criterion[str(idx)] = criterion
      idx += 1

    for idx, info in enumerate(prediction[inc_exc].items()):
      criterion_idx, preds = info

      if criterion_idx not in idx2criterion:
        continue

      criterion = idx2criterion[criterion_idx]

      if len(preds) != 3:
        continue

      output += f"{inc_exc} criterion {idx}: {criterion}\n"
      output += f"\tPatient relevance: {preds[0]}\n"

      if len(preds[1]) > 0:
        output += f"\tEvident sentences: {preds[1]}\n"
      output += f"\tPatient eligibility: {preds[2]}\n"

  return output

def convert_pred_to_prompt(patient: str,pred: dict,trial_info: dict,) -> str:
  """Convert the prediction to a prompt string."""
  # get the trial string
  trial = f"Title: {trial_info['brief_title']}\n"
  trial += f"Target conditions: {', '.join(trial_info['diseases_list'])}\n"
  trial += f"Summary: {trial_info['brief_summary']}"

  # then get the prediction strings
  pred = convert_criteria_pred_to_string(pred, trial_info)

  # construct the prompt
  prompt = "You are a helpful assistant for clinical trial recruitment. You will be given a patient note, a clinical trial, and the patient eligibility predictions for each criterion.\n"
  prompt += "Your task is to output two scores, a relevance score (R) and an eligibility score (E), between the patient and the clinical trial.\n"
  prompt += "First explain the consideration for determining patient-trial relevance. Predict the relevance score R (0~100), which represents the overall relevance between the patient and the clinical trial. R=0 denotes the patient is totally irrelevant to the clinical trial, and R=100 denotes the patient is exactly relevant to the clinical trial.\n"
  prompt += "Then explain the consideration for determining patient-trial eligibility. Predict the eligibility score E (-R~R), which represents the patient's eligibility to the clinical trial. Note that -R <= E <= R (the absolute value of eligibility cannot be higher than the relevance), where E=-R denotes that the patient is ineligible (not included by any inclusion criteria, or excluded by all exclusion criteria), E=R denotes that the patient is eligible (included by all inclusion criteria, and not excluded by any exclusion criteria), E=0 denotes the patient is neutral (i.e., no relevant information for all inclusion and exclusion criteria).\n"
  prompt += 'Please output a JSON dict formatted as Dict{"relevance_explanation": Str, "relevance_score_R": Float, "eligibility_explanation": Str, "eligibility_score_E": Float, "eligibilityCriteriaMet": Str}.'
  prompt += 'Make sure to mention just "Yes" or "No" for the "eligibilityCriteriaMet" key in the output JSON dict.'

  user_prompt = "Here is the patient note:\n"
  user_prompt += patient + "\n\n"
  user_prompt += "Here is the clinical trial description:\n"
  user_prompt += trial + "\n\n"
  user_prompt += "Here are the criterion-levle eligibility prediction:\n"
  user_prompt += pred + "\n\n"
  user_prompt += "Plain JSON output:"

  return prompt, user_prompt


def trialgpt_aggregation(patient: str, trial_results: dict, trial_info: dict, model: str):
	system_prompt, user_prompt = convert_pred_to_prompt(
			patient,
			trial_results,
			trial_info
	)

	messages = [
		{"role": "system", "content": system_prompt},
		{"role": "user", "content": user_prompt}
	]

	response = client.chat.completions.create(
		model="gpt-4o",
		messages=messages,
		temperature=0,
	)
	result = response.choices[0].message.content.strip()
	result = result.strip("`").strip("json")
	result = json.loads(result)

	return result

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Cell 16

In [None]:
"""
Using GPT to aggregate the scores by itself.
"""

from beir.datasets.data_loader import GenericDataLoader
import json
from nltk.tokenize import sent_tokenize
import os
import sys
import time

Cell 17

In [None]:
if __name__ == "__main__":
  corpus = "Synthetic_Mass"
  model = "gpt-4-turbo"

	# the path of the matching results
  matching_results_path = "/content/matching_results_Synthetic_Mass_gpt-4-turbo.json"
  results = json.load(open(matching_results_path))

  # loading the trial2info dict
  trial2info = json.load(open("trial_info.json"))

  # loading the patient info
  queries_path = "/content/queries.jsonl"
  queries = {}
  with open(queries_path, 'r') as f:
    for line in f:
      # Parse each line as JSON and append to the list
      queries_dict = (json.loads(line))
      queries[queries_dict["_id"]] = queries_dict["text"]

  # output file path
  output_path = f"aggregation_results_{corpus}_{model}.json"

  if os.path.exists(output_path):
    output = json.load(open(output_path))
  else:
    output = {}

	# patient-level
  for patient_id, info in results.items():
		# get the patient note
    patient = queries[patient_id]
    sents = sent_tokenize(patient)
    sents.append("The patient will provide informed consent, and will comply with the trial protocol without any practical issues.")
    sents = [f"{idx}. {sent}" for idx, sent in enumerate(sents)]
    patient = "\n".join(sents)

    if patient_id not in output:
      output[patient_id] = {}

		# label-level, 3 label / patient
    for label, trials in info.items():

			# trial-level
      for trial_id, trial_results in trials.items():
				# already cached results
        if trial_id in output[patient_id]:
          continue

        if type(trial_results) is not dict:
          output[patient_id][trial_id] = "matching result error"

          with open(output_path, "w") as f:
            json.dump(output, f, indent=4)

          continue

				# specific trial information
        trial_info = trial2info[trial_id]

        try:
          result = trialgpt_aggregation(patient, trial_results, trial_info, model)
          output[patient_id][trial_id] = result

          with open(output_path, "w") as f:
            json.dump(output, f, indent=4)

        except:
          continue

Cell 18

In [None]:
"""
Rank the trials given the matching and aggregation results
"""

import json
import sys

eps = 1e-9

def get_matching_score(matching):
	# count only the valid ones
	included = 0
	not_inc = 0
	na_inc = 0
	no_info_inc = 0

	excluded = 0
	not_exc = 0
	na_exc = 0
	no_info_exc = 0

	# first count inclusions
	for criteria, info in matching["inclusion"].items():

		if len(info) != 3:
			continue

		if info[2] == "included":
			included += 1
		elif info[2] == "not included":
			not_inc += 1
		elif info[2] == "not applicable":
			na_inc += 1
		elif info[2] == "not enough information":
			no_info_inc += 1

	# then count exclusions
	for criteria, info in matching["exclusion"].items():

		if len(info) != 3:
			continue

		if info[2] == "excluded":
			excluded += 1
		elif info[2] == "not excluded":
			not_exc += 1
		elif info[2] == "not applicable":
			na_exc += 1
		elif info[2] == "not enough information":
			no_info_exc += 1

	# get the matching score
	score = 0

	score += included / (included + not_inc + no_info_inc + eps)

	if not_inc > 0:
		score -= 1

	if excluded > 0:
		score -= 1

	return score


def get_agg_score(assessment):
	try:
		rel_score = float(assessment["relevance_score_R"])
		eli_score = float(assessment["eligibility_score_E"])
	except:
		rel_score = 0
		eli_score = 0

	score = (rel_score + eli_score) / 100

	return score

Cell 19

In [None]:
if __name__ == "__main__":
	# args are the results paths
  matching_results_path = "/content/matching_results_Synthetic_Mass_gpt-4-turbo.json"
  agg_results_path = "/content/aggregation_results_Synthetic_Mass_gpt-4-turbo.json"
  trial_info_path = "/content/trial_info.json"

	# loading the results
  matching_results = json.load(open(matching_results_path))
  agg_results = json.load(open(agg_results_path))
  trial_info = json.load(open(trial_info_path))

  final_result = []

	# loop over the patients
  for patient_id, label2trial2results in matching_results.items():
    trial2score = {}
    result_dict = {}
    result_dict["patientID"] = patient_id

    for _, trial2results in label2trial2results.items():
      for trial_id, results in trial2results.items():
        matching_score = get_matching_score(results)

        if patient_id not in agg_results or trial_id not in agg_results[patient_id]:
          print(f"Patient {patient_id} Trial {trial_id} not in the aggregation results.")
          agg_score = 0
        else:
          agg_score = get_agg_score(agg_results[patient_id][trial_id])

        trial_score = matching_score + agg_score

        trial2score[trial_id] = trial_score

    sorted_trial2score = sorted(trial2score.items(), key=lambda x: -x[1])

    result_dict["eligibleTrials"] = []

    print()
    print(f"Patient ID: {patient_id}")
    print("Clinical trial ranking:")

    for trial, score in sorted_trial2score:
      print(trial, score)
      result_dict["eligibleTrials"].append({"trialID": trial,
                                            "trialName": trial_info[trial]["brief_title"],
                                            "score":score,
                                            "eligibilityCriteriaMet": agg_results[patient_id][trial]["eligibilityCriteriaMet"]})

    final_result.append(result_dict)

    print("===")
    print()

  result_path = "/content/eligibility_results.json"
  with open(result_path, "w") as f:
    json.dump(final_result, f, indent=4)


Patient ID: sigir-20142
Clinical trial ranking:
NCT00665366 0.0
NCT02110251 0.0
NCT02490241 -0.80000000004
NCT00188279 -1.0
NCT02073188 -1.6666666667777779
===



Cell 20

In [None]:
final_result

[{'patientID': 'sigir-20142',
  'eligibleTrials': [{'trialID': 'NCT00665366',
    'trialName': 'Study to Evaluate the Efficacy and Safety of Aripiprazole Administered With Lithium or Valproate Over 12 Weeks in the Treatment of Mania in Bipolar I Disorder',
    'score': 0.0,
    'eligibilityCriteriaMet': 'No'},
   {'trialID': 'NCT02110251',
    'trialName': 'Exercise Therapy With Risk Factor Management and Life Style Coaching After Vascular Intervention for Patients With Peripheral Arterial Disease',
    'score': 0.0,
    'eligibilityCriteriaMet': 'No'},
   {'trialID': 'NCT02490241',
    'trialName': 'Lithium Therapy: Understanding Mothers, Metabolism and Mood',
    'score': -0.80000000004,
    'eligibilityCriteriaMet': 'No'},
   {'trialID': 'NCT00188279',
    'trialName': 'Minimum Dose Computed Tomography of the Thorax for Follow-up in Patients With Resected Lung Carcinoma',
    'score': -1.0,
    'eligibilityCriteriaMet': 'No'},
   {'trialID': 'NCT02073188',
    'trialName': 'Comparat