# Setup

In [None]:
# standard library imports
import os
import subprocess
from typing import Optional

# related third party imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from numpy.typing import ArrayLike, NDArray
from transformers import AutoTokenizer

# local application/library specific imports
from tools.utils import activate_latex, deactivate_latex, ensure_dir
from data_loader.data_loader import QDET

In [2]:
###### INPUTS ######
PRINT_PAPER = True
SANS_SERIF = True

# Functions

In [4]:
def compute_num_tokens(
    input_sentences: ArrayLike, tokenizer: AutoTokenizer
) -> np.ndarray:
    """Computes the number of tokens in each sentence in a list of sentences.

    Parameters
    ----------
    input_sentences : ArrayLike
        List of sentences.
    tokenizer : AutoTokenizer
        Tokenizer object.

    Returns
    -------
    np.ndarray
        Array of number of tokens in each sentence.
    """
    input_ids = tokenizer(input_sentences)["input_ids"]
    return np.array([len(sentence) for sentence in input_ids])


In [7]:
def print_token_info(num_tokens: ArrayLike) -> None:
    """Print info about tokens.

    Parameters
    ----------
    num_tokens : ArrayLike
        Array of number of tokens per sequence.
    """
    print(f"Mean: {np.mean(num_tokens):.2f}")
    print(f"Median: {np.median(num_tokens):.2f}")
    print(f"Max: {np.max(num_tokens)}")
    print(f"Min: {np.min(num_tokens)}")

In [48]:
def plot_token_hist(
    num_tokens_dict: dict[str, NDArray],
    bins: int = 20,
    vline: Optional[int] = None,
    savename: Optional[str] = None,
) -> plt.Axes:
    """Plots histogram of number of tokens in sentences.

    Parameters
    ----------
    num_tokens_dict : dict[str, NDArray]
        Dictionary of number of tokens in sentences.
    bins : int, optional
        Histogram bins, by default 20
    vline : Optional[int], optional
        Vertical line at some token limit, by default None
    savename : Optional[str], optional
        Name to save plot to, by default None

    Returns
    -------
    plt.Axes
        Axes object.
    """
    fig, ax = plt.subplots(1, 1, figsize=(5,3))
    for name, num_tokens in num_tokens_dict.items():
        ax.hist(num_tokens, bins=bins, alpha=0.5, label=name)
    if vline:
        ax.axvline(vline, color="red", linestyle="--", label="512 tokens")
    ax.set_xlabel("Number of tokens")
    ax.set_ylabel("Frequency")
    ax.legend()
    ax.grid(True, linestyle="--")
    # get ticks in sans-serif if sans-serif is used
    ax.xaxis.get_major_formatter()._usetex = False
    ax.yaxis.get_major_formatter()._usetex = False
    if savename is not None:
        plt.tight_layout()
        ensure_dir(os.path.dirname(savename))
        plt.savefig(savename)
    plt.show()

# Compute tokens

In [None]:
loader = QDET(
    name="arc",
    num_classes=7,
    output_type="regression",
    small_dev=None,
    balanced=True,
    seed=42,
)
arc_dataset = loader.load_all()

loader = QDET(
    name="race_pp",
    num_classes=3,
    output_type="regression",
    small_dev=None,
    balanced=False,
    seed=42,
)
racepp_dataset = loader.load_all()

In [None]:
# ARC - BERT
num_tokens_bert_arc = compute_num_tokens(
    input_sentences=arc_dataset["train"]["text"],
    tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"),
)

# RACE++ - BERT
num_tokens_bert_racepp = compute_num_tokens(
    input_sentences=racepp_dataset["train"]["text"],
    tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased"),
)

# ARC - ModernBERT
num_tokens_modernbert_arc = compute_num_tokens(
    input_sentences=arc_dataset["train"]["text"],
    tokenizer=AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base"),
)

# RACE++ - ModernBERT
num_tokens_modernbert_racepp = compute_num_tokens(
    input_sentences=racepp_dataset["train"]["text"],
    tokenizer=AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base"),
)

In [None]:
print_token_info(num_tokens_bert_arc)

In [None]:
print_token_info(num_tokens_bert_racepp)

In [None]:
print_token_info(num_tokens_modernbert_arc)

In [None]:
print_token_info(num_tokens_modernbert_racepp)

In [None]:
plot_token_hist(
    num_tokens_dict={"BERT": num_tokens_bert_arc, "ModernBERT": num_tokens_modernbert_arc},
    vline=512,
)

In [None]:
plot_token_hist(
    num_tokens_dict={"BERT": num_tokens_bert_racepp, "ModernBERT": num_tokens_modernbert_racepp},
    vline=512,
)

In [None]:
if PRINT_PAPER:
    activate_latex(sans_serif=SANS_SERIF)
    ########
    plot_token_hist(
        num_tokens_dict={
            "BERT": num_tokens_bert_arc,
            "ModernBERT": num_tokens_modernbert_arc,
        },
        vline=512,
        savename=os.path.join("output", "figures", "arc_token_hist.pdf"),
    )
    ########
    plot_token_hist(
        num_tokens_dict={
            "BERT": num_tokens_bert_racepp,
            "ModernBERT": num_tokens_modernbert_racepp,
        },
        vline=512,
        savename=os.path.join("output", "figures", "racepp_token_hist.pdf"),
    )
    ########
    deactivate_latex()