In [1]:
import os
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter, MaxNLocator, LogLocator 
from matplotlib import transforms
from datasets import load_dataset, get_dataset_split_names, DatasetDict
from tqdm import tqdm
tqdm.pandas()

sns.set_color_codes("colorblind")
sns.set_theme(style="white")
sns.set_context("talk")
# sns.color_palette()
sns.color_palette("tab10")

# set font to times new roman for plots
sns.set_style({'font.family':'serif', 'font.serif':'Times New Roman'})

  from .autonotebook import tqdm as notebook_tqdm


## Load Datasets

In [2]:
split_sample_size = None
label_title_padding = 10
study_pile = False
RECITATION_THRESHOLD = 5

figures_path = f"scale+time_figures/recitation_threshold_{RECITATION_THRESHOLD}/"
if not os.path.exists(figures_path):
    os.makedirs(figures_path)

In [3]:
memories_path = "usvsnsp/generation-semantic-filters"
intermediate_path = "usvsnsp/generation-semantic-intermediate-filters"
memories_dataset = DatasetDict()
pile_dataset = DatasetDict()
splits = [split for split in get_dataset_split_names(memories_path) if "deduped" in split] + get_dataset_split_names(intermediate_path)
splits = [split for split in splits if "deduped" in split]
print(splits)

In [None]:
print("Loading datasets...")
print(f"Split sample size: {split_sample_size}")

for split in tqdm(splits):
    model = split.split("_")[-1]
    checkpoint = int(split.split(".")[-1]) if split.split(".")[-1][1].isnumeric() else 143000
    formatted_split_name = split.replace("memories_", "").replace("deduped_", "deduped.").replace("pile_", "")
    dataset_path = memories_path if checkpoint == 143000 else intermediate_path
    if "memories" in split: 
        # continue # Don't load memories since they're unnecessary for the first part of the analysis
        memories_dataset[formatted_split_name] = load_dataset(dataset_path, split=f"{split}[:{split_sample_size}]" if split_sample_size else split)
    else:
        # continue
        pile_dataset[formatted_split_name] = load_dataset(dataset_path, split=f"{split}[:{split_sample_size}]" if split_sample_size else split)

display(memories_dataset)
display(pile_dataset)

In [None]:
split_to_param_count = {
    "70m": 70000000,
    "410m": 410000000,
    "160m": 160000000,
    "1b": 1000000000,
    "1.4b": 1400000000,
    "2.8b": 2800000000,
    "6.9b": 6900000000,
    "12b": 12000000000,
}

In [None]:
def get_frame_from_split(dataset, split_name, is_pile_sample):
    current_frame = dataset[split].to_pandas()
    current_frame.drop(columns=columns_to_drop, inplace=True)
    checkpoint = int(split.split(".")[-1]) if split.split(".")[-1][1].isnumeric() and len(split.split(".")) != 2 else "Final"
    current_frame["Checkpoint"] = checkpoint
    current_frame["TrainingPercentage"] = 1 if checkpoint == "Final" else checkpoint / 143000
    model = split.split("deduped")[-1][1:] if checkpoint == "Final" else split.split(".")[-2]
    current_frame["Model"] = model
    current_frame["Param Count"] = split_to_param_count[current_frame["Model"].iloc[0]]
    current_frame["Deduped"] = "deduped" in split
    current_frame["Memorized"] = current_frame["memorization_score"] >= 1
    current_frame["IsPileSample"] = is_pile_sample
    current_frame["IsCode"] = current_frame["nl_scores"].apply(lambda x: x <= 0.45)
    return current_frame


columns_to_drop = ["frequencies", "tokens", "text"]
combined_dataframe = None
for split in tqdm(memories_dataset, desc="Loading Memories"):
    current_frame = get_frame_from_split(memories_dataset, split, False)
    if combined_dataframe is None:
        combined_dataframe = current_frame
    else:
        combined_dataframe = pd.concat([combined_dataframe, current_frame])

for split in tqdm(pile_dataset, desc="Loading Pile"):
    current_frame = get_frame_from_split(pile_dataset, split, True)
    combined_dataframe = pd.concat([combined_dataframe, current_frame])

combined_dataframe = combined_dataframe.sort_values("Param Count")
combined_dataframe

