In [None]:
from typing import Callable, List

import torch

from indxr import Indxr

from torch.utils.data import DataLoader
from tqdm.notebook import tqdm


In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        main: Indxr = None,
        others: List[Indxr] = None,
        callback: Callable = None,
    ):
        self.main = main
        self.others = others
        self.callback = callback

    # Support indexing such that dataset[i] can be used to get i-th sample
    def __getitem__(self, index: int) -> str:
        if self.callback:
            if self.others:
                return self.callback(self.main[index], self.others)
                
            return self.callback(self.main[index])
            
        return self.main[index]

    # This allows to call len(dataset) to get the dataset size
    def __len__(self) -> int:
        return len(self.main)


In [None]:
a = Indxr(kind="jsonl", path="data/queries.jsonl", key_id="q_id")
b = Indxr(kind="jsonl", path="data/users.jsonl", key_id="u_id")
c = Indxr(kind="jsonl", path="data/docs.jsonl", key_id="d_id")

In [None]:
def do_something(query, others):
    users, docs = others

    pos_docs = docs.mget(query["pos_doc_ids"])
    pos_docs = [doc["text"] for doc in pos_docs]

    neg_docs = docs.mget(query["neg_doc_ids"])
    neg_docs = [doc["text"] for doc in neg_docs]

    user = users.get(query["user_id"])
    user_docs = docs.mget(user["doc_ids"])
    user_docs = [doc["text"] for doc in user_docs]

    return query["text"], pos_docs, neg_docs, user_docs


dataset = Dataset(
    main=a,
    others=[
        b,
        c,
    ],
    callback=do_something,
)

In [None]:
%%timeit
x = dataset[0]

In [None]:
train_dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=1,
    # prefetch_factor=2,
)

for _ in tqdm(train_dataloader):
    continue