In [6]:
from QwenMCQA import *  # your legacy_prompt or cot_prompt
from datasets import load_dataset
from transformers import AutoTokenizer

def compute_truncation_and_overflow(dataset, tokenizer, max_length, prompt_function):
    total_examples     = len(dataset)
    truncated_examples = 0

    total_tokens     = 0
    overflow_tokens  = 0

    for example in dataset:
        prompt = prompt_function(example)[0]    # or cot_prompt(example)[0]
        ids    = tokenizer(prompt, add_special_tokens=False)["input_ids"]
        L      = len(ids)

        # count examples that exceed the limit
        if L > max_length:
            truncated_examples += 1

        # accumulate token counts
        total_tokens    += L
        overflow_tokens += max(L - max_length, 0)

    # example‐level stats
    print(
        f"Examples truncated: {truncated_examples/total_examples*100:5.2f}%"
    )

    # token‐level stats
    ratio = overflow_tokens / total_tokens if total_tokens > 0 else 0.0
    print(
        f"Tokens beyond max_length: {ratio*100:5.2f}%"
    )

# usage
BASE_MODEL = "Qwen/Qwen3-0.6B-Base"
MAX_LEN    = 512

ds        = load_dataset("NicoHelemon/MNLP_M3_mcqa_dataset", split="train")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)

In [7]:
compute_truncation_and_overflow(ds, tokenizer, max_length=MAX_LEN, prompt_function = legacy_prompt)

Examples truncated:  4.76%
Tokens beyond max_length:  2.40%


In [9]:
cot_prompt_1 = partial(
        cot_prompt,
        cot_prob = 1
    )

compute_truncation_and_overflow(ds, tokenizer, max_length=MAX_LEN, prompt_function = cot_prompt_1)

Examples truncated:  9.79%
Tokens beyond max_length:  3.17%