In [None]:
combined_dataframe = combined_dataframe[combined_dataframe["Model"] != "160m"]
combined_dataframe.value_counts("Model")

## Assign Examples to Taxonomy

In [None]:
def get_category(row):
    if row["Memorized"] == False:
        return "Not Memorized"
    if row["sequence_duplicates"] > RECITATION_THRESHOLD:
        return "Recitation"
    if row["is_incrementing"] or row["is_repeating"]:
        return "Reconstruction"

    return "Recollection"

combined_dataframe["category"] = combined_dataframe.progress_apply(lambda row: get_category(row), axis=1)
combined_dataframe.value_counts(["Model", "Checkpoint", "category"])

## Plot Graphs

In [None]:
# harcoded order
categories = ["Recitation", "Reconstruction", "Recollection"]

### Figure: Count and Memories by Taxonomy Across Time and Scale

In [None]:
combined_memories_dataframe = combined_dataframe[combined_dataframe["IsPileSample"] == False]
combined_memories_dataframe

In [None]:
combined_memories_dataframe[combined_memories_dataframe["Model"] == "12b"].value_counts("TrainingPercentage")

In [None]:
# Create plot for data across 
counts_frame_scale = combined_memories_dataframe[combined_memories_dataframe["Checkpoint"] == "Final"].value_counts(["Param Count", "category"]).unstack().reindex(split_to_param_count.values())
counts_frame_scale = counts_frame_scale[categories].dropna()
counts_frame_scale.to_csv(f"final_checkpoint_counts_recitation={RECITATION_THRESHOLD}.csv")
display(counts_frame_scale)

intermediate_frame = combined_memories_dataframe[combined_memories_dataframe["Model"] == "12b"]
sorted_checkpoints = sorted(sorted(intermediate_frame["TrainingPercentage"].unique(), key=lambda x: int(x)))

counts_frame_time = intermediate_frame.value_counts(["TrainingPercentage", "category"]).unstack().reindex(sorted_checkpoints)
counts_frame_time = counts_frame_time[categories].dropna()
counts_frame_time.to_csv(f"intermediate_checkpoint_counts_recitation={RECITATION_THRESHOLD}.csv")
display(counts_frame_time)

### Figure: Percents and Memories by Taxonomy Across Time and Scale

In [None]:
counts_frame = combined_memories_dataframe.value_counts(["Param Count", "category"]).unstack().reindex(split_to_param_count.values())[categories].dropna()
display(counts_frame)

In [None]:
all_percents_time = []
for checkpoint in tqdm(sorted_checkpoints):
    model_examples = combined_memories_dataframe[combined_memories_dataframe["TrainingPercentage"] == checkpoint]
    model_percents = model_examples.value_counts("category", normalize=True).to_dict()
    for category in model_percents:
        all_percents_time.append({
            "TrainingPercentage": checkpoint,
            "category": category,
            "percent": model_percents[category],
        })

percents_frame_time = pd.DataFrame(all_percents_time).pivot(index="TrainingPercentage", columns="category", values="percent").reindex(sorted_checkpoints)
percents_frame_time.index = [f"{int(percent * 100)}%" for percent in percents_frame_time.index]
percents_frame_time = percents_frame_time[categories]
percents_frame_time.to_csv(f"percents_frame_time_recitation={RECITATION_THRESHOLD}.csv")
display(percents_frame_time)

In [None]:
all_percents_scale = []
for param_count in tqdm(split_to_param_count.values()):
    model_examples = combined_memories_dataframe[combined_memories_dataframe["Param Count"] == param_count]
    if len(model_examples) == 0:
        continue

    model_percents = model_examples.value_counts("category", normalize=True).to_dict()
    for category in categories:
        all_percents_scale.append({
            "Model": model_examples["Model"].unique()[0],
            "Param Count": param_count,
            "category": category,
            "percent": model_percents[category],
        })

# where Model != 160m
model_keys = [key for key in split_to_param_count.keys() if key != "160m"]
percents_frame_scale = pd.DataFrame(all_percents_scale).pivot(index="Model", columns="category", values="percent").reindex(model_keys)
# change column order
percents_frame_scale = percents_frame_scale[categories]
percents_frame_scale.to_csv(f"percents_frame_scale_recitation={RECITATION_THRESHOLD}.csv")
display(percents_frame_scale)

