In [None]:

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"                 
os.environ["HF_PYTORCH_ATTENTION_BACKEND"] = "eager" #force eager attention

from torch.backends.cuda import sdp_kernel
sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)

from unsloth import FastLanguageModel
from peft import PeftModel

BASE = "openai/gpt-oss-20b"
LORA_PATH = "/home/alhusain/scratch/ondevice-llm/lora_cot_alhusain"
MAX_SEQ_LEN = 4096

# Load base model in 4-bit
model_base, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE,
    max_seq_length=MAX_SEQ_LEN,
    load_in_4bit=True,
    fast_inference=False,         
    attn_implementation="eager",   
    float8_kv_cache=True,
)

# Attach LoRA adapter (inference)
model = PeftModel.from_pretrained(model_base, LORA_PATH)
model.eval()

# Tokenizer tweaks for generation
tokenizer.padding_side = "left"
tokenizer.pad_token    = tokenizer.eos_token

print("Ready. PEFT attached:", isinstance(model, PeftModel))
print("Active adapters:", list(getattr(model, "peft_config", {}).keys()))
print("Device:", next(model.parameters()).device, "| Max seq len:", MAX_SEQ_LEN)


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.9.5: Fast Gpt_Oss patching. Transformers: 4.56.1.
   \\   /|    NVIDIA RTX 6000 Ada Generation. Num GPUs = 4. Max memory: 47.383 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gpt_Oss does not support SDPA - switching to fast eager.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.


Unsloth: Making `model.base_model.model.model` require gradients


In [None]:
# Diverse-beam CoT inference with external per-beam scoring
import re, json, math, torch, random, numpy as np, pandas as pd
from collections import defaultdict, Counter
from tqdm import tqdm

_device = next(model.parameters()).device
if _device.type == "cuda":
    torch.cuda.set_device(_device.index or 0)

DATA_DIR = "/home/alhusain/scratch/ondevice-llm"
VAL_CSV  = f"{DATA_DIR}/eurorad_val.csv"

AGREE_THRESHOLD = 0.5


df = pd.read_csv(VAL_CSV)
df = df.iloc[38:].copy()

# Prompt 
def make_msgs_from_row(row):
    post = str(row.get("PostDescription", "")).strip()
    ddx  = str(row.get("DifferentialDiagnosisList", "")).strip()
    user = (
        "You are an expert radiologist.\n"
        "Given a clinical case description and a finite list of candidate diagnoses,\n"
        "reason step by step using only the provided information, and then choose the single most likely final diagnosis FROM THE LIST.\n\n"
        f"PostDescription: {post}\n"
        f"DifferentialDiagnosisList: {ddx}\n\n"
        "Response rules:\n"
        "1) First, provide your reasoning inside <analysis> ... </analysis>. and keep it under 10 sentences.\n"
        "2) Then, provide the single chosen diagnosis inside <final> ... </final>.\n"
        "3) The <final> diagnosis must be copied VERBATIM from the DifferentialDiagnosisList.\n"
        "4) Do not include any text outside of the <analysis> and <final> blocks."
    )
    return [
        {"role": "system", "content": "You are an expert radiologist."},
        {"role": "user",   "content": user},
    ]

# Parsers
ANALYSIS_RE = re.compile(r"<analysis>\s*(.*?)\s*</analysis>", re.I | re.S)
FINAL_RE    = re.compile(r"<final>\s*(.*?)\s*</final>",       re.I | re.S)

def parse_analysis_final(text: str):
    text = text or ""
    a = (ANALYSIS_RE.search(text).group(1).strip()
         if ANALYSIS_RE.search(text) else "")
    f = (FINAL_RE.search(text).group(1).strip()
         if FINAL_RE.search(text) else "")
    if not f:
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines: f = lines[-1]
    return a, f

def parse_final_only(text: str):
    m = FINAL_RE.search(text or "")
    if m: return m.group(1).strip()
    lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]
    return lines[-1] if lines else ""

# Utilities 
def _model_compute_dtype(model):
    for p in model.parameters():
        if p is not None:
            return p.dtype
    return torch.float32

# Diverse beams (deterministic)     
def gen_cot_beams(msgs, effort="high", max_new_tokens=512,
                  beams=8, diversity_penalty=0.35, length_penalty=1.0, no_repeat_ngram_size=0):
    # Encode once; keep these tensors to later score beams under the same context
    enc = tokenizer.apply_chat_template(
        msgs,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
        reasoning_effort=effort,
    )
    enc = {k: v.to(model.device) for k, v in enc.items()}

    with torch.inference_mode():
        out = model.generate(
            **enc,
            do_sample=False,                   
            num_beams=beams,
            num_beam_groups=beams,             
            num_return_sequences=beams,
            diversity_penalty=diversity_penalty,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=False,
            output_scores=False,
        )

    # Return both texts (for parsing) and the raw token IDs (for scoring)
    prompt_len = enc["input_ids"].shape[1]
    gens, cont_ids_list = [], []
    for i in range(out.size(0)):
        cont_ids = out[i, prompt_len:] # continuation ids
        cont_ids_list.append(cont_ids.unsqueeze(0)) # [1, T_i]
        gens.append(tokenizer.decode(cont_ids, skip_special_tokens=True).strip())
    return gens, enc, cont_ids_list

