In [None]:
from transformers import pipeline, AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import numpy as np
import torch
import os

tqdm.pandas()

In [None]:
class HFMemoriesDataset(Dataset):
    is_dataframe = False

    def __init__(self, memories, tokenizer, sample=None):
        self.tokenizer = tokenizer
        self.memories = memories
        if sample is not None:
            self.memories = self.memories.to_pandas().sample(sample)
            self.is_dataframe = True

    def __getitem__(self, index):
        memory_record = (
            self.memories.iloc[index] if self.is_dataframe else self.memories[index]
        )
        decoded_text = self.tokenizer.decode(memory_record["tokens"])
        return decoded_text

    def __len__(self):
        return len(self.memories)


def load_tokenizer(split_name):
    isDeduped = split_name.startswith("deduped")
    model = split_name.split("duped.")[-1]
    corresponding_model = f"EleutherAI/pythia-{model}{'-deduped' if isDeduped else ''}"
    tokenizer =  AutoTokenizer.from_pretrained(corresponding_model)
    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def load_model(split_name):
    isDeduped = split_name.startswith("deduped")
    model = split_name.split("duped.")[-1]
    corresponding_model = f"EleutherAI/pythia-{model}{'-deduped' if isDeduped else ''}"
    device = torch.device("cuda:7")
    return GPTNeoXForCausalLM.from_pretrained(corresponding_model).eval().to(device)


def calculate_perplexity(logits, labels):
    # Store the probabilities for each token. These will be summed later, but having the
    # individual probabilities is helpful for debugging.
    token_probs = []

    # Don't include the final token logits. There are no labels for
    # these since the sequence has ended.
    shifted_logits = logits.detach()[:-1, :]

    for token_index in range(len(shifted_logits)):
        # Map the logits to probabilities.
        predicted_probs = torch.softmax(shifted_logits[token_index], dim=0)
        # Get the probability of the correct label.
        label_prob = predicted_probs[labels[token_index + 1]]
        # Store the probability for this token.
        token_probs.append(label_prob.detach())

    # Caluclate the log-likelyhood of the sequence by summing the probabilities
    # of each token and then taking the log.
    log_likelihood = torch.log(torch.stack(token_probs)).sum()

    # Caluclate the cross entropy by dividing the negative log-likelihood by the number of tokens.
    cross_entropy = -log_likelihood / len(shifted_logits)

    # Calculate the perplexity by taking the exponential of the cross entropy.
    perplexity = torch.exp(cross_entropy).item()
    return perplexity

In [None]:
split_name = "duped.70m"
memories = load_dataset("EleutherAI/pythia-memorized-evals")[split_name]
tokenizer = load_tokenizer(split_name)
memories_dataset = HFMemoriesDataset(
    load_dataset("EleutherAI/pythia-memorized-evals")[split_name], 
    tokenizer)

pythia_model = load_model(split_name)

In [None]:
data_loader = DataLoader(memories_dataset, batch_size=32)
hf_perplexities = []
all_perplexities = []

with torch.no_grad():
    for batch in [memories_dataset[0]]:
        tokenized_batch = tokenizer(
            batch, return_tensors="pt", max_length=512, truncation=True, padding=True
        )
        tokenized_batch.to(torch.device("cuda:7"))
        labels = tokenized_batch["input_ids"]

        outputs = pythia_model(**tokenized_batch, labels=labels)
        logits = outputs.logits.detach()
        hf_perplexities += [torch.exp(outputs.loss).item()]

        all_perplexities += [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))]

print(hf_perplexities)
print(all_perplexities)

## Analyze Forgotten Sequences

In [13]:
pile_deduped = load_dataset("EleutherAI/the_pile_deduplicated", data_files="the_pile_deduplicated/data/train-00000-of-01650-f70471ee3deb09c0")
pile_deduped

In [6]:
np.load("/mnt/ssd-1/data/pile_20B_tokenizer/pile_20B_tokenizer_text_document_valid_indexmap_193280ns_2048sl_1234s_doc_idx.npy")

array([210157511, 210161511, 209560440, ..., 209717117, 210003242,
       210038150], dtype=int32)

In [5]:
import numpy as np
np.fromfile("/mnt/ssd-1/data/pile_20B_tokenizer/pile_20B_tokenizer_text_document.bin", dtype=np.int64, count=1000)

array([    3668700955018363,   400257481914056989,  3969363441657577948,
          86695379471547791,    92045187114467833,    92043541450926309,
       -4372995173464932099,   840784584587936071,   127228270056112141,
           3959854180925721,  1375298886691944441,    90074036973207837,
           4315080783298807,   128644853326872763,    86985461560967252,
          79097693980656073,    86975832272013576,   267122175300600265,
         275282583673635862,    71223489932099848,   277826225948334803,
           4229076281589776,   257838693656232123,   194074544027926715,
         890588411811725565,  8741487963370750533,    79096933787181756,
          52636623804762518,   161567963777695210,   211670355204768015,
          86985066425682110,     3664702338567174,   459653396829206163,
        3926014148261904726,   835699264731676688,   245749327074246168,
        1840001978743980307,     3948187357741488,    86975832349937055,
         228559009100595569,   415218600494437200, 

In [19]:
n = 100
with open("/mnt/ssd-1/data/pile_20B_tokenizer/pile_20B_tokenizer_text_document.bin", mode='rb') as file: # b is important -> binary
    for i in range(n):
        print(file.readline(n))
        fileContent = next(file)
        # print(fileContent)

b'{\x046\x01\xaa\x08\r\x00\x1d\x01.$\x0f\x00\x8e\x05\xdc\x01l\x04\x05\x03\x167\x8f\xbd\x11\x01\xfd\x004\x01\xf9\x01>)\x9c\x02G\x01\xe5(\r\x00\x1d\x01G\x01\xfd\x00 \x11\x0f\x00P\xc3G\x01\xfd\x00 \x11\xab\x0b\r\x00\x15\x02p\x01\xc4\x01\x19\x01\xd1ew\x11\x0e\x00\xf9wC\x01\r\x0b\x16\x13\x1d\x01F\x01\xdc\x01@\x01\xf7\x00D\t'
b'\x19\x01;\n'
b'q\x01\xa1\x01V\x01\xf0\x025\x01\x06\x0c\x19\x01\xea\n'
b'\x0f\x01\x05\x035\x02\x7f\xb4\xa8\x08\x9c\x02V\x9c\x1d\x01k\xba+\x04`\x01\xe6\x15\x0f\x00\xbb\x00\xbb\x00\x10\x066\x01\xf7\x00g\x07\r\x00\x11\x01\xea\x08\r\x00\x16\x04\x83\x02\x810\x82\x01\x9f\x02>\x02\x19\x01\x96\t\x0f\x00\xdc\x04\xd1\x7f\x97\x01\xfd\x00\xfc..$\r\x00\xa1\x01v\x04\x87\x1c\x19\x01\x15\x03\xc3\x05`\x01\x0f\x005\x01\xc9\x01E\x00'
b'\x19\x175\x01M\x07\x8b\x02\x15\x032\x02\x0f\x00\xbb\x00\xbb\x00\x94\x0c\xe6\x0cv\x0bq\x01\x19\x01\xea\n'
b'\x00\x15\x025\x01\xdb\x11\xc9\x01U\x00\x86\x04\x11\x01OH9\x01@\x04\x1b\x00\xfc\x0f\n'
b"w#P\x13}\x0f\xcd\x03\xcb\x02\xfd\x00\xe6\x15\x1e'\x0f\x00\x0c