In [None]:
from unsloth import FastLanguageModel
from transformers import AutoTokenizer

MODEL_NAME = "openai/gpt-oss-20b"
MAX_SEQ_LEN = 4096

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_NAME,
    max_seq_length = MAX_SEQ_LEN,
    load_in_4bit = True, # QLoRA
)
# Prepare LoRA
model = FastLanguageModel.get_peft_model(
    model,
    r=32, lora_alpha=64, lora_dropout=0.05,
    target_modules="all-linear",
)


tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

In [None]:
from datasets import load_dataset


DATA_DIR = "/home/alhusain/scratch/ondevice-llm/eurorad"

# Harmony responses template
raw = load_dataset(
    "json",
    data_files={"train": f"{DATA_DIR}/train_cot.jsonl",
                "val":   f"{DATA_DIR}/val.jsonl"}
)

def to_text(batch):
    texts = []
    for msgs in batch["messages"]:
        text = tokenizer.apply_chat_template(
            msgs,
            tokenize=False,
            add_generation_prompt=False,  # training on full dialogues including the label
            reasoning_effort="high"
        )
        texts.append(text)
    return {"text": texts}

ds = raw.map(to_text, batched=True, remove_columns=raw["train"].column_names)

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments

MAX_SEQ_LEN = 4096
OUTPUT_DIR = "/home/alhusain/scratch/ondevice-llm"

args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-4,
    num_train_epochs=2,
    logging_steps=20,
    save_steps=100,
    eval_strategy="steps",
    eval_steps=200,
    bf16=True,
    fp16=False,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    gradient_checkpointing=True,
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=ds["train"],
    eval_dataset=ds["val"],
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LEN,
    packing=False,
    args=args,
)

trainer.train()
trainer.save_model(f"{OUTPUT_DIR}/lora_long_prompt")

In [None]:
from unsloth import FastLanguageModel
from peft import PeftModel

BASE = "openai/gpt-oss-20b"
LORA_PATH = "/home/alhusain/scratch/ondevice-llm/lora_long_prompt/"

model_base, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE, max_seq_length=4096, load_in_4bit=True,
)
model = PeftModel.from_pretrained(model_base, LORA_PATH)
model.eval()

print("Using PEFT adapter:", isinstance(model, PeftModel))
print("Active adapters:", list(getattr(model, "peft_config", {}).keys()))

In [None]:
# Diverse-beam CoT inference with external per-beam scoring
import re, json, math, torch, random, numpy as np, pandas as pd
from collections import defaultdict, Counter
from tqdm import tqdm

_device = next(model.parameters()).device
if _device.type == "cuda":
    torch.cuda.set_device(_device.index or 0)

DATA_DIR = "/home/alhusain/scratch/ondevice-llm"
VAL_CSV  = f"{DATA_DIR}/eurorad_val.csv"


df = pd.read_csv(VAL_CSV)
#df = df.iloc[101:].copy()

# Prompt 
def make_msgs_from_row(row):
    post = str(row.get("PostDescription", "")).strip()
    ddx  = str(row.get("DifferentialDiagnosisList", "")).strip()
    user = (
        "You are an expert radiologist. You will receive a case and a finite list of candidate diagnoses. "
        "Reason step by step, then pick exactly one diagnosis copied verbatim from the list. "
        "You may use standard radiology knowledge to interpret the given findings (don't invent new findings).\n\n"
        f"PostDescription:\n{post}\n\n"
        f"DifferentialDiagnosisList:\n{ddx}\n\n"
        "Response rules:\n"
        "1) Provide <analysis> then <final>. No text outside these blocks.\n"
        "2) In <final>, copy the chosen label verbatim from the list. Do not include any other text.\n\n"
        "<analysis> content and order:\n"
        "A) Key findings (bullets): age/sex; clinical context; organ or site; modality and sequences; morphology and matrix; "
        "signal/enhancement; pathognomonic phrases; risk factors.\n"
        "B) Candidate pass (coverage required): for each i=1..K write\n"
        "   - [i] Pros: 1–2 explicit findings that support it (quote phrases from PostDescription).\n"
        "   - [i] Cons: 1–2 explicit contradictions or missing must-haves.\n"
        "   If the required organ/site is incompatible, mark CONTRADICTION and discard.\n"
        "C) Ranking with tie-breaks, in order:\n"
        "   1) Pathognomonic or hallmark findings win.\n"
        "   2) Exact anatomic site and laterality match.\n"
        "   3) Specific subtype over umbrella when supported.\n"
        "   4) Age, epidemiology, and stated risk factors.\n"
        "   5) If still tied, choose the narrowest diagnosis among remaining.\n"
        "D) Sanity check and alignment:\n"
        "   - State: Top candidate: [i] LABEL.\n"
        "   - Confirm LABEL exists verbatim in the list and matches your ranking.\n"
        "</analysis>\n\n"
        "<final> LABEL </final>\n"
    )
    return [
        {"role": "system", "content": "You are an expert radiologist."},
        {"role": "user",   "content": user},
    ]

# Parsers
ANALYSIS_RE = re.compile(r"<analysis>\s*(.*?)\s*</analysis>", re.I | re.S)
FINAL_RE    = re.compile(r"<final>\s*(.*?)\s*</final>",       re.I | re.S)

def parse_analysis_final(text: str):
    text = text or ""
    a = (ANALYSIS_RE.search(text).group(1).strip()
         if ANALYSIS_RE.search(text) else "")
    f = (FINAL_RE.search(text).group(1).strip()
         if FINAL_RE.search(text) else "")
    if not f:
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines: f = lines[-1]
    return a, f

