In [1]:
import torch
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

torch.cuda.set_per_process_memory_fraction(1.0, device=0)  # 75% of total memory (30GB)

torch.backends.cudnn.benchmark = False  # Avoids excessive memory allocation
torch.backends.cudnn.enabled = False  # Disables unnecessary optimizations that consume memory
torch.cuda.empty_cache()  # Frees up any unused reserved memory
import gc
torch.cuda.empty_cache()
gc.collect()  # Clears all Python garbage


108

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import re
import pandas as pd
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, cohen_kappa_score
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import itertools

# Load the model and tokenizer from Hugging Face
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# Convert to Pandas DataFrame
ground_truth_df = pd.read_csv('Dataset/Final_HF_Data_With_Trial_Info.csv')
ground_truth_df.head()

Unnamed: 0,annotation_id,patient_id,note,NCTID,trial_title,criterion_type,criterion_text,gpt4_explanation,explanation_correctness,gpt4_sentences,expert_sentences,gpt4_eligibility,expert_eligibility,training,brief_title,phase,drugs,diseases,enrollment,brief_summary
0,0,sigir-20141,0. A 58-year-old African-American woman presen...,NCT01397994,Study to Assess Efficacy of Nicorandil+Atenolo...,inclusion,Patients of chronic stable angina with abnorma...,The patient note does not provide direct evide...,Correct,"[0, 1, 2]","[0, 1, 2]",not enough information,not enough information,True,Study to Assess Efficacy of Nicorandil+Atenolo...,Phase 4,"['Nicorandil', 'Atenolol']",['Chronic Stable Angina'],40.0,This study is to determine the anti-anginal an...
1,1,sigir-20141,0. A 58-year-old African-American woman presen...,NCT01397994,Study to Assess Efficacy of Nicorandil+Atenolo...,inclusion,Male and female,The patient is identified as a female in the n...,Correct,[0],[0],included,included,True,Study to Assess Efficacy of Nicorandil+Atenolo...,Phase 4,"['Nicorandil', 'Atenolol']",['Chronic Stable Angina'],40.0,This study is to determine the anti-anginal an...
2,2,sigir-20141,0. A 58-year-old African-American woman presen...,NCT01397994,Study to Assess Efficacy of Nicorandil+Atenolo...,inclusion,Age 25 to 65 years,"The patient is 58 years old, which falls withi...",Correct,[0],[0],included,included,True,Study to Assess Efficacy of Nicorandil+Atenolo...,Phase 4,"['Nicorandil', 'Atenolol']",['Chronic Stable Angina'],40.0,This study is to determine the anti-anginal an...
3,3,sigir-20141,0. A 58-year-old African-American woman presen...,NCT01397994,Study to Assess Efficacy of Nicorandil+Atenolo...,inclusion,"Patient must understand and be willing, able a...",The patient note mentions that the patient wil...,Correct,[8],[8],included,included,True,Study to Assess Efficacy of Nicorandil+Atenolo...,Phase 4,"['Nicorandil', 'Atenolol']",['Chronic Stable Angina'],40.0,This study is to determine the anti-anginal an...
4,4,sigir-20141,0. A 58-year-old African-American woman presen...,NCT01397994,Study to Assess Efficacy of Nicorandil+Atenolo...,inclusion,Patient must be able to give voluntary written...,The patient note mentions that the patient wil...,Correct,[8],[8],included,included,True,Study to Assess Efficacy of Nicorandil+Atenolo...,Phase 4,"['Nicorandil', 'Atenolol']",['Chronic Stable Angina'],40.0,This study is to determine the anti-anginal an...


In [None]:
### 🚀 **Fix: Efficient GPU Execution (Batch Processing)** ###
def generate_response(prompt):
    """
    Generates structured JSON output while ensuring GPU execution.
    """
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=False,  
            temperature=0.0,  
            max_new_tokens=1024,  
            eos_token_id=tokenizer.eos_token_id,  
            pad_token_id=tokenizer.pad_token_id,
            repetition_penalty=1.2
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    return response

