In [1]:
#!/usr/bin/env python3
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM

# --- point this to your saved checkpoint directory ---


def pick_dtype():
    if torch.cuda.is_available():
        if torch.cuda.is_bf16_supported():
            return torch.bfloat16
        return torch.float16
    return torch.float32

def load_model_and_q_head(model_dir: str, device: str):
    """Load base HF model and attach q_head from sidecar q_head.pt if present."""
    dtype = pick_dtype() if device.startswith("cuda") else torch.float32

    tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        torch_dtype=dtype if device.startswith("cuda") else None,
        trust_remote_code=True,
    ).to(device).eval()

    # Try to attach q_head with correct shapes and dtype/device.
    q_head_path = os.path.join(model_dir, "q_head.pt")
    if os.path.exists(q_head_path):
        # infer hidden size from config (robustly)
        cfg = getattr(model, "config", None)
        hidden_size = (
            getattr(cfg, "hidden_size", None)
            or getattr(cfg, "n_embd", None)
            or getattr(cfg, "hidden_dim", None)
        )
        if hidden_size is None:
            raise RuntimeError("Could not infer hidden size for q_head from model.config")

        # infer vocab size from output embeddings
        out_emb = model.get_output_embeddings()
        if hasattr(out_emb, "num_embeddings"):
            vocab_size = int(out_emb.num_embeddings)
        elif hasattr(out_emb, "out_features"):
            vocab_size = int(out_emb.out_features)
        else:
            raise RuntimeError("Could not infer vocab size from model.get_output_embeddings()")

        # construct head on same device/dtype as model weights
        model_dtype = next(model.parameters()).dtype
        model_device = next(model.parameters()).device
        q_head = nn.Linear(hidden_size, vocab_size, bias=True).to(device=model_device, dtype=model_dtype)

        # load sidecar weights
        sd = torch.load(q_head_path, map_location="cpu")
        q_head.load_state_dict(sd, strict=True)
        q_head.to(device=model_device, dtype=model_dtype)

        # attach to model
        model.q_head = q_head
        print(f"[INFO] Attached q_head: hidden={hidden_size}, vocab={vocab_size}, dtype={model_dtype}, device={model_device}")
    else:
        print("[WARN] q_head.pt not found; will fall back to using logits as Q if you trained without --use_q_head")

    return tok, model

