In [1]:
%%capture
import torch
major_version, minor_version = torch.cuda.get_device_capability()
# Must install separately since Colab has torch 2.2.1, which breaks packages
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
if major_version >= 8:
    # Use this for new GPUs like Ampere, Hopper GPUs (RTX 30xx, RTX 40xx, A100, H100, L40)
    !pip install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
else:
    # Use this for older GPUs (V100, Tesla T4, RTX 20xx)
    !pip install --no-deps xformers trl peft accelerate bitsandbytes
pass

In [2]:
from unsloth import FastLanguageModel
max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(model_name = "unsloth/llama-3-8b-bnb-4bit", # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length = max_seq_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.50.0.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096, padding_idx=128255)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSN

In [28]:
import os
import re

def read_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

def parse_document(document):
    key_value_pattern = re.compile(r'^([^,:]+):,(.*)$')
    section_pattern = re.compile(r'-{50,}')
    data = {}
    for line in document.split('\n'):
        if section_pattern.match(line):
            continue
        elif key_value_pattern.match(line):
            key, val = key_value_pattern.match(line).groups()
            data[key.strip()] = val.strip()
        elif line.strip() and list(data):
            last_key = list(data)[-1]
            data[last_key] += ' ' + line.strip()
    return data

def split_criteria(text):
    return [s.strip() for s in re.split(r'\.\s+', text) if s.strip()]

def process_trial_file(file_path):
    doc = read_file(file_path)
    data = parse_document(doc)
    eligibility = data.get("Eligibility Criteria", "")
    inc, exc = "", ""
    if "||" in eligibility:
        inc, exc = eligibility.split("||")
    else:
        inc = eligibility
    return {
        "name": data.get("Name", "Unnamed Trial"),
        "inclusion": split_criteria(inc.replace("Inclusion:", "").strip()),
        "exclusion": split_criteria(exc.replace("Exclusion:", "").strip())
    }

def load_trials(folder_path):
    return [
        process_trial_file(os.path.join(folder_path, f))
        for f in os.listdir(folder_path)
        if f.endswith(".csv")
    ]

In [41]:
trials = load_trials("trials")
print(trials)

[{'name': 'Wavelia', 'inclusion': ['" Informed Consent Female subjects with an investigator assessed discrete breast abnormality of size > 1cm Able and willing to comply with the requirements of this study protocol Negative urine pregnancy test on the day of microwave imaging procedure (if of childbearing potential) intact breast skin (i.e., without bleeding lesion, no evidence of inflammation and/or erythema of the breast) Able to comfortably lie reasonably still in a prone position for approximately 15 minutes Have had biopsy more than 2 weeks prior to the microwave breast investigation (if applicable)'], 'exclusion': ['Have a cup size of A or whose breast is deemed too small to allow MBI assessment in the opinion of the investigator Are pregnant or breast-feeding Have had surgery on either breast within the past 12 months Have any active or metallic implant other than a biopsy clip Would be unsuitable for an MBI scan or unlikely to follow the protocol in the opinion of the Investiga

In [3]:
from unsloth import FastLanguageModel

max_seq_length = 4096
load_in_4bit = True
dtype = None  # auto

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/llama-3-8b-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
)

==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.50.0.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.3.19 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [63]:
import pandas as pd
import os
from datasets import Dataset
from sklearn.model_selection import train_test_split
import unicodedata
def normalize(text):
    return unicodedata.normalize("NFKD", text).strip().lower()
# --- 1. Load all patients from folder ---
def load_all_patients(patient_folder="patients"):
    dfs = []
    for filename in os.listdir(patient_folder):
        if filename.endswith(".csv"):
            path = os.path.join(patient_folder, filename)
            df = pd.read_csv(path)
            dfs.append(df)
    return pd.concat(dfs, ignore_index=True)

# --- 2. Build trial lookup ---
# --- 3. Build training dataset ---
def build_training_dataset(patient_df):
    training_rows = []

    for _, row in patient_df.iterrows():
        matched_trial = None
        trial_name_raw = row["trial_name"].strip().lower()
        for trial in trials:
            matched_trial = next((trial for trial in trials if trial_name_raw.startswith(trial["name"].strip().lower())),None)

        if not matched_trial:
            continue

        # Format trial description
        inc = "\n".join(f"- {c}" for c in matched_trial["inclusion"])
        exc = "\n".join(f"- {c}" for c in matched_trial["exclusion"])
        trial_description = (
            f"Trial: {matched_trial['name']}\n"
            f"Inclusion Criteria:\n{inc}\n"
            f"Exclusion Criteria:\n{exc}"
        )

        # Format patient profile
        patient_profile = "\n".join([
            f"{k}: {v}"
            for k, v in row.items()
            if k not in ["eligibility_label", "patient_id", "trial_name"]
        ])

        input_text = f"Patient Profile:\n{patient_profile}\n\n{trial_description}"

        training_rows.append({
        "instruction": "Is this patient eligible for the trial? Respond with 'eligible' or 'not eligible' and give a reason.",
        "input": input_text,
        "output": f"{row['eligibility_label'].lower()} — {row['explanation']}"
        })

    print(f"✅ Number of training rows generated: {len(training_rows)}")
    return Dataset.from_list(training_rows)


# --- 4. Generate training dataset ---
patient_df = load_all_patients("patients")
training_data = build_training_dataset(patient_df)


✅ Number of training rows generated: 20016


In [64]:
processed_data = [
    {
        "text": f"{ex['instruction']}\n\n{ex['input']}\n\n{ex['output']}{tokenizer.eos_token}"
    }
    for ex in training_data
]
print(len(processed_data))

20016


In [65]:
from datasets import Dataset
dataset = Dataset.from_list(processed_data)

In [35]:
lengths = [len(tokenizer(example["text"])["input_ids"]) for example in dataset]
print(max(lengths), sum(lengths)/len(lengths))

983 765.958


In [66]:
train_list, val_list = train_test_split(processed_data, test_size=0.1, random_state=42)
train_dataset = Dataset.from_list(train_list)
val_dataset = Dataset.from_list(val_list)

# Now create trainer and use formatting_func
from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    formatting_func=lambda example: [example["text"]],
    dataset_text_field="text",
    max_seq_length=1024,
    dataset_num_proc=8,
    packing=False,
    args=TrainingArguments(
        per_device_train_batch_size=32,
        gradient_accumulation_steps=2,
        warmup_steps=100,
        num_train_epochs=3,
        learning_rate=1e-4,
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=50,
        evaluation_strategy="no",
        eval_steps=500,
        save_steps=2000,
        save_total_limit=2,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="llama3_trial_matcher2",
    ),
)



Unsloth: Tokenizing ["text"] (num_proc=8):   0%|          | 0/18014 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=8):   0%|          | 0/2002 [00:00<?, ? examples/s]

In [67]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 18,014 | Num Epochs = 3 | Total steps = 843
O^O/ \_/ \    Batch size per device = 32 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (32 x 2 x 1) = 64
 "-____-"     Trainable parameters = 41,943,040/8,000,000,000 (0.52% trained)


Step,Training Loss
50,0.0187
100,0.0216
150,0.0193
200,0.0187
250,0.0182
300,0.018
350,0.0179
400,0.018
450,0.0181
500,0.0179


TrainOutput(global_step=843, training_loss=0.018177144962980514, metrics={'train_runtime': 13294.2007, 'train_samples_per_second': 4.065, 'train_steps_per_second': 0.063, 'total_flos': 2.498490899210699e+18, 'train_loss': 0.018177144962980514})

In [68]:
trainer.save_model("llama3_trial_matcher2")
tokenizer.save_pretrained("llama3_trial_matcher2")

('llama3_trial_matcher2/tokenizer_config.json',
 'llama3_trial_matcher2/special_tokens_map.json',
 'llama3_trial_matcher2/tokenizer.json')

In [50]:
from unsloth import FastLanguageModel
import torch

# --- Model Setup ---
max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "llama3_trial_matcher2",  # Fine-tuned model
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model)

