In [None]:
import json
from datasets import load_dataset
from tqdm import tqdm

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

In [None]:
DATASET_NAME = "TheFinAI/FinMR_sUB"
SPLIT = "test"
QUERY_COL = "query"
ID_COL = "id"
GT_COL = "answer"

MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"  # Replace it with your vLLM model.
OUT_PATH = "predictions.jsonl"

In [None]:
# generation
MAX_TOKENS = 256
TEMPERATURE = 0.0
TOP_P = 1.0

In [None]:
# vLLM runtime
TENSOR_PARALLEL_SIZE = 2      # If you have 4 GPUs, enter 4 (70B models usually require multiple GPUs).
MAX_MODEL_LEN = 90000         # If your query is long, set the value to a larger number, but it must be less than or equal to the value supported by the model.
GPU_MEMORY_UTILIZATION = 0.90 # 0.85~0.95

In [None]:
# =========================
# Load dataset
# =========================
ds = load_dataset(DATASET_NAME, split=SPLIT)

In [None]:
# =========================
# Init vLLM
# =========================
llm = LLM(
    model=MODEL_NAME,
    tensor_parallel_size=TENSOR_PARALLEL_SIZE,
    max_model_len=MAX_MODEL_LEN,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    trust_remote_code=True,
)

In [None]:
sampling_params = SamplingParams(
    temperature=TEMPERATURE,
    top_p=TOP_P,
    max_tokens=MAX_TOKENS,
)

In [None]:
# =========================
# Batch inference
# =========================
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)

prompts = []
meta = []

for i, ex in enumerate(ds):
    q = ex.get(QUERY_COL, "")
    ex_id = ex.get(ID_COL, i)
    gt = ex.get(GT_COL, None)

    chat_prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": q}],
        tokenize=False,
        add_generation_prompt=True, 
    )

    prompts.append(chat_prompt)
    meta.append({
        "id": ex_id,
        "query": q,
        "ground_truth": gt,
    })

outputs = llm.generate(prompts, sampling_params)

In [None]:
outputs[0]

In [None]:
# =========================
# Save jsonl
# =========================
with open(OUT_PATH, "w", encoding="utf-8") as f:
    for m, out in zip(meta, outputs):
        pred_text = out.outputs[0].text.strip() if out.outputs else ""

        record = {
            "id": m["id"],
            "prediction": pred_text,
            "ground_truth": m["ground_truth"],
            # Optional: Retain the query for easier debugging.
            "query": m["query"],
        }

        f.write(json.dumps(record, ensure_ascii=False) + "\n")

print(f"Saved: {OUT_PATH}")