def parse_final_only(text: str):
    m = FINAL_RE.search(text or "")
    if m: return m.group(1).strip()
    lines = [ln.strip() for ln in (text or "").splitlines() if ln.strip()]
    return lines[-1] if lines else ""

# Utilities 
def _model_compute_dtype(model):
    for p in model.parameters():
        if p is not None:
            return p.dtype
    return torch.float32

# Diverse beams (deterministic)     
def gen_cot_beams(msgs, effort="high", max_new_tokens=1024,
                  beams=8, diversity_penalty=0.35, length_penalty=1.0, no_repeat_ngram_size=0):
    # Encode once; keep these tensors to later score beams under the same context
    enc = tokenizer.apply_chat_template(
        msgs,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
        reasoning_effort=effort,
    )
    enc = {k: v.to(model.device) for k, v in enc.items()}

    with torch.inference_mode():
        out = model.generate(
            **enc,
            do_sample=False,                   
            num_beams=beams,
            num_beam_groups=beams,             
            num_return_sequences=beams,
            diversity_penalty=diversity_penalty,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=False,
            output_scores=False,
        )

    # Return both texts (for parsing) and the raw token IDs (for scoring)
    prompt_len = enc["input_ids"].shape[1]
    gens, cont_ids_list = [], []
    for i in range(out.size(0)):
        cont_ids = out[i, prompt_len:] # continuation ids
        cont_ids_list.append(cont_ids.unsqueeze(0)) # [1, T_i]
        gens.append(tokenizer.decode(cont_ids, skip_special_tokens=True).strip())
    return gens, enc, cont_ids_list

# External per-beam scoring (length-normalized logprob) 
@torch.no_grad()
def score_continuation_given_ids(model, enc_input_ids, cont_ids) -> float:
    """
    Compute mean log p(cont_ids | enc_input_ids) directly on token IDs (no decode->encode).
    Shapes:
      enc_input_ids: [1, Lp]
      cont_ids:      [1, Lc]
    """
    input_ids = torch.cat([enc_input_ids, cont_ids], dim=1) #[1, Lp+Lc]
    attn_mask = torch.ones_like(input_ids)

    compute_dtype = _model_compute_dtype(model)
    with torch.autocast(device_type="cuda", dtype=compute_dtype, enabled=compute_dtype != torch.float32):
        logits = model(input_ids=input_ids[:, :-1], attention_mask=attn_mask[:, :-1]).logits  #[1, Lp+Lc-1, V]

    label_start = enc_input_ids.shape[1]
    seg_logits  = logits[:, label_start-1 : input_ids.shape[1]-1, :] #[1, Lc, V]
    seg_target  = input_ids[:, label_start : ] #[1, Lc]
    tok_logprobs = torch.log_softmax(seg_logits, dim=-1).gather(-1, seg_target.unsqueeze(-1)).squeeze(-1) #[1, Lc]
    return float(tok_logprobs.sum().item()) / max(int(seg_target.numel()), 1)

# Seeds & eval 
model.eval()
torch.manual_seed(1234); random.seed(1234); np.random.seed(1234)
try:
    model.generation_config.trust_remote_code = True
except Exception:
    pass

# Run inference 
top_reasonings, final_answers = [], []
beam_labels_col, beam_scores_col, all_beams_text = [], [], []


for _, row in tqdm(df.iterrows(), total=len(df)):
    msgs = make_msgs_from_row(row)

    gens, enc, cont_ids_list = gen_cot_beams(
        msgs,
        max_new_tokens=1100,
        beams=11,
        diversity_penalty=0.5,
        length_penalty=1.0,
        no_repeat_ngram_size=0,
    )

    # Compute external scores for each continuation using token IDs
    enc_input_ids = enc["input_ids"]
    per_beam_scores = []
    for cont_ids in cont_ids_list:
        s = score_continuation_given_ids(model, enc_input_ids, cont_ids)
        per_beam_scores.append(s)

    # Parse labels from text (only for extracting <final>)
    per_beam_labels = [parse_final_only(g) for g in gens]

    # pick the most frequent non-empty label
    counts = Counter([lab for lab in per_beam_labels if lab])
    if counts:
        top_lab, _ = counts.most_common(1)[0]
        chosen_idx = next(i for i, lab in enumerate(per_beam_labels) if lab == top_lab)
        chosen_label = top_lab
    else:
        # all labels empty → fall back to first beam
        chosen_idx = 0
        chosen_label = per_beam_labels[0] if per_beam_labels else ""

    # Reasoning from the chosen beam
    chosen_reasoning, _ = parse_analysis_final(gens[chosen_idx])

    # Collect outputs
    top_reasonings.append(chosen_reasoning)
    final_answers.append(chosen_label)
    beam_labels_col.append(json.dumps(per_beam_labels, ensure_ascii=False))
    beam_scores_col.append(json.dumps(per_beam_scores, ensure_ascii=False))  # length-normalized logprobs
    all_beams_text.append(gens)
# Save
OUT_CSV  = f"{DATA_DIR}/val_preds_long_prompt_11beams.csv"

df_out = df.copy()
df_out["reasoning"]         = top_reasonings
df_out["final_answer"]      = final_answers # set by the consensus rule or max-confidence
df_out["beam_final_labels"] = beam_labels_col
df_out["beam_sequence_scores"] = beam_scores_col # mean log-probs
max_beams = max(len(x) for x in all_beams_text)
for j in range(max_beams):
    df_out[f"raw_generation{j+1}"] = [row[j] if j < len(row) else "" for row in all_beams_text]

df_out.to_csv(OUT_CSV, index=False)
print("Saved:", OUT_CSV)