In [None]:
import logging
import pathlib
import sys
from itertools import product

import pandas as pd
from pandas import DataFrame
from tqdm import tqdm

from topollm.analysis.compare_sampling_methods.make_plots import (
    Y_AXIS_LIMITS_ONLY_FULL,
    create_boxplot_of_mean_over_different_sampling_seeds,
    generate_fixed_params_text,
)
from topollm.analysis.compare_sampling_methods.run_general_comparisons import (
    filter_dataframe_based_on_filters_dict,
    load_and_concatenate_saved_dataframes,
)
from topollm.config_classes.constants import NAME_PREFIXES_TO_FULL_DESCRIPTIONS, TOPO_LLM_REPOSITORY_BASE_PATH
from topollm.typing.enums import Verbosity

# Create a logger
default_logger: logging.Logger = logging.getLogger(name=__name__)
default_logger.setLevel(level=logging.DEBUG)

# Create a stream handler
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setLevel(level=logging.DEBUG)

# Create a formatter and attach it to the handler
formatter = logging.Formatter(fmt="[%(asctime)s][%(levelname)8s][%(name)s] %(message)s (%(filename)s:%(lineno)s)")
stream_handler.setFormatter(fmt=formatter)

# Add the handler to the logger
if not default_logger.handlers:  # Avoid adding duplicate handlers in case the cell is re-executed
    default_logger.addHandler(hdlr=stream_handler)

verbosity: Verbosity = Verbosity.NORMAL
logger: logging.Logger = default_logger

# Example usage
logger.debug(msg="This is a debug message.")
logger.info(msg="This is an info message.")

In [None]:
comparisons_folder_base_path = pathlib.Path(
    TOPO_LLM_REPOSITORY_BASE_PATH,
    "data/analysis/sample_sizes/",
    "run_general_comparisons/",
    "array_truncation_size=5000/",
    "analysis/twonn/",
)

concatenated_df: DataFrame = load_and_concatenate_saved_dataframes(
    root_dir=comparisons_folder_base_path,
)

columns_to_investigate: list[str] = [
    "data_full",
    "data_subsampling_full",
    "model_partial_name",
]

for column_name in columns_to_investigate:
    print(30 * "=")
    print(
        f"Unique values in column '{column_name = }':",
    )
    print(
        concatenated_df[column_name].unique(),
    )

concatenated_df.info()

In [None]:
concatenated_df

### Investigate different model checkpoints

In [None]:
from topollm.analysis.compare_sampling_methods.make_plots import Y_AXIS_LIMITS

# TODO: Save the raw data to files


def create_histograms_over_model_checkpoints(
    concatenated_df: pd.DataFrame,
    concatenated_filters_dict: dict,
    base_model_partial_name: str = "model=roberta-base",
    figsize: tuple[int, int] = (24, 8),
    common_prefix_path: pathlib.Path | None = None,
    raw_data_save_path: pathlib.Path | None = None,
    verbosity: Verbosity = Verbosity.NORMAL,
    logger: logging.Logger = default_logger,
) -> None:
    """Create histograms over the different model checkpoints for the concatenated dataframe."""
    filtered_concatenated_df: pd.DataFrame = filter_dataframe_based_on_filters_dict(
        df=concatenated_df,
        filters_dict=concatenated_filters_dict,
        verbosity=verbosity,
        logger=logger,
    )

    # # # #
    # Filter for the dataframe with just the base model data

    same_filters_but_for_base_model = concatenated_filters_dict.copy()
    same_filters_but_for_base_model["model_partial_name"] = base_model_partial_name

    filtered_for_base_model_concatenated_df = filter_dataframe_based_on_filters_dict(
        df=concatenated_df,
        filters_dict=same_filters_but_for_base_model,
    )

    # Set all the values in the "model_checkpoint" column to "-1"
    filtered_for_base_model_concatenated_df["model_checkpoint"] = -1

    # # # #
    # Create a dataframe by concatenating the two dataframes
    data_for_checkpoint_analysis_df: DataFrame = pd.concat(
        objs=[filtered_concatenated_df, filtered_for_base_model_concatenated_df],
        ignore_index=True,
    )

    # # # #
    # Group "data_for_checkpoint_analysis_df" by value in 'model_checkpoint' column
    # and make a boxplot of "array_data_truncated_mean" for each group

    fixed_params_text: str = generate_fixed_params_text(
        filters_dict=concatenated_filters_dict,
    )

    for y_min, y_max in Y_AXIS_LIMITS.values():
        if common_prefix_path is not None:
            plot_save_path = pathlib.Path(
                common_prefix_path,
                "plots",
                f"y_{y_min}_{y_max}.pdf",
            )
        else:
            plot_save_path = None

        create_boxplot_of_mean_over_different_sampling_seeds(
            subset_local_estimates_df=data_for_checkpoint_analysis_df,
            x_column_name="model_checkpoint",
            y_column_name="array_data_truncated_mean",
            fixed_params_text=fixed_params_text,
            figsize=figsize,  # This should be a bit larger than the default, because we have more checkpoints to show
            y_min=y_min,
            y_max=y_max,
            plot_save_path=plot_save_path,
            raw_data_save_path=raw_data_save_path,
            verbosity=verbosity,
            logger=logger,
        )


# # # #
# Select which analysis to run and call the analysis

data_full_list_to_process = [
    "data=multiwoz21_spl-mode=do_nothing_ctxt=dataset_entry_feat-col=ner_tags",
    "data=one-year-of-tsla-on-reddit_spl-mode=proportions_spl-shuf=True_spl-seed=0_tr=0.8_va=0.1_te=0.1_ctxt=dataset_entry_feat-col=ner_tags",
]

