In [None]:
!pip -q install transformers accelerate bitsandbytes faiss-cpu datasets wikipedia tiktoken


  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for wikipedia (setup.py) ... [?25l[?25hdone


In [None]:
import json, wikipedia, random
from tqdm import tqdm

random.seed(0)
TITLES = [
  "Kensington Runestone","Python (programming language)","Llama","Thailand","Machine learning",
  "Natural language processing","Information retrieval","Hugging Face","Google Colab","Alexandria, Minnesota",
  "US President"
]  # seed list; add more or sample via wikipedia.random()

PAGES = 300  # keep small for Colab
seen=set(TITLES)
while len(TITLES)<PAGES:
    try:
        t = wikipedia.random(1)
        if t not in seen:
            TITLES.append(t); seen.add(t)
    except: pass

with open("wiki_intro.jsonl","w") as f:
    for t in tqdm(TITLES):
        try:
            p = wikipedia.page(t, auto_suggest=False)
            rec = {"id": p.pageid, "title": p.title, "text": p.summary}
            f.write(json.dumps(rec, ensure_ascii=False)+"\n")
        except: pass




  lis = BeautifulSoup(html).find_all('li')
100%|██████████| 300/300 [02:48<00:00,  1.78it/s]


In [None]:
import faiss, numpy as np, json, torch
from transformers import AutoTokenizer, AutoModel

tok = AutoTokenizer.from_pretrained("facebook/contriever-msmarco")
enc = AutoModel.from_pretrained("facebook/contriever-msmarco", torch_dtype=torch.float16, device_map="auto")

def encode(texts, bs=32):
    embs=[]
    for i in range(0,len(texts),bs):
        batch=texts[i:i+bs]
        t = tok(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
        t = {k:v.to(enc.device) for k,v in t.items()}
        with torch.no_grad():
            out = enc(**t).last_hidden_state
            m = (t["attention_mask"].unsqueeze(-1)*out).sum(1)/t["attention_mask"].sum(1,keepdim=True)
        embs.append(m.detach().float().cpu().numpy())
    return np.vstack(embs).astype("float32")

# Load corpus
docs=[]
with open("wiki_intro.jsonl") as f:
    for line in f:
        d=json.loads(line)
        docs.append(d)

# Embed and build FAISS (IVF,PQ to reduce RAM; 768 dims for Contriever)
xb = encode([d["title"]+"\n"+d["text"] for d in docs])
d = xb.shape[1]
nlist = min(64, max(8, len(xb)//50))  # coarse cells
m = 64  # PQ subvectors
quantizer = faiss.IndexFlatIP(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)  # 8-bit codes
faiss.normalize_L2(xb)
index.train(xb)
index.add(xb)
faiss.write_index(index, "contriever_ivfpq.index")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [None]:
index = faiss.read_index("contriever_ivfpq.index")
def search(query, k=5):
    q = encode([query])
    faiss.normalize_L2(q)
    D,I = index.search(q, k)
    hits=[]
    for idx in I[0]:
        d = docs[idx]
        hits.append({"title": d["title"], "text": d["text"]})
    return hits


In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch, textwrap

bnb = BitsAndBytesConfig(load_in_4bit=True,
                         bnb_4bit_use_double_quant=True,
                         bnb_4bit_quant_type="nf4",
                         bnb_4bit_compute_dtype=torch.bfloat16)

gen_tok = AutoTokenizer.from_pretrained("shayekh/openrag_llama2_7b_8x135m", use_fast=True)
gen = AutoModelForCausalLM.from_pretrained(
    "shayekh/openrag_llama2_7b_8x135m",
    device_map="auto",
    #quantization_config=bnb,
    trust_remote_code=True,
    dtype="float16"
)

def format_retrieval_block(hits):
    kn=[]
    for i,h in enumerate(hits,1):
        kn.append(f"Knowledge {i}: {h['title']}\n{h['text']}")
    return "[Retrieval]<paragraph>\n" + "\n[SEP]\n".join(kn) + "\n</paragraph>"




tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/433 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/926 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/897 [00:00<?, ?B/s]

modeling_openrag.py: 0.00B [00:00, ?B/s]

configuration_openrag.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/shayekh/openrag_llama2_7b_8x135m:
- configuration_openrag.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/shayekh/openrag_llama2_7b_8x135m:
- modeling_openrag.py
- configuration_openrag.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/734M [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]



In [None]:
def ask(query, k=5, max_new_tokens=200):
    hits = search(query, k)
    retrieval_blk = format_retrieval_block(hits)

    print(retrieval_blk)

    instr = 'You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.'
    prompt = f'### Instruction:\n"{instr}\n## Instruction:\n{query}\n{retrieval_blk}"'

    print(prompt)

    ipt = gen_tok(prompt, return_tensors="pt").to(gen.device)
    out = gen.generate(**ipt, do_sample=False, max_new_tokens=max_new_tokens, use_cache=False)
    ans = gen_tok.decode(out[0, ipt.input_ids.shape[1]:], skip_special_tokens=False)
    return ans, hits

In [None]:
q = "Which museum in Alexandria, Minnesota displays the Kensington Runestone?"
answer, hits = ask(q, k=2)
print(answer)

[Retrieval]<paragraph>
Knowledge 1: Kensington Runestone
The Kensington Runestone is a slab of greywacke stone covered in runes that was discovered in Western Minnesota, United States, in 1898. Olof Ohman, a Swedish immigrant, reported that he unearthed it from a field in the largely rural township of Solem in Douglas County. It was later named after the nearest settlement, Kensington.
The inscription purports to be a record left behind by Scandinavian explorers in the 14th century (internally dated to 1362). There has been a drawn-out debate regarding the stone's authenticity, but since the first scientific examination in 1910, the scholarly consensus has classified it as a 19th-century hoax, with some critics directly charging Ohman with fabrication. Nevertheless, there remains a community convinced of the stone's authenticity. The city of Kensington, Minnesota's website claims that the stone is genuine, that there were blue-eyed Blonde Mandan, and that Nicholas of Lynn, who was not 

In [None]:
q = "Who is the current (47th) president of the United States?"
answer, hits = ask(q, k=2)
print(answer)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[Retrieval]<paragraph>
Knowledge 1: President of the United States
The president of the United States (POTUS) is the head of state and head of government of the United States. The president directs the executive branch of the federal government and is the commander-in-chief of the United States Armed Forces.
The power of the presidency has grown since the first president, George Washington, took office in 1789. While presidential power has ebbed and flowed over time, the presidency has played an increasing role in American political life since the beginning of the 20th century, carrying over into the 21st century with some expansions during the presidencies of Franklin D. Roosevelt and George W. Bush. In the 21st century, the president is one of the world's most powerful political figures and the leader of the world's only remaining superpower. As the leader of the nation with the largest economy by nominal GDP, the president possesses significant domestic and international hard and so

In [None]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def iterative_generate(
    model,
    tokenizer,
    input_ids,
    attention_mask=None,
    max_new_tokens=200,
    eos_token_id=None,
    pad_token_id=None,
    temperature=0.0,        # 0 = greedy; >0 = sampling
    top_p=1.0,
    repetition_penalty=1.0,
    no_repeat_ngram_size=0, # 0 disables the constraint
    return_all=True,        # return full sequences; False returns only new tokens
):
    """
    Minimal custom decoding loop that explicitly uses and reuses past_key_values.
    Works for decoder-only causal LMs (e.g., LLaMA). For encoder-decoder, adapt as noted below.
    """
    device = next(model.parameters()).device
    input_ids = input_ids.to(device)
    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids, device=device)

    if eos_token_id is None:
        eos_token_id = getattr(tokenizer, "eos_token_id", None)
    if pad_token_id is None:
        pad_token_id = getattr(tokenizer, "pad_token_id", eos_token_id)

    model.eval()
    past = None
    B = input_ids.size(0)
    done = torch.zeros(B, dtype=torch.bool, device=device)

    # Utility: forbid generating n-gram already seen at the end (no_repeat_ngram_size)
    def apply_no_repeat_ngram(logits, generated, n):
        if n <= 0 or generated.size(1) < n:
            return logits
        bsz, vocab = logits.size()
        for b in range(bsz):
            if done[b]:
                continue
            # collect all n-grams in the prefix (sliding window)
            prefix = generated[b].tolist()
            ngrams = {}
            for i in range(len(prefix) - n + 1):
                key = tuple(prefix[i:i+n-1])
                nxt = prefix[i+n-1]
                ngrams.setdefault(tuple(key), set()).add(nxt)
            # last (n-1) tokens determine which next tokens to ban
            last_ngram = tuple(prefix[-(n-1):]) if n > 1 else ()
            banned = ngrams.get(last_ngram, set())
            if banned:
                logits[b, list(banned)] = -float("inf")
        return logits

    # Utility: repetition penalty (CTRL, HF-style)
    def apply_repetition_penalty(logits, generated, penalty):
        if penalty == 1.0:
            return logits
        for b in range(logits.size(0)):
            if done[b]:
                continue
            prev_tokens = torch.unique(generated[b])
            # Decrease prob for tokens we've already used
            logits[b, prev_tokens] /= penalty
        return logits

    # Greedy vs nucleus sampling
    def sample_next_token(logits):
        if temperature <= 0.0:
            return torch.argmax(logits, dim=-1)
        # temperature + top-p nucleus
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)
        if top_p < 1.0:
            # sort probs
            sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
            cumprobs = torch.cumsum(sorted_probs, dim=-1)
            cutoff = (cumprobs > top_p).float()
            # keep first token beyond cutoff as well
            cutoff[..., 1:] = torch.maximum(cutoff[..., :-1], cutoff[..., 1:])
            sorted_probs = sorted_probs * (1 - cutoff)
            # renormalize
            sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
            # sample in sorted space and map back
            next_sorted = torch.multinomial(sorted_probs, num_samples=1).squeeze(-1)
            next_token = sorted_idx.gather(-1, next_sorted.unsqueeze(-1)).squeeze(-1)
            return next_token
        else:
            return torch.multinomial(probs, num_samples=1).squeeze(-1)

    # Storage for building sequences
    generated = input_ids
    new_tokens = []

    # Step 0: first forward over full prompt to build cache
    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        use_cache=True,                 # force-return of past_key_values
        past_key_values=None
    )
    past = out.past_key_values
    assert past is not None, "Model did not return past_key_values; check config.use_cache or architecture."

    for _ in range(max_new_tokens):
        # Get logits for next-token distribution
        next_logits = out.logits[:, -1, :]  # (B, V)

        # Constraints
        next_logits = apply_repetition_penalty(next_logits, generated, repetition_penalty)
        next_logits = apply_no_repeat_ngram(next_logits, generated, no_repeat_ngram_size)

        # Pick token
        next_token = sample_next_token(next_logits)

        # Handle EOS/padding
        if eos_token_id is not None:
            next_token = torch.where(done, torch.tensor(pad_token_id, device=device), next_token)
        new_tokens.append(next_token.unsqueeze(-1))
        generated = torch.cat([generated, next_token.unsqueeze(-1)], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones((B,1), dtype=attention_mask.dtype, device=device)], dim=1)

        if eos_token_id is not None:
            done = done | (next_token == eos_token_id)
            if torch.all(done):
                break

        # Next step: feed only last token + cache
        out = model(
            input_ids=next_token.unsqueeze(-1),
            attention_mask=attention_mask,  # grows by 1 each step
            use_cache=True,
            past_key_values=past
        )
        past = out.past_key_values
        if past is None:
            raise RuntimeError("past_key_values unexpectedly None mid-generation")

    new_tokens = torch.cat(new_tokens, dim=1) if new_tokens else torch.empty((B,0), dtype=input_ids.dtype, device=device)
    return (generated if return_all else new_tokens)


In [None]:
def ask(query, k=3, max_new_tokens=200, do_sample=False, temperature=0.0, top_p=1.0):
    hits = search(query, k)
    retrieval_blk = format_retrieval_block(hits)
    print(retrieval_blk)

    instr = 'You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.'
    prompt = f'### Instruction:\n"{instr}\n## Instruction:\n{query}\n{retrieval_blk}"'

    ipt = gen_tok(prompt, return_tensors="pt")
    ipt = {k: v.to(gen.device) for k, v in ipt.items()}

    # Make sure cache is enabled at the config level
    if hasattr(gen.config, "use_cache"):
        gen.config.use_cache = True

    out_ids = iterative_generate(
        model=gen,
        tokenizer=gen_tok,
        input_ids=ipt["input_ids"],
        attention_mask=ipt.get("attention_mask"),
        max_new_tokens=max_new_tokens,
        eos_token_id=getattr(gen_tok, "eos_token_id", None),
        pad_token_id=getattr(gen_tok, "pad_token_id", getattr(gen_tok, "eos_token_id", None)),
        temperature=(temperature if do_sample else 0.0),
        top_p=top_p,
        repetition_penalty=1.1,        # tweakable
        no_repeat_ngram_size=3,        # tweakable
        return_all=True,
    )

    # Strip the prompt to get only the completion
    start = ipt["input_ids"].shape[1]
    completion_ids = out_ids[:, start:]
    ans = gen_tok.decode(completion_ids[0], skip_special_tokens=False)
    return ans, hits


In [None]:
INSTR_PREAMBLE = (
    '### Instruction:\n'
    '"You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.\n'
    '## Instruction:\n'
)

# --- Curated tests (easy → hard). Uses your retrieval block schema. ---
tests = [
    {
        "name": "Corliss Archer – minimal (easy)",
        "question": "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?",
        "retrieval": """
[Retrieval]<paragraph>
Knowledge 1: Kiss and Tell (1945 film)
Kiss and Tell is a 1945 American comedy film starring 17-year-old Shirley Temple as Corliss Archer.

[SEP]
Knowledge 2: Shirley Temple Black
As an adult, Shirley Temple Black served the United States as a diplomat, including as Chief of Protocol of the United States, and as U.S. ambassador to Ghana and to Czechoslovakia.
</paragraph>""",
        "must_contain": ["Chief of Protocol"],
    },
    {
        "name": "Corliss Archer – with confounder (medium)",
        "question": "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?",
        "retrieval": """
[Retrieval]<paragraph>
Knowledge 1: Kiss and Tell (1945 film)
Kiss and Tell is a 1945 American comedy film in which Shirley Temple portrays Corliss Archer.

[SEP]
Knowledge 2: Shirley Temple Black
Shirley Temple Black later became a U.S. diplomat and served as Chief of Protocol of the United States.

[SEP]
Knowledge 3: Janet Waldo
Janet Waldo voiced Corliss Archer on radio and appeared in the TV adaptation Meet Corliss Archer; she was a radio/voice actress, not a U.S. government official.
</paragraph>""",
        "must_contain": ["Chief of Protocol"],
    },
    {
        "name": "Corliss Archer – adversarial (hard)",
        "question": "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?",
        "retrieval": """
[Retrieval]<paragraph>
Knowledge 1: Meet Corliss Archer (radio/TV)
The character Corliss Archer appeared in radio and TV series. Janet Waldo portrayed Corliss Archer on radio.

[SEP]
Knowledge 2: Janet Waldo
Janet Waldo was a radio and voice actress; she did not hold U.S. government office.

[SEP]
Knowledge 3: Kiss and Tell (1945 film)
The 1945 film Kiss and Tell stars Shirley Temple as Corliss Archer.

[SEP]
Knowledge 4: Shirley Temple Black
As an adult, Shirley Temple Black served in several diplomatic roles, including Chief of Protocol of the United States, and U.S. ambassador to Ghana and to Czechoslovakia.
</paragraph>""",
        "must_contain": ["Chief of Protocol"],
    },
]

def build_prompt(question: str, retrieval_block: str) -> str:
    return INSTR_PREAMBLE + question.strip() + "\n" + textwrap.dedent(retrieval_block).strip() + '"'

tokenizer = gen_tok
model = gen
def run_test(test, max_new_tokens=128):
    prompt = build_prompt(test["question"], test["retrieval"])
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(model.device)
    attention_mask = enc.attention_mask.to(model.device)

    # -----------------------------
    # >>> Use iterative_generate <<<
    # -----------------------------
    new_tokens = iterative_generate(
        model=model,
        tokenizer=tokenizer,
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        temperature=0.0,          # greedy (set >0.0 to sample)
        top_p=1.0,
        repetition_penalty=1.0,
        no_repeat_ngram_size=0,
        return_all=False,         # only return the newly generated continuation
    )

    gen = tokenizer.decode(new_tokens[0], skip_special_tokens=True).strip()
    ok = all(needle.lower() in gen.lower() for needle in test["must_contain"])
    return prompt, gen, ok



In [None]:
sep = "-" * 80
for t in tests:
    prompt, ans, ok = run_test(t)
    print(sep)
    print(f"[Test] {t['name']}")
    print(sep)
    print("Prompt:")
    print(prompt)
    print(sep)
    print("Model answer:")
    print(ans)
    print(sep)
    print("PASS" if ok else "FAIL", "| must contain:", t["must_contain"])
    print()

--------------------------------------------------------------------------------
[Test] Corliss Archer – minimal (easy)
--------------------------------------------------------------------------------
Prompt:
### Instruction:
"You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.
## Instruction:
What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?
[Retrieval]<paragraph>
Knowledge 1: Kiss and Tell (1945 film)
Kiss and Tell is a 1945 American comedy film starring 17-year-old Shirley Temple as Corliss Archer.

[SEP]
Knowledge 2: Shirley Temple Black
As an adult, Shirley Temple Black served the United States as a diplomat, including as Chief of Protocol of the United States, and as U.S. ambassador to Ghana and to Czechoslovakia.
</paragraph>"
--------------------------------------------------------------------------------
Model answer:
Answer: United States ambassador
-

In [None]:
# Beam search with explicit past_key_values (decoder-only, batch=1)
# pip install -U transformers accelerate bitsandbytes

import torch, torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------------
# Utility helpers
# -----------------------------
def _len_penalize(score, length, length_penalty):
    # Transformer-style length penalty from Wu et al. (2016)
    if length_penalty == 1.0:
        return score / max(length, 1)
    denom = ((5.0 + length) / 6.0) ** length_penalty
    return score / denom

def _no_repeat_ngram_mask(logits, sequences, n, done_mask):
    if n <= 0 or sequences.size(1) < n:
        return logits
    bsz, vocab = logits.size()
    for b in range(bsz):
        if done_mask[b]:
            continue
        prefix = sequences[b].tolist()
        ngrams = {}
        for i in range(len(prefix) - n + 1):
            key = tuple(prefix[i : i + n - 1])
            nxt = prefix[i + n - 1]
            ngrams.setdefault(key, set()).add(nxt)
        last = tuple(prefix[-(n - 1):]) if n > 1 else ()
        banned = ngrams.get(last, set())
        if banned:
            logits[b, list(banned)] = -float("inf")
    return logits

def _apply_repetition_penalty(logits, sequences, penalty, done_mask):
    if penalty == 1.0:
        return logits
    for b in range(logits.size(0)):
        if done_mask[b]:
            continue
        prev_tokens = torch.unique(sequences[b])
        logits[b, prev_tokens] /= penalty
    return logits

def _gather_past(past, beam_indices):
    # past: tuple per layer -> (k, v) each shape (B, nH, T, d)
    # beam_indices: (num_beams,) Long
    new_past = []
    for k, v in past:
        k = k.index_select(0, beam_indices)
        v = v.index_select(0, beam_indices)
        new_past.append((k, v))
    return tuple(new_past)

def _expand_past_for_beams(past, num_beams):
    # Duplicate batch=1 to batch=num_beams
    new = []
    for k, v in past:
        k = k.repeat_interleave(num_beams, dim=0)
        v = v.repeat_interleave(num_beams, dim=0)
        new.append((k, v))
    return tuple(new)

# -----------------------------
# Beam search (KV-cache aware)
# -----------------------------
@torch.no_grad()
def beam_search_iterative(
    model,
    tokenizer,
    input_ids,
    attention_mask=None,
    num_beams=5,
    num_return_sequences=1,
    max_new_tokens=128,
    eos_token_id=None,
    pad_token_id=None,
    length_penalty=1.0,
    no_repeat_ngram_size=0,
    repetition_penalty=1.0,
    early_stopping=True,
):
    """
    Decoder-only beam search with explicit past_key_values.
    Assumes batch_size == 1. Returns (sequences[List[Tensor]], scores[List[float]]).
    """
    assert input_ids.size(0) == 1, "This reference implementation assumes batch_size=1."
    device = next(model.parameters()).device
    input_ids = input_ids.to(device)
    if attention_mask is None:
        attention_mask = torch.ones_like(input_ids, device=device)

    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id
    if pad_token_id is None:
        pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id

    model.eval()

    # Initial forward to build cache
    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        use_cache=True,
        past_key_values=None
    )
    past = out.past_key_values
    if past is None:
        raise RuntimeError("Model did not return past_key_values; ensure config.use_cache=True.")

    # Expand to beams
    past = _expand_past_for_beams(past, num_beams)
    # Duplicate attention mask for beams
    attn = attention_mask.repeat(num_beams, 1)
    # Start sequences buffer with the prompt tokens per beam
    # (We store only generated tokens separately; sequences view includes prompt for convenience)
    sequences = input_ids.repeat(num_beams, 1)
    # Beam scores (log-probs)
    beam_scores = torch.zeros(num_beams, device=device)

    # Track finished hypotheses
    done = torch.zeros(num_beams, dtype=torch.bool, device=device)
    finalized = []  # list of (score, seq_tensor)

    # First next-token distribution comes from the initial out (but only once for beam search):
    logits = out.logits[:, -1, :]  # shape (1, V)
    logits = logits.repeat(num_beams, 1)  # replicate for beams

    step = 0
    while step < max_new_tokens:
        # Constraints
        logits = _apply_repetition_penalty(logits, sequences, repetition_penalty, done)
        logits = _no_repeat_ngram_mask(logits, sequences, no_repeat_ngram_size, done)

        logprobs = F.log_softmax(logits, dim=-1)  # (num_beams, V)

        # If a beam is done, force it to stay on pad
        if eos_token_id is not None:
            logprobs = torch.where(
                done.unsqueeze(-1),
                torch.full_like(logprobs, float("-inf")),
                logprobs
            )
            # But allow pad for done beams to keep shape consistent
            logprobs[done, pad_token_id] = 0.0

        # Combine beam scores with candidate next tokens
        vocab_size = logprobs.size(-1)
        candidate_scores = (beam_scores.unsqueeze(1) + logprobs)  # (B, V)
        candidate_scores = candidate_scores.view(-1)              # (B*V,)

        # Select top num_beams
        topk_scores, topk_ids = torch.topk(candidate_scores, k=num_beams, dim=0)

        next_beam_indices = topk_ids // vocab_size     # which beam we came from
        next_tokens = topk_ids % vocab_size            # next token chosen

        # Reorder past by selected beams
        past = _gather_past(past, next_beam_indices)

        # Update sequences (append next token)
        sequences = sequences.index_select(0, next_beam_indices)
        sequences = torch.cat([sequences, next_tokens.unsqueeze(-1)], dim=-1)

        # Update attention (grow by 1; identical for all beams)
        attn = attn.index_select(0, next_beam_indices)
        attn = torch.cat([attn, torch.ones((num_beams, 1), dtype=attn.dtype, device=device)], dim=1)

        # Update beam scores
        beam_scores = topk_scores

        # EOS handling
        if eos_token_id is not None:
            just_finished = (next_tokens == eos_token_id) & (~done)
            if just_finished.any():
                # Move finished beams to finalized with length-penalized score
                for i in torch.nonzero(just_finished, as_tuple=False).squeeze(-1).tolist():
                    seq = sequences[i].clone()
                    score = _len_penalize(beam_scores[i], length=seq.size(0) - input_ids.size(1), length_penalty=length_penalty)
                    finalized.append((score.item(), seq))
                done = done | just_finished

        # Early stop if all beams finished
        if early_stopping and torch.all(done):
            break

        # Next step forward with only the new token per beam
        out = model(
            input_ids=next_tokens.unsqueeze(-1),  # (num_beams, 1)
            attention_mask=attn,
            use_cache=True,
            past_key_values=past
        )
        past = out.past_key_values
        logits = out.logits[:, -1, :]  # (num_beams, V)
        step += 1

    # If nothing finalized, take current beams as candidates
    if not finalized:
        for i in range(num_beams):
            seq = sequences[i]
            score = _len_penalize(beam_scores[i], length=seq.size(0) - input_ids.size(1), length_penalty=length_penalty)
            finalized.append((score.item(), seq))

    # Sort by score desc and return top-k
    finalized.sort(key=lambda x: x[0], reverse=True)
    finalized = finalized[:num_return_sequences]
    return finalized


