# Instruct Model Inference and Final Evaluation

##Installing Dependencies

In [1]:
%%capture installation_log
!pip install vllm datasets -q

In [2]:
# Core Python libraries
from datasets import load_dataset
from pprint import pprint
import json
import pandas as pd

In [3]:
from google.colab import userdata
import os
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
hf_profile = 'aymangomaa'

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Load Dataset

In [5]:
from datasets import load_dataset

dataset = load_dataset(f"{hf_profile}/entity_extraction_ade_v2_chat_base")
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'relations', 'messages'],
        num_rows: 3458
    })
    validation: Dataset({
        features: ['text', 'relations', 'messages'],
        num_rows: 385
    })
    test: Dataset({
        features: ['text', 'relations', 'messages'],
        num_rows: 428
    })
})

In [6]:
test_data = dataset["test"]

In [7]:
pprint(test_data[9])

{'messages': [{'content': 'Extract all adverse drug effect (ADE) relationships '
                          'from the sentence. ### TEXT: METHODS: We report a '
                          'patient who had an anaphylactic reaction during the '
                          'intravenous infusion of cyclosporine.',
               'role': 'user'},
              {'content': '[{"ade": "anaphylactic reaction", "drug": '
                          '"cyclosporine"}]',
               'role': 'assistant'}],
 'relations': [{'ade': 'anaphylactic reaction', 'drug': 'cyclosporine'}],
 'text': 'METHODS: We report a patient who had an anaphylactic reaction during '
         'the intravenous infusion of cyclosporine.'}


In [8]:
# ✅ Improved Prompt Template for Qwen Instruct (ChatML)
def format_qwen_chat_prompt(text):
    """
    Formats clinical text into a structured ChatML prompt for Qwen Instruct models.

    Emphasizes structured output to improve ADE–drug extraction accuracy.

    Parameters:
        text (str): A clinical sentence describing possible ADE–drug relations.

    Returns:
        str: ChatML-formatted prompt with clear JSON extraction instruction.
    """
    sanitized_text = text.strip().replace("\n", " ")

    return (
        "<|im_start|>user\n"
        "You are a medical AI assistant. Given a clinical sentence, extract all adverse drug effect (ADE) "
        "relationships and return them as a JSON list of objects. "
        "Each object should contain two fields: 'drug' and 'ade'. "
        "Do not include any explanation or extra text.\n\n"
        f"### Sentence:\n{sanitized_text}\n\n"
        "### Expected Output Format:\n"
        '[{"drug": "drug_name", "ade": "adverse_effect"}, ...]\n'
        "<|im_end|>\n"
        "<|im_start|>assistant\n"
    )

# Generate prompts from test set
prompts = [format_qwen_chat_prompt(x['text']) for x in test_data]

# Preview one prompt
from pprint import pprint
pprint(prompts[4])


('<|im_start|>user\n'
 'You are a medical AI assistant. Given a clinical sentence, extract all '
 'adverse drug effect (ADE) relationships and return them as a JSON list of '
 "objects. Each object should contain two fields: 'drug' and 'ade'. Do not "
 'include any explanation or extra text.\n'
 '\n'
 '### Sentence:\n'
 'To the best of our knowledge, this is the first case of lithium-associated '
 'CDI and NDI presenting concurrently.\n'
 '\n'
 '### Expected Output Format:\n'
 '[{"drug": "drug_name", "ade": "adverse_effect"}, ...]\n'
 '<|im_end|>\n'
 '<|im_start|>assistant\n')


## Inference with Qwen3-1.7B-instruct + QLoRA

In [9]:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from huggingface_hub import snapshot_download
import json

INFO 05-10 20:19:51 [__init__.py:239] Automatically detected platform cuda.


In [10]:
# Load Model
import torch
base_model = "Qwen/Qwen3-1.7B"
lora_repo = "aymangomaa/drug-ade-extraction-finetuned-instruct-2"
adapter_path = snapshot_download(repo_id=lora_repo)
llm = LLM(model=base_model, enable_lora=True, max_lora_rank=128,dtype=torch.float16)
sampling_params = SamplingParams(temperature=0.0, max_tokens=512)
lora_request = LoRARequest("qwen3_instruct_adapter", 1, adapter_path)

Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