In [None]:
### 🚀 **Fix: Extract JSON Correctly** ###
def extract_json_from_response(response):
    """
    Extracts the first valid JSON object from the model response while ignoring extra text.
    """
    try:
        # Locate the JSON block that starts after "```json"
        json_start = response.find("```json")
        if json_start == -1:
            return {"error": "No valid JSON block found in response", "raw_response": response}
        
        # Extract everything after "```json"
        response = response[json_start + 7:]  # Skip past ```json

        # Locate the first valid JSON object in the extracted portion
        json_match = re.search(r'\{.*?\}', response, re.DOTALL)
        if not json_match:
            return {"error": "No valid JSON found in response", "raw_response": response}

        json_str = json_match.group()  # Extract JSON string

        # Try parsing JSON
        result = json.loads(json_str)
        return result  # Return the extracted JSON

    except json.JSONDecodeError as e:
        return {"error": f"Failed to parse JSON: {str(e)}", "raw_json": json_str}

    except Exception as e:
        return {"error": f"Unexpected error: {str(e)}"}

In [None]:
def print_trial(trial_info: dict, inc_exc: str) -> str:
    """Given a dict of trial information, returns a string of trial."""
    
    trial = f"Title: {trial_info.get('trial_title', 'N/A')}\n"
    
    # Ensure diseases and drugs are lists before joining
    diseases = trial_info.get("diseases", [])
    drugs = trial_info.get("drugs", [])
    
    if not isinstance(diseases, list):
        diseases = [str(diseases)]  # Convert to list if it's a string or None
    if not isinstance(drugs, list):
        drugs = [str(drugs)]  # Convert to list if it's a string or None
    
    trial += f"Target diseases: {', '.join(diseases)}\n"
    trial += f"Interventions: {', '.join(drugs)}\n"

    if inc_exc == "inclusion":
        trial += f"Inclusion criteria:\n {trial_info.get('criterion_text', 'N/A')}\n"
    elif inc_exc == "exclusion":
        trial += f"Exclusion criteria:\n {trial_info.get('criterion_text', 'N/A')}\n"

    return trial

In [None]:
### 🚀 **Fix: Improve Prompt Clarity** ###
def get_matching_prompt(trial_info: dict, inc_exc: str, patient: str) -> (str, str):
    """
    Constructs the prompt to ensure only actual criteria are evaluated (not diseases/drugs).
    """
    prompt = f"You are a clinical trial eligibility assistant. Your task is to compare a given patient note with the {inc_exc} criteria of a clinical trial and determine the patient's eligibility at the criterion level.\n\n"
    prompt += "**IMPORTANT:** Only ONE eligibility criterion is provided at a time.\n"
    prompt += "DO NOT evaluate trial metadata such as 'diseases' or 'drugs' as criteria. Only evaluate the provided eligibility criterion.\n\n"
    if inc_exc == "inclusion":
        prompt += ("Inclusion criteria are the factors that allow someone to participate in a clinical study. "
                   "They may include characteristics such as age, gender, disease stage, treatment history, and other medical conditions.\n\n")
    elif inc_exc == "exclusion":
        prompt += ("Exclusion criteria are the factors that disqualify someone from participating in a clinical study. "
                   "They may include characteristics such as age, gender, disease stage, treatment history, and other medical conditions.\n\n")
    if inc_exc == "inclusion":
        prompt += 'the eligibility_label label must be chosen from {"not applicable", "not enough information", "included", "not included"}. "not applicable" should only be used for criteria that are not applicable to the patient. "not enough information" should be used where the patient note does not contain sufficient information for making the classification. Try to use as less "not enough information" as possible because if the note does not mention a medically important fact, you can assume that the fact is not true for the patient. "included" denotes that the patient meets the inclusion criterion, while "not included" means the reverse.\n'
    elif inc_exc == "exclusion":
        prompt += 'the eligibility_label label must be chosen from {"not applicable", "not enough information", "excluded", "not excluded"}. "not applicable" should only be used for criteria that are not applicable to the patient. "not enough information" should be used where the patient note does not contain sufficient information for making the classification. Try to use as less "not enough information" as possible because if the note does not mention a medically important fact, you can assume that the fact is not true for the patient. "excluded" denotes that the patient meets the exclusion criterion and should be excluded in the trial, while "not excluded" means the reverse.\n'
    
    prompt += "**Note:** The patient has already provided informed consent for participation. Any criteria related to consent should be considered met.\n\n"
    prompt += f"\nFor each {inc_exc} criterion, do the following:\n"
    prompt += "1. Output the exact text of the criterion as 'criterion_text'.\n"
    prompt += "2. Set 'criteria_type' as either 'inclusion' or 'exclusion'.\n"
    prompt += "3. Provide a 'brief_reasoning' explaining your evaluation process.\n"
    prompt += "4. List 'relevant_sentences' as a list of sentence IDs from the patient note supporting your reasoning.\n"
    prompt += "5. Assign an 'eligibility_label'.\n"
    prompt += "Output only a JSON object. Do not include any extra text.\n\n"
    
    user_prompt = f"Patient note:\n{patient}\n\nTrial Information:\n"
    user_prompt += f"- NCTID: {trial_info['NCTID']}\n"
    user_prompt += f"- Drugs: {trial_info['drugs']}\n"
    user_prompt += f"- Diseases: {trial_info['diseases']} (For context only, NOT an eligibility criterion.)\n\n"
    user_prompt += f"Eligibility Criteria to evaluate:\n{trial_info['criterion_text']}\n\n"
    return prompt, user_prompt