In [None]:

tok = tokenizer
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token

prompt = (
    "### Instruction:\n"
    "\"You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.\n"
    "## Instruction:\n"
    "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?\n"
    "[Retrieval]<paragraph>\n"

    "Knowledge 1: Shirley Temple\nShirley Temple Black (April 23, 1928 – February 10, 2014) was an American actress, singer, dancer, businesswoman, and diplomat who was Hollywood's number one box-office draw as a child actress from 1935 to 1938. As an adult, she was named United States ambassador to Ghana and to Czechoslovakia and also served as Chief of Protocol of the United States."

    [SEP]

    Knowledge 2: Kiss and Tell (1945 film)\nKiss and Tell is a 1945 American comedy film starring then 17-year-old Shirley Temple as Corliss Archer. In the film, two teenage girls cause their respective parents much concern when they start to become interested in boys. The parents' bickering about which girl is the worse influence causes more problems than it solves."

    "</paragraph>\""
)

enc = tok(prompt, return_tensors="pt")
input_ids = enc.input_ids.to(model.device)
attention_mask = enc.attention_mask.to(model.device)

beams = beam_search_iterative(
    model=model,
    tokenizer=tok,
    input_ids=input_ids,
    attention_mask=attention_mask,
    num_beams=6,
    num_return_sequences=6,
    max_new_tokens=96,
    length_penalty=1.0,
    no_repeat_ngram_size=3,
    repetition_penalty=1.0,
    early_stopping=True,
)

