In [1]:
import os

from dotenv import load_dotenv
from pymongo import MongoClient

load_dotenv("../env")
client = MongoClient(os.environ.get("MONGO_DB_CONNECTION"))
collection = client.get_database("prismai").get_collection("collected_items")

In [2]:
from itertools import batched

import datasets
from datasets import Dataset
from tqdm.auto import tqdm

datasets.disable_progress_bars()

In [3]:
from transition_scores.scorer import OnnxTransitionScorer
from transition_scores.data import CustomTokenizer, RollingWindowChunkTokenizer

scorer = OnnxTransitionScorer(
    "/hot_storage/models/onnx/gpt2_onnx_o4/",
    tokenizer=RollingWindowChunkTokenizer.from_pretrained("gpt2"),
    batch_size=4,
    device="cuda",
)

[0;93m2025-01-22 18:57:48.099021612 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 24 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.[m
[0;93m2025-01-22 18:57:48.103104602 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-01-22 18:57:48.103111104 [W:onnxruntime:, session_state.cc:1170 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m


In [4]:
total = collection.count_documents({})

tq = tqdm(
    collection.find(
        projection=[
            "text",
            "chunks",
        ],
        batch_size=128,
    ),
    total=total,
)
for batch in batched(tq, 512):
    dataset = Dataset.from_list([dd | {"_id": str(dd["_id"])} for dd in batch])
    dataset = dataset.filter(lambda x: x["text"] and x["chunks"])
    for scores in scorer.process(dataset=dataset, top_k=100):
        pass

  0%|          | 0/802852 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1078 > 1024). Running this sequence through the model will result in indexing errors


KeyboardInterrupt: 

In [None]:
one = collection.find_one({"_id": "22c34302-0ec6-4781-8d96-1d6a4fda049e"})
print(one["text"])
print("".join(one["chunks"]))
one

In [None]:
from bson.dbref import DBRef

from transition_scores.data import LogProbs
from transition_scores.mongo import TextTransitionScore, TransitionScoreItem

dict(
    TransitionScoreItem(
        DBRef("a", "b"),
        "gpt2",
        "onnx",
        TextTransitionScore([LogProbs(0, 1.0, [0], [1.0])]),
    )
)

In [1]:
from datasets import Dataset

from transition_scores.data import CustomTokenizer

tokenizer = CustomTokenizer.from_pretrained("gpt2")

dataset = Dataset.from_dict(
    {
        "_id": ["abc-def-123"],
        "text": [
            "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."
        ],
        "chunks": [
            [
                "Lorem ipsum dolor sit amet,",
                "consectetur adipiscing elit.",
                "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
                "Ut enim ad minim veniam,",
                "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.",
            ]
        ],
    }
)
dataset = tokenizer.tokenize_dataset(dataset)
dataset

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Dataset({
    features: ['_id', 'text_sha256', 'input_ids', 'attention_mask', 'length'],
    num_rows: 1
})

In [2]:
from datasets import Dataset

from transition_scores.data import CustomTokenizer, RollingWindowChunkTokenizer

tokenizer = RollingWindowChunkTokenizer.from_pretrained("gpt2")

dataset = Dataset.from_dict(
    {
        "_id": ["abc-def-123"],
        "text": [
            "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."
        ],
        "chunks": [
            [
                "Lorem ipsum dolor sit amet,",
                "consectetur adipiscing elit.",
                "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
                "Ut enim ad minim veniam,",
                "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.",
            ]
        ],
    }
)
dataset = tokenizer.tokenize_dataset(dataset)
dataset

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Dataset({
    features: ['_id', 'text_sha256', 'length', 'start_idx', 'attention_mask', 'start_token_idx', 'end_idx', 'text', 'prefix_idx', 'input_ids'],
    num_rows: 5
})

In [4]:
import torch

torch.arange(16).view(4, 2, 2).tolist()

[[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]], [[12, 13], [14, 15]]]