## Figure: Combined Counts + Percents Plot

### Same chart but with bounded bar charts

In [None]:
# Single row with four plots. Plot ordering is counts across scale, percents across scale, counts across time, percents across time
fig, axes = plt.subplots(1, 4, figsize=(20, 4))

percents_figure_y_lim = (0.75, 1.01)
percents_figure_y_ticks = [0.8, 0.9, 1]

# set figure 1
display(counts_frame_scale)
sns.lineplot(ax=axes[0], data=counts_frame_scale, dashes=False, markers=True, markersize=8)
axes[0].tick_params(axis='x', rotation=20, labelsize=10)
axes[0].set_yscale("log")
axes[0].set_xscale("log")
axes[0].tick_params(axis='y', labelsize=10)
axes[0].set_ylabel("Memories")
axes[0].set_xlabel("Parameters")
axes[0].legend(loc='upper center', bbox_to_anchor=(2.25, -0.3), ncol=4, frameon=False)

# set figure 2
display(percents_frame_scale)
percents_frame_scale.plot.bar(
    stacked=True,
    ax=axes[1],
    rot=0,
    width=1,
)
axes[1].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}%".format(int(x * 100))))
axes[1].set_ylim(percents_figure_y_lim)
axes[1].set_yticks(percents_figure_y_ticks)
axes[1].tick_params(axis='x', rotation=20, labelsize=10)
axes[1].tick_params(axis='y', labelsize=10)
axes[1].set_xlabel("Parameters")
axes[1].get_legend().remove()

# Add ...
trans = transforms.blended_transform_factory(axes[1].transData, axes[1].transAxes)
axes[1].text(-1.5, 0.16, ".", fontsize=24, transform=trans, ha='center', va='top')
axes[1].text(-1.5, 0.13, ".", fontsize=24, transform=trans, ha='center', va='top')
axes[1].text(-1.5, 0.09, ".", fontsize=24, transform=trans, ha='center', va='top')

# set figure 3
display(counts_frame_time)
sns.lineplot(ax=axes[2], data=counts_frame_time, dashes=False, markers=True, markersize=8)
axes[2].set_yscale("log")
axes[2].tick_params(axis='x', labelsize=10)
axes[2].tick_params(axis='y', labelsize=10)
axes[2].tick_params(axis='x', rotation=30)
axes[2].set_xlabel("Training Time")
axes[2].xaxis.set_major_formatter(PercentFormatter(1))
axes[2].legend().remove()

# set figure 4
display(percents_frame_time)
percents_frame_time.plot.bar(
    stacked=True,
    ax=axes[3],
    rot=0,
    width=1,
)
axes[3].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}%".format(int(x * 100))))
axes[3].set_ylim(percents_figure_y_lim)
axes[3].set_yticks(percents_figure_y_ticks)
axes[3].tick_params(axis='x', rotation=20, labelsize=10)
axes[3].tick_params(axis='y', labelsize=10)
axes[3].set_xlabel("Training Time")
axes[3].get_legend().remove()

# Add ...
trans = transforms.blended_transform_factory(axes[3].transData, axes[3].transAxes)
axes[3].text(-1.5, 0.16, ".", fontsize=24, transform=trans, ha='center', va='top')
axes[3].text(-1.5, 0.13, ".", fontsize=24, transform=trans, ha='center', va='top')
axes[3].text(-1.5, 0.09, ".", fontsize=24, transform=trans, ha='center', va='top')


# make all the x labels have the same height
fig.align_xlabels()

# Save the figure
fig.savefig(f"{figures_path}/categories_counts_percents_across_time+scale_bounded_bars.pdf", bbox_inches="tight")

## Code VS NL by Category Across Time and Scale

In [None]:
combined_memories_dataframe["category"].value_counts()

In [None]:
stats_dir = "./data/nl_vs_code_stats"
if not os.path.exists(stats_dir):
    os.makedirs(stats_dir)

