In [38]:
from datasets import load_dataset
from functools import partial
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, IterableDataset
import numpy as np

def unstack_element(element,n_examples=None):
    keys = list(element.keys())
    if n_examples is None:
        n_examples = len(element[keys[0]])
    for i in range(n_examples):
        micro_element = {}
        for key in keys:
            try:
                micro_element[key] = element[key][i]
            except:
                print([(key,len(element[key])) for key in keys])
                raise
        yield micro_element

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def tokenize_examples(example,query_field="query",pos_field="positive_passages",neg_field="negative_passages"):
    tokenize = partial(tokenizer, return_attention_mask=False, return_token_type_ids=False, padding=True,
                        truncation=True)
    query = example[query_field]
    pos_psgs = [p['title'] + " " + p['text'] for p in list(unstack_element(example[pos_field]))[:1]]
    neg_psgs = [p['title'] + " " + p['text'] for p in list(unstack_element(example[neg_field]))[:9]]
    def tok(x,l):
        return dict(tokenize(x, max_length=l,padding='max_length', return_tensors='np'))["input_ids"]
        
    query_input_ids = tok(query, 32)
    psgs_input_ids = pos_psgs+neg_psgs
    psgs_input_ids = [tok(x,128) for x in psgs_input_ids ] 
    psgs_input_ids = np.stack(psgs_input_ids)
    

    return dict(query_input_ids=query_input_ids, psgs_input_ids=psgs_input_ids)

In [6]:

p="https://huggingface.co/datasets/iohadrubin/nq/resolve/main/data/train-00000-of-00012-aebee16ac9d5ed6f.parquet"
train_dataset = load_dataset("parquet",data_files={"train":[p]},split="train")





train_data = train_dataset.map(
    partial(tokenize_examples,query_field="question",pos_field="positive_ctxs",neg_field="hard_negative_ctxs"),
    batched=False,
    num_proc=20,
    remove_columns=train_dataset.column_names,
    desc="Running tokenizer on train dataset",
)


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/294M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Running tokenizer on train dataset (num_proc=20):   0%|          | 0/4907 [00:00<?, ? examples/s]

In [66]:

class IterableDatasetWrapper(IterableDataset):
    def __init__(self, dataset):
        super(IterableDatasetWrapper).__init__()
        self.dataset = dataset
    def __iter__(self):
        while True:
            for x in self.dataset:
                yield x
            self.dataset = self.dataset.shuffle()

def package(result):
    keys = list(result[0].keys())
    batch = {}
    for key in keys:
        batch[key] = np.array([res[key] for res in result]).squeeze(-2)
    return batch   
def get_dataloader(data, batch_size):
    iterable = IterableDatasetWrapper(data) 
    dloader= DataLoader(iterable,
                            batch_size=batch_size,
                            collate_fn=lambda v: package(v),
                            num_workers=16, prefetch_factor=256,
                            )
    return dloader



In [67]:
dloader = get_dataloader(train_data, 2)

In [68]:
b = next(iter(dloader))

In [69]:
b

{'query_input_ids': array([[ 101, 2502, 2210, 3658, 2161, 1016, 2129, 2116, 4178,  102,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
        [ 101, 2040, 6369, 3403, 2005, 1037, 2611, 2066, 2017,  102,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0]]),
 'psgs_input_ids': array([[[  101,  2502,  2210, ...,   102,     0,     0],
         [  101,  2210,  2111, ..., 18868,  1010,   102],
         [  101, 15883,  2007, ...,  2285,  2418,   102],
         ...,
         [  101,  5487, 20996, ...,  1000,  1012,   102],
         [  101,  2502,  2567, ...,  1015,  1012,   102],
         [  101,  2129,  1045, ...,  2544,  1997,   102]],
 
        [[  101,  3403,  2005, ...,  2316,  1005,   102],
         [  101,  3403,  2005, ...,  2051,  1000,   102],
         [  101,  3403,  2005, ...,  228

In [58]:
b.keys()

dict_keys(['query_input_ids', 'psgs_input_ids'])

In [61]:
b["query_input_ids"]

array([[ 101, 2502, 2210, 3658, 2161, 1016, 2129, 2116, 4178,  102,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [ 101, 2040, 6369, 3403, 2005, 1037, 2611, 2066, 2017,  102,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0]])

In [60]:
b["psgs_input_ids"].shape

(2, 10, 128)