for i, (score, seq) in enumerate(beams, 1):
    text = tok.decode(seq, skip_special_tokens=True)
    print(f"\n--- Beam {i} | score={score:.3f} ---\n{text}\n")



--- Beam 1 | score=-0.392 ---
### Instruction:
"You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.
## Instruction:
What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?

Knowledge 1: Kiss and Tell (1945 film)
Kiss and Tell is a 1945 American comedy film starring 17-year-old Shirley Temple as Corliss Archer.

[SEP]
Knowledge 2: Shirley Temple Black
As an adult, Shirley Temple Black served as Chief of Protocol of the United States, and as U.S. ambassador to Ghana and to Czechoslovakia.
"
Question: What government position did Shirly Temple Black serve in?
Answer: Chief of protocol of the US


--- Beam 2 | score=-0.392 ---
### Instruction:
"You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.
## Instruction:
What government position was held by the woman who portrayed Corliss Archer in the film

In [None]:

# -----------------------------
# Your requested tests schema
# -----------------------------
tests = [
    {
        "name": "Corliss Archer – minimal (easy)",
        "question": "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?",
        "retrieval": """
[Retrieval]<paragraph>
Knowledge 1: Kiss and Tell (1945 film)
Kiss and Tell is a 1945 American comedy film starring 17-year-old Shirley Temple as Corliss Archer.

[SEP]
Knowledge 2: Shirley Temple Black
As an adult, Shirley Temple Black served the United States as a diplomat, including as Chief of Protocol of the United States, and as U.S. ambassador to Ghana and to Czechoslovakia.
</paragraph>""",
        "must_contain": ["Chief of Protocol"],
    },
    {
        "name": "Corliss Archer – with confounder (medium)",
        "question": "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?",
        "retrieval": """
[Retrieval]<paragraph>
Knowledge 1: Kiss and Tell (1945 film)
Kiss and Tell is a 1945 American comedy film in which Shirley Temple portrays Corliss Archer.

[SEP]
Knowledge 2: Shirley Temple Black
Shirley Temple Black later became a U.S. diplomat and served as Chief of Protocol of the United States.

[SEP]
Knowledge 3: Janet Waldo
Janet Waldo voiced Corliss Archer on radio and appeared in the TV adaptation Meet Corliss Archer; she was a radio/voice actress, not a U.S. government official.
</paragraph>""",
        "must_contain": ["Chief of Protocol"],
    },
    {
        "name": "Corliss Archer – adversarial (hard)",
        "question": "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?",
        "retrieval": """
[Retrieval]<paragraph>
Knowledge 1: Meet Corliss Archer (radio/TV)
The character Corliss Archer appeared in radio and TV series. Janet Waldo portrayed Corliss Archer on radio.

[SEP]
Knowledge 2: Janet Waldo
Janet Waldo was a radio and voice actress; she did not hold U.S. government office.

[SEP]
Knowledge 3: Kiss and Tell (1945 film)
The 1945 film Kiss and Tell stars Shirley Temple as Corliss Archer.

[SEP]
Knowledge 4: Shirley Temple Black
As an adult, Shirley Temple Black served in several diplomatic roles, including Chief of Protocol of the United States, and U.S. ambassador to Ghana and to Czechoslovakia.
</paragraph>""",
        "must_contain": ["Chief of Protocol"],
    },
]

