In [1]:
import re
import datasets
import transformers
import torch
import numpy as np
import rich
import rich.panel
import rich.markup
import rich.table
import rich.rule
import matplotlib.pyplot as plt
import more_itertools as mit


In [2]:
# Extracted from open-instruct, slightly modified
def verify_gsm8k_sample(
        model_output, 
        ground_truth_answer, 
        logits, 
        tokenizer,
    ):
    
    # model_output = model_output.split("<|assistant|>\n")[-1].strip()

    # gsm is easy: extract numbers, and then just compare last number with answer.
    # matches how we do eval.
    predictions = None

    # replace numbers like `x,xxx` with `xxxx`
    response = re.sub(r"(\d),(\d)", r"\1\2", model_output)
    numbers = re.findall(r"[-+]?\d*\.\d+|\d+", response)

    # Find the token indices of the numbers that we extracted
    if numbers:
        predictions = numbers[-1]
    else:
        predictions = response

    # Put everything in a rich table and print it
    table = rich.table.Table(
        show_header=False, 
        show_lines=True, 
        box=rich.box.ROUNDED, 
        border_style="blue",
    )

    # Remove the text "<|im_end|>" from the end of the text
    while model_output.endswith(tokenizer.pad_token):
        model_output = model_output[:-len(tokenizer.pad_token)]

    model_output = model_output.replace(
        tokenizer.bos_token, 
        f"\n{tokenizer.bos_token}"
    ).strip()

    table.add_row("[bold blue]Model Output:", model_output)
    table.add_row("[bold blue]Parsed predictions:", predictions)
    rich.print(table)

    return str(predictions).lower() == str(ground_truth_answer).lower()

def box(title, text, title_style="[bold blue]"):
    """ Just a simple function to print a box with a title and text """
    rich.print(rich.panel.Panel.fit(
        rich.markup.escape(text), 
        title=f"{title_style}" + rich.markup.escape(title), 
        border_style="blue",
        title_align="left",
    ))
    

In [None]:
# model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
model_name = "Qwen/Qwen2.5-7B-Instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, padding_side="left")

config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.06k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

In [4]:
gsm8k = datasets.load_dataset("gsm8k", "main")["train"]

In [5]:
###########################################################################
# Prepare example questions and generate associated responses.
###########################################################################
batch_size = 30

inputs = [
    [
        {"role": "user", "content": gsm8k[i]["question"]},
    ] for i in range(batch_size)
]

templated_text = tokenizer.apply_chat_template(
    inputs, 
    tokenize=False,
    add_generation_prompt=True,
)

templated = tokenizer(
    templated_text,
    return_tensors="pt", 
    padding=True,
    return_offsets_mapping=True,
)

output = model.generate(
    input_ids=templated.input_ids.to(0), 
    attention_mask=templated.attention_mask.to(0), 
    max_new_tokens=512,
    do_sample=False,
    output_scores=True,
    return_dict_in_generate=True,
)


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


In [6]:
outputs = tokenizer.batch_decode(output.sequences) 

re_tokenizized = tokenizer(
    outputs, 
    padding=True, 
    padding_side="right", 
    return_tensors="pt",
    return_offsets_mapping=True,
)

# for i, output_i in enumerate(outputs):
#     box(f"Model Output {i}", output_i.replace(tokenizer.pad_token, ""))

assert (re_tokenizized["input_ids"] == output.sequences.cpu()).all()


RuntimeError: The size of tensor a (627) must match the size of tensor b (626) at non-singleton dimension 1

In [7]:
#################################################################################
# Find the answer in the offset mapping
#################################################################################
def extract_start_end(*, generated_text, sequence_ids,):

    last_match = mit.last(re.finditer(r"\d+", generated_text))
    assert isinstance(generated_text, str), type(generated_text).mro()

    re_tokenizized = tokenizer(
        generated_text, 
        padding=True, 
        padding_side="right", 
        return_tensors="pt",
        return_offsets_mapping=True,
    )

    offset_mapping = mit.one(re_tokenizized["offset_mapping"])

    start_token_pos = None
    for i, (start, end) in enumerate(offset_mapping):
        if start == last_match.start():
            start_token_pos = i
            break

    assert start_token_pos is not None

    # Find end of answer.
    end_token_pos = None
    for i, (start, end) in enumerate(
        offset_mapping[start_token_pos:], 
        start=start_token_pos,
    ):
        if end == last_match.end():
            end_token_pos = i
            break

    assert end_token_pos is not None, (
        start_token_pos, 
        end_token_pos, 
        len(re_tokenizized.input_ids[0])
    )
    assert end_token_pos >= start_token_pos, (end_token_pos, start_token_pos)
    
    re_decoded_answer = tokenizer.decode(
        sequence_ids[start_token_pos:end_token_pos + 1], 
        skip_special_tokens=True,
    ).strip()

    return start_token_pos, end_token_pos