fig, axes = plt.subplots(4, 4, figsize=(20, 16))
for row_index, row_category in tqdm(enumerate(["All Memories", "Recitation", "Reconstruction", "Recollection"]), desc="Plotting Categories", total=4):
    row_frame = combined_memories_dataframe[combined_memories_dataframe["category"] == row_category] if row_category != "All Memories" else combined_memories_dataframe
    row_frame["Type"] = row_frame["IsCode"].apply(lambda x: "Code" if x else "NL")

    # Get the count of NL vs Code across scale
    nl_code_counts_frame_scale = row_frame[(row_frame["Checkpoint"] == "Final")].value_counts(["Param Count", "Type"]).unstack().reindex(split_to_param_count.values()).dropna()
    nl_code_counts_frame_scale.to_csv(f"{stats_dir}/{row_category}_counts_scale.csv")

    # Get the percent of NL vs Code across scale
    all_percents_scale = []
    for param_count in tqdm(split_to_param_count.values(), desc="Calculating Percents Across Scale"):
        model_examples = row_frame[row_frame["Param Count"] == param_count]
        model_percents = model_examples.value_counts("Type", normalize=True).to_dict()
        for category in model_percents:
            all_percents_scale.append({
                "Model": model_examples["Model"].unique()[0],
                "Param Count": param_count,
                "Type": category,
                "percent": model_percents[category],
            })
    
    nl_code_percents_frame_scale = pd.DataFrame(all_percents_scale).pivot(index="Model", columns="Type", values="percent").reindex(model_keys)
    nl_code_percents_frame_scale.to_csv(f"{stats_dir}/{row_category}_percents_scale.csv")

    # set figure 1
    sns.lineplot(ax=axes[row_index, 0], data=nl_code_counts_frame_scale, dashes=False, markers=True, markersize=8)
    axes[row_index, 0].tick_params(axis='x', rotation=20, labelsize=10)
    axes[row_index, 0].set_xscale("log")
    axes[row_index, 0].tick_params(axis='y', labelsize=10)
    axes[row_index, 0].set_ylabel(row_category)
    axes[row_index, 0].set_xlabel("Parameters")

    # set figure 2
    nl_code_percents_frame_scale.plot.bar(
        stacked=True,
        ax=axes[row_index, 1],
        rot=0,
        width=1,
    )
    axes[row_index, 1].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}%".format(int(x * 100))))
    axes[row_index, 1].tick_params(axis='x', rotation=20, labelsize=10)
    axes[row_index, 1].tick_params(axis='y', labelsize=10)

    # set figure 3
    intermediate_frame = row_frame[row_frame["Model"] == "12b"]
    sorted_checkpoints = sorted(sorted(intermediate_frame["TrainingPercentage"].unique(), key=lambda x: int(x)))
    nl_code_counts_frame_time = intermediate_frame.value_counts(["TrainingPercentage", "Type"]).unstack().reindex(sorted_checkpoints)
    nl_code_counts_frame_time.to_csv(f"{stats_dir}/{row_category}_counts_time.csv")
    
    sns.lineplot(ax=axes[row_index, 2], data=nl_code_counts_frame_time, dashes=False, markers=True, markersize=8)
    axes[row_index, 2].tick_params(axis='x', labelsize=10)
    axes[row_index, 2].tick_params(axis='y', labelsize=10)
    axes[row_index, 2].tick_params(axis='x', rotation=30)
    axes[row_index, 2].xaxis.set_major_formatter(PercentFormatter(1))

    # set figure 4
    all_percents_time = []
    for checkpoint in tqdm(sorted_checkpoints):
        model_examples = row_frame[(row_frame["TrainingPercentage"] == checkpoint)]
        model_percents = model_examples.value_counts("Type", normalize=True).to_dict()
        for category in model_percents:
            all_percents_time.append({
                "TrainingPercentage": checkpoint,
                "Type": category,
                "percent": model_percents[category],
            })
        
    nl_code_percents_frame_time = pd.DataFrame(all_percents_time).pivot(index="TrainingPercentage", columns="Type", values="percent").reindex(sorted_checkpoints)
    nl_code_percents_frame_time.to_csv(f"{stats_dir}/{row_category}_percents_time.csv")
    nl_code_percents_frame_time.index = [f"{int(percent * 100)}%" for percent in nl_code_percents_frame_time.index]
    
    nl_code_percents_frame_time.plot.bar(
        stacked=True,
        ax=axes[row_index, 3],
        rot=0,
        width=1,
    )
    axes[row_index, 3].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}%".format(int(x * 100))))
    axes[row_index, 3].tick_params(axis='x', rotation=20, labelsize=10)
    axes[row_index, 3].tick_params(axis='y', labelsize=10)

    # Set x labels for last row
    if row_index == 3:
        axes[row_index, 0].set_xlabel("Parameters")
        axes[row_index, 1].set_xlabel("Parameters")
        axes[row_index, 2].set_xlabel("Training Time")
        axes[row_index, 3].set_xlabel("Training Time")
    else:
        axes[row_index, 0].set_xlabel("")
        axes[row_index, 1].set_xlabel("")
        axes[row_index, 2].set_xlabel("")
        axes[row_index, 3].set_xlabel("")
    
    if row_index == 3:
        axes[row_index, 0].legend(loc='upper center', bbox_to_anchor=(2.25, -0.3), ncol=4, frameon=False, fontsize=24)
    else:
        axes[row_index, 0].get_legend().remove()
    
    axes[row_index, 1].get_legend().remove() if axes[row_index, 1].get_legend() else None
    axes[row_index, 2].get_legend().remove() if axes[row_index, 2].get_legend() else None
    axes[row_index, 3].get_legend().remove() if axes[row_index, 3].get_legend() else None

