In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
import pathlib
import pprint
from logging import Formatter
from logging.handlers import RotatingFileHandler

from hydra import compose, initialize_config_dir
from rich.console import Console
from rich.logging import RichHandler
from rich.table import Table
from rich.text import Text

from topollm.config_classes.constants import HYDRA_CONFIGS_BASE_PATH
from topollm.config_classes.main_config import MainConfig
from topollm.logging.initialize_configuration_and_log import initialize_configuration
from topollm.model_finetuning.load_tokenizer_from_finetuning_config import load_tokenizer_from_finetuning_config
from topollm.typing.enums import Verbosity

verbosity = Verbosity.NORMAL

LOGFORMAT_FILE = "[%(asctime)s][%(levelname)8s][%(name)s] %(message)s (%(filename)s:%(lineno)s)"
LOGFORMAT_RICH = "%(message)s"

error_console = Console(stderr=True)

rich_handler = RichHandler(console=error_console)

rich_handler.setFormatter(Formatter(LOGFORMAT_RICH))

rotating_file_path = pathlib.Path(
    "logs",
    "load_tokenizer_and_model.log",
)
rotating_file_path.parent.mkdir(exist_ok=True, parents=True)
rotating_file_handler = RotatingFileHandler(
    rotating_file_path,
    maxBytes=1024 * 1024 * 10,  # 10Mb
    backupCount=10,
)

logging.basicConfig(
    level=logging.INFO,
    format=LOGFORMAT_FILE,
    handlers=[
        rich_handler,
        rotating_file_handler,
    ],
)

logger = logging.getLogger("load_tokenizer_and_model")

In [None]:
# Using hydra in Jupyter notebook:
# https://github.com/facebookresearch/hydra/blob/main/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb
from topollm.model_handling.tokenizer.tokenizer_modifier.factory import get_tokenizer_modifier

abs_config_dir = pathlib.Path(HYDRA_CONFIGS_BASE_PATH)
with initialize_config_dir(version_base=None, config_dir=str(abs_config_dir)):
    config = compose(
        config_name="main_config",
        overrides=["feature_flags.finetuning.use_wandb=false", "finetuning=finetuning_for_token_classification"],
        return_hydra_config=True,
    )
    logger.info(pprint.pformat(config, indent=4))

    main_config: MainConfig = initialize_configuration(
        config=config,
        logger=logger,
    )

# # # #
finetuning_config = main_config.finetuning

base_tokenizer = load_tokenizer_from_finetuning_config(
    finetuning_config=finetuning_config,
    verbosity=verbosity,
    logger=logger,
)
tokenizer_modifier = get_tokenizer_modifier(
    tokenizer_modifier_config=finetuning_config.base_model.tokenizer_modifier,
    verbosity=verbosity,
    logger=logger,
)

tokenizer = tokenizer_modifier.modify_tokenizer(
    tokenizer=base_tokenizer,
)

In [None]:
# # # # # # # # # # # # # # # # # # # # # # # #
# START Example data

example_token_ids: list[int] = [
    0,
    8814,
    7,
    2487,
    456,
    16415,
    45835,
    3175,
    27785,
    3309,
    359,
    2393,
    1173,
    32,
    32659,
    103,
    8419,
    897,
    8419,
    46909,
    23184,
    897,
    28696,
    155,
    2,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
]

example_attention_mask = [
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    1,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
    0,
]

example_labels_aligned = [
    -100,
    0,
    0,
    7,
    0,
    0,
    -100,
    -100,
    0,
    9,
    0,
    9,
    -100,
    0,
    0,
    0,
    0,
    0,
    0,
    -100,
    0,
    0,
    0,
    0,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
    -100,
]

# END Example data
# # # # # # # # # # # # # # # # # # # # # # # #

In [None]:
# Decode the token_ids

example_token_ids_decoded = [tokenizer.convert_ids_to_tokens(token_id) for token_id in example_token_ids]

logger.info("example_token_ids_decoded:\n%s", example_token_ids_decoded)

label_list: list[str] = [
    "O",
    "B-corporation",
    "I-corporation",
    "B-creative-work",
    "I-creative-work",
    "B-group",
    "I-group",
    "B-location",
    "I-location",
    "B-person",
    "I-person",
    "B-product",
    "I-product",
]


def decode_label(
    label: int,
    label_list: list[str],
) -> str:
    """Decode label from label index to label string."""
    padding_label = -100
    if label == padding_label:
        return "X"
    return label_list[label]


example_labels_aligned_mapped = [
    decode_label(
        label=label,
        label_list=label_list,
    )
    for label in example_labels_aligned
]

logger.info("example_labels_aligned_mapped:\n%s", example_labels_aligned_mapped)

In [None]:
merged_list = list(
    zip(
        example_token_ids,
        example_attention_mask,
        example_token_ids_decoded,
        example_labels_aligned,
        example_labels_aligned_mapped,
        strict=True,
    ),
)

logger.info("merged_list:\n%s", merged_list)

In [None]:
console = Console()

# Create a table with headers
table = Table(title="Merged List")

# Add columns (make sure these match the number of elements in your tuples)
table.add_column("Token ID", justify="right")
table.add_column("Attention Mask", justify="right")
table.add_column("Token Decoded")
table.add_column("Label Aligned", justify="right")
table.add_column("Label Aligned Mapped")

# Add rows to the table
for item in merged_list:
    table.add_row(*map(str, item))  # Convert all items to strings for display

# Print the table to the console
console.print(table)


def log_table_to_string(rich_table: Table) -> Text:
    """Generate an ascii formatted presentation of a Rich table.

    Eliminates any column styling.
    """
    console = Console(width=150)
    with console.capture() as capture:
        console.print(rich_table)
    return Text.from_ansi(capture.get())


logger.info("table:\n%s", log_table_to_string(table))