# External per-beam scoring (length-normalized logprob) 
@torch.no_grad()
def score_continuation_given_ids(model, enc_input_ids, cont_ids) -> float:
    """
    Compute mean log p(cont_ids | enc_input_ids) directly on token IDs (no decode->encode).
    Shapes:
      enc_input_ids: [1, Lp]
      cont_ids:      [1, Lc]
    """
    input_ids = torch.cat([enc_input_ids, cont_ids], dim=1) #[1, Lp+Lc]
    attn_mask = torch.ones_like(input_ids)

    compute_dtype = _model_compute_dtype(model)
    with torch.autocast(device_type="cuda", dtype=compute_dtype, enabled=compute_dtype != torch.float32):
        logits = model(input_ids=input_ids[:, :-1], attention_mask=attn_mask[:, :-1]).logits  #[1, Lp+Lc-1, V]

    label_start = enc_input_ids.shape[1]
    seg_logits  = logits[:, label_start-1 : input_ids.shape[1]-1, :] #[1, Lc, V]
    seg_target  = input_ids[:, label_start : ] #[1, Lc]
    tok_logprobs = torch.log_softmax(seg_logits, dim=-1).gather(-1, seg_target.unsqueeze(-1)).squeeze(-1) #[1, Lc]
    return float(tok_logprobs.sum().item()) / max(int(seg_target.numel()), 1)

# Seeds & eval 
model.eval()
torch.manual_seed(1234); random.seed(1234); np.random.seed(1234)
try:
    model.generation_config.trust_remote_code = True
except Exception:
    pass

# Run inference 
top_reasonings, final_answers = [], []
beam_labels_col, beam_scores_col, all_beams_text = [], [], []


for _, row in tqdm(df.iterrows(), total=len(df)):
    msgs = make_msgs_from_row(row)

    gens, enc, cont_ids_list = gen_cot_beams(
        msgs,
        max_new_tokens=512,
        beams=18,
        diversity_penalty=0.5,
        length_penalty=1.0,
        no_repeat_ngram_size=0,
    )

    # Compute external scores for each continuation using token IDs
    enc_input_ids = enc["input_ids"]
    per_beam_scores = []
    for cont_ids in cont_ids_list:
        s = score_continuation_given_ids(model, enc_input_ids, cont_ids)
        per_beam_scores.append(s)

    # Parse labels once from text (only for extracting <final>)
    per_beam_labels = [parse_final_only(g) for g in gens]

    # consensus rule 
    counts = Counter([lab for lab in per_beam_labels if lab])
    if counts:
        top_lab, top_count = counts.most_common(1)[0]
        ratio = top_count / len(per_beam_labels)
    else:
        top_lab, ratio = "", 0.0

    if ratio >= AGREE_THRESHOLD and top_lab != "":
        # Use consensus label; take reasoning from the earliest beam that predicted it
        chosen_idx = next(i for i, lab in enumerate(per_beam_labels) if lab == top_lab)
        chosen_label = top_lab
    else:
        # Use the most confident beam (highest mean log-prob)
        chosen_idx = int(np.argmax(per_beam_scores))
        chosen_label = per_beam_labels[chosen_idx]

    # Reasoning from the chosen beam
    chosen_reasoning, _ = parse_analysis_final(gens[chosen_idx])

    # Collect outputs (lists -> JSON strings for CSV)
    top_reasonings.append(chosen_reasoning)
    final_answers.append(chosen_label)
    beam_labels_col.append(json.dumps(per_beam_labels, ensure_ascii=False))
    beam_scores_col.append(json.dumps(per_beam_scores, ensure_ascii=False))  # length-normalized logprobs
    all_beams_text.append(" ||| ".join(gens))

# Save
OUT_CSV  = f"{DATA_DIR}/val_preds_cot_divbeam_extscores_long_prompt_0.5diversity.csv"

df_out = df.copy()
df_out["reasoning"]         = top_reasonings
df_out["final_answer"]      = final_answers # set by the consensus rule or max-confidence
df_out["beam_final_labels"] = beam_labels_col
df_out["beam_sequence_scores"] = beam_scores_col # our external scores (mean log-probs)
df_out["raw_generations"]   = all_beams_text

df_out.to_csv(OUT_CSV, index=False)
print("Saved:", OUT_CSV)

  0%|          | 0/171 [00:00<?, ?it/s]Group Beam Search is scheduled to be moved to a `custom_generate` repository in v4.55.0. To prevent loss of backward compatibility, add `trust_remote_code=True` to your `generate` call.
  5%|▌         | 9/171 [07:03<2:10:09, 48.21s/it]

In [None]:
OUT_CSV  = f"{DATA_DIR}/val_preds_newscored_18beams_v2.csv"

df_out = df.head(17).copy()

df_out["reasoning"]         = top_reasonings
df_out["final_answer"]      = final_answers           
df_out["beam_final_labels"] = beam_labels_col
df_out["beam_sequence_scores"] = beam_scores_col
df_out["raw_generations"]   = all_beams_text

df_out.to_csv(OUT_CSV, index=False)
print("Saved:", OUT_CSV)