In [17]:



### 🚀 **Fix: Query Model & Debugging** ###
def query_model(patient_note, trial_info, inc_exc):
    """
    Queries the model and extracts JSON from the response.
    """
    system_prompt, user_prompt = get_matching_prompt(trial_info, inc_exc, patient_note)
    full_prompt = system_prompt + "\n\n" + user_prompt

    # print("\n🚀 DEBUG: Querying Model with Prompt:\n", full_prompt[:1000])  # Print first 1000 chars of prompt
    
    try:
        response = generate_response(full_prompt)
        print("\n📜 RAW MODEL RESPONSE:\n", response)  # Debugging print statement
    except Exception as e:
        print("ERROR: Model failed to generate response:", str(e))
        response = ""

    extracted_json = extract_json_from_response(response)

    print("\n✅ EXTRACTED JSON:\n", extracted_json)  # Debugging print statement

    return extracted_json


### 🚀 **Fix: Process Each Patient-Trial Case** ###
def process_patient_trial(patient_id, patient_note, trial, inc_exc):
    """
    Runs the trial evaluation for a single patient and criterion.
    """
    NCTID = trial.get("NCTID", "N/A")
    evaluation_result = query_model(patient_note, trial, inc_exc)

    return {
        "patient_id": patient_id,
        "NCTID": NCTID,
        "evaluation_result": evaluation_result,
        "criteria_type": inc_exc
    }
    
def save_results_to_json(results, filename="eligibility_matching_results.json"):
    with open(filename, "w") as f:
        json.dump(results, f, indent=4)
    print(f"Matching process completed! Results saved to {filename}")


### 🚀 **Fix: Parallel Processing (Only if Needed)** ###
test_patient_id = "sigir-20141"
one_patient_df = ground_truth_df
if one_patient_df.empty:
    print("❌ ERROR: No matching patient ID found in dataset!")
else:
    print("\n🚀 Running Model for All Rows on GPU (Batched Processing)...\n")

    results = []
    batch_inputs = []
    batch_metadata = []

    for idx, row in tqdm(one_patient_df.iterrows(), total=len(one_patient_df), desc="🔄 Preparing Inputs"):
        patient_id = row["patient_id"]
        patient_note = row["note"]
        trial = {
            "criterion_text": row["criterion_text"], 
            "NCTID": row["NCTID"], 
            "drugs": row["drugs"], 
            "diseases": row["diseases"],
            "trial_title": row["trial_title"]
        }
        inc_exc = str(row["criterion_type"]).strip().lower()
    
        system_prompt, user_prompt = get_matching_prompt(trial, inc_exc, patient_note)
        full_prompt = system_prompt + "\n\n" + user_prompt
    
        batch_inputs.append(full_prompt)
        batch_metadata.append((patient_id, trial, inc_exc))


    batch_size = 70  # ✅ Define batch size (adjust based on memory)

    with torch.inference_mode():
        num_batches = (len(batch_inputs) + batch_size - 1) // batch_size  # Compute total batches
    
        print("\n⚡ Running Batched Inference on GPU...\n")
    
        responses = []
        for i in tqdm(range(num_batches), desc="📝 Processing Batches"):
            batch_start = i * batch_size
            batch_end = min((i + 1) * batch_size, len(batch_inputs))
    
            batch_chunk = batch_inputs[batch_start:batch_end]  # ✅ Process smaller batches
            inputs = tokenizer(batch_chunk, return_tensors="pt", padding=True, truncation=True).to("cuda")
    
            outputs = model.generate(
                **inputs,
                do_sample=False,
                temperature=0.0,
                max_new_tokens=1024,  # ✅ Lower max tokens to prevent OOM
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                repetition_penalty=1.2
            )
    
            responses.extend([tokenizer.decode(output, skip_special_tokens=True).strip() for output in outputs])
    
    # ✅ Extract JSON from each response
    for i, response in tqdm(enumerate(responses), total=len(responses), desc="📝 Extracting JSON"):
        extracted_json = extract_json_from_response(response)
        results.append({
            "patient_id": batch_metadata[i][0],
            "NCTID": batch_metadata[i][1]["NCTID"],
            "evaluation_result": extracted_json,
            "criteria_type": batch_metadata[i][2]
        })



    print("\n✅ FINAL RESULTS (BATCHED PROCESSING ON GPU)\n")
    save_results_to_json(results, "eligibility_matching_results_unquant.json")


