In [1]:
from functools import partial
import hashlib

import datasets
from dotenv import load_dotenv
import tiktoken
import torch
import vec2text 

load_dotenv()

def compute_cosine_similarity(embeddings1, embeddings2):
    return torch.nn.functional.cosine_similarity(embeddings1, embeddings2, dim=1)


  from .autonotebook import tqdm as notebook_tqdm


## Setup for error analysis

Let's take the first $n=100$ rows of the precomputed val dataset for MS MARCO.

In [2]:
N_SAMPLES=100
dataset = datasets.load_dataset("jxm/msmarco__openai_ada2")
dataset = dataset["train"].select(range(N_SAMPLES))

In [3]:
tokenizer = tiktoken.get_encoding("cl100k_base")
MAX_LENGTH=128

def truncate_text(example):
    text_tokens = tokenizer.encode_batch(example["text"])
    text_tokens = [tok[:MAX_LENGTH] for tok in text_tokens]
    text_list = tokenizer.decode_batch(text_tokens)
    example["text"] = text_list
    return example

In [4]:
dataset = dataset.map(truncate_text, batched=True, batch_size=1024, num_proc=12)

Map (num_proc=12): 100%|██████████| 100/100 [00:00<00:00, 713.44 examples/s]


In [5]:
# Assumes no batching
def get_text_hash(example):
    example["source_id"] = hashlib.md5(example["text"].encode()).hexdigest()
    return example
    

dataset = dataset.map(get_text_hash, batched=False, num_proc=12)

Map (num_proc=12): 100%|██████████| 100/100 [00:00<00:00, 715.05 examples/s]


In [6]:
dataset = dataset.add_column(name="step", column=[0] * N_SAMPLES)
dataset = dataset.add_column(name="sim", column=[1] * N_SAMPLES)

## Generating samples

In [7]:
corrector = vec2text.load_pretrained_corrector("text-embedding-ada-002")

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [8]:
# Get prediction trajectory at n_steps=10 
# Assume non-batched
def get_trajectory(n_steps, examples):
    new_examples = {k: [] for k in examples.keys()}

    for i, original_embedding in enumerate(examples["embeddings_A"]):
        original_embedding = torch.Tensor(original_embedding).cuda().unsqueeze(0)


        output_strings, hypothesis_embeddings = vec2text.invert_embeddings_and_return_hypotheses(original_embedding, corrector, num_steps=n_steps, sequence_beam_width=4)

        # Append to example
        new_examples["source_id"] += [examples["source_id"][i] for _ in range(len(hypothesis_embeddings))]
        new_examples["text"] += [output[0] for output in output_strings]
        new_examples["embeddings_A"] += [emb.squeeze().tolist() for emb in hypothesis_embeddings]
        new_examples["step"] += range(1, len(hypothesis_embeddings) + 1)
        new_examples["sim"] += [compute_cosine_similarity(original_embedding, embedding).item() for embedding in hypothesis_embeddings]

    return {k: examples[k] + new_examples[k] for k in examples.keys()}


In [9]:
test_dataset = dataset.select(range(100)).map(partial(get_trajectory, 50), batched=True, batch_size=1)

Map: 100%|██████████| 100/100 [2:01:31<00:00, 72.91s/ examples]


In [110]:
test_dataset = test_dataset.sort("source_id").sort("step")

In [10]:
#test_dataset.to_csv('test.csv', index=False)
test_dataset.remove_columns(['embeddings_A']).to_csv("test_50_no_emb.csv", index=False)

Creating CSV from Arrow format: 100%|██████████| 6/6 [00:00<00:00, 180.96ba/s]


1849366

In [11]:
test_dataset.to_parquet("test_50_emb.parquet")

Creating parquet from Arrow format: 100%|██████████| 6/6 [00:00<00:00, 30.20ba/s]


65762328

In [100]:
test_dataset.unique("source_id")

['8bd034ea81f91372874cf6d90dffbba1',
 '41ff020092780cbe3f0bb1a19af9a9bb',
 'b41462004c2f175c26b021580d52ebdb',
 'ee1c1f0fd5e5b4d8c57b5f7ce4a524c7',
 'd320cf0dec7398aff7157ae6bf50d95a',
 '43d708365012811206eae310f234d268',
 '0251de9b50ca73ac3f2f8d0b1d8f6b7b',
 'aab5f71e8417e4e22dde53048c8aee21',
 '85d7a4ce47403d7eb621eb2814069bec',
 'adc5c918b742688fa96a7da70fce56d4']

In [102]:
example = test_dataset.filter(lambda example: example['source_id'] == "8bd034ea81f91372874cf6d90dffbba1")

Filter: 100%|██████████| 120/120 [00:00<00:00, 2324.35 examples/s]
