In [18]:
import os

from dotenv import load_dotenv
from pymongo import MongoClient

load_dotenv("../env")
client = MongoClient(os.environ.get("MONGO_DB_CONNECTION"))
db = client.get_database("prismai")
collection = db.get_collection("collected_items")

In [19]:
# db.drop_collection("transition_scores")

{'nIndexesWas': 1, 'ns': 'prismai.transition_scores', 'ok': 1.0}

In [8]:
stc = db.get_collection("synthesized_texts")

In [17]:
stc.find_one({"agent": "gpt-4o-mini"})

{'_id': ObjectId('679793f8bb95bc200c6a2068'),
 'ref_id': DBRef('collected_items', 'fb556a23-7abf-47b2-93f6-ab99eeedec0b'),
 '_ref_id': DBRef('collected_items', ObjectId('678fb3abdbe3ac531644d65e')),
 'domain': 'bundestag',
 'date': '2019-02-13T00:00:00',
 'source': 'https://bundestag-mine.de/api/DashboardController/GetNLPSpeechById/e9e8a810-308a-4719-6138-08da0f22a008',
 'lang': 'de-DE',
 'agent': 'gpt-4o-mini',
 'text': 'Sehr geehrte Damen und Herren, liebe Kolleginnen und Kollegen,\n\nheute stehen wir hier im Deutschen Bundestag, um über eine der drängendsten Herausforderungen unserer Zeit zu sprechen: die Energiewende und ihre weitreichenden Implikationen für unser Land, unsere Gesellschaft und unseren Planeten. Wir alle sind uns der Verantwortung bewusst, die wir tragen, nicht nur gegenüber unseren Bürgerinnen und Bürgern, sondern auch gegenüber zukünftigen Generationen. Die Klimakrise ist nicht länger ein fernes Szenario, sondern eine gegenwärtige Realität, die uns alle betrifft.\

In [2]:

tsc = db.get_collection("transition_scores")

In [7]:
tsc.count_documents({"_ref_id.$ref": "synthesized_texts"})

52739

In [6]:
from itertools import batched

import datasets
from datasets import Dataset
from tqdm.auto import tqdm

datasets.disable_progress_bars()

In [None]:
from transition_scores.pre_processor.chunks import RollingWindowChunkPreProcessor
from transition_scores.pre_processor.text import TextPreProcessor
from transition_scores.scorer import OnnxTransitionScorer

scorer = OnnxTransitionScorer(
    "/hot_storage/models/onnx/gpt2_onnx_o4/",
    pre_processor=RollingWindowChunkPreProcessor.from_pretrained(
        "/hot_storage/models/onnx/gpt2_onnx_o4/"
    ),
    batch_size=1,
    device="cuda",
    top_k=4,
)

In [None]:
from bson import DBRef

total = collection.count_documents({})

tq = tqdm(
    collection.find(
        projection=[
            "text",
            "chunks",
        ],
        batch_size=128,
    ),
    total=total,
)
for batch in batched(tq, 16):
    batch = [
        {
            "ref": {
                "$ref": "collected_items",
                "$id": str(row.pop("_id")),
            }
        }
        | row
        for row in batch
    ]
    dataset = Dataset.from_list(batch)
    dataset = dataset.filter(lambda x: x["text"] and x["chunks"])
    for scores in scorer.process(dataset):
        print(str(scores)[:500])
        raise RuntimeError

In [None]:
_scores = scores.copy()

transposed = {"feature_metadata": dict()}
for key in scorer.pre_processor.additional_fields:
    transposed["feature_metadata"][key] = _scores.pop(key)
transposed = _scores | transposed
transposed

In [None]:
one = collection.find_one({"_id": "22c34302-0ec6-4781-8d96-1d6a4fda049e"})
print(one["text"])
print("".join(one["chunks"]))
one

In [None]:
from bson.dbref import DBRef

from transition_scores.data import LogProbs
from transition_scores.mongo import TextTransitionScore, TransitionScoreItem

dict(
    TransitionScoreItem(
        DBRef("a", "b"),
        "gpt2",
        "onnx",
        TextTransitionScore([LogProbs(0, 1.0, [0], [1.0])]),
    )
)

In [None]:
from datasets import Dataset

from transition_scores.pre_processor.text import TextPreProcessor

tokenizer = TextPreProcessor.from_pretrained("gpt2")

dataset = Dataset.from_dict(
    {
        "_id": ["abc-def-123"],
        "text": [
            "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."
        ],
        "chunks": [
            [
                "Lorem ipsum dolor sit amet,",
                "consectetur adipiscing elit.",
                "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
                "Ut enim ad minim veniam,",
                "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.",
            ]
        ],
    }
)
dataset = tokenizer.prepare_dataset(dataset)
dataset

In [None]:
from datasets import Dataset

from transition_scores.pre_processor.chunks import RollingWindowChunkPreProcessor
from transition_scores.pre_processor.text import TextPreProcessor

tokenizer = RollingWindowChunkPreProcessor.from_pretrained("gpt2")

dataset = Dataset.from_dict(
    {
        "_id": ["abc-def-123"],
        "text": [
            "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."
        ],
        "chunks": [
            [
                "Lorem ipsum dolor sit amet,",
                "consectetur adipiscing elit.",
                "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.",
                "Ut enim ad minim veniam,",
                "quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.",
            ]
        ],
    }
)
dataset = tokenizer.prepare_dataset(dataset)
dataset

In [1]:
import json
import os
from argparse import ArgumentParser, Namespace
from itertools import batched
from pathlib import Path

import datasets
from datasets import Dataset
from dotenv import load_dotenv
from pymongo import MongoClient
from tqdm import tqdm, trange

from transition_scores.pre_processor.chunks import RollingWindowChunkPreProcessor
from transition_scores.scorer import OnnxTransitionScorer

mongodb_batch_size = 16
mongodb_filter_query = {}
dataset_batch_size = 8

mongodb_client = MongoClient(os.environ.get("MONGO_DB_CONNECTION"))
mongodb_database = mongodb_client.get_database("prismai")
source_collection = mongodb_database.get_collection("collected_items")
target_collection = mongodb_database.get_collection("test")


scorer = OnnxTransitionScorer(
    "/hot_storage/models/onnx/gpt2_onnx_o4/",
    batch_size=4,
    device="cuda",
)

pre_processors = RollingWindowChunkPreProcessor.from_pretrained("gpt2")
num_documents = source_collection.count_documents(mongodb_filter_query)
tq_fetch = trange(
    0,
    16,
    8,
    desc=f"Processing Document Batches of {dataset_batch_size} from {source_collection}",
)
for offset in tq_fetch:
    batch = []
    for row in source_collection.find(
        mongodb_filter_query,
        projection=[
            "text",
            "chunks",
            "id",
        ],
        batch_size=mongodb_batch_size,
        limit=min(dataset_batch_size, num_documents),
        skip=offset,
    ):
        refs = {
            "_ref_id": {
                "$ref": "collected_items",
                "$id": str(row.pop("_id")),
            }
        }
        if "id" in row:
            refs["ref_id"] = {
                "$ref": "collected_items",
                "$id": row.pop("id"),
            }
        else:
            refs["ref_id"] = None

        if "_ref_id" in row:
            refs["_orig_ref_id"] = row.pop("_ref_id")
        if "ref_id" in row:
            refs["orig_ref_id"] = row.pop("ref_id")

        batch.append(refs | row)
    dataset = Dataset.from_list(batch).filter(
        lambda x: x["text"] and x["chunks"],
        keep_in_memory=not datasets.is_caching_enabled(),
    )

    for pre_processor in [pre_processors]:
        processed_dataset = scorer.process(dataset, pre_processor)
        for r_batch in batched(
            tqdm(
                processed_dataset,
                desc="Inserting Batch Results",
                position=1,
                leave=False,
            ),
            mongodb_batch_size,
        ):
            target_collection.insert_many(r_batch, ordered=False)