# --- Patient Profile ---
patient_profile = """Patient is over 18: Yes
ECOG performance status is 0 or 1: Yes
Organ function is adequate: Yes
Has received radiotherapy: Yes
Has distant metastasis: No
Has history of cancer: No
Has cardiac condition: No
Has BRCA mutation: No
Has received HER2-targeted therapy before: No
Has received endocrine therapy before: No
Has active infection: No
Is pregnant or breastfeeding: No
Left ventricular ejection fraction is below 50%: No
Has arrhythmia: No
Previously treated with CTLA-4 inhibitor: No
Previously treated with CD137 agent: No
Previously treated with OX40 agent: No
Previously treated with topoisomerase inhibitor: No
Estrogen receptor positive: No
Progesterone receptor positive: No
HER2 negative: Yes
Triple negative breast cancer: Yes
Smoker: No
Family history of cancer: Yes
BMI over 30: No"""

# --- Trial Description (Formatted) ---
instruction = "Is this patient eligible for the trial? Respond with 'eligible' or 'not eligible' and give a reason."

trial_description = """Trial: ASCENT-05

Inclusion Criteria:
- Age > 18 years
- Residual invasive triple negative breast cancer (TNBC) in the breast or lymph nodes after neoadjuvant therapy and surgery
- TNBC defined as ER and PR < 10%, and HER2-negative per ASCO/CAP guidelines (IHC/ISH)
- Adequate excision and surgical removal of all clinically evident disease in breast and/or lymph nodes
- Adequately recovered from surgery
- Submission of both pre-neoadjuvant treatment diagnostic biopsy and resected residual invasive disease tissue
- ECOG performance status 0-1
- Received appropriate radiotherapy and recovered before starting study treatment
- Adequate organ function

Exclusion Criteria:
- Stage IV (metastatic) breast cancer
- History of prior (ipsi- or contralateral) invasive breast cancer
- Prior treatment with stimulatory or coinhibitory T-cell receptor agents (e.g., CTLA-4, OX-40, CD137)
- Prior treatment with any HER2-directed agent
- Prior or concurrent treatment with any endocrine therapy agent
- Evidence of recurrent disease following preoperative therapy and surgery
- Prior treatment with topoisomerase 1 inhibitors or ADCs containing a topoisomerase inhibitor
- Individuals with germline BRCA mutations
- Myocardial infarction or unstable angina pectoris within 6 months of enrollment
- History of serious ventricular arrhythmia (ventricular tachycardia or fibrillation)
- High-grade atrioventricular block or other serious cardiac arrhythmias
- Left ventricular ejection fraction (LVEF) < 50%
- Active serious infections requiring antimicrobial therapy"""

