In [3]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from tqdm.auto import tqdm

In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
dataset = load_dataset("cais/mmlu", 'all', split="validation")

In [6]:
# model_name = 'distilbert/distilgpt2'
model_name = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [11]:
num_samples=1
ds = dataset.select(range(min(num_samples, len(dataset))))
ds

Dataset({
    features: ['question', 'subject', 'choices', 'answer'],
    num_rows: 1
})

In [35]:
args = {'strategy': 'best_of_n','best_of': 5, 'max_new_tokens': 10, 'k': True, 'temperature': 1.0, 'model': }

In [39]:
def beam_search(model, input_ids, attention_mask, args):
    """
    Beam search decoding with num_beams.
    """
    eos_id = model.config.eos_token_id
    device = input_ids.device
    beam_size = args.num_beams
    seq_len = input_ids.size(1)
    # initialize beam candidates
    beams = [{
        'ids': input_ids,
        'mask': attention_mask,
        'past': None,
        'score': 0.0,
        'done': False
    }]
    for _ in range(args.max_new_tokens):
        all_candidates = []
        for beam in beams:
            if beam['done']:
                all_candidates.append(beam)
                continue
            if past is None:
                out = model(
                    input_ids=generated,
                    attention_mask=mask,
                    use_cache=True
                )
            else:
                out = model(
                    input_ids=generated[:, -1:],
                    attention_mask=mask,
                    past_key_values=past,
                    use_cache=True
                )
            logits = out.logits[:, -1, :]
            past = out.past_key_values
            log_probs = torch.log_softmax(logits, dim=-1)
            topk_logprobs, topk_idx = torch.topk(log_probs, beam_size, dim=-1)
            topk_logprobs = topk_logprobs[0]
            topk_idx = topk_idx[0]
            for j in range(beam_size):
                next_tok = topk_idx[j].unsqueeze(0).unsqueeze(0)
                new_ids = torch.cat([beam['ids'], next_tok], dim=-1)
                new_mask = torch.cat([
                    beam['mask'],
                    torch.ones((1,1), dtype=beam['mask'].dtype, device=device)
                ], dim=-1)
                new_score = beam['score'] + topk_logprobs[j].item()
                done = (next_tok.item() == eos_id)
                all_candidates.append({
                    'ids': new_ids,
                    'mask': new_mask,
                    'past': past,
                    'score': new_score,
                    'done': done
                })
        # select top beams
        beams = sorted(all_candidates, key=lambda x: x['score'], reverse=True)[:beam_size]
        if all(b['done'] for b in beams):
            break
    # pick best beam
    best_beam = max(beams, key=lambda x: x['score'])
    return best_beam['ids'][:, seq_len:]

In [40]:
def greedy_search(model, input_ids, attention_mask, args):
    """
    Greedy decoding: pick the highest-probability token at each step.
    """
    eos_id = model.config.eos_token_id
    device = input_ids.device
    generated = input_ids
    mask = attention_mask
    past = None
    for _ in range(args.max_new_tokens):
        if past is None:
            out = model(
                input_ids=generated,
                attention_mask=mask,
                use_cache=True
            )
        else:
            out = model(
                input_ids=generated[:, -1:],
                attention_mask=mask,
                past_key_values=past,
                use_cache=True
            )
        logits = out.logits[:, -1, :]
        past = out.past_key_values
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=-1)
        mask = torch.cat([
            mask,
            torch.ones((generated.size(0), 1), dtype=mask.dtype, device=device)
        ], dim=-1)
        if next_token.item() == eos_id:
            break
    return generated[:, input_ids.size(1):]