fig.align_xlabels()
plt.tight_layout()
plt.subplots_adjust(hspace=0.25, wspace=0.25)
fig.savefig(f"{figures_path}/nl_code_counts_percents_across_time+scale.pdf", bbox_inches="tight")

### Create LaTEX Table for Code Proportion by Category Across Time and Scale

In [None]:
# combined_counts_frame = None
combined_percents_frame = None

for file_name in os.listdir(stats_dir):
    category = file_name.split("_")[0]
    metric = file_name.split("_")[1].split(".")[0]
    if "counts" in metric:
        continue

    current_frame = pd.read_csv(f"{stats_dir}/{file_name}")
    current_frame["Category"] = category
    current_frame["Metric"] = metric

    if "TrainingPercentage" not in current_frame.columns:
        current_frame["TrainingPercentage"] = 1
    
    if "Param Count" not in current_frame.columns:
        current_frame["Param Count"] = 12000000000
    
    if "Unnamed: 0" in current_frame.columns:
        current_frame.rename(columns={"Unnamed: 0": "TrainingPercentage"}, inplace=True)

    if "Model" not in current_frame.columns:
        # swap keys and values in split_to_param_count
        param_count_to_split = {value: key for key, value in split_to_param_count.items()}
        current_frame["Model"] = current_frame["Param Count"].apply(lambda x: param_count_to_split[x])

    if combined_percents_frame is None:
        combined_percents_frame = current_frame
    else:
        combined_percents_frame = pd.concat([combined_percents_frame, current_frame])

# display(combined_counts_frame.head())
# display(combined_percents_frame.head())

final_checkpoints_nl_code_percents = combined_percents_frame[combined_percents_frame["TrainingPercentage"] == 1].sort_values("Param Count")

# Across Scale
for model_size in list(split_to_param_count.keys()):
    if model_size == "160m":
        continue

    model_frame = final_checkpoints_nl_code_percents[final_checkpoints_nl_code_percents["Model"] == model_size]
    all_memories_row = model_frame[model_frame["Category"] == "All Memories"]
    recitation_row = model_frame[model_frame["Category"] == "Recitation"]
    reconstruction_row = model_frame[model_frame["Category"] == "Reconstruction"]
    recollection_row = model_frame[model_frame["Category"] == "Recollection"]

    latex_row = f"{model_size} & {all_memories_row['Code'].iloc[0]:.2%} & {all_memories_row['NL'].iloc[0]:.2%} & {recitation_row['Code'].iloc[0]:.2%} & {recitation_row['NL'].iloc[0]:.2%} & {reconstruction_row['Code'].iloc[0]:.2%} & {reconstruction_row['NL'].iloc[0]:.2%} & {recollection_row['Code'].iloc[0]:.2%} & {recollection_row['NL'].iloc[0]:.2%} \\\\"
    latex_row = latex_row.replace("%", "\%")
    print(latex_row)

