### Imports

In [3]:
!pip install datasets transformers torch  # install required packages
!pip install transformers faiss-cpu

import faiss
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import math
from datasets import load_dataset
import torch.utils.data as d
from transformers import BertTokenizer, BertModel
from torchtext.data import get_tokenizer
from torch import nn as nn
from typing import List, Tuple, Any, Optional
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import normalize




### Parameters

In [4]:
BATCH_SIZE=16
NUM_BATCHES=10
HIDDEN_DIM=768
SEQ_LEN=20

### Data similarity batching for Lory

From Lory paper:


"We adapt the pipeline of in-context pre-training (Shi et al., 2024) in our approach. Given a set of documents D, for each document d ∈ D, we first use Contriever (Izacard et al., 2022) to retrieve top-k most similar documents N(d). The similarity between the document di and dj is defined as the cosine similarity of their Contriever embeddings, i.e., sim(di , dj) = cos(C(di), C(dj)), where C denotes the Contriever encoder model. We implement an efficient approximate nearest-neighbors search based on the FAISS library (Johnson et al., 2019). Then, we sort all the documents according to the similarity and construct training instances by batch consecutive documents. We use the same greedy algorithm as Shi et al. (2024). We start from a single document and repeatedly add the document that has the highest similarity value and has not been added to the list; we restart the process with a new document if all documents that are connected to the last document of the list are selected. We repeat this process until there are no documents left."

### Facebook contriever as tokenizer