In [36]:
def best_of_n(model, input_ids, attention_mask, args):
    """
    Run Best-of-N sampling: draw args.best_of samples via multinomial sampling
    and pick the sequence with the highest sum of log-probabilities.
    """
    eos_id = model.config.eos_token_id
    device = input_ids.device
    seq_len = input_ids.size(1)
    best_seq = None
    best_score = float('-inf')
    for _ in range(args['best_of']):
        # sampling pass
        generated = input_ids
        past = None
        mask = attention_mask
        for _ in range(args['max_new_tokens']):
            if past is None:
                out = model(
                    input_ids=generated,
                    attention_mask=mask,
                    use_cache=True
                )
            else:
                out = model(
                    input_ids=generated[:, -1:],
                    attention_mask=mask,
                    past_key_values=past,
                    use_cache=True
                )
            logits = out.logits[:, -1, :]
            past = out.past_key_values
            # apply temperature
            logits = logits / args['temperature']
            # top-k if requested
            if args['k']:
                k = 5
                topk_vals, topk_idx = torch.topk(logits, k, dim=-1)
                probs = torch.zeros_like(logits).scatter_(1, topk_idx, torch.softmax(topk_vals, dim=-1))
            else:
                probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=-1)
            mask = torch.cat([mask, torch.ones((generated.size(0), 1), dtype=mask.dtype, device=device)], dim=-1)
            if next_token.item() == eos_id:
                break
        # score this sequence
        with torch.no_grad():
            full_out = model(input_ids=generated, attention_mask=mask, use_cache=False)
            log_probs = torch.log_softmax(full_out.logits, dim=-1)
        score = 0.0
        for i in range(seq_len, generated.size(1)):
            token_id = generated[0, i]
            score += log_probs[0, i-1, token_id].item()
        if score > best_score:
            best_score = score
            best_seq = generated[:, seq_len:]
    return best_seq

In [41]:
def generate(model, input_ids, attention_mask, args):
    if args['strategy'] == 'best_of_n':
        return best_of_n(model, input_ids, attention_mask, args)
    elif args['strategy'] == 'greedy':
        return greedy_search(model, input_ids, attention_mask, args)
    elif args['strategy'] == 'beam_search':
        return beam_search(model, input_ids, attention_mask, args)


In [45]:
prompts, preds, trues = [], [], []
for ex in tqdm(ds, desc="Generation"):
    question, choices, label = ex['question'], ex.get('choices', []), ex['answer']
    print(question)
    print(choices)
    print(label)

    choices_str = ' '.join(f"({chr(65+i)}) {c}" for i,c in enumerate(choices))
    print(choices_str)

    prompt = (
            f"You are an assistant. Read the question and answer with a single letter."
            f"\nQuestion: {question}\nChoices: {choices_str}\nAnswer:"
    )

    print("prompt:", f"-length {len(prompt)}-", prompt)
    inputs = tokenizer(prompt, return_tensors='pt')
    print('input_ids:', f"length{inputs['input_ids'].shape}", inputs['input_ids'])
    print('attention_mask:', f"length{inputs['attention_mask'].shape}", inputs['attention_mask'])
    print("inputs items", inputs.items())
    inputs = {k: v.to('cpu') for k,v in inputs.items()}
    print("inputs mapped", inputs)

    # CUSTOM GENERATE FUNCTION

    response = generate(model, inputs['input_ids'], inputs['attention_mask'], args)
    print(response)
    text = tokenizer.decode(response[0], skip_special_tokens=True).strip()

    print(text)

Generation:   0%|          | 0/1 [00:00<?, ?it/s]

The cyclic subgroup of Z_24 generated by 18 has order
['4', '8', '12', '6']
0
(A) 4 (B) 8 (C) 12 (D) 6
prompt: -length 178- You are an assistant. Read the question and answer with a single letter.
Question: The cyclic subgroup of Z_24 generated by 18 has order
Choices: (A) 4 (B) 8 (C) 12 (D) 6
Answer:
input_ids: lengthtorch.Size([1, 60]) tensor([[151646,   2610,    525,    458,  17847,     13,   4457,    279,   3405,
            323,   4226,    448,    264,   3175,   6524,    624,  14582,     25,
            576,  76002,  80115,    315,   1863,     62,     17,     19,   7907,
            553,    220,     16,     23,    702,   1973,    198,  89283,     25,
            320,     32,      8,    220,     19,    320,     33,      8,    220,
             23,    320,     34,      8,    220,     16,     17,    320,     35,
              8,    220,     21,    198,  16141,     25]])
attention_mask: lengthtorch.Size([1, 60]) tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [51]:
inputs['input_ids'][:, -1:]

tensor([[25]])

In [46]:
text

'200, 200,'