In [35]:
import pathlib
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
data_folder_list = [
    "data-multiwoz21_split-train_ctxt-dataset_entry_samples-10000_feat-col-ner_tags/",
    "data-multiwoz21_split-validation_ctxt-dataset_entry_samples-3000_feat-col-ner_tags/",
    "data-multiwoz21_split-test_ctxt-dataset_entry_samples-3000_feat-col-ner_tags/",
    "data-one-year-of-tsla-on-reddit_split-train_ctxt-dataset_entry_samples-10000_feat-col-ner_tags/",
    "data-one-year-of-tsla-on-reddit_split-validation_ctxt-dataset_entry_samples-3000_feat-col-ner_tags/",
]

selected_data_folder = data_folder_list[4]

file_path = pathlib.Path(
    "/Users/ruppik/git-source/Topo_LLM/data/analysis/sample_sizes/run_general_comparisons/analysis/twonn/",
    selected_data_folder,
    "lvl-token/add-prefix-space-True_max-len-512/model-roberta-base_task-masked_lm/layer--1_agg-mean/norm-None/full_local_estimates_df.csv",
)

results_base_directory_path: pathlib.Path = file_path.parent

local_estimates_df: pd.DataFrame = pd.read_csv(
    filepath_or_buffer=file_path,
)

local_estimates_df

In [None]:
# Select a subset of the data with the same parameters.
# This allows comparing over different seeds.


def filter_dataframe(
    df: pd.DataFrame,
    filters: dict[str, Any],
) -> pd.DataFrame:
    """Filter a DataFrame based on key-value pairs specified in a dictionary.

    Args:
        df:
            The DataFrame to be filtered.
        filters:
            A dictionary of column names and corresponding values to filter by.

    Returns:
        A filtered DataFrame with rows matching all key-value pairs.

    """
    subset_df = df.copy()
    for column, value in filters.items():
        subset_df = subset_df[subset_df[column] == value]
    return subset_df


# Function to generate the text for the fixed parameters to be displayed on the plot
def generate_fixed_params_text(filters: dict[str, Any]) -> str:
    """Generate a string representation of the fixed parameters used for filtering.

    Args:
        filters:
            A dictionary of column names and corresponding values used for filtering.

    Returns:
        str:
            A formatted string suitable for display in the plot.

    """
    return "\n".join([f"{key}: {value}" for key, value in filters.items()])


# Apply the generalized filtering function to the dataset
#
# We do not fix the local_estimates_samples,
# since we want to compare the results for different sample sizes.
filters = {
    "data_prep_sampling_method": "random",
    "deduplication": "array_deduplicator",
    "n_neighbors": 128,
    "data_prep_sampling_samples": 50000,
}


subset_local_estimates_df = filter_dataframe(df=local_estimates_df, filters=filters)

subset_local_estimates_df

In [None]:
subset_local_estimates_df.describe()

In [None]:
from matplotlib.ticker import AutoLocator, MultipleLocator


def create_boxplot_of_mean_over_different_sampling_seeds(
    subset_local_estimates_df: pd.DataFrame,
    plot_save_path: pathlib.Path | None = None,
    *,
    y_min: float = 6.5,
    y_max: float = 15.5,
    show_plot: bool = True,
    connect_points: bool = True,
) -> None:
    plt.figure(figsize=(10, 6))

    # Set the fixed y-axis limits
    plt.ylim(y_min, y_max)

    # Automatically set major and minor tick locators
    plt.gca().yaxis.set_major_locator(AutoLocator())  # Auto-adjust major ticks
    plt.gca().yaxis.set_minor_locator(MultipleLocator(0.1))  # Set minor ticks for finer grid

    # Enable the grid with different styling for major and minor lines
    plt.grid(which="major", axis="y", color="gray", linestyle="-", linewidth=0.6, alpha=0.5)  # Major grid lines
    plt.grid(which="minor", axis="y", color="gray", linestyle="--", linewidth=0.3, alpha=0.3)  # Minor grid lines

    # Create boxplot and stripplot
    sns.boxplot(
        x="local_estimates_samples",
        y="array_data_truncated_mean",
        data=subset_local_estimates_df,
    )
    sns.stripplot(
        x="local_estimates_samples",
        y="array_data_truncated_mean",
        data=subset_local_estimates_df,
        color="red",
        jitter=False,
        dodge=True,
        marker="o",
        alpha=0.5,
    )

    # Convert the 'local_estimates_samples' column to categorical for proper ordering
    subset_local_estimates_df["local_estimates_samples"] = pd.Categorical(
        subset_local_estimates_df["local_estimates_samples"],
        ordered=True,
    )

    # Connect the points from the same seed across different samples if requested
    if connect_points:
        unique_seeds = subset_local_estimates_df["data_prep_sampling_seed"].unique()
        # Use modern colormap access without resampling argument
        colormap = plt.colormaps.get_cmap("tab20")

        for idx, seed in enumerate(unique_seeds):
            seed_data = subset_local_estimates_df[subset_local_estimates_df["data_prep_sampling_seed"] == seed]
            # Sort seed_data by 'local_estimates_samples' for consistent plotting
            seed_data = seed_data.sort_values(by="local_estimates_samples")

            # Plot lines connecting the same seed points
            plt.plot(
                seed_data["local_estimates_samples"].cat.codes,  # Using categorical codes for proper x ordering
                seed_data["array_data_truncated_mean"],
                linestyle="-",
                linewidth=1,
                alpha=0.7,
                color=colormap(idx / len(unique_seeds)),  # Use a different color for each seed
                label=f"Seed {seed}" if idx < 2 else "",  # Labeling only the first few for readability
            )

    # Adding additional information about the fixed parameters in the plot
    if filters is not None:
        fixed_params_text = generate_fixed_params_text(filters)
        plt.text(
            x=0.55,
            y=0.25,
            s=f"Fixed Parameters:\n{fixed_params_text}",
            transform=plt.gca().transAxes,
            fontsize=10,
            verticalalignment="top",
            bbox={
                "boxstyle": "round",
                "facecolor": "wheat",
                "alpha": 0.3,
            },
        )

    # Save plot to the specified path if provided
    if plot_save_path is not None:
        plot_save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(plot_save_path, bbox_inches="tight")

    # Show plot if needed
    if show_plot:
        plt.show()


for connect_points in [True, False]:
    plot_save_path = pathlib.Path(
        results_base_directory_path,
        "different_sampling_seeds",
        "array_data_truncated_mean_boxplot_"
        f"{filters['n_neighbors']=}_{filters['data_prep_sampling_samples']=}_{connect_points=}.pdf",
    )

    create_boxplot_of_mean_over_different_sampling_seeds(
        subset_local_estimates_df=subset_local_estimates_df,
        plot_save_path=plot_save_path,
        y_min=6.5,
        y_max=15.5,
        show_plot=False,
        connect_points=connect_points,
    )