In [None]:
import logging
import pathlib
import sys

from pandas import DataFrame

from topollm.analysis.compare_sampling_methods.run_general_comparisons import (
    load_and_concatenate_saved_dataframes,
)
from topollm.config_classes.constants import TOPO_LLM_REPOSITORY_BASE_PATH

default_logger: logging.Logger = logging.getLogger(name=__name__)
# Add stdout handler to default logger
default_logger.addHandler(
    hdlr=logging.StreamHandler(
        stream=sys.stdout,
    ),
)

logger: logging.Logger = default_logger

In [None]:
comparisons_folder_base_path = pathlib.Path(
    TOPO_LLM_REPOSITORY_BASE_PATH,
    "data/analysis/sample_sizes/",
    "run_general_comparisons/",
    "array_truncation_size=5000/analysis/twonn/",
)

In [None]:
concatenated_df: DataFrame = load_and_concatenate_saved_dataframes(
    root_dir=comparisons_folder_base_path,
)

concatenated_df.info()

In [None]:
columns_to_investigate: list[str] = [
    "data_full",
    "data_subsampling_full",
]

for column_name in columns_to_investigate:
    print(30 * "=")
    print(
        f"Unique values in column '{column_name = }':",
    )
    print(
        concatenated_df[column_name].unique(),
    )

In [None]:
column_name = "model_full"

concatenated_df[column_name].unique()

In [None]:
# Investigate the influence of the data subsampling method on the results

from pandas.core.frame import DataFrame

from topollm.analysis.compare_sampling_methods.run_general_comparisons import filter_dataframe_based_on_filters_dict
from topollm.config_classes.constants import NAME_PREFIXES_TO_FULL_DESCRIPTIONS

data_full = concatenated_df["data_full"].unique()[1]
data_subsampling_split = concatenated_df["data_subsampling_split"].unique()[2]
data_subsampling_sampling_mode: str = "random"

model_full = concatenated_df["model_full"].unique()[0]

concatenated_filters_dict = {
    "data_full": data_full,
    "model_full": model_full,
    "data_subsampling_split": data_subsampling_split,
    "data_subsampling_sampling_mode": data_subsampling_sampling_mode,
    "data_prep_sampling_method": "random",
    "data_prep_sampling_samples": 100_000,
    NAME_PREFIXES_TO_FULL_DESCRIPTIONS["dedup"]: "array_deduplicator",
    "local_estimates_samples": 60_000,
    "n_neighbors": 128,
}

filtered_concatenated_df: DataFrame = filter_dataframe_based_on_filters_dict(
    df=concatenated_df,
    filters_dict=concatenated_filters_dict,
)

print(f"{filtered_concatenated_df.shape = }")

In [None]:
filtered_concatenated_df

In [None]:
# For every occurence of value in "data_subsampling_number_of_samples",
# check how many rows are present in the filtered dataframe

data_subsampling_number_of_samples_values = filtered_concatenated_df["data_subsampling_number_of_samples"].unique()

for data_subsampling_number_of_samples in data_subsampling_number_of_samples_values:
    data_subsampling_number_of_samples_filters_dict = {
        "data_subsampling_number_of_samples": data_subsampling_number_of_samples,
    }

    filtered_concatenated_df_for_number_of_samples: DataFrame = filter_dataframe_based_on_filters_dict(
        df=filtered_concatenated_df,
        filters_dict=data_subsampling_number_of_samples_filters_dict,
    )

    print(
        f"{data_subsampling_number_of_samples = }: {filtered_concatenated_df_for_number_of_samples.shape = }",
    )

print("Unique data_subsampling_sampling_seed:\n", filtered_concatenated_df["data_subsampling_sampling_seed"].unique())

In [None]:
filtered_concatenated_df

In [None]:
from topollm.analysis.compare_sampling_methods.make_plots import (
    Y_AXIS_LIMITS_ONLY_FULL,
    create_boxplot_of_mean_over_different_sampling_seeds,
    generate_fixed_params_text,
)

# # # #
# START Additional data cleaning:
# Remove those samples where "array_data.size" is smaller than 30_000

filtered_concatenated_df_cleaned = filtered_concatenated_df[filtered_concatenated_df["array_data.size"] >= 50_000]

# END Additional data cleaning
# # # #


data_for_different_data_subsampling_number_of_samples_analysis_df: DataFrame = filtered_concatenated_df_cleaned

fixed_params_text: str = generate_fixed_params_text(
    filters_dict=concatenated_filters_dict,
)

x_column_name = "data_subsampling_number_of_samples"

for y_min, y_max in Y_AXIS_LIMITS_ONLY_FULL.values():
    # for y_min, y_max in [(6.0, 10.0)]:
    create_boxplot_of_mean_over_different_sampling_seeds(
        subset_local_estimates_df=data_for_different_data_subsampling_number_of_samples_analysis_df,
        plot_save_path=None,  # TODO: Select path
        raw_data_save_path=None,  # TODO: Select path
        x_column_name=x_column_name,
        y_column_name="array_data_truncated_mean",
        seed_column_name="data_subsampling_sampling_seed",
        fixed_params_text=fixed_params_text,
        y_min=y_min,
        y_max=y_max,
        logger=logger,
    )

In [None]:
filtered_concatenated_df_grouped = filtered_concatenated_df.groupby(
    by=x_column_name,
    observed=True,
)

In [None]:
selected_group = filtered_concatenated_df_grouped.get_group(
    name=2000,
)

selected_group["data_subsampling_sampling_seed"]