In [2]:
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from peft import PeftConfig, AutoPeftModelForCausalLM, PeftModel
import torch
import os

import pandas as pd
from datasets import load_dataset
from tqdm.auto import tqdm

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type = "nf4",
    bnb_4bit_compute_dtype = torch.bfloat16
)

adapter_name = "../adapter/Zip-Llama-aligned"

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    quantization_config = bnb_config,
    attn_implementation = "eager",
    use_cache = True,
    dtype = torch.bfloat16,
    device_map = "cuda:0"
)

model = PeftModel.from_pretrained(model, os.path.join(adapter_name, "policy"))
tokenizer = AutoTokenizer.from_pretrained(adapter_name, use_fast = True)

Loading checkpoint shards: 100%|██████████| 4/4 [00:28<00:00,  7.18s/it]


In [3]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

In [None]:
vllm_inf = pd.read_csv("../data/inference_result_greedy.csv")

In [None]:
gen_ds = load_dataset("json", data_files = "../data/inference_data.json", split = "train")

In [10]:
idx = 0

input_ids = tokenizer.apply_chat_template(
                gen_ds[idx]["messages"],
                add_generation_prompt=True,
                return_tensors="pt"
).to(model.device)

terminators = [tokenizer.eos_token_id]

outputs = model.generate(
    input_ids,
    max_new_tokens=512,
    eos_token_id=terminators,
    pad_token_id=tokenizer.eos_token_id,
    do_sample=False,
    num_beams=1
)

response = outputs[0][input_ids.shape[-1]:]
generation = tokenizer.decode(response, skip_special_tokens=True)

In [11]:
print(generation)

A female patient presented with foot ulcer and hammer toes; on physical examination, BP NA, HR NA, RR NA, and Temp NA; admission labs showed WBC 5.6, RBC 4.15, Hgb 11.1, Hct 32.9, Plt NA, MCV 79, MCH 26.6, MCHC 33.6, RDW 13.7, and Glucose 114; the most diagnostically relevant finding was a right submetatarsal ulcer with partial-thickness skin loss and exposed bone.


In [14]:
print(vllm_inf.iloc[0, 1])

A female patient presented with foot ulcer and hammer toes; on physical examination, BP NA, HR NA, RR NA and Temp NA; admission labs showed WBC 5.6, RBC 4.15, Hgb 11.1, Hct 32.9, Plt NA, MCV 79, MCH 26.6, MCHC 33.6, RDW 13.7, and Glucose 114; the most diagnostically relevant finding was a right submetatarsal ulcer with hammer digit syndrome affecting 4 toes.


In [26]:
print(generation)

A female patient presented with foot ulcer and hammer toes; on physical examination, BP NA, HR NA, RR NA, and Temp NA; admission labs showed WBC 5.6, RBC 4.15, Hgb 11.1, Hct 32.9, Plt NA, MCV 79, MCH 26.6, MCHC 33.6, RDW 13.7, and Glucose 114; the most diagnostically relevant finding was a right submetatarsal ulcer with hammer digit syndrome.


In [11]:
print(generation)

A female patient presented with foot ulcer and hammer toes; on physical examination, BP NA, HR NA, RR NA, and Temp NA; admission labs showed WBC 5.6, RBC 4.15, Hgb 11.1, Hct 32.9, Plt NA, MCV 79, MCH 26.6, MCHC 33.6, RDW 13.7, and Glucose 114; the most diagnostically relevant finding was a right submetatarsal ulcer with partial-thickness skin loss and exposed bone.


In [None]:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
import pandas as pd


def template_dataset(example):
    return {"prompt": tokenizer.apply_chat_template(example["messages"], tokenize = False, add_generation_prompt = True)}