prompt = f"{instruction}\n\nPatient Profile:\n{patient_profile}\n\n{trial_description}\n\nAnswer:"
# --- Tokenize + Predict ---
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(
    **inputs,
    max_new_tokens=1024,
    temperature=0.0,
    do_sample=False,
    repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id,
)

# --- Decode and Clean ---
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()

if prompt.lower() in response:
    response = response.replace(prompt.lower(), "").strip()

print("🧠 Final output:", response)


==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.50.0.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
🧠 Final output: eligible


In [None]:
alpaca_prompt = "{}\n\n{}\n\n"

In [69]:
import re
import os
import pandas as pd
import unicodedata
from sklearn.metrics import classification_report
from tqdm import tqdm
import torch

# 🔍 Normalize helper
def normalize(text):
    return unicodedata.normalize("NFKD", text).strip().lower()

# 🧠 Trial lookup from your trial definitions
trial_lookup = {trial["name"]: trial for trial in trials}

# 🧠 Load test patient files
test_folder = "test_patients"
all_predictions = []
all_true_labels = []
all_trial_names = []

for filename in os.listdir(test_folder):
    if not filename.endswith(".csv"):
        continue

    filepath = os.path.join(test_folder, filename)
    df = pd.read_csv(filepath)

    # 🧠 Loop over patients
    for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Evaluating {filename}"):
        trial_name_raw = row["trial_name"].strip().lower()
        matched_trial = next(
            (trial for trial in trials if normalize(trial["name"]) in trial_name_raw),
            None
        )

        if not matched_trial:
            print(f"⚠️ Could not match trial: {trial_name_raw}")
            continue

        # 🧾 Build trial description
        inc = "\n".join(f"- {c}" for c in matched_trial["inclusion"])
        exc = "\n".join(f"- {c}" for c in matched_trial["exclusion"])
        trial_description = f"Trial: {matched_trial['name']}\nInclusion Criteria:\n{inc}\nExclusion Criteria:\n{exc}"

        patient_profile = row["natural_language_profile"]
        prompt = f"""Is this patient eligible for the trial? Respond with 'eligible' or 'not eligible' and give a reason.

Patient Profile:
{patient_profile}

{trial_description}

Answer:"""

        inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
        outputs = model.generate(
            **inputs,
            max_new_tokens=20,
            temperature=0.0,
            do_sample=False,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id,
        )

        response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower()
        if prompt.lower() in response:
            response = response.replace(prompt.lower(), "").strip()

        answer_match = re.search(r"answer:\s*(.*)", response, re.IGNORECASE | re.DOTALL)

        if answer_match:
            clean_answer = answer_match.group(1).strip()
            match = re.match(r"^(eligible|not eligible)", clean_answer, re.IGNORECASE)
            prediction = match.group(1).lower() if match else "not eligible"
        else:
            prediction = "not eligible"  # fallback

        true_label = row["eligibility_label"].strip().lower()

        all_predictions.append(prediction)
        all_true_labels.append(true_label)
        all_trial_names.append(matched_trial["name"])