🚀 Running Model for All Rows on GPU (Batched Processing)...



🔄 Preparing Inputs: 100%|███████████████| 1015/1015 [00:00<00:00, 19975.31it/s]



⚡ Running Batched Inference on GPU...



📝 Processing Batches: 100%|█████████████████| 15/15 [1:24:47<00:00, 339.18s/it]
📝 Extracting JSON: 100%|████████████████| 1015/1015 [00:00<00:00, 61179.24it/s]


✅ FINAL RESULTS (BATCHED PROCESSING ON GPU)

Matching process completed! Results saved to eligibility_matching_results_unquant.json





In [18]:
import json
import csv

def convert_json_to_csv(json_filename="eligibility_matching_results_unquant.json", 
                        csv_filename="eligibility_matching_results_unquant.tsv"):
    """
    Convert the JSON results to a TSV file for benchmarking.
    Each evaluation (criterion) becomes one row in the TSV.
    The output includes fields:
      patient_id, NCTID, criteria_type, criterion_text, brief_reasoning, 
      relevant_sentences, eligibility_label.
    """
    with open(json_filename, "r", encoding="utf-8") as f:
        data = json.load(f)

    rows = []
    fieldnames = ["patient_id", "NCTID", "criterion_type", "criterion_text", 
                  "brief_reasoning", "relevant_sentences", "eligibility_label"]
    
    for result in data:
        patient_id = result.get("patient_id", "")
        NCTID = result.get("NCTID", "")
        criterion_type = result.get("criteria_type", "")
        eval_result = result.get("evaluation_result", {})

        # Handle cases where the evaluation result is an error
        if "error" in eval_result:
            row = {
                "patient_id": patient_id,
                "NCTID": NCTID,
                "criterion_type": criterion_type,
                "criterion_text": "ERROR",
                "brief_reasoning": eval_result.get("error", ""),
                "relevant_sentences": "",
                "eligibility_label": "ERROR"
            }
            rows.append(row)
            continue  # Skip to the next record

        # Extract evaluation details
        criterion_text = eval_result.get("criterion_text", "N/A")
        brief_reasoning = eval_result.get("brief_reasoning", "No reasoning provided")
        relevant_sentences = ", ".join(map(str, eval_result.get("relevant_sentences", [])))
        eligibility_label = eval_result.get("eligibility_label", eval_result.get("eligibility", "unknown"))

        row = {
            "patient_id": patient_id,
            "NCTID": NCTID,
            "criterion_type": criterion_type,
            "criterion_text": criterion_text,
            "brief_reasoning": brief_reasoning,
            "relevant_sentences": relevant_sentences,
            "eligibility_label": eligibility_label
        }
        rows.append(row)

    # Write to TSV file
    with open(csv_filename, "w", newline="", encoding="utf-8") as tsvfile:
        writer = csv.DictWriter(tsvfile, fieldnames=fieldnames, delimiter="\t", quoting=csv.QUOTE_MINIMAL)
        writer.writeheader()
        for row in rows:
            writer.writerow(row)

    print(f"✅ TSV file generated and saved to: {csv_filename}")

# Run conversion
convert_json_to_csv()


✅ TSV file generated and saved to: eligibility_matching_results_unquant.tsv


In [20]:
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

pred_df = pd.read_csv("eligibility_matching_results_unquant.tsv", delimiter="\t", encoding="utf-8")

gt_df = ground_truth_df.copy()

# For merging, we need to align key columns.
# In the predicted DataFrame, we have:
#    patient_id, trial_id, criterion_text, criterion_type, eligibility_label
# In the ground truth, we use:
#    patient_id, trial_id, criterion_text, criterion_type, expert_eligibility