# Across Time
intermediate_frame_nl_code_percents = combined_percents_frame[combined_percents_frame["Model"] == "12b"]
intermediate_frame_nl_code_percents.sort_values("TrainingPercentage", inplace=True)
intermediate_frame_nl_code_percents["TrainingPercentage"] = intermediate_frame_nl_code_percents["TrainingPercentage"].astype(str)
unqiue_checkpoints = intermediate_frame_nl_code_percents["TrainingPercentage"].unique()
display(intermediate_frame_nl_code_percents)
display(sorted_training_percentages)
for checkpoint in unqiue_checkpoints:
    model_frame = intermediate_frame_nl_code_percents[intermediate_frame_nl_code_percents["TrainingPercentage"] == checkpoint]
    all_memories_row = model_frame[model_frame["Category"] == "All Memories"]
    recitation_row = model_frame[model_frame["Category"] == "Recitation"]
    reconstruction_row = model_frame[model_frame["Category"] == "Reconstruction"]
    recollection_row = model_frame[model_frame["Category"] == "Recollection"]

    # print(checkpoint)
    # display(model_frame)
    # display(all_memories_row)
    # display(recitation_row)
    # display(reconstruction_row)
    # display(recollection_row)

    latex_row = f"{int(float(checkpoint) * 100)}% & {all_memories_row['Code'].iloc[0]:.2%} & {all_memories_row['NL'].iloc[0]:.2%} & {recitation_row['Code'].iloc[0]:.2%} & {recitation_row['NL'].iloc[0]:.2%} & {reconstruction_row['Code'].iloc[0]:.2%} & {reconstruction_row['NL'].iloc[0]:.2%} & {recollection_row['Code'].iloc[0]:.2%} & {recollection_row['NL'].iloc[0]:.2%} \\\\"
    latex_row = latex_row.replace("%", "\%")
    print(latex_row)






# Feature Histograms

In [None]:
# 12b deduped final checkpoint
hists_plotting_frame = combined_dataframe[(combined_dataframe["Model"] == "12b") & (combined_dataframe["Checkpoint"] == "Final") & (combined_dataframe["Deduped"] == True)]
hists_plotting_frame = hists_plotting_frame.drop_duplicates(subset=["sequence_id"])
hists_plotting_frame

In [None]:
hist_splits = ["_deduped_12b"]
features = [
    "sequence_duplicates",
    "max_frequency",
    "avg_frequency",
    "min_frequency",
    "median_frequency",
    "p25_frequency", "p75_frequency",
    "0_8_templates", "huffman_coding_length", "prompt_perplexity", "generation_perplexity", "sequence_perplexity", "0_8_snowclones", "loss"
]
bins_per_feature = {
    "sequence_duplicates": 10,
    "max_frequency": 15,
    "avg_frequency": 150,
    "min_frequency": 150,
    "median_frequency": 100,
    "p25_frequency": 100,
    "p75_frequency": 20,
    "0_8_templates": 20,
    "prompt_perplexity": 100,
    "generation_perplexity": 35,
    "sequence_perplexity": 100, 
    "huffman_coding_length": 60, #dont take log scale for this feature 
    "0_8_snowclones": 60, 
    "loss": 50 
}
min_threshold = {
    "sequence_duplicates": 10**0,
    "max_frequency": 10**8.5,
    "avg_frequency": 10**8,
    "min_frequency": 10**4.5,
    "median_frequency": 10**6,
    "p25_frequency": 10**5,
    "p75_frequency": 10**6,
    "prompt_perplexity": 10**0,
    "generation_perplexity": 10**0,
    "sequence_perplexity": 10**0, 
    "loss": 10**-0.5,
    "huffman_coding_length": 2, 
    "0_8_templates": 10**-0.6,
    "0_8_snowclones": 10**0, 
}
max_threshold = {
    "sequence_duplicates": 10**7,
    "max_frequency": 10**10,
    "avg_frequency": 10**10,
    "min_frequency": 10**8,
    "median_frequency": 10**10,
    "p25_frequency": 10**9,
    "p75_frequency": 10**10,
    "0_8_templates": 10**4,
    "prompt_perplexity": 10**1.5,
    "generation_perplexity": 10**1.1,
    "sequence_perplexity": 10**2.5, 
    "huffman_coding_length": 6,
    "0_8_snowclones": 10**3.1, 
    "loss": 10**0.5
}
name_map = {
    "sequence_duplicates": "Duplicates",
    "0_8_templates": "Textual Duplicates",
    "0_8_snowclones": "Semantic Duplicates",
    "prompt_perplexity": "Prompt PPL",
    "generation_perplexity": "Generation PPL",
    "sequence_perplexity": "Sequence PPL",
    "loss": "Loss",
    "max_frequency": "Max Token Freq.",
    "avg_frequency": "Mean Token Freq.",
    "min_frequency": "Min Token Freq.",
    "median_frequency": "Median Token Freq.",
    "p25_frequency": "P25 Token Freq.",
    "p75_frequency": "P75 Token Freq.",
    "huffman_coding_length": "Huffman Length",
}
e = 1e-10
num_rows = 2
num_columns = 7 