if __name__ == "__main__":
    base_model_path = "../base_model/Llama-3.1-8B-Instruct-nf4"
    adapter_path = "../adapter/Zip-Llama-aligned/policy"

    llm = LLM(
        model= base_model_path,
        dtype=torch.bfloat16,
        trust_remote_code = True,
        max_model_len = 32768,
        gpu_memory_utilization = 0.3,
        enable_lora = True,
        max_lora_rank = 64
    )

    tokenizer = AutoTokenizer.from_pretrained(base_model_path, use_fast = True)
    inference_data = load_dataset("json", data_files = "../data/inference_data.json", split = "train")
    inference_data = inference_data.map(template_dataset, remove_columns = ["messages"])
    prompts = inference_data["prompt"]

  from .autonotebook import tqdm as notebook_tqdm


INFO 11-20 10:59:16 [utils.py:253] non-default args: {'trust_remote_code': True, 'dtype': torch.bfloat16, 'max_model_len': 32768, 'gpu_memory_utilization': 0.3, 'disable_log_stats': True, 'enable_lora': True, 'max_lora_rank': 64, 'model': 'base_model/Llama-3.1-8B-Instruct-nf4'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 11-20 10:59:16 [model.py:631] Resolved architecture: LlamaForCausalLM
INFO 11-20 10:59:16 [model.py:1745] Using max model len 32768


2025-11-20 10:59:17,106	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 11-20 10:59:17 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=16384.
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:17 [core.py:93] Initializing a V1 LLM engine (v0.11.1) with config: model='base_model/Llama-3.1-8B-Instruct-nf4', speculative_config=None, tokenizer='base_model/Llama-3.1-8B-Instruct-nf4', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=bitsandbytes, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=Obse

Traceback (most recent call last):
  File "/root/miniconda3/envs/LLM/lib/python3.12/site-packages/tvm_ffi/utils/_build_optional_torch_c_dlpack.py", line 836, in <module>
    main()
  File "/root/miniconda3/envs/LLM/lib/python3.12/site-packages/tvm_ffi/utils/_build_optional_torch_c_dlpack.py", line 829, in main
    build_ninja(build_dir=str(build_dir))
  File "/root/miniconda3/envs/LLM/lib/python3.12/site-packages/tvm_ffi/cpp/extension.py", line 353, in build_ninja
    raise RuntimeError("\n".join(msg))
RuntimeError: ninja exited with status 1
stdout:
[1/2] c++ -MMD -MF main.o.d -std=c++17 -fPIC -O3 -DBUILD_WITH_CUDA -D_GLIBCXX_USE_CXX11_ABI=1 -I/root/miniconda3/envs/LLM/lib/python3.12/site-packages/tvm_ffi/include -I/root/miniconda3/envs/LLM/include/python3.12 -I/root/miniconda3/envs/LLM/lib/python3.12/site-packages/torch/include -I/root/miniconda3/envs/LLM/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/usr/local/cuda-12.3/include -c /tmp/tvm-ffi-torch-c-dlpack-gr

[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:24 [cuda.py:418] Valid backends: ['FLASH_ATTN', 'FLASHINFER', 'TRITON_ATTN', 'FLEX_ATTENTION']
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:24 [cuda.py:427] Using FLASH_ATTN backend.
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:24 [bitsandbytes_loader.py:791] Loading weights with BitsAndBytes quantization. May take a while ...


Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:00<00:00, 25.70it/s]
[1;36m(EngineCore_DP0 pid=3440340)[0;0m 
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:05<00:05,  5.34s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:06<00:00,  2.96s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:06<00:00,  3.32s/it]
[1;36m(EngineCore_DP0 pid=3440340)[0;0m 


[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:31 [punica_selector.py:20] Using PunicaWrapperGPU.
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:31 [gpu_model_runner.py:3334] Model loading took 6.0187 GiB memory and 11.383862 seconds
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:44 [backends.py:631] Using cache directory: /root/.cache/vllm/torch_compile_cache/9b5f756d58/rank_0_0/backbone for vLLM's torch.compile
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:44 [backends.py:647] Dynamo bytecode transform time: 12.31 s
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:48 [backends.py:210] Directly load the compiled graph(s) for dynamic shape from the cache, took 3.238 s
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:51 [monitor.py:34] torch.compile takes 15.55 s in total
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 10:59:52 [gpu_worker.py:359] Available KV cache memory: 30.96 GiB
[1;36m(EngineCore_DP0 pid=344

Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   0%|          | 0/102 [00:00<?, ?it/s]



Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 102/102 [00:15<00:00,  6.72it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 102/102 [00:13<00:00,  7.29it/s]


[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 11:00:22 [gpu_model_runner.py:4240] Graph capturing finished in 30 secs, took 3.21 GiB
[1;36m(EngineCore_DP0 pid=3440340)[0;0m INFO 11-20 11:00:22 [core.py:250] init engine (profile, create kv cache, warmup model) took 51.07 seconds
INFO 11-20 11:00:24 [llm.py:352] Supported tasks: ['generate']


In [5]:
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=512
)

In [29]:
output = llm.generate(
    [prompts[0]],
    sampling_params,
    lora_request = LoRARequest("adapter", 1, adapter_path)
)

Adding requests: 100%|██████████| 1/1 [00:00<00:00, 70.97it/s]
Processed prompts: 100%|██████████| 1/1 [00:01<00:00,  1.25s/it, est. speed input: 2002.06 toks/s, output: 96.13 toks/s]


In [27]:
output[0].outputs[0].text

'A female patient presented with foot ulcer and hammer toes; on physical examination, BP NA, HR NA, RR NA and Temp NA; admission labs showed WBC 5.6, RBC 4.15, Hgb 11.1, Hct 32.9, Plt NA, MCV 79, MCH 26.6, MCHC 33.6, RDW 13.7, and Glucose 114; the most diagnostically relevant finding was a right submetatarsal ulcer with hammer digit syndrome of the right foot.'

In [30]:
output[0].outputs[0].text

'A female patient presented with foot ulcer and hammer toes; on physical examination, BP NA, HR NA, RR NA and Temp NA; admission labs showed WBC 5.6, RBC 4.15, Hgb 11.1, Hct 32.9, Plt NA, MCV 79, MCH 26.6, MCHC 33.6, RDW 13.7, and Glucose 114; the most diagnostically relevant finding was a right submetatarsal ulcer with hammer digit syndrome of the right foot.'

In [40]:
prompts[0] == output[0].prompt

True

In [6]:
gen_ds = gen_ds.map(
    lambda example: {"token_len": len(tokenizer.apply_chat_template(example["messages"], tokenize = True))}
)

Map: 100%|██████████| 10000/10000 [00:43<00:00, 231.17 examples/s]


In [7]:
import numpy as np

token_len = np.array(gen_ds["token_len"])

In [8]:
np.argmax(token_len)

np.int64(8922)

In [9]:
output = llm.generate(
    [prompts[np.argmax(token_len)]],
    sampling_params,
    lora_request = LoRARequest("adapter", 1, adapter_path)
)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]



Adding requests: 100%|██████████| 1/1 [00:00<00:00, 15.13it/s]
Processed prompts: 100%|██████████| 1/1 [00:04<00:00,  4.10s/it, est. speed input: 4194.99 toks/s, output: 29.79 toks/s]


In [None]:
output[0].prompt

'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are the world’s leading expert in survival analysis. From a discharge summary, extract Chief Complaint, Physical Exam, and Admission Labs (Pertinent Results) and produce one sentence. The sentence will be used for hazard calculation, so be precise, clinically accurate, and concise.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n \r\nName:  ___                   Unit No:   ___\r\n \r\nAdmission Date:  ___              Discharge Date:   ___\r\n \r\nDate of Birth:  ___             Sex:   M\r\n \r\nService: MEDICINE\r\n \r\nAllergies: \r\nPatient recorded as having No Known Allergies to Drugs\r\n \r\nAttending: ___\r\n \r\nChief Complaint:\r\nHeadache\r\n \r\nMajor Surgical or Invasive Procedure:\r\nRight central venous line placement (___)\r\n \r\nHistory of Present Illness:\r\n___ year old male with 3 weeks of worsening headache. Around 4 \r\nweeks back he had fevers with sore throat and was found to have \r\ntonsi

In [13]:
prompts[np.argmax(token_len)] == output[0].prompt

True