In [1]:
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset, get_dataset_split_names, DatasetDict
from tqdm import tqdm

sns.set_theme(style="darkgrid")
sns.set_context("talk")
tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm


## Load Datasets

In [2]:
split_sample_size = 10000

In [3]:
memories_path = "usvsnsp/memories-semantic-memorization-filter-results"
get_dataset_split_names(memories_path)
memories_dataset = DatasetDict()

# get splits that have deduped in the name
splits = [split for split in get_dataset_split_names(memories_path) if "deduped" in split]
for split in tqdm(splits):
    memories_dataset[split] = load_dataset(memories_path, split=f"{split}[:{split_sample_size}]" if split_sample_size else split)

memories_dataset

  0%|          | 0/8 [00:00<?, ?it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--memories-semantic-memorization-filter-results-7ad10bc8c7f6aa70/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
 12%|█▎        | 1/8 [00:00<00:03,  1.96it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--memories-semantic-memorization-filter-results-7ad10bc8c7f6aa70/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
 25%|██▌       | 2/8 [00:01<00:03,  1.87it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--memories-semantic-memorization-filter-results-7ad10bc8c7f6aa70/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
 38%|███▊      | 3/8 [00:01<00:02,  1.93it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--memories-semantic-memorization-filter-r

DatasetDict({
    memories.deduped.1.4b: Dataset({
        features: ['sequence_id', 'text', 'sequence_duplicates', 'max_frequency', 'avg_frequency', 'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency', 'frequencies', 'is_incrementing', 'tokens', 'repeating_offset', 'num_repeating', 'smallest_repeating_chunk', 'memorization_score', 'templating_frequency_0.9', 'templating_frequency_0.8', 'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity'],
        num_rows: 10000
    })
    memories.deduped.12b: Dataset({
        features: ['sequence_id', 'text', 'sequence_duplicates', 'max_frequency', 'avg_frequency', 'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency', 'frequencies', 'is_incrementing', 'tokens', 'repeating_offset', 'num_repeating', 'smallest_repeating_chunk', 'memorization_score', 'templating_frequency_0.9', 'templating_frequency_0.8', 'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity'],
        num_rows: 10000
    

In [4]:
pile_path = "usvsnsp/pile-semantic-memorization-filter-results"
get_dataset_split_names(pile_path)
pile_dataset = DatasetDict()

splits = [split for split in get_dataset_split_names(pile_path) if "deduped" in split]
for split in tqdm(splits):
    pile_dataset[split] = load_dataset(pile_path, split=f"{split}[:{split_sample_size}]" if split_sample_size else split)

pile_dataset

  0%|          | 0/8 [00:00<?, ?it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--pile-semantic-memorization-filter-results-e8ad7274ba998093/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
 12%|█▎        | 1/8 [00:00<00:03,  1.98it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--pile-semantic-memorization-filter-results-e8ad7274ba998093/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
 25%|██▌       | 2/8 [00:01<00:03,  1.68it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--pile-semantic-memorization-filter-results-e8ad7274ba998093/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
 38%|███▊      | 3/8 [00:01<00:02,  1.83it/s]Found cached dataset parquet (/home/kyle/.cache/huggingface/datasets/usvsnsp___parquet/usvsnsp--pile-semantic-memorization-filter-results-e8ad7274b

DatasetDict({
    pile.deduped.1.4b: Dataset({
        features: ['sequence_id', 'text', 'sequence_duplicates', 'max_frequency', 'avg_frequency', 'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency', 'frequencies', 'is_incrementing', 'tokens', 'repeating_offset', 'num_repeating', 'smallest_repeating_chunk', 'memorization_score', 'templating_frequency_0.9', 'templating_frequency_0.8', 'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity'],
        num_rows: 10000
    })
    pile.deduped.12b: Dataset({
        features: ['sequence_id', 'text', 'sequence_duplicates', 'max_frequency', 'avg_frequency', 'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency', 'frequencies', 'is_incrementing', 'tokens', 'repeating_offset', 'num_repeating', 'smallest_repeating_chunk', 'memorization_score', 'templating_frequency_0.9', 'templating_frequency_0.8', 'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity'],
        num_rows: 10000
    })
    p

In [5]:
split_to_param_count = {
    "70m": 70000000,
    "160m": 160000000,
    "410m": 410000000,
    "1b": 1000000000,
    "1.4b": 1400000000,
    "2.8b": 2800000000,
    "6.9b": 6900000000,
    "12b": 12000000000,
}

In [6]:
combined_dataframe = None
for split in tqdm(memories_dataset, desc="Loading Memories"):
    current_frame = memories_dataset[split].to_pandas()
    current_frame.drop(columns=["text", "frequencies", "tokens"], inplace=True)
    current_frame["Model"] = ".".join(split.split(".")[2:])
    current_frame["Param Count"] = split_to_param_count[current_frame["Model"].iloc[0]]
    current_frame["Deduped"] = "deduped" in split
    current_frame["Memorized"] = True
    if combined_dataframe is None:
        combined_dataframe = current_frame
    else:
        combined_dataframe = pd.concat([combined_dataframe, current_frame])

for split in tqdm(pile_dataset, desc="Loading Pile"):
    current_frame = pile_dataset[split].to_pandas()
    current_frame.drop(columns=["text", "frequencies", "tokens"], inplace=True)
    current_frame["Model"] = ".".join(split.split(".")[2:])
    current_frame["Param Count"] = split_to_param_count[current_frame["Model"].iloc[0]]
    current_frame["Deduped"] = "deduped" in split
    current_frame["Memorized"] = False
    combined_dataframe = pd.concat([combined_dataframe, current_frame])

display(combined_dataframe.shape)
combined_dataframe.head()

Loading Memories: 100%|██████████| 8/8 [00:00<00:00, 30.43it/s]
Loading Pile: 100%|██████████| 8/8 [00:00<00:00, 30.24it/s]


(160000, 22)

Unnamed: 0,sequence_id,sequence_duplicates,max_frequency,avg_frequency,min_frequency,median_frequency,p25_frequency,p75_frequency,is_incrementing,repeating_offset,...,memorization_score,templating_frequency_0.9,templating_frequency_0.8,prompt_perplexity,generation_perplexity,sequence_perplexity,Model,Param Count,Deduped,Memorized
0,21590,55,11740996961,937904100.0,3053059,277329702.0,20962725,395603541,True,0,...,1.0,22,130,1.598633,1.00293,1.603516,1.4b,1400000000,True,True
1,30252,21482,10346382453,2780063000.0,1869557,385281005.0,13592032,695610999,True,0,...,1.0,696,2333,1.178711,1.0,1.178711,1.4b,1400000000,True,True
2,35232,21829,11740996961,2526616000.0,860666,42752068.5,6514834,1502731047,False,0,...,1.0,116,389,1.164062,1.013672,1.178711,1.4b,1400000000,True,True
3,62350,10,9362638615,839352900.0,397964,16284663.0,4176692,918861018,False,0,...,1.0,1,1,1.933594,1.035156,2.001953,1.4b,1400000000,True,True
4,75902,1,11740996961,980826700.0,783155,216143855.5,43385287,909893795,False,0,...,1.0,18,329,1.779297,1.005859,1.791016,1.4b,1400000000,True,True


In [7]:
# drop cases where generation_perplexity is -1
before_count = combined_dataframe.shape[0]
combined_dataframe = combined_dataframe[combined_dataframe["generation_perplexity"] != -1]
after_count = combined_dataframe.shape[0]
print(f"Dropped {before_count - after_count} rows with -1 generation_perplexity")

# set num_repeating = 0 if -1
combined_dataframe.loc[combined_dataframe["num_repeating"] == -1, "num_repeating"] = 0
display(combined_dataframe.value_counts("num_repeating").head())

Dropped 13 rows with -1 generation_perplexity


num_repeating
0     154873
2       1093
4        754
32       497
6        418
Name: count, dtype: int64

### Assign Examples to Taxonomy

In [8]:
def get_category(row):
    if row["Memorized"] == False:
        return "Not Memorized"
    if row["sequence_duplicates"] >= 200:
        return "Recitation"
    if row["is_incrementing"] or row["num_repeating"] != -1:
        return "Reconstruction"

    return "Recollection"

combined_dataframe["category"] = combined_dataframe.progress_apply(lambda row: get_category(row), axis=1)
combined_dataframe.value_counts("category")

100%|██████████| 159987/159987 [00:01<00:00, 90960.90it/s] 


category
Not Memorized     80000
Reconstruction    47014
Recitation        32973
Name: count, dtype: int64

In [9]:
# code_path = "usvsnsp/pile-pythia-code-vs-nl-scores"
# code_dataset = load_dataset(code_path)["train"].to_pandas()
# code_dataset

In [10]:
# # Join combined_dataframe with code_dataset on sequence_id
# combined_dataframe = combined_dataframe.merge(code_dataset, on="sequence_id", how="inner")
# combined_dataframe["is_code"] = combined_dataframe["nl_score"] <= 0.45
# display(combined_dataframe.shape)
# combined_dataframe.head()

In [11]:
# box_plot_token_stats = []
# for param_count in tqdm(split_to_param_count.values()):
#     sub_plots = []
#     for is_memorized in [True, False]:
#         model_examples = combined_dataframe[(combined_dataframe["Param Count"] == param_count) & (combined_dataframe["Memorized"] == is_memorized)]
#         sub_plots.append({
#             "mean": model_examples["avg_frequency"].mean(),
#             "med": model_examples["median_frequency"].mean(),
#             "q1": model_examples["p25_frequency"].mean(),
#             "q3": model_examples["p75_frequency"].mean(),
#             "whislo": model_examples["min_frequency"].mean(),
#             "whishi": model_examples["max_frequency"].mean(),
#         })

#     box_plot_token_stats.append(sub_plots)

# box_plot_token_stats

In [12]:
box_plot_token_stats = []
for param_count in tqdm(split_to_param_count.values()):
    model_examples = combined_dataframe[combined_dataframe["Param Count"] == param_count]
    box_plot_token_stats.append({
        # "label": str(param_count),
        "mean": model_examples["avg_frequency"].mean(),
        "med": model_examples["median_frequency"].mean(),
        "q1": model_examples["p25_frequency"].mean(),
        "q3": model_examples["p75_frequency"].mean(),
        "whislo": model_examples["min_frequency"].mean(),
        "whishi": model_examples["max_frequency"].mean(),
    })


box_plot_token_stats

100%|██████████| 8/8 [00:00<00:00, 237.89it/s]


[{'mean': 1825873775.0383341,
  'med': 217947783.51687837,
  'q1': 25506369.640428085,
  'q3': 1794763340.2740047,
  'whislo': 6669234.314462893,
  'whishi': 10927637798.817913},
 {'mean': 1824578740.2041776,
  'med': 203097249.5979647,
  'q1': 23469151.846376956,
  'q3': 1787233613.706706,
  'whislo': 5526651.434615192,
  'whishi': 10989805343.758564},
 {'mean': 1811900453.8840518,
  'med': 186002567.02240223,
  'q1': 20547499.620562058,
  'q3': 1782834454.339934,
  'whislo': 3538349.2325732573,
  'whishi': 10995104127.622612},
 {'mean': 1802332541.0913756,
  'med': 174383212.50712535,
  'q1': 19302464.214660734,
  'q3': 1755750327.5931296,
  'whislo': 2970367.6941847093,
  'whishi': 11016705581.756739},
 {'mean': 1799504678.3529613,
  'med': 176575392.96699834,
  'q1': 19529202.6879844,
  'q3': 1760173906.0016,
  'whislo': 3108316.172358618,
  'whishi': 11006584245.78804},
 {'mean': 1802351611.881844,
  'med': 171248812.59442973,
  'q1': 18499750.483724188,
  'q3': 1763076621.183109,

In [75]:
combined_dataframe = combined_dataframe.sort_values("Param Count").dropna(subset=["sequence_perplexity"])

## Plot Graphs

In [90]:
"""
['sequence_id', 'sequence_duplicates', 'max_frequency', 'avg_frequency',
'min_frequency', 'median_frequency', 'p25_frequency', 'p75_frequency',
'is_incrementing', 'repeating_offset', 'num_repeating',
'smallest_repeating_chunk', 'memorization_score',
'templating_frequency_0.9', 'templating_frequency_0.8',
'prompt_perplexity', 'generation_perplexity', 'sequence_perplexity',
'Model', 'Param Count', 'Deduped', 'Memorized', 'category']
"""

titles = {
    # Categorical
    "category": "Count of Memories by Taxonomical Category",
    "sequence_duplicates": "Mean Duplication Per Example",
    "is_incrementing": "Percent of Sequences That Are Incrementing",

    # Length of repeating subsequences
    "num_repeating": "Mean Token Length For Repeating Subsequences",

    # Cosine Similarities
    "templating_frequency_0.9": "Mean Number of Examples 0.9 Cosime Similarity To Each Example",
    "templating_frequency_0.8": "Mean Number of Examples 0.8 Cosime Similarity To Each Example",

    # Perplexity
    # "prompt_perplexity": "Mean Prompt Perplexity",
    # "sequence_perplexity": "Mean Sequence Perplexity",
    # "generation_perplexity": "Mean Generation Perplexity",

    # Token frequencies
    # "token_frequency": "Mean Token Frequency Statistics",
    "median_frequency": "Mean Median Frequency for All Unique Tokens in Each Sequence",
    "avg_frequency": "Mean Average Frequency for All Unique Tokens in Each Sequence",
    "p25_frequency": "Mean 25th Percentile Frequency for All Unique Tokens in Each Sequence",
    "min_frequency": "Mean Minimum Frequency for All Unique Tokens in Each Sequence",
}

# create subplots where each metric is on its own row. The first column is fo rmemorized overall and the second is broken down by category.
fig, axes = plt.subplots(len(titles), 2, figsize=(30, 15 * len(titles)))

for metric in tqdm(titles):
    for column in [0, 1]:
        title_text = titles[metric]

        if metric == "token_frequency":
            sns.boxplot(
                data=combined_dataframe,
                y="avg_frequency",
                x="Model",
                ax=axes[list(titles.keys()).index(metric), column],
                gap=0.5,
                hue="category" if column == 1 else "Memorized",
            )

        elif metric == "category":
            if column == 0:
                # histogram of the count of each category by model
                sns.histplot(
                    data=combined_dataframe[combined_dataframe["Memorized"] == True],
                    x="Model",
                    hue="category",
                    ax=axes[list(titles.keys()).index(metric), column],
                    multiple="stack",
                    stat="count",
                    common_norm=False,
                )

            else:
                fig.delaxes(axes[list(titles.keys()).index(metric), column])

        else:
            sns.lineplot(
                data=combined_dataframe,
                x="Param Count",
                y=metric,
                ax=axes[list(titles.keys()).index(metric), column],
                hue="category" if column == 1 else "Memorized",
            )

        # log x axis if line plot
        if metric not in ["category", "token_frequency"]:
            axes[list(titles.keys()).index(metric), column].set_xscale("log")

        # set title
        axes[list(titles.keys()).index(metric), column].set_title(title_text)

        # make title bold
        axes[list(titles.keys()).index(metric), column].title.set_weight("bold")

        # set x label based off the title
        quant_metic = title_text.split()[0]
        axes[list(titles.keys()).index(metric), column].set_ylabel(quant_metic)


# add margins between rows
plt.subplots_adjust(hspace=0.25)

 44%|████▍     | 4/9 [00:15<00:21,  4.34s/it]