INFO 05-10 20:20:11 [config.py:717] This model supports multiple tasks: {'classify', 'embed', 'generate', 'reward', 'score'}. Defaulting to 'generate'.
INFO 05-10 20:20:11 [llm_engine.py:240] Initializing a V0 LLM engine (v0.8.5.post1) with config: model='Qwen/Qwen3-1.7B', speculative_config=None, tokenizer='Qwen/Qwen3-1.7B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=40960, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_n

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 05-10 20:20:17 [loader.py:458] Loading weights took 1.74 seconds
INFO 05-10 20:20:17 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 05-10 20:20:18 [model_runner.py:1140] Model loading took 3.4930 GiB and 2.918111 seconds
INFO 05-10 20:20:29 [worker.py:287] Memory profiling takes 10.75 seconds
INFO 05-10 20:20:29 [worker.py:287] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.90) = 13.27GiB
INFO 05-10 20:20:29 [worker.py:287] model weights take 3.49GiB; non_torch_memory takes 0.05GiB; PyTorch activation peak memory takes 2.06GiB; the rest of the memory reserved for KV Cache is 7.67GiB.
INFO 05-10 20:20:29 [executor_base.py:112] # cuda blocks: 4486, # CPU blocks: 2340
INFO 05-10 20:20:29 [executor_base.py:117] Maximum concurrency for 40960 tokens per request: 1.75x
INFO 05-10 20:20:33 [model_runner.py:1450] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager 

Capturing CUDA graph shapes:   0%|          | 0/35 [00:00<?, ?it/s]

INFO 05-10 20:21:09 [model_runner.py:1592] Graph capturing finished in 37 secs, took 0.40 GiB
INFO 05-10 20:21:09 [llm_engine.py:437] init engine (profile, create kv cache, warmup model) took 51.77 seconds


In [11]:
# Generate Predictions
outputs_base = llm.generate(prompts, sampling_params, lora_request=lora_request)
# Save Raw Predictions
with open("outputs_qwen3_instruct.json", "w") as f:
    json.dump([o.outputs[0].text for o in outputs_base], f, indent=2)

Processed prompts:   0%|          | 0/428 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]



In [54]:
# Clean Generated Outputs and Normalization
# Regex cleanup of raw predictions to extract structured relationships.
import json
import re
import pandas as pd

# Load raw outputs
with open("outputs_qwen3_instruct.json", "r") as f:
    raw_outputs = json.load(f)

# Extract clean ADE-drug pairs
cleaned_outputs = []

for output in raw_outputs:
    try:
        # Strip <think>...</think>
        cleaned = re.sub(r"<think>.*?</think>", "", output, flags=re.DOTALL).strip()

        # Extract all JSON-like {...} objects inside brackets
        list_match = re.findall(r'{\s*"drug"\s*:\s*".+?",\s*"ade"\s*:\s*".+?"\s*}', cleaned, flags=re.DOTALL)

        parsed = [json.loads(item) for item in list_match] if list_match else []

    except Exception as e:
        parsed = []
    cleaned_outputs.append(parsed)

# Save cleaned predictions
with open("outputs_qwen3_instruct_cleaned.json", "w") as f:
    json.dump(cleaned_outputs, f, indent=2)




In [55]:
# Load predictions and ground truth
import json
import pandas as pd
from IPython.display import display

with open("outputs_qwen3_instruct_cleaned.json") as f:
    cleaned_outputs = json.load(f)

# Load test_data if not already loaded
# with open("test_data.json") as f:
#     test_data = json.load(f)

results = []

for idx in range(len(test_data)):
    true_set = set(
        (rel["drug"].lower(), rel["ade"].lower())
        for rel in test_data[idx]["relations"]
    )

    pred_set = set(
        (rel.get("drug", "").lower(), rel.get("ade", "").lower())
        for rel in cleaned_outputs[idx]
        if "drug" in rel and "ade" in rel
    )

    correct_set = true_set & pred_set

    results.append({
        "idx": idx,
        "text": test_data[idx]["text"],
        "ground_truth": list(true_set),
        "prediction": list(pred_set),
        "correct": list(correct_set)
    })