# 📊 Final report
print("\n📊 Overall Classification Report:")
print(classification_report(all_true_labels, all_predictions, digits=3))

# 📈 Optional: Per-trial breakdown
print("\n📈 Trial-wise Accuracy:")
df_results = pd.DataFrame({
    "trial": all_trial_names,
    "true": all_true_labels,
    "pred": all_predictions,
})

for trial in df_results["trial"].unique():
    subset = df_results[df_results["trial"] == trial]
    print(f"\n📍 {trial}")
    print(classification_report(subset['true'], subset['pred'], digits=3))


Evaluating Wavelia.csv: 100%|██████████| 200/200 [03:22<00:00,  1.01s/it]
Evaluating UCARE.csv: 100%|██████████| 200/200 [00:42<00:00,  4.72it/s]
Evaluating PREcoopERA.csv: 100%|██████████| 200/200 [03:00<00:00,  1.11it/s]
Evaluating CAMBRIA-2.csv: 100%|██████████| 200/200 [02:49<00:00,  1.18it/s]
Evaluating ASCENT-05.csv: 100%|██████████| 200/200 [02:24<00:00,  1.38it/s]
Evaluating EPIK-B5.csv: 100%|██████████| 200/200 [03:29<00:00,  1.05s/it]
Evaluating MK-2870-012.csv: 100%|██████████| 200/200 [03:20<00:00,  1.00s/it]
Evaluating EMBER-4.csv: 100%|██████████| 200/200 [03:39<00:00,  1.10s/it]
Evaluating TREAT_ctDNA_study.csv: 100%|██████████| 200/200 [03:30<00:00,  1.05s/it]


📊 Overall Classification Report:
              precision    recall  f1-score   support

    eligible      0.820     0.892     0.855       900
not eligible      0.882     0.804     0.841       900

    accuracy                          0.848      1800
   macro avg      0.851     0.848     0.848      1800
weighted avg      0.851     0.848     0.848      1800


📈 Trial-wise Accuracy:

📍 Wavelia
              precision    recall  f1-score   support

    eligible      0.833     0.100     0.179       100
not eligible      0.521     0.980     0.681       100

    accuracy                          0.540       200
   macro avg      0.677     0.540     0.430       200
weighted avg      0.677     0.540     0.430       200


📍 UCARE
              precision    recall  f1-score   support

    eligible      0.840     1.000     0.913       100
not eligible      1.000     0.810     0.895       100

    accuracy                          0.905       200
   macro avg      0.920     0.905     0.904       




In [49]:
def rule_based_label(patient, trial):
    # Exclusion is strict: any True → Not Eligible
    for exc_key in trial["exclusion"]:
        if patient.get(exc_key, False):
            return "Not Eligible"

    # Inclusion is flexible: inclusion not required to all be True
    return "Eligible"
def evaluate_patient_labels(patient_df):
    mismatches = []

    for _, row in patient_df.iterrows():
        trial = trial_lookup.get(row["trial_name"])
        if not trial:
            continue

        patient_dict = row.drop(["eligibility_label", "patient_id", "trial_name"]).to_dict()
        predicted = rule_based_label(patient_dict, trial)

        if predicted.lower() != row["eligibility_label"].lower():
            mismatches.append({
                "patient_id": row["patient_id"],
                "trial_name": row["trial_name"],
                "true_label": row["eligibility_label"],
                "predicted_label": predicted,
            })

    return pd.DataFrame(mismatches)
mismatch_df = evaluate_patient_labels(patient_df)
print(f"🧐 Found {len(mismatch_df)} mismatches")
display(mismatch_df.head())

🧐 Found 500 mismatches


Unnamed: 0,patient_id,trial_name,true_label,predicted_label
0,ASCE_0501,ASCENT-05,Not Eligible,Eligible
1,ASCE_0502,ASCENT-05,Not Eligible,Eligible
2,ASCE_0503,ASCENT-05,Not Eligible,Eligible
3,ASCE_0504,ASCENT-05,Not Eligible,Eligible
4,ASCE_0505,ASCENT-05,Not Eligible,Eligible



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