#################################################################################

good_bads = []
softmax_prods_good = []
softmax_prods_bad = []
assert len(outputs) == batch_size, (len(outputs), batch_size)
sequence_length = output.sequences.shape[-1] - templated.input_ids.shape[-1]
vocab_size = len(tokenizer.vocab)

for i in range(batch_size):
    reference_answer = gsm8k[i]["answer"].split("####")[-1].strip()

    # table = rich.table.Table(show_header=False, show_lines=True, box=rich.box.ROUNDED, border_style="blue")
    # table.add_row("reference answer", reference_answer)
    # table.add_row("generated text", outputs[i])
    # rich.print(table)

    start_token_pos, end_token_pos = extract_start_end(
        generated_text=outputs[i], 
        sequence_ids=output.sequences[i],
    )

    input_length = templated.input_ids.shape[-1]
    answer_tokens = output.sequences[i][start_token_pos:end_token_pos + 1].tolist()

    # print("output.sequences.shape", output.sequences.shape)
    # print("output.scores"         , len(output.scores))
    # print("output.scores shapes"  , [x.shape for x in output.scores], 0)

    output_only_seq_len = len(output.scores)

    confidence_scores = []
    assert len(output.scores) == sequence_length, (len(output.scores), sequence_length)

    for entry in output.scores[start_token_pos - input_length:end_token_pos + 1 - input_length]:
        # print(entry.shape)
        assert len(entry) == batch_size, (len(entry), batch_size)
        entry = entry[i]
        # assert len(entry) == len(tokenizer.vocab), (len(entry), len(tokenizer.vocab))
        confidence_scores.append(entry)

    confidence_scores = torch.stack(confidence_scores, 0)
    
    assert len(answer_tokens) == len(confidence_scores), (len(answer_tokens), len(confidence_scores))
    # assert confidence_scores.shape == (len(answer_tokens), len(tokenizer.vocab)), ((len(answer_tokens), len(tokenizer.vocab)), confidence_scores.shape, [x.shape for x in output.scores])
    
    softmaxed = torch.tensor([
        torch.softmax(confidence_scores, dim=-1)[in_answer_seq_pos, token_id].item() 
        for in_answer_seq_pos, token_id in enumerate(answer_tokens)
    ], dtype=torch.float32)

    is_wrong = reference_answer.strip() != tokenizer.decode(answer_tokens).strip()
    table = rich.table.Table(
        show_header=False, 
        show_lines=True, 
        box=rich.box.ROUNDED, 
        border_style="blue",
    )
    table.add_row("is_wrong"                , "[bold]" + ("[red]" if is_wrong else "[green]") + str(is_wrong))
    table.add_row("reference answer"        , reference_answer)
    table.add_row("generated text"          , outputs[i])
    table.add_row("answer_tokens"           , str(answer_tokens))
    table.add_row("confidence_scores shape" , str(confidence_scores.shape))
    table.add_row("answer_tokens and scores", str([(token_id, torch.softmax(confidence_scores, dim=-1)[in_answer_seq_pos, token_id].item()) for in_answer_seq_pos, token_id in enumerate(answer_tokens)]))
    table.add_row("Decoded answer_tokens"   , str([tokenizer.decode(x) for x in answer_tokens]))
    table.add_row("softmax_prod"              , str(softmaxed.prod().item()))
    table.add_row("softmax_average"           , str(softmaxed.mean().item()))
    table.add_row("softmax_first"             , str(softmaxed[0:1].item()))
    table.add_row("geometric average softmax"              , str(softmaxed.prod().pow(1/len(softmaxed)).item()))
    rich.print(table)

    good_bads.append(is_wrong)
    if is_wrong:
        softmax_prods_bad .append(softmaxed.mean().item())
    else:
        softmax_prods_good.append(softmaxed.mean().item())

    print("Good", np.mean(softmax_prods_good))
    print("Bad", np.mean(softmax_prods_bad))
    print("Accuracy", 1 - np.mean(good_bads))


