In [1]:
# # Load model directly
# from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("/ocean/projects/med230010p/yji3/llama3_70B")
# model = AutoModelForCausalLM.from_pretrained("/ocean/projects/med230010p/yji3/llama3_70B")

In [1]:
!pip show transformers
!pip install transformers==4.43.0

Name: transformers
Version: 4.43.0
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /jet/home/yji3/.local/lib/python3.10/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: llama-recipes, optimum, peft, pyserini, sentence-transformers, transformer-lens


In [2]:
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer

model_id="/ocean/projects/med230010p/yji3/llama3_70B"

# tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)


  from .autonotebook import tqdm as notebook_tqdm
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 30/30 [20:32<00:00, 41.07s/it]


# LLAMA 70 B for matching

In [3]:
import json
from nltk.tokenize import sent_tokenize
import time
import os
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer

model_id="/ocean/projects/med230010p/yji3/llama3_70B"

def parse_criteria(criteria):
    output = ""
    criteria = criteria.split("\n\n")
    
    idx = 0
    for criterion in criteria:
        criterion = criterion.strip()

        if "inclusion criteria" in criterion.lower() or "exclusion criteria" in criterion.lower():
            continue

        if len(criterion) < 5:
            continue
    
        output += f"{idx}. {criterion}\n" 
        idx += 1
    
    return output


def print_trial(
    trial_info: dict,
    inc_exc: str,
) -> str:
    """Given a dict of trial information, returns a string of trial."""
    
    trial = f"Title: {trial_info['brief_title']}\n"
    trial += f"Target diseases: {', '.join(trial_info['diseases_list'])}\n"
    trial += f"Interventions: {', '.join(trial_info['drugs_list'])}\n"
    trial += f"Summary: {trial_info['brief_summary']}\n"
    
    if inc_exc == "inclusion":
        trial += "Inclusion criteria:\n %s\n" % parse_criteria(trial_info['inclusion_criteria'])
    elif inc_exc == "exclusion":
        trial += "Exclusion criteria:\n %s\n" % parse_criteria(trial_info['exclusion_criteria']) 

    return trial


def get_matching_prompt(
    trial_info: dict,
    inc_exc: str,
    patient: str,
) -> str:
    """Output the prompt."""
    prompt = f"You are a helpful assistant for clinical trial recruitment. Your task is to compare a given patient note and the {inc_exc} criteria of a clinical trial to determine the patient's eligibility at the criterion level.\n"

    if inc_exc == "inclusion":
        prompt += "The factors that allow someone to participate in a clinical study are called inclusion criteria. They are based on characteristics such as age, gender, the type and stage of a disease, previous treatment history, and other medical conditions.\n"
    
    elif inc_exc == "exclusion":
        prompt += "The factors that disqualify someone from participating are called exclusion criteria. They are based on characteristics such as age, gender, the type and stage of a disease, previous treatment history, and other medical conditions.\n"

    prompt += f"You should check the {inc_exc} criteria one-by-one, and output the following three elements for each criterion:\n"
    prompt += f"\tElement 1. For each {inc_exc} criterion, briefly generate your reasoning process: First, judge whether the criterion is not applicable (not very common), where the patient does not meet the premise of the criterion. Then, check if the patient note contains direct evidence. If so, judge whether the patient meets or does not meet the criterion. If there is no direct evidence, try to infer from existing evidence, and answer one question: If the criterion is true, is it possible that a good patient note will miss such information? If impossible, then you can assume that the criterion is not true. Otherwise, there is not enough information.\n"
    prompt += f"\tElement 2. If there is relevant information, you must generate a list of relevant sentence IDs in the patient note. If there is no relevant information, you must annotate an empty list.\n" 
    prompt += f"\tElement 3. Classify the patient eligibility for this specific {inc_exc} criterion: "
    
    if inc_exc == "inclusion":
        prompt += 'the label must be chosen from {"not applicable", "not enough information", "included", "not included"}. "not applicable" should only be used for criteria that are not applicable to the patient. "not enough information" should be used where the patient note does not contain sufficient information for making the classification. Try to use as less "not enough information" as possible because if the note does not mention a medically important fact, you can assume that the fact is not true for the patient. "included" denotes that the patient meets the inclusion criterion, while "not included" means the reverse.\n'
    elif inc_exc == "exclusion":
        prompt += 'the label must be chosen from {"not applicable", "not enough information", "excluded", "not excluded"}. "not applicable" should only be used for criteria that are not applicable to the patient. "not enough information" should be used where the patient note does not contain sufficient information for making the classification. Try to use as less "not enough information" as possible because if the note does not mention a medically important fact, you can assume that the fact is not true for the patient. "excluded" denotes that the patient meets the exclusion criterion and should be excluded in the trial, while "not excluded" means the reverse.\n'
    
    prompt += "You should output only a JSON dict exactly formatted as: dict{str(criterion_number): list[str(element_1_brief_reasoning), list[int(element_2_sentence_id)], str(element_3_eligibility_label)]}."
    
    user_prompt = f"Here is the patient note, each sentence is led by a sentence_id:\n{patient}\n\n" 
    user_prompt += f"Here is the clinical trial:\n{print_trial(trial_info, inc_exc)}\n\n"
    user_prompt += f"Plain JSON output:"

    return prompt, user_prompt


