In [5]:
import pandas as pd

mmlu_qwen_train_df_easy = pd.read_csv(
    "../../../data/data_splits/entropy_fallback/qwen/train_df_easy.tsv",
    sep="\t",
    header=0,
)
mmlu_qwen_train_df_mid = pd.read_csv(
    "../../../data/data_splits/entropy_fallback/qwen/train_df_middle.tsv",
    sep="\t",
    header=0,
)
mmlu_qwen_train_df_hard = pd.read_csv(
    "../../../data/data_splits/entropy_fallback/qwen/train_df_hard.tsv",
    sep="\t",
    header=0,
)

mmlu_phi4_train_df_easy = pd.read_csv(
    "../../../data/data_splits/entropy_fallback/phi/train_df_easy.tsv",
    sep="\t",
    header=0,
)
mmlu_phi4_train_df_mid = pd.read_csv(
    "../../../data/data_splits/entropy_fallback/phi/train_df_middle.tsv",
    sep="\t",
    header=0,
)
mmlu_phi4_train_df_hard = pd.read_csv(
    "../../../data/data_splits/entropy_fallback/phi/train_df_hard.tsv",
    sep="\t",
    header=0,
)

In [6]:
distill_df = pd.read_csv("../../../data/out/distillation/mmlu_deepseek_v3.tsv", sep="\t", header=0)
distill_df = distill_df[["distill_response", "question_id"]]

In [7]:
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer


def count_response_tokens(df, tokenizer: AutoTokenizer, response_column: str):
    token_cnt = 0

    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        response = row[response_column]

        if type(response) is not str:
            continue

        tokens = tokenizer.encode(response)
        token_cnt += len(tokens)

    return token_cnt


def count_response_tokens_by_split(easy_df, mid_df, hard_df, distill_df, tokenizer):
    mmlu_df_easy_distill = pd.merge(easy_df, distill_df, on="question_id")
    mmlu_df_mid_distill = pd.merge(mid_df, distill_df, on="question_id")
    mmlu_df_hard_distill = pd.merge(hard_df, distill_df, on="question_id")

    distill_response_column = "distill_response"

    distill_response_token_cnt_easy = count_response_tokens(mmlu_df_easy_distill, tokenizer, distill_response_column)
    distill_response_token_cnt_mid = count_response_tokens(mmlu_df_mid_distill, tokenizer, distill_response_column)
    distill_response_token_cnt_hard = count_response_tokens(mmlu_df_hard_distill, tokenizer, distill_response_column)

    sft_response_column = "answer_index"

    easy_df[sft_response_column] = (easy_df[sft_response_column] + 1).astype(str)
    mid_df[sft_response_column] = (mid_df[sft_response_column] + 1).astype(str)
    hard_df[sft_response_column] = (hard_df[sft_response_column] + 1).astype(str)

    sft_response_token_cnt_easy = count_response_tokens(easy_df, tokenizer, sft_response_column)
    sft_response_token_cnt_mid = count_response_tokens(mid_df, tokenizer, sft_response_column)
    sft_response_token_cnt_hard = count_response_tokens(hard_df, tokenizer, sft_response_column)

    total_distill_token_count = (
        distill_response_token_cnt_easy + distill_response_token_cnt_mid + distill_response_token_cnt_hard
    )
    total_sft_token_count = sft_response_token_cnt_easy + sft_response_token_cnt_mid + sft_response_token_cnt_hard
    pipeline_token_count = sft_response_token_cnt_easy + sft_response_token_cnt_mid + distill_response_token_cnt_hard
    alternative_token_count = (
        distill_response_token_cnt_easy + distill_response_token_cnt_mid + sft_response_token_cnt_hard
    )

    return total_sft_token_count, total_distill_token_count, alternative_token_count, pipeline_token_count


In [8]:
from transformers import AutoTokenizer

qwen_token_count = count_response_tokens_by_split(
    mmlu_qwen_train_df_easy,
    mmlu_qwen_train_df_mid,
    mmlu_qwen_train_df_hard,
    distill_df,
    AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct"),
)

phi4_token_count = count_response_tokens_by_split(
    mmlu_phi4_train_df_easy,
    mmlu_phi4_train_df_mid,
    mmlu_phi4_train_df_hard,
    distill_df,
    AutoTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct"),
)

print(
    f"Qwen 3B token count: sft = {qwen_token_count[0]}, distill = {qwen_token_count[1]}, alternative = {qwen_token_count[2]}, pipeline = {qwen_token_count[3]}, advantage = {1 - qwen_token_count[3] / qwen_token_count[1]}"
)
print(
    f"Phi4-mini token count: sft = {phi4_token_count[0]}, distill = {phi4_token_count[1]}, alternative = {phi4_token_count[2]}, pipeline = {phi4_token_count[3]}, advantage = {1 - phi4_token_count[3] / phi4_token_count[1]}"
)

100%|██████████| 900/900 [00:00<00:00, 1094.47it/s]
100%|██████████| 900/900 [00:01<00:00, 888.71it/s]
100%|██████████| 900/900 [00:01<00:00, 796.33it/s]
100%|██████████| 900/900 [00:00<00:00, 40892.56it/s]
100%|██████████| 900/900 [00:00<00:00, 39578.03it/s]
100%|██████████| 900/900 [00:00<00:00, 41661.13it/s]
100%|██████████| 900/900 [00:00<00:00, 1215.56it/s]
100%|██████████| 900/900 [00:00<00:00, 1069.60it/s]
100%|██████████| 900/900 [00:00<00:00, 1091.96it/s]
100%|██████████| 900/900 [00:00<00:00, 34713.08it/s]
100%|██████████| 900/900 [00:00<00:00, 38351.62it/s]
100%|██████████| 900/900 [00:00<00:00, 39243.11it/s]


Qwen 3B token count: sft = 2948, distill = 1956671, alternative = 1158232, pipeline = 801387, advantage = 0.5904334453773782
Phi4-mini token count: sft = 2700, distill = 1481696, alternative = 950794, pipeline = 533602, advantage = 0.6398707967086366