# Save results
with open("qwen3_instruct_all_evaluated.json", "w") as f:
    json.dump(results, f, indent=2)

# Display top 30 rows
df = pd.DataFrame(results)
display(df[["idx", "text", "ground_truth", "prediction", "correct"]].head(10))



Unnamed: 0,idx,text,ground_truth,prediction,correct
0,0,We present a case report of a patient with typ...,"[(chloramphenicol sodium succinate, hypersensi...","[(chloramphenicol sodium succinate, hypersensi...","[(chloramphenicol sodium succinate, hypersensi..."
1,1,The ototoxicity of quinine can accurately be s...,"[(quinine, ototoxicity)]","[(quinine, ototoxicity)]","[(quinine, ototoxicity)]"
2,2,Patient 1 presented bilateral ballism 1 week a...,"[(heroin, bilateral ballism)]","[(heroin, bilateral ballism)]","[(heroin, bilateral ballism)]"
3,3,A 58-year-old woman developed unilateral acute...,"[(scopolamine, unilateral acute angle-closure ...","[(scopolamine, unilateral acute angle-closure ...","[(scopolamine, unilateral acute angle-closure ..."
4,4,"To the best of our knowledge, this is the firs...","[(lithium, ndi), (lithium, cdi)]","[(lithium, ndi), (lithium, cdi)]","[(lithium, ndi), (lithium, cdi)]"
5,5,CONCLUSION: A 26-year-old man with bipolar dis...,"[(carbamazepine, hyperammonemia)]","[(carbamazepine, hyperammonemia)]","[(carbamazepine, hyperammonemia)]"
6,6,RESULTS: Quetiapine was associated with leucop...,"[(quetiapine, leucopenia), (quetiapine, agranu...","[(quetiapine, leucopenia), (quetiapine, agranu...","[(quetiapine, leucopenia), (quetiapine, agranu..."
7,7,Hepatopathy subsided after the cessation of ca...,"[(carbamazepine, hepatopathy), (lynestrenol, h...",[],[]
8,8,Carbamazepine induced right bundle branch bloc...,"[(carbamazepine, right bundle branch block)]","[(carbamazepine, right bundle branch block)]","[(carbamazepine, right bundle branch block)]"
9,9,METHODS: We report a patient who had an anaphy...,"[(cyclosporine, anaphylactic reaction)]","[(cyclosporine, anaphylactic reaction)]","[(cyclosporine, anaphylactic reaction)]"


In [56]:
import json
import re

# Load cleaned predictions
with open("outputs_qwen3_instruct_cleaned.json", "r") as f:
    preds = json.load(f)

# Load test data (you must have already loaded or defined `test_data`)
# For example: test_data = load_dataset("your_dataset")["test"]

# Utility: Normalize text
def normalize(text):
    return re.sub(r"[^a-z0-9]", "", text.lower().strip())

# Utility: Extract normalized (drug, ade) pairs
def extract_pairs(rel_list):
    try:
        return set(
            (normalize(d.get("drug", "")), normalize(d.get("ade", "")))
            for d in rel_list
            if "drug" in d and "ade" in d
        )
    except:
        return set()

# Compute metrics
tp, fp, fn = 0, 0, 0

for i in range(len(test_data)):
    true_pairs = extract_pairs(test_data[i]["relations"])
    pred_pairs = extract_pairs(preds[i])

    tp += len(true_pairs & pred_pairs)
    fp += len(pred_pairs - true_pairs)
    fn += len(true_pairs - pred_pairs)

precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)

# Save score to file
score_dict = {
    "TP": tp,
    "FP": fp,
    "FN": fn,
    "Precision": round(precision, 4),
    "Recall": round(recall, 4),
    "F1": round(f1, 4)
}

with open("eval_score_instruct.json", "w") as f:
    json.dump(score_dict, f, indent=2)

# Print summary
print("✅ Evaluation Results:")
for k, v in score_dict.items():
    print(f"{k}: {v}")


✅ Evaluation Results:
TP: 298
FP: 72
FN: 370
Precision: 0.8054
Recall: 0.4461
F1: 0.5742