# Normalize the merge keys in both DataFrames.
def clean_keys(df, keys):
    for key in keys:
        df[key] = df[key].astype(str).str.lower().str.strip()
    return df

merge_keys = ["patient_id", "NCTID", "criterion_text", "criterion_type"]
pred_df = clean_keys(pred_df, merge_keys)
gt_df = clean_keys(gt_df, merge_keys)

# check a few key samples:
print("Predicted key sample:")
print(pred_df[merge_keys].head())
print("\nGround truth key sample:")
print(gt_df[merge_keys].head())

# Merge the two DataFrames on the common keys.
merged_df = pd.merge(
    pred_df,
    gt_df,
    on=merge_keys,
    how="inner",
    suffixes=("_pred", "_truth")
)

print("Number of matched records:", len(merged_df))
if len(merged_df) == 0:
    print("No records matched. Verify that the keys match between your predicted data and ground truth.")
else:
    # For evaluation, we compare the predicted eligibility_label with the expert_eligibility from the ground truth.
    y_pred = merged_df["eligibility_label"]
    y_true = merged_df["expert_eligibility"]
    
    accuracy = accuracy_score(y_true, y_pred)
    print("Accuracy:", accuracy)
    
    print("Classification Report:")
    print(classification_report(y_true, y_pred))
    
    print("Confusion Matrix:")
    print(confusion_matrix(y_true, y_pred))


Predicted key sample:
    patient_id        NCTID  \
0  sigir-20141  nct01397994   
1  sigir-20141  nct01397994   
2  sigir-20141  nct01397994   
3  sigir-20141  nct01397994   
4  sigir-20141  nct01397994   

                                      criterion_text criterion_type  
0  inclusion - patients of chronic stable angina ...      inclusion  
1                                    male and female      inclusion  
2                                 age 25 to 65 years      inclusion  
3  patient must understand and be willing, able a...      inclusion  
4  patient must be able to give voluntary written...      inclusion  

Ground truth key sample:
    patient_id        NCTID  \
0  sigir-20141  nct01397994   
1  sigir-20141  nct01397994   
2  sigir-20141  nct01397994   
3  sigir-20141  nct01397994   
4  sigir-20141  nct01397994   

                                      criterion_text criterion_type  
0  patients of chronic stable angina with abnorma...      inclusion  
1                 

In [10]:
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# Compute overall accuracy.
accuracy = accuracy_score(merged_df["expert_eligibility"], merged_df["eligibility_label"])
print("Overall Accuracy:", accuracy)

# Print detailed classification report.
print("Classification Report:")
print(classification_report(merged_df["expert_eligibility"], merged_df["eligibility_label"]))

# Print confusion matrix.
print("Confusion Matrix:")
print(confusion_matrix(merged_df["expert_eligibility"], merged_df["eligibility_label"]))

# Identify misclassified rows.
misclassified_df = merged_df[merged_df["expert_eligibility"] != merged_df["eligibility_label"]]

# Save misclassified rows to a TSV (tab-delimited) file.
misclassified_df.to_csv("misclassified_results.tsv", index=False, sep="\t", encoding="utf-8")
print("Misclassified rows saved to misclassified_results.tsv")

# Optional: Print misclassification counts by ground truth vs. predicted.
print("Misclassification counts by ground truth vs. predicted:")
print(misclassified_df.groupby(["expert_eligibility", "eligibility_label"]).size())


Overall Accuracy: 0.6280254777070063
Classification Report:
                        precision    recall  f1-score   support

              excluded       0.21      0.88      0.34         8
              included       0.77      0.84      0.80       131
        not applicable       0.12      0.02      0.04        47
not enough information       0.45      0.51      0.48       203
          not excluded       0.81      0.69      0.74       378
          not included       0.23      0.61      0.34        18

              accuracy                           0.63       785
             macro avg       0.43      0.59      0.46       785
          weighted avg       0.65      0.63      0.63       785

Confusion Matrix:
[[  7   0   0   0   1   0]
 [  0 110   0  15   0   6]
 [  2   1   1  17  25   1]
 [  9  25   0 104  36  29]
 [ 15   2   7  94 260   0]
 [  0   5   0   2   0  11]]
Misclassified rows saved to misclassified_results.tsv
Misclassification counts by ground truth vs. predicted:
expert