fig, axs = plt.subplots(num_rows, num_columns, figsize=(17, 5))
axs = axs.flatten()
for i, split in enumerate(hist_splits):
    for j, fx in tqdm(enumerate(name_map.keys()), desc="Plotting Features", total=len(features)):
        memories = hists_plotting_frame[hists_plotting_frame["Memorized"] == True][fx]
        memories = [value for value in memories if value >= 0]
        df_memories = pd.DataFrame(memories, columns=[fx])

        ppile = hists_plotting_frame[hists_plotting_frame["Memorized"] == False][fx]
        ppile = [value for value in ppile if value >= 0]
        df_pile = pd.DataFrame(ppile, columns=[fx])

    
        bins_all = np.logspace(np.log10(min(df_memories[fx].min(), df_pile[fx].min())+e), np.log10(max(df_memories[fx].max(), df_pile[fx].max())), bins_per_feature[fx])
        if fx == "huffman_coding_length":
            bins = 60 
        else: 
            bins = bins_all  

        # no whitespace between histograms for continuous features. Make width a bit larger
        sns.histplot(data=df_pile[fx], bins=bins, label="Pile", ax=axs[i * num_columns + j], stat="percent", element="step")
        sns.histplot(data=df_memories[fx], bins=bins, label="Memorized", ax=axs[i * num_columns + j], stat="percent", element="step")

        if fx == "huffman_coding_length":
            axs[i * num_columns + j].set_xscale("linear") 
            axs[i * num_columns + j].yaxis.set_major_formatter(PercentFormatter(xmax=100, decimals=0))
            axs[i * num_columns + j].set_xlim(min_threshold[fx], max_threshold[fx])
            axs[i * num_columns + j].xaxis.set_major_locator(MaxNLocator(nbins=3))
            axs[i * num_columns + j].yaxis.set_major_locator(MaxNLocator(nbins=3))

        elif fx == "loss":
            axs[i * num_columns + j].set_xscale("log")   
            axs[i * num_columns + j].yaxis.set_major_formatter(PercentFormatter(xmax=100, decimals=0))
            axs[i * num_columns + j].xaxis.set_major_locator(LogLocator(numticks=3)) 
            axs[i * num_columns + j].yaxis.set_major_locator(MaxNLocator(nbins=3))
           
        else:
            axs[i * num_columns + j].set_xscale("log")   
            axs[i * num_columns + j].yaxis.set_major_formatter(PercentFormatter(xmax=100, decimals=0))
            axs[i * num_columns + j].set_xlim(min_threshold[fx], max_threshold[fx])
            axs[i * num_columns + j].xaxis.set_major_locator(LogLocator(numticks=3)) 
            axs[i * num_columns + j].yaxis.set_major_locator(MaxNLocator(nbins=3))
         
        axs[i * num_columns + j].tick_params(axis="both", labelsize=14) 
        axs[i * num_columns + j].set_xlabel(name_map[fx], fontsize=16)
        if j % num_columns == 0:
            axs[i * num_columns + j].set_ylabel("Percentage", fontsize=16)
        else:
            axs[i * num_columns + j].set_ylabel("")

fig.legend(labels=["Not Memorized", "Memorized"], loc="upper center", bbox_to_anchor=(0.5, -0.005), ncol=2, fontsize=18, frameon=False)
fig.align_xlabels()
plt.tight_layout()
fig.savefig(f"{figures_path}/histograms_percents.pdf", bbox_inches="tight")
plt.show()