In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["HF_PYTORCH_ATTENTION_BACKEND"] = "eager"  # force eager attention

from torch.backends.cuda import sdp_kernel
sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)

from unsloth import FastLanguageModel
import re, json, math, torch, random, numpy as np, pandas as pd
from collections import defaultdict, Counter
from tqdm import tqdm

BASE = "openai/gpt-oss-20b"
MAX_SEQ_LEN = 4096

# ---------- Load base model in 4-bit ----------
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE,
    max_seq_length=MAX_SEQ_LEN,
    load_in_4bit=True,
    fast_inference=False,
    attn_implementation="eager",
    float8_kv_cache=True,
)
model.eval()

tokenizer.padding_side = "left"
tokenizer.pad_token    = tokenizer.eos_token

print("Ready. Base model only (no LoRA).")
print("Device:", next(model.parameters()).device, "| Max seq len:", MAX_SEQ_LEN)


In [None]:
##### OPTHALMOLOGY BEAMS #####

# Diverse-beam inference
import re, json, math, torch, random, numpy as np, pandas as pd
from collections import defaultdict, Counter
from tqdm import tqdm

_device = next(model.parameters()).device
if _device.type == "cuda":
    torch.cuda.set_device(_device.index or 0)

DATA_DIR = "/home/alhusain/scratch/ondevice-llm"
VAL_CSV = f"{DATA_DIR}/ophthalmology_mcq.csv"
OUT_CSV = f"{DATA_DIR}/final_results/oph_mcq_5beams.csv"

# Main hyperparams 
NUM_BEAMS = 5
DIV_PENALTY = 0.5
LENGTH_PENALTY = 1.0
NO_REPEAT_NGRAM_SIZE = 0
MAX_NEW_TOKENS = 3500  

df = pd.read_csv(VAL_CSV)
print(df.shape)
df = df.iloc[0:].copy()

# -------------------------
# Prompt (letters-only)
# -------------------------
SHORT_INSTR = (
    "You are an ophthalmology multiple-choice assistant.\n"
    "Output ONLY the correct answer letters (A–K), concatenated without spaces or punctuation.\n"
    "Examples: A, BD, ACE.\n"
    "Do not output any words, explanations, or reasoning.\n"
    "Your response MUST be letters only."
)

def make_msgs_from_row(row):
    post = str(row.get("Question", "")).strip()
    user = f"{SHORT_INSTR}\n\nQuestion: {post}\n"
    return [
        {"role": "system", "content": "You are an expert ophthalmology assistant."},
        {"role": "user",   "content": user},
    ]

# -------------------------
# Parsers (letters only)
# -------------------------
ASSISTANTFINAL_TAG = re.compile(r"assistant\s*final", re.I) 

def _normalize_choice_letters_only(text: str) -> str:
    s = (text or "").upper()
    # Allow A–K; dedupe & keep canonical A..K order
    letters = re.findall(r"[A-K]", s)
    if not letters:
        return ""
    ordered = [ch for ch in "ABCDEFGHIJK" if ch in letters]
    seen, cleaned = set(), []
    for ch in ordered:
        if ch not in seen:
            seen.add(ch)
            cleaned.append(ch)
    return "".join(cleaned)

def parse_after_assistantfinal(text: str) -> str:
    if not text:
        return ""
    m = ASSISTANTFINAL_TAG.search(text)
    if not m:
        # No assistantfinal marker -> failure
        return ""
    tail = text[m.end():]  # everything after 'assistantfinal'
    tail = re.sub(r"^[\s:>\]\)\-–—_]*", "", tail)  # strip separators
    m2 = re.search(r"([A-Ka-k]{1,10})", tail)
    if not m2:
        # Marker found but no A–K letters immediately after -> failure
        return ""
    return _normalize_choice_letters_only(m2.group(1))


# -------------------------
# Generation 
# -------------------------
def gen_cot_beams(msgs, effort="medium", max_new_tokens=8,
                  beams=3, diversity_penalty=0.5, length_penalty=1.0, no_repeat_ngram_size=0):
    # Encode once
    enc = tokenizer.apply_chat_template(
        msgs,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,
        # reasoning_effort=effort, 
    )
    enc = {k: v.to(model.device) for k, v in enc.items()}

    with torch.inference_mode():
        out = model.generate(
            **enc,
            do_sample=False,
            num_beams=beams,
            num_beam_groups=beams,
            num_return_sequences=beams,
            diversity_penalty=diversity_penalty,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=False,
            output_scores=False,
        )

    prompt_len = enc["input_ids"].shape[1]
    gens, cont_ids_list = [], []
    for i in range(out.size(0)):
        cont_ids = out[i, prompt_len:]
        cont_ids_list.append(cont_ids.unsqueeze(0))
        gens.append(tokenizer.decode(cont_ids, skip_special_tokens=True).strip())
    return gens, enc, cont_ids_list

# Seeds & eval
model.eval()
#torch.manual_seed(42); random.seed(42); np.random.seed(42)
try:
    model.generation_config.trust_remote_code = True
except Exception:
    pass

# Run inference
final_answers = []
beam_labels_col, all_beams_text = [], []

for _, row in tqdm(df.iterrows(), total=len(df)):
    msgs = make_msgs_from_row(row)

    gens, _, _ = gen_cot_beams(
        msgs,
        max_new_tokens=MAX_NEW_TOKENS,
        beams=NUM_BEAMS,
        diversity_penalty=DIV_PENALTY,
        length_penalty=LENGTH_PENALTY,
        no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
    )

    # Parse labels (letters only)
    per_beam_labels = [parse_after_assistantfinal(g) for g in gens]

    # -----------------------------
    # PICK THE MOST FREQUENT LABEL
    # Tie-breaker: earliest beam index among tied labels
    # -----------------------------
    counts = Counter([lab for lab in per_beam_labels if lab])
    if counts:
        max_count = max(counts.values())
        tied_labels = [lab for lab, c in counts.items() if c == max_count]

        if len(tied_labels) == 1:
            chosen_label = tied_labels[0]
        else:
            # map each tied label -> earliest index where it appears
            first_idx = {
                lab: next(i for i, lab_i in enumerate(per_beam_labels) if lab_i == lab)
                for lab in tied_labels
            }
            chosen_label = min(tied_labels, key=lambda lab: first_idx[lab])
    else:
        chosen_label = next((lab for lab in per_beam_labels if lab), "")

    
    final_answers.append(chosen_label)
    beam_labels_col.append(json.dumps(per_beam_labels, ensure_ascii=False))
    all_beams_text.append(" ||| ".join(gens))

# Save
df_out = df.copy()
df_out["final_answer"] = final_answers
df_out["beam_final_labels"] = beam_labels_col
df_out["raw_generations"] = all_beams_text

df_out.to_csv(OUT_CSV, index=False)
print("Saved:", OUT_CSV)