Good nan
Bad 0.9996699094772339
Accuracy 0.0


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Good nan
Bad 0.9990869164466858
Accuracy 0.0


Good 0.9997540712356567
Bad 0.9990869164466858
Accuracy 0.33333333333333337


Good 0.9997540712356567
Bad 0.9977306524912516
Accuracy 0.25


Good 0.9997540712356567
Bad 0.9922066479921341
Accuracy 0.19999999999999996


Good 0.9997540712356567
Bad 0.9880954146385192
Accuracy 0.16666666666666663


Good 0.9997540712356567
Bad 0.9895069499810537
Accuracy 0.1428571428571429


Good 0.9997540712356567
Bad 0.990957430430821
Accuracy 0.125


Good 0.9997540712356567
Bad 0.9918598979711533
Accuracy 0.11111111111111116


Good 0.9997540712356567
Bad 0.9927095439698961
Accuracy 0.09999999999999998


Good 0.9997540712356567
Bad 0.9750606298446656
Accuracy 0.09090909090909094


Good 0.9997540712356567
Bad 0.9603350812738592
Accuracy 0.08333333333333337


Good 0.9997540712356567
Bad 0.963624894618988
Accuracy 0.07692307692307687


Good 0.9997540712356567
Bad 0.966179398389963
Accuracy 0.0714285714285714


Good 0.9997540712356567
Bad 0.9264072626829147
Accuracy 0.06666666666666665


Good 0.9997540712356567
Bad 0.9313133895397187
Accuracy 0.0625


Good 0.9997540712356567
Bad 0.9353943672031164
Accuracy 0.05882352941176472


Good 0.9997540712356567
Bad 0.9391585781293756
Accuracy 0.05555555555555558


Good 0.9997540712356567
Bad 0.9199444303909937
Accuracy 0.052631578947368474


Good 0.9997540712356567
Bad 0.9122940819514426
Accuracy 0.050000000000000044


Good 0.9997540712356567
Bad 0.9160747423768043
Accuracy 0.04761904761904767


Good 0.9997540712356567
Bad 0.9198682265622276
Accuracy 0.045454545454545414


Good 0.9997540712356567
Bad 0.899804103103551
Accuracy 0.04347826086956519


Good 0.9997540712356567
Bad 0.9039572723533796
Accuracy 0.04166666666666663


Good 0.9997540712356567
Bad 0.9079589731991291
Accuracy 0.040000000000000036


Good 0.9997540712356567
Bad 0.9116083490848541
Accuracy 0.038461538461538436


Good 0.9997540712356567
Bad 0.9150071499439386
Accuracy 0.03703703703703709


RuntimeError: stack expects a non-empty TensorList

In [38]:
confidence_scores.shape

torch.Size([2, 49152])

In [10]:
print(outputs[0])

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|im_end|>
<|im_start|>assistant
Natalia sold 48 clips in April.
In May, she sold half as many clips as in April, so she sold 48 / 2 = 24 clips.
Altogether, Natalia sold 48 + 24 = 72 clips in April and May.
#### 72
The answer is: 72<|im_end|>


In [11]:
text_list = tokenizer.batch_decode(output)

for i, output_text in enumerate(text_list):
    output_text = output_text[output_text.find("<|im_start|>"):]
    while output_text.endswith("<|im_end|>"):
        output_text = output_text[:-len("<|im_end|>")]
        
    messages = [x for x in output_text.split("<|im_start|>") if x.strip()]

    for message in messages:
        lines = message.split("\n")
        box(lines[0].capitalize(), "\n".join(lines[1:]))
    
    verify_gsm8k_sample(output_text, gsm8k[i]["answer"])
    if i < len(text_list) - 1:
        rich.print(rich.rule.Rule())


TypeError: argument 'ids': Can't extract `str` to `Vec`

In [10]:
batch_size = 30