In [None]:
from typing import Callable, Dict

import torch

from indxr import Indxr


class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        indxr_kwargs: Dict,
        callback: Callable = None,
    ):
        self.indxr_args = indxr_kwargs
        self.callback = callback

        self.main_index = Indxr(**indxr_kwargs)

    # Support indexing such that dataset[i] can be used to get i-th sample
    def __getitem__(self, index: int) -> str:

        if self.callback:
            return self.callback(self.main_index[index])

        return self.main_index[index]

    # This allows to call len(dataset) to get the dataset size
    def __len__(self) -> int:
        return len(self.main_index)


In [None]:
dataset = Dataset(
    indxr_kwargs={
        "kind": "jsonl",
        "path": "tests/test_data/sample.jsonl",
    }
)

In [None]:
dataset[0]

In [None]:
from typing import Callable, Dict

import torch

from indxr import Indxr


class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        kwargs: Dict = None,
        callback: Callable = None,
    ):
        self.kwargs = kwargs
        self.callback = callback

    # Support indexing such that dataset[i] can be used to get i-th sample
    def __getitem__(self, index: int) -> str:

        if self.callback:
            return self.callback(self.main_index[index])

        return self.main_index[index]

    # This allows to call len(dataset) to get the dataset size
    def __len__(self) -> int:
        return len(self.main_index)


In [None]:
dataset = Dataset(
    indexes=[
        Indxr("queries.jsonl", key_id="q_id"),
        Indxr("users.jsonl", key_id="u_id"),
        Indxr("docs.jsonl", key_id="d_id"),
    ],
    callback=do_something,
)

In [None]:
def do_something(query):
    users, docs = others

    pos_docs = docs.get(query["pos_doc_ids"])
    pos_docs = [doc["text"] for doc in pos_docs]

    neg_docs = docs.get(user["neg_doc_ids"])
    neg_docs = [doc["text"] for doc in neg_docs]

    user = users.get(query["user_id"])
    user_docs = docs.get(user["doc_ids"])
    user_docs = [doc["text"] for doc in user_docs]

    return query["text"], pos_docs, neg_docs, user_docs

In [1]:
from typing import Callable, List

import torch

from indxr import Indxr

In [2]:

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        main: Indxr = None,
        others: List[Indxr] = None,
        callback: Callable = None,
    ):
        self.main = main
        self.others = others
        self.callback = callback

    # Support indexing such that dataset[i] can be used to get i-th sample
    def __getitem__(self, index: int) -> str:
        if self.callback:
            if self.others:
                return self.callback(self.main[index], self.others)
                
            return self.callback(self.main[index])
            
        return self.main[index]

    # This allows to call len(dataset) to get the dataset size
    def __len__(self) -> int:
        return len(self.main)


In [3]:
a = Indxr(kind="jsonl", path="data/queries.jsonl", key_id="q_id")
b = Indxr(kind="jsonl", path="data/users.jsonl", key_id="u_id")
c = Indxr(kind="jsonl", path="data/docs.jsonl", key_id="d_id")

In [4]:
def do_something(query, others):
    users, docs = others

    # pos_docs = docs.mget(query["pos_doc_ids"])
    # pos_docs = [doc["text"] for doc in pos_docs]
    pos_docs = [docs.get(x)["text"] for x in query["pos_doc_ids"]]

    # neg_docs = docs.mget(query["neg_doc_ids"])
    # neg_docs = [doc["text"] for doc in neg_docs]
    neg_docs = [docs.get(x)["text"] for x in query["neg_doc_ids"]]

    user = users.get(query["user_id"])
    # user_docs = docs.mget(user["doc_ids"])
    # user_docs = [doc["text"] for doc in user_docs]
    user_docs = [docs.get(x)["text"] for x in user["doc_ids"]]

    return query["text"], pos_docs, neg_docs, user_docs


# def do_something(query, others):
#     users, docs = others

#     pos_doc_ids = query["pos_doc_ids"]
#     neg_doc_ids = query["neg_doc_ids"]