def trialgpt_matching(trial: dict, patient: str):
    results = {}
    # hard code model
    # model = LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)
    # tokenizer = AutoTokenizer.from_pretrained(model_id)

    # doing inclusions and exclusions in separate prompts
    for inc_exc in ["inclusion", "exclusion"]:
        system_prompt, user_prompt = get_matching_prompt(trial, inc_exc, patient)
        eval_prompt = f"""
           content: {system_prompt},
           content": {user_prompt},
        """
        model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
        model.eval()
        with torch.no_grad():
            res = tokenizer.decode(model.generate(**model_input, max_new_tokens=200,pad_token_id=tokenizer.eos_token_id)[0], skip_special_tokens=True)
            message = res[len(eval_prompt):]

        try:
            results[inc_exc] = json.loads(message)
        except:
            results[inc_exc] = message

    return results

# inference matching

In [None]:
# from trialgpt import trialgpt_matching
import json
from nltk.tokenize import sent_tokenize
import os
import sys
# corpuses = ["sigir", "trec_2021", "trec_2022"]
corpuses = ["trec_2022"]


def process_patient(patient):
    sents = sent_tokenize(patient)
    sents.append("The patient will provide informed consent, and will comply with the trial protocol without any practical issues.")
    sents = [f"{idx}. {sent}" for idx, sent in enumerate(sents)]
    patient = "\n".join(sents)
    return patient
for corpus in corpuses:
    original_dataset = json.load(open(f"{corpus}/retrieved_trials.json"))
    model_name = "llama"
    # gender_race_patient_note = json.load(open(f"sensitive_changed_patient_note/{corpus}_sensitive.json"))
    social_patient_note = json.load(open(f"sensitive_changed_patient_note/{corpus}_all_sensitive.json"))
    output_path = f"fairness_results/matching_results_{corpus}_{model_name}.json" 

    # Dict{Str(patient_id): Dict{Str(label): Dict{Str(trial_id): Str(output)}}}
    if os.path.exists(output_path):
        output = json.load(open(output_path))
    else:
        output = {}   
    for instance in original_dataset: 
        # Dict{'patient': Str(patient), '0': Str(NCTID), ...}
        patient_id = instance["patient_id"]
        patient = instance["patient"]
        sensitive = social_patient_note[patient_id]
        cat_sensitive = {}
        for k, v in sensitive.items():
            # if v!=None and k!="patient":
            if v!=None:
                cat_sensitive[k] = process_patient(v)
        # same process 
        original_patient = process_patient(patient)
        
        # initialize the patient id in the output 
        if patient_id not in output:
            output[patient_id] = {}
            for sensitive_type, _ in cat_sensitive.items(): # this time I also added the normal 
                output[patient_id][sensitive_type] = {"0": {}, "1": {}, "2": {}}
        
        
        # only use the sensitive changed as the input patient note
        for sensitive_type, sensitive_note in cat_sensitive.items():
            for label in ["2", "1", "0"]:
                if label not in instance: continue
                
                for trial in instance[label]:  # here I limit every label only using 30
                    trial_id = trial["NCTID"]
                    
                    # already calculated and cached
                    if trial_id in output[patient_id][sensitive_type][label]:
                        continue

                    # in case anything goes wrong (e.g., API calling errors)
                    # try:
                    results = trialgpt_matching(trial, sensitive_note)
                    # results = "sdsa"
                    output[patient_id][sensitive_type][label][trial_id] = results
                    # output[patient_id][]
                    with open(output_path, "w") as f:
                        json.dump(output, f, indent=4)

                    

## Generate ranking