# -----------------------------
# Prompt builder + test runner
# -----------------------------
INSTR_PREAMBLE = (
    '### Instruction:\n'
    '"You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.\n'
    '## Instruction:\n'
)

def build_prompt(question: str, retrieval_block: str) -> str:
    return INSTR_PREAMBLE + question.strip() + "\n" + textwrap.dedent(retrieval_block).strip() + '"'

def run_test(test, num_beams=6, num_return_sequences=6, max_new_tokens=96):
    prompt = build_prompt(test["question"], test["retrieval"])
    enc = tok(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(model.device)
    attention_mask = enc.attention_mask.to(model.device)

    beams = beam_search_iterative(
        model=model,
        tokenizer=tok,
        input_ids=input_ids,
        attention_mask=attention_mask,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        max_new_tokens=max_new_tokens,
        length_penalty=1.0,
        no_repeat_ngram_size=3,
        repetition_penalty=1.0,
        early_stopping=True,
    )

    decoded = [(score, tok.decode(seq, skip_special_tokens=True)) for score, seq in beams]
    # Check pass/fail against must_contain on the BEST beam only
    best_text = decoded[0][1]
    ok = all(needle.lower() in best_text.lower() for needle in test["must_contain"])
    return prompt, decoded, ok

# -----------------------------
# Main
# -----------------------------
sep = "-" * 80
for t in tests:
    prompt, beams, ok = run_test(t)
    print(sep)
    print(f"[Test] {t['name']}")
    print(sep)
    print("Prompt:")
    print(prompt)
    print(sep)
    print("TOP BEAMS:")
    for i, (score, text) in enumerate(beams, 1):
        print(f"\n--- Beam {i} | score={score:.3f} ---\n{text}\n")
    print(sep)
    print("PASS" if ok else "FAIL", "| must contain:", t["must_contain"])
    print()

--------------------------------------------------------------------------------
[Test] Corliss Archer – minimal (easy)
--------------------------------------------------------------------------------
Prompt:
### Instruction:
"You are a question answering agent. Given a context and a question, your task is to answer the question based on the context.
## Instruction:
What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?
[Retrieval]<paragraph>
Knowledge 1: Kiss and Tell (1945 film)
Kiss and Tell is a 1945 American comedy film starring 17-year-old Shirley Temple as Corliss Archer.

[SEP]
Knowledge 2: Shirley Temple Black
As an adult, Shirley Temple Black served the United States as a diplomat, including as Chief of Protocol of the United States, and as U.S. ambassador to Ghana and to Czechoslovakia.
</paragraph>"
--------------------------------------------------------------------------------
TOP BEAMS:

--- Beam 1 | score=-0.514 ---
### In

OutOfMemoryError: CUDA out of memory. Tried to allocate 252.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 168.12 MiB is free. Process 41498 has 14.57 GiB memory in use. Of the allocated memory 14.17 GiB is allocated by PyTorch, and 288.42 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)