#     user = users.get(query["user_id"])
#     user_doc_ids = user["doc_ids"]

#     docs = docs.mget(pos_doc_ids + neg_doc_ids + user_doc_ids)
#     docs = [doc["text"] for doc in docs]

#     pos_docs = docs[: len(pos_doc_ids)]
#     neg_docs = docs[len(pos_doc_ids) : len(pos_doc_ids) + len(neg_doc_ids)]
#     user_docs = docs[len(pos_doc_ids) + len(neg_doc_ids) :]

#     # assert len(pos_docs) == len(pos_doc_ids), "pos_docs"
#     # assert len(neg_docs) == len(neg_doc_ids), "neg_docs"
#     # assert len(user_docs) == len(user_doc_ids), "user_docs"

#     return query["text"], pos_docs, neg_docs, user_docs


dataset = Dataset(
    main=a,
    others=[
        b,
        c,
    ],
    callback=do_something,
)


In [6]:
%%timeit
x = dataset[0]

1.41 ms ± 47.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [7]:
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

train_dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    # prefetch_factor=4,
)

In [8]:
for x in tqdm(train_dataloader):
    continue

KeyboardInterrupt: 

In [45]:
%%timeit
asd = a[345]

36.5 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [5]:
import linecache
import json

In [51]:
%%timeit
asd = json.loads(linecache.getline("data/queries.jsonl", 345))

10.1 µs ± 94.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [6]:
linecache.getline("data/queries.jsonl", 0)

''

In [53]:
a[345]

{'q_id': 'q345',
 'text': 'lorem ipsum',
 'user_id': 'u3',
 'pos_doc_ids': ['d2440828',
  'd2350610',
  'd2622392',
  'd155325',
  'd3129857',
  'd341540',
  'd3524478',
  'd4897201',
  'd3832800',
  'd830828'],
 'neg_doc_ids': ['d4599337',
  'd1752145',
  'd4820849',
  'd3226581',
  'd2328669',
  'd3295127',
  'd2117101',
  'd2760129',
  'd1658716',
  'd4754674',
  'd1891680',
  'd2588664',
  'd4291885',
  'd1363998',
  'd572726',
  'd4319392',
  'd4200941',
  'd2031437',
  'd1719342',
  'd2097254',
  'd706280',
  'd1093021',
  'd4621948',
  'd4939106',
  'd2100583',
  'd3055014',
  'd2763698',
  'd4939144',
  'd806842',
  'd1400714',
  'd52605',
  'd2196060',
  'd1295359',
  'd4122526',
  'd682028',
  'd4690378',
  'd4346522',
  'd4215622',
  'd1214980',
  'd4150395',
  'd4678115',
  'd467257',
  'd2845576',
  'd2011460',
  'd4185745',
  'd3300502',
  'd2154419',
  'd4323957',
  'd211592',
  'd772400',
  'd964586',
  'd1795364',
  'd417819',
  'd2848940',
  'd3955915',
  'd3631490',


In [12]:
from indxr import Indxr
import json
import orjson

In [2]:
a = Indxr(kind="jsonl", path="data/queries.jsonl", key_id="q_id")

In [5]:
positions = list(a.index.values())[:100]

In [9]:
%%timeit

lines = [None] * len(positions)

with open("data/queries.jsonl", "rb") as file:
    for i, position in enumerate(positions):
        file.seek(position)
        lines[i] = file.readline()

res = [json.loads(line) for line in lines]

1.32 ms ± 95.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
%%timeit

lines = [None] * len(positions)

with open("data/queries.jsonl", "rb") as file:
    for i, position in enumerate(positions):
        file.seek(position)
        lines[i] = file.readline()

res = [orjson.loads(line) for line in lines]

684 µs ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [10]:
file = open("data/queries.jsonl", "rb")

In [11]:
%%timeit

for i, position in enumerate(positions):
    file.seek(position)
    lines[i] = file.readline()

res = [json.loads(line) for line in lines]

1.23 ms ± 9.75 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