@torch.no_grad()
def q_values_for_sequence(model, tok, text: str, device: str):
    """
    Returns:
      q_logits: (1, T, V) tensor = predicted Q(s_t, a) for all tokens at each position
      q_taken:  (1, T-1) tensor = Q(s_t, a_t) for the actually-taken next token (labels)
      seq_value: scalar = predicted V(s_0) = Q(s_0, a_0) under your training target
    Notes:
      - We use next-token labeling: action at time t is the token at position t+1.
      - If no q_head is present, we interpret base logits as Q (matching your training code).
    """
    enc = tok(text, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    attn = enc["attention_mask"].to(device)

    out = model(input_ids=input_ids, attention_mask=attn,
                output_hidden_states=True, return_dict=True, use_cache=False)

    # If q_head exists, project last hidden; else reuse logits (as in your training code)
    if hasattr(model, "q_head") and isinstance(model.q_head, nn.Module):
        hidden_last = out.hidden_states[-1]            # (1, T, H)
        # cast hidden to the head's dtype to avoid bf16/fp16 matmul dtype errors
        if hidden_last.dtype != model.q_head.weight.dtype:
            hidden_last = hidden_last.to(model.q_head.weight.dtype)
        q_logits = model.q_head(hidden_last)           # (1, T, V)
    else:
        q_logits = out.logits                          # (1, T, V) treated as Q

    # Build labels (next-token)
    labels = input_ids.clone()
    labels[labels == tok.pad_token_id] = -100
    # Gather Q(s_t, a_t) for taken actions (shift by one for next-token)
    # actions at time t are labels[:, t+1]
    if q_logits.size(1) >= 2:
        taken = labels[:, 1:]  # (1, T-1)
        q_taken = torch.gather(q_logits[:, :-1, :], dim=-1, index=taken.unsqueeze(-1)).squeeze(-1)  # (1, T-1)
        seq_value = q_taken[0, 0].item()  # predicted V(s_0)
    else:
        # Single token edge case
        q_taken = torch.empty((1, 0), device=q_logits.device, dtype=q_logits.dtype)
        seq_value = float("nan")

    return q_logits, q_taken, seq_value

def pretty_topk(q_logits, tok, t: int, k: int = 10):
    """Utility: show top-k actions by Q at position t."""
    logits_t = q_logits[0, t]  # (V,)
    topv, topi = torch.topk(logits_t, k)
    toks = tok.convert_ids_to_tokens(topi.tolist())
    return list(zip(toks, topv.tolist()))

# if __name__ == "__main__":
# MODEL_DIR = "/nfs/shuozhe/clean_pretrain/checkpoints/Qwen2.5-0.5B-TinyStories-q_reg_0.7/model_3000"
MODEL_DIR = "/nfs/shuozhe/clean_pretrain/checkpoints/Qwen2.5-0.5B-TinyStories-q_reg_ce_head_only/model_2000"
device = "cuda" if torch.cuda.is_available() else "cpu"
tok, model = load_model_and_q_head(MODEL_DIR, device)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at /nfs/shuozhe/clean_pretrain/checkpoints/Qwen2.5-0.5B-TinyStories-q_reg_ce_head_only/model_2000 were not used when initializing Qwen2ForCausalLM: {'q_head.weight', 'q_head.bias'}
- This IS expected if you are initializing Qwen2ForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2ForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  sd = torch.load(q_head_path, map_location="cpu")


[INFO] Attached q_head: hidden=896, vocab=151665, dtype=torch.bfloat16, device=cuda:0


In [18]:
# Example text to score
text = "Once upon a time, there was a girl"
# text = "Once upon a time, there was a bot"

q_logits, q_taken, seq_value = q_values_for_sequence(model, tok, text, device)
print(f"Sequence: {text!r}")
print(f"Predicted V(s_0) (i.e., Q(s_0, a_0)): {seq_value:.4f}")

# Inspect a specific position's top-Q actions
t = 1
top = pretty_topk(q_logits, tok, t=t, k=100)
print(f"\nTop-10 Q(s_{t}, a) suggestions:")
for tok_str, val in top:
    print(f"{tok_str:>12s} : {val:.4f}")

# Also print Q for the actually taken actions (labels)
if q_taken.numel() > 0:
    print("\nQ(s_t, a_t) along the trajectory:")
    vals = q_taken[0].tolist()
    for i, v in enumerate(vals):
        # print index and value and token
        print(f"t={i:2d} : {v:.4f} (token={tok.convert_ids_to_tokens([tok.encode(text, add_special_tokens=False)[i+1]])[0]!r})")
        

Sequence: 'Once upon a time, there was a girl'
Predicted V(s_0) (i.e., Q(s_0, a_0)): 0.2051

Top-10 Q(s_1, a) suggestions:
           . : 2.4688
           , : 1.6328
        Ġand : 1.5391
        Ġthe : 1.4844
          Ġa : 1.3359
         Ġto : 1.2031
        Ġwas : 1.0156
         .ĊĊ : 0.4805
         Ġit : 0.4629
        Ġher : 0.4160
        ĠShe : 0.4082
         ĠHe : 0.4023
         Ġin : 0.3691
         Ġhe : 0.3242
          Ġ" : 0.3105
       Ġsaid : 0.3086
        Ġshe : 0.2988
       Ġwith : 0.2773
        Ġhis : 0.2773
        Ġday : 0.2754
       ĠThey : 0.2695
        Once : 0.2598
      Ġthere : 0.2480
       Ġtime : 0.2383
         Ġso : 0.2334
         Ġof : 0.2275
        Ġhad : 0.2275
       ĠLily : 0.2275
       Ġupon : 0.2158
       Ġthat : 0.2139
     Ġlittle : 0.2129
          's : 0.2109
       Ġthey : 0.2002
         Ġon : 0.1963
       Ġvery : 0.1963
        ĠThe : 0.1904
        Ġsaw : 0.1904
        Ġbig : 0.1768
       Ġplay : 0.1699
         ĠIt : 0.16