In [1]:
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset, get_dataset_split_names, DatasetDict
from tqdm import tqdm

sns.set_theme(style="darkgrid")
sns.set_context("talk")
tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm


## Load Datasets

In [2]:
split_sample_size = None
label_title_padding = 10

In [3]:
memories_path = "usvsnsp/memories-semantic-memorization-filter-results"
get_dataset_split_names(memories_path)
memories_dataset = DatasetDict()

# get splits that have deduped in the name
splits = [split for split in get_dataset_split_names(memories_path) if "deduped" in split]
for split in tqdm(splits):
    memories_dataset[split] = load_dataset(memories_path, split=f"{split}[:{split_sample_size}]" if split_sample_size else split)

memories_dataset

Downloading data: 100%|██████████| 1.18G/1.18G [03:15<00:00, 6.06MB/s]
Downloading data: 100%|██████████| 277M/277M [00:43<00:00, 6.30MB/s]
Downloading data: 100%|██████████| 571M/571M [01:33<00:00, 6.10MB/s]
Downloading data: 100%|██████████| 784M/784M [02:06<00:00, 6.22MB/s]
Downloading data: 100%|██████████| 422M/422M [01:10<00:00, 5.99MB/s]
Downloading data: 100%|██████████| 1.04G/1.04G [02:43<00:00, 6.36MB/s]
Downloading data: 100%|██████████| 162M/162M [00:22<00:00, 7.13MB/s]
Generating memories.deduped.1.4b split: 100%|██████████| 1048097/1048097 [00:01<00:00, 569707.43 examples/s]
Generating memories.deduped.12b split: 100%|██████████| 1871216/1871216 [00:03<00:00, 525746.21 examples/s]
Generating memories.deduped.160m split: 100%|██████████| 581195/581195 [00:00<00:00, 615243.95 examples/s]
Generating memories.deduped.1b split: 100%|██████████| 1032865/1032865 [00:01<00:00, 565019.29 examples/s]
Generating memories.deduped.2.8b split: 100%|██████████| 1355211/1355211 [00:02<00

DatasetDict({
    memories.deduped.70m: Dataset({
        features: ['sequence_id', 'text', 'sequence_duplicates', 'max_frequency', 'avg_frequency', 'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency', 'frequencies', 'is_incrementing', 'tokens', 'repeating_offset', 'num_repeating', 'smallest_repeating_chunk', 'memorization_score', 'templating_frequency_0.9', 'templating_frequency_0.8', 'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity'],
        num_rows: 411448
    })
    memories.deduped.160m: Dataset({
        features: ['sequence_id', 'text', 'sequence_duplicates', 'max_frequency', 'avg_frequency', 'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency', 'frequencies', 'is_incrementing', 'tokens', 'repeating_offset', 'num_repeating', 'smallest_repeating_chunk', 'memorization_score', 'templating_frequency_0.9', 'templating_frequency_0.8', 'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity'],
        num_rows: 581195
  

In [4]:
pile_path = "usvsnsp/pile-semantic-memorization-filter-results"
get_dataset_split_names(pile_path)

splits = [split for split in get_dataset_split_names(pile_path) if "deduped" in split]
for split in tqdm(splits):
    pile_dataset[split] = load_dataset(pile_path, split=f"{split}[:{split_sample_size}]" if split_sample_size else split)

pile_dataset

Downloading readme: 100%|██████████| 2.48k/2.48k [00:00<00:00, 6.85MB/s]
Downloading data: 100%|██████████| 1.51G/1.51G [04:12<00:00, 5.95MB/s]
Downloading data: 100%|██████████| 1.51G/1.51G [04:04<00:00, 6.16MB/s]
Downloading data: 100%|██████████| 1.51G/1.51G [04:28<00:00, 5.61MB/s]
Downloading data: 100%|██████████| 1.51G/1.51G [03:59<00:00, 6.28MB/s]
  0%|          | 0/8 [2:15:32<?, ?it/s]


KeyboardInterrupt: 

In [None]:
split_to_param_count = {
    "70m": 70000000,
    "160m": 160000000,
    "410m": 410000000,
    "1b": 1000000000,
    "1.4b": 1400000000,
    "2.8b": 2800000000,
    "6.9b": 6900000000,
    "12b": 12000000000,
}

In [None]:
columns_to_drop = ["frequencies", "tokens"]
combined_dataframe = None
for split in tqdm(memories_dataset, desc="Loading Memories"):
    current_frame = memories_dataset[split].to_pandas()
    current_frame.drop(columns=columns_to_drop, inplace=True)
    current_frame["Model"] = ".".join(split.split(".")[2:])
    current_frame["Param Count"] = split_to_param_count[current_frame["Model"].iloc[0]]
    current_frame["Deduped"] = "deduped" in split
    current_frame["Memorized"] = True
    if combined_dataframe is None:
        combined_dataframe = current_frame
    else:
        combined_dataframe = pd.concat([combined_dataframe, current_frame])

for split in tqdm(pile_dataset, desc="Loading Pile"):
    current_frame = pile_dataset[split].to_pandas()
    current_frame.drop(columns=columns_to_drop, inplace=True)
    current_frame["Model"] = ".".join(split.split(".")[2:])
    current_frame["Param Count"] = split_to_param_count[current_frame["Model"].iloc[0]]
    current_frame["Deduped"] = "deduped" in split
    current_frame["Memorized"] = False
    combined_dataframe = pd.concat([combined_dataframe, current_frame])

combined_dataframe = combined_dataframe.sort_values("Param Count").dropna(subset=["sequence_perplexity"])
display(combined_dataframe.shape)
combined_dataframe.head()

In [None]:
# drop cases where generation_perplexity is -1
before_count = combined_dataframe.shape[0]
combined_dataframe = combined_dataframe[combined_dataframe["generation_perplexity"] != -1]
after_count = combined_dataframe.shape[0]
print(f"Dropped {before_count - after_count} rows with -1 generation_perplexity")

# set num_repeating = 0 if -1
combined_dataframe.loc[combined_dataframe["num_repeating"] == -1, "num_repeating"] = 0
display(combined_dataframe.value_counts("num_repeating").head())

### Assign Examples to Taxonomy

In [None]:
def get_category(row):
    if row["Memorized"] == False:
        return "Not Memorized"
    if row["sequence_duplicates"] >= 200:
        return "Recitation"
    if row["is_incrementing"] or row["num_repeating"] != 0:
        return "Reconstruction"

    return "Recollection"

combined_dataframe["category"] = combined_dataframe.progress_apply(lambda row: get_category(row), axis=1)
combined_dataframe.value_counts(["Model", "Deduped", "category"])

# 12 deduped
# memories_frame = combined_dataframe[(combined_dataframe["Memorized"] == True) & (combined_dataframe["Deduped"] == True) & (combined_dataframe["Model"] == "12b")]
# memories_frame["category"] = memories_frame.progress_apply(lambda row: get_category(row), axis=1)
# memories_frame.value_counts("category")

In [None]:
code_path = "usvsnsp/pile-pythia-code-vs-nl-scores"
code_dataset = load_dataset(code_path)["train"].to_pandas()
code_dataset

In [None]:
# # Join combined_dataframe with code_dataset on sequence_id
# combined_dataframe = combined_dataframe.merge(code_dataset, on="sequence_id", how="inner")
# combined_dataframe["is_code"] = combined_dataframe["nl_score"] <= 0.45
# display(combined_dataframe.shape)
# combined_dataframe.head()

In [None]:
# box_plot_token_stats = []
# for param_count in tqdm(split_to_param_count.values()):
#     model_examples = combined_dataframe[combined_dataframe["Param Count"] == param_count]
#     box_plot_token_stats.append({
#         # "label": str(param_count),
#         "mean": model_examples["avg_frequency"].mean(),
#         "med": model_examples["median_frequency"].mean(),
#         "q1": model_examples["p25_frequency"].mean(),
#         "q3": model_examples["p75_frequency"].mean(),
#         "whislo": model_examples["min_frequency"].mean(),
#         "whishi": model_examples["max_frequency"].mean(),
#     })


## Plot Graphs

In [None]:
deduped_plotting_frame = combined_dataframe[combined_dataframe["Deduped"] == True]
deduped_plotting_frame

In [None]:
deduped_memories = deduped_plotting_frame[deduped_plotting_frame["Memorized"] == True]

In [None]:
percents_frame

## Figure: Trends in Categories Over Scale

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 6))

counts_frame = deduped_memories.value_counts(["Param Count", "category"]).unstack().reindex(split_to_param_count.values())
# counts_frame.plot.bar(
#     stacked=True,
#     ax=axes[0],
#     rot=0,
#     width=1,
#     ylabel="Count",
# )
counts_frame.plot.line(
    ax=axes[0],
    rot=0,
    ylabel="Count",
    marker="o",
    markersize=10,
    linewidth=4,
)

# log x axis
axes[0].set_xscale("log")

axes[0].set_title("Memorized Samples by Category", pad=label_title_padding)

# right plot is the each category across model size
all_percents = []
for param_count in tqdm(split_to_param_count.values()):
    model_examples = plotting_frame[plotting_frame["Param Count"] == param_count]
    model_percents = model_examples.value_counts("category", normalize=True).to_dict()
    for category in model_percents:
        all_percents.append({
            "Model": model_examples["Model"].unique()[0],
            "Param Count": param_count,
            "category": category,
            "percent": model_percents[category],
        })

# create a normalized bar plot stacked by category with a seperate bar for each Model
# have no space between bars
percents_frame = pd.DataFrame(all_percents).pivot(index="Model", columns="category", values="percent").reindex(split_to_param_count.keys())
percents_frame.plot.bar(
    stacked=True,
    ax=axes[1],
    rot=0,
    width=1,
    ylabel="Percent",
)

axes[1].set_title("Memorized Samples by Category", pad=label_title_padding)
axes[1].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}%".format(int(x * 100))))

# remove right legend
axes[1].get_legend().remove()

# have a common legend for both plots centered below the figure. No legend box
axes[0].legend(loc='upper center', bbox_to_anchor=(1, -0.125), ncol=4, frameon=False)

# set x label for both plots as "Parameter Count"
for ax in axes:
    ax.set_xlabel("Parameter Count")

# save figure_categories_count_pcercents.png
plt.savefig("figure_categories_count_percents.png", bbox_inches='tight')

## Combined Plots

In [None]:
"""
['sequence_id', 'sequence_duplicates', 'max_frequency', 'avg_frequency',
'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency',
'is_incrementing', 'repeating_offset', 'num_repeating',
'smallest_repeating_chunk', 'memorization_score',
'templating_frequency_0.9', 'templating_frequency_0.8',
'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity',
'Model', 'Param Count', 'Deduped', 'Memorized', 'category']
"""

titles = {
    # Categorical
    "category": "Count of Memories by Taxonomical Category",
    # "sequence_duplicates": "Mean Duplication Per Example",
    # "is_incrementing": "Percent of Sequences That Are Incrementing",

    # # Length of repeating subsequences
    # "num_repeating": "Mean Token Length For Repeating Subsequences",

    # # Cosine Similarities
    # "templating_frequency_0.9": "Mean Number of Examples 0.9 Cosime Similarity To Each Example",
    # "templating_frequency_0.8": "Mean Number of Examples 0.8 Cosime Similarity To Each Example",

    # # Perplexity
    # "prompt_perplexity": "Mean Prompt Perplexity",
    # "sequence_perplexity": "Mean Sequence Perplexity",
    # "generation_perplexity": "Mean Generation Perplexity",

    # # Token frequencies
    # "token_frequency": "Mean Token Frequency Statistics",
    # "median_frequency": "Mean Median Frequency for All Unique Tokens in Each Sequence",
    # "avg_frequency": "Mean Average Frequency for All Unique Tokens in Each Sequence",
    # "p25_frequency": "Mean 25th Percentile Frequency for All Unique Tokens in Each Sequence",
    # "min_frequency": "Mean Minimum Frequency for All Unique Tokens in Each Sequence",

    "null": "null"
}

# create subplots where each metric is on its own row. The first column is fo rmemorized overall and the second is broken down by category.
fig, axes = plt.subplots(len(titles), 2, figsize=(15, 7.5 * len(titles)))

for metric in tqdm(titles):
    if metric == "null":
        continue

    for column in [0, 1]:
        title_text = titles[metric]

        if metric == "token_frequency":
            sns.boxplot(
                data=deduped_plotting_frame,
                y="avg_frequency",
                x="Model",
                ax=axes[list(titles.keys()).index(metric), column],
                gap=0.5,
                hue="category" if column == 1 else "Memorized",
            )

        elif metric == "category":
            plotting_frame = deduped_plotting_frame[deduped_plotting_frame["Memorized"] == True]
            if column == 0:
                sns.histplot(
                    data=plotting_frame,
                    x="Model",
                    hue="category",
                    ax=axes[list(titles.keys()).index(metric), column],
                    multiple="stack",
                    stat="count",
                    common_norm=False,
                )
            else:
                title_text = title_text.replace("Count", "Percent")
                all_percents = []
                for param_count in tqdm(split_to_param_count.values()):
                    model_examples = plotting_frame[plotting_frame["Param Count"] == param_count]
                    model_percents = model_examples.value_counts("category", normalize=True).to_dict()
                    for category in model_percents:
                        all_percents.append({
                            "Model": model_examples["Model"].unique()[0],
                            "Param Count": param_count,
                            "category": category,
                            "percent": model_percents[category],
                        })
                
                # create a normalized bar plot stacked by category with a seperate bar for each Model
                # have no space between bars
                pd.DataFrame(all_percents).pivot(index="Model", columns="category", values="percent").plot.bar(
                    stacked=True,
                    ax=axes[list(titles.keys()).index(metric), column],
                    rot=0,
                    width=1,
                )
                
                axes[list(titles.keys()).index(metric), column].margins(x=0)

                # make y axis percents and scale values by 100 and have %
                axes[list(titles.keys()).index(metric), column].set_yticklabels([f"{int(tick * 100)}%" for tick in axes[list(titles.keys()).index(metric), column].get_yticks()])
        else:
            sns.lineplot(
                data=deduped_plotting_frame.reset_index(),
                x="Param Count",
                y=metric,
                ax=axes[list(titles.keys()).index(metric), column],
                markers=True,
                hue="category" if column == 1 else "Memorized",
                marker="o",
            )

        # log x axis if line plot
        if metric not in ["category", "token_frequency"]:
            axes[list(titles.keys()).index(metric), column].set_xscale("log")

        # set title
        axes[list(titles.keys()).index(metric), column].set_title(title_text)

        # make title bold
        # axes[list(titles.keys()).index(metric), column].title.set_weight("bold")

        # set x label based off the title
        quant_metic = title_text.split()[0]
        axes[list(titles.keys()).index(metric), column].set_ylabel(quant_metic)

        # don't use scientific notation on y axis
        try:
            axes[list(titles.keys()).index(metric), column].get_yaxis().get_major_formatter().set_scientific(False)
        except:
            print(f"Failed to set scientific notation for {metric}")


# add margins between rows
plt.subplots_adjust(hspace=0.25)

# save fig
plt.savefig("metrics_analysis.png", dpi=300, bbox_inches="tight")

In [None]:
recollection_sample = deduped_memories[(deduped_memories["Model"] == "12b") & (deduped_memories["category"] == "Recollection")].sample(500)
recollection_sample.drop(columns=["Model", "Param Count", "Deduped", "Memorized", "category"], inplace=True)
recollection_sample.to_csv("recollection_sample.csv", index=False)
recollection_sample


In [None]:

# round "prompt_perplexity", "generation_perplexity",	"sequence_perplexity" to the hundredth place
stats_Frame = deduped_plotting_frame[(deduped_plotting_frame["Model"] == "12b")].copy()
stats = stats_Frame[["category", "prompt_perplexity", "generation_perplexity",	"sequence_perplexity"]].groupby("category").describe().T
stats