data_subsampling_split_to_process = [
    "train",
    "validation",
    "test",
]

model_partial_name_list_to_process = [
    "model=model-roberta-base_task-masked_lm_multiwoz21-train-10000-ner_tags_ftm-standard_lora-None_5e-05-constant-0.01-50",
    "model=model-roberta-base_task-masked_lm_one-year-of-tsla-on-reddit-train-10000-ner_tags_ftm-standard_lora-None_5e-05-constant-0.01-50",
]

# Note: The "model_seed" column contains type integer values
language_model_seed_list_to_process = [
    1234,
]

data_subsampling_sampling_mode: str = "random"

for data_full, data_subsampling_split, model_partial_name, language_model_seed in tqdm(
    product(
        data_full_list_to_process,
        data_subsampling_split_to_process,
        model_partial_name_list_to_process,
        language_model_seed_list_to_process,
    ),
    desc="Processing different combinations of data_full, data_subsampling_split, and model_partial_name",
):
    concatenated_filters_dict = {
        "data_full": data_full,
        "data_subsampling_sampling_mode": data_subsampling_sampling_mode,
        "data_subsampling_split": data_subsampling_split,
        "data_subsampling_number_of_samples": 10_000,
        "model_partial_name": model_partial_name,
        "model_seed": language_model_seed,
        "data_prep_sampling_method": "random",
        "data_prep_sampling_samples": 150_000,
        NAME_PREFIXES_TO_FULL_DESCRIPTIONS["dedup"]: "array_deduplicator",
        "local_estimates_samples": 60_000,
        "n_neighbors": 128,
    }

    common_prefix_path = pathlib.Path(
        TOPO_LLM_REPOSITORY_BASE_PATH,
        "data",
        "saved_plots",
        "mean_estimates_over_different_checkpoints",
        data_full,
        f"{data_subsampling_split=}",
        model_partial_name,
        f"{language_model_seed=}",
    )

    create_histograms_over_model_checkpoints(
        concatenated_df=concatenated_df,
        concatenated_filters_dict=concatenated_filters_dict,
        figsize=(22, 8),
        common_prefix_path=common_prefix_path,
        verbosity=verbosity,
        logger=logger,
    )


### Investigate the influence of the data subsampling method on the results

In [None]:
def log_subsampling_number_of_samples_values(
    filtered_concatenated_df: pd.DataFrame,
) -> None:
    """For every occurence of value in "data_subsampling_number_of_samples", check how many rows are present in the filtered dataframe."""
    data_subsampling_number_of_samples_values = filtered_concatenated_df["data_subsampling_number_of_samples"].unique()

    for data_subsampling_number_of_samples in data_subsampling_number_of_samples_values:
        data_subsampling_number_of_samples_filters_dict = {
            "data_subsampling_number_of_samples": data_subsampling_number_of_samples,
        }

        filtered_concatenated_df_for_number_of_samples: DataFrame = filter_dataframe_based_on_filters_dict(
            df=filtered_concatenated_df,
            filters_dict=data_subsampling_number_of_samples_filters_dict,
        )

        print(
            f"{data_subsampling_number_of_samples = }: {filtered_concatenated_df_for_number_of_samples.shape = }",
        )

    print(
        "Unique data_subsampling_sampling_seed:\n",
        filtered_concatenated_df["data_subsampling_sampling_seed"].unique(),
    )


data_full = concatenated_df["data_full"].unique()[1]
data_subsampling_split = concatenated_df["data_subsampling_split"].unique()[2]
data_subsampling_sampling_mode: str = "random"

model_full = concatenated_df["model_full"].unique()[0]

concatenated_filters_dict = {
    "data_full": data_full,
    "model_full": model_full,
    "data_subsampling_split": data_subsampling_split,
    "data_subsampling_sampling_mode": data_subsampling_sampling_mode,
    "data_prep_sampling_method": "random",
    "data_prep_sampling_samples": 100_000,
    NAME_PREFIXES_TO_FULL_DESCRIPTIONS["dedup"]: "array_deduplicator",
    "local_estimates_samples": 60_000,
    "n_neighbors": 128,
}

filtered_concatenated_df: DataFrame = filter_dataframe_based_on_filters_dict(
    df=concatenated_df,
    filters_dict=concatenated_filters_dict,
)

log_subsampling_number_of_samples_values(
    filtered_concatenated_df=filtered_concatenated_df,
)

print(f"{filtered_concatenated_df.shape = }")

# # # #
# START Additional data cleaning:
# Remove those samples where "array_data.size" is smaller than 30_000

filtered_concatenated_df_cleaned = filtered_concatenated_df[filtered_concatenated_df["array_data.size"] >= 50_000]

# END Additional data cleaning
# # # #


data_for_different_data_subsampling_number_of_samples_analysis_df: DataFrame = filtered_concatenated_df_cleaned

fixed_params_text: str = generate_fixed_params_text(
    filters_dict=concatenated_filters_dict,
)

x_column_name = "data_subsampling_number_of_samples"

for y_min, y_max in Y_AXIS_LIMITS_ONLY_FULL.values():
    # for y_min, y_max in [(6.0, 10.0)]:
    create_boxplot_of_mean_over_different_sampling_seeds(
        subset_local_estimates_df=data_for_different_data_subsampling_number_of_samples_analysis_df,
        plot_save_path=None,  # TODO: Select path
        raw_data_save_path=None,  # TODO: Select path
        x_column_name=x_column_name,
        y_column_name="array_data_truncated_mean",
        seed_column_name="data_subsampling_sampling_seed",
        fixed_params_text=fixed_params_text,
        y_min=y_min,
        y_max=y_max,
        logger=logger,
    )