In [5]:
class MyIterableDataset(d.IterableDataset):
    def __init__(self, dataset, tokenizer, model, seq_len, article_indices):
        super(MyIterableDataset, self).__init__()
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.model = model
        self.seq_len = seq_len
        self.article_indices = article_indices

    def __iter__(self):
        def helper(start, end):
            for i in range(start, end):
                article = self.dataset[i]["text"]
                tokenized = self.tokenizer(article, padding='max_length', truncation=True, max_length=self.seq_len, return_tensors='pt')

                # Get embeddings from the model
                input_ids = tokenized['input_ids']
                attention_mask = tokenized['attention_mask']
                with torch.no_grad():
                    outputs = self.model(input_ids, attention_mask=attention_mask)
                    embeddings = outputs.last_hidden_state.squeeze(0).numpy()  # [seq_len, hidden_size]

                yield embeddings

        worker_info = d.get_worker_info()
        if worker_info is None:
            start = 0
            end = len(self.article_indices)
        else:
            per_worker = int(math.ceil(len(self.article_indices) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            start = worker_id * per_worker
            end = min(start + per_worker, len(self.article_indices))
        return helper(start, end)

def give_dataloader(development=True, batch_size=64, seq_len=20, num_batches=None):
    if development:
        wiki_huggingface_dataset = load_dataset("wikipedia", "20220301.simple")["train"]
    else:
        wiki_huggingface_dataset = load_dataset("wikipedia", "20220301.en")["train"]

    if num_batches is None:
        article_indices = range(wiki_huggingface_dataset.num_rows)
    else:
        article_indices = range(num_batches * batch_size)

    # Load tokenizer and model from facebook/contriever (NEW!)
    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
    model = AutoModel.from_pretrained('facebook/contriever').eval()  # Load the entire model (can we skip this?)

    ds = MyIterableDataset(wiki_huggingface_dataset, tokenizer, model, seq_len, article_indices=article_indices)
    return d.DataLoader(ds, batch_size=batch_size, collate_fn=lambda x: x)


data_loader = give_dataloader(development=True, batch_size=BATCH_SIZE, num_batches=NUM_BATCHES)

sample = next(iter(data_loader))
num_batches = sum(1 for _ in data_loader)

# Print the shape of the sample batch to verify the embedding dimensions
print("Number of batches in the dataloader:", num_batches)
print("Batch Size:", len(sample))
print("Seq Length:", len(sample[0]))
print("Embedding size (hidden size):", len(sample[0][0]))


Access to the secret `HF_TOKEN` has not been granted on this notebook.
You will not be requested again.
Please restart the session if you want to be prompted again.


Downloading builder script:   0%|          | 0.00/36.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

The repository for wikipedia contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/wikipedia.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/134M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/205328 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Number of batches in the dataloader: 10
Batch Size: 16
Seq Length: 20
Embedding size (hidden size): 768


In [8]:
class SequenceDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx]

def reorganize_dataloader(dataloader, num_batches, batch_size, k=10):
    # Extract embeddings and keep original shapes!
    embeddings = []
    sequences = []

    for batch in dataloader:
        for sequence_embedding in batch:
            flattened_embedding = sequence_embedding.flatten()  # Flatten the sequence embedding
            embeddings.append(flattened_embedding)
            sequences.append(sequence_embedding)

    embeddings = np.array(embeddings)  # Convert list of flattened embeddings to numpy array
    embeddings = normalize(embeddings, axis=1)  # Normalize the flattened embeddings

    # Use FAISS to build an index for approximate nearest neighbors search
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)

    # Retrieve top-k similar sequences for each sequence
    _, indices = index.search(embeddings, k)

    # Create adjacency list for similarity
    adjacency_list = {i: set(indices[i]) for i in range(len(embeddings))}

    def create_batches():
        """Use a greedy algorithm to create batches based on similarity."""
        visited = set()
        batches = []
        num_sequences = len(embeddings)

        for i in range(num_sequences):
            if i in visited:
                continue
            batch = [i]
            visited.add(i)
            while len(batch) < batch_size and len(batch) < num_sequences:
                last_seq = batch[-1]
                candidates = adjacency_list[last_seq] - visited
                if not candidates:
                    break
                next_seq = max(candidates, key=lambda x: np.dot(embeddings[last_seq], embeddings[x]))
                batch.append(next_seq)
                visited.add(next_seq)
            batches.append(batch)
            if len(batches) == num_batches:
                break

        # Ensure to fill up the number of required batches - They are going to be less similar!!!
        remaining_indices = set(range(num_sequences)) - visited
        for batch in batches:
            if len(batch) < batch_size and remaining_indices:
                needed = batch_size - len(batch)
                for _ in range(needed):
                    if not remaining_indices:
                        break
                    next_seq = remaining_indices.pop()
                    batch.append(next_seq)

        while remaining_indices and len(batches)<num_batches:
            new_batch = [remaining_indices.pop() for _ in range(batch_size) if remaining_indices]
            batches.append(new_batch)

        return batches

    batches = create_batches()
    # Reorganize sequences into new dataloader batches
    reorganized_sequences = []
    for batch in batches:
        reorganized_sequences.extend([sequences[i] for i in batch])
    reorganized_dataset = SequenceDataset(reorganized_sequences)
    reorganized_dataloader = DataLoader(reorganized_dataset, batch_size=batch_size, shuffle=False)
    return reorganized_dataloader



### Testing

In [None]:
# Create the dataloader
data_loader = give_dataloader(development=True, batch_size=BATCH_SIZE, num_batches=NUM_BATCHES)
sample = next(iter(data_loader))
num_batches = sum(1 for _ in data_loader)

# Print the shape of the sample batch to verify the embedding dimensions
print("Number of batches in the dataloader:", num_batches)
print("Batch Size:", len(sample))
print("Seq Length:", len(sample[0]))
print("Embedding size (hidden size):", len(sample[0][0]))

# Reorganize the dataloader
reorganized_dataloader = reorganize_dataloader(data_loader, num_batches=NUM_BATCHES, batch_size=10, k=10)

# Verify the reorganized dataloader
sample = next(iter(reorganized_dataloader))
num_batches = sum(1 for _ in reorganized_dataloader)

# Print the shape of the sample batch to verify the embedding dimensions
print("Number of batches in the reorganized dataloader:", num_batches)
print("Batch Size reorganized:", len(sample))
print("Seq Length reorganized:", len(sample[0]))
print("Embedding size (hidden size) reorganized:", len(sample[0][0]))

### Lory