In [7]:
# from chainlit.utils import mount_chainlit
import logging
import time
from typing import Dict, Iterable, List, Optional

import torch
from datasets import Dataset, load_dataset
from pydantic import BaseModel
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
from enum import Enum

logging.basicConfig(filename="query_logs.log", level=logging.INFO)

class BertHFPath(str, Enum):
    modern_bert_large_embed = "lightonai/modernbert-embed-large"
    modern_bert_base_embed = "nomic-ai/modernbert-embed-base"
    modern_bert_base = "answerdotai/ModernBERT-base"
    modern_bert_large = "answerdotai/ModernBERT-large"
    gte_base = "Alibaba-NLP/gte-base-en-v1.5"
    gte_large = "Alibaba-NLP/gte-large-en-v1.5"
    modern_bert_base_v2 = "Alibaba-NLP/gte-modernbert-base"
    daberta_base = "microsoft/deberta-v3-base"

class BenchmarkQueryRequest(BaseModel):
    encoder_path:BertHFPath =  BertHFPath.modern_bert_base
    batch_size: int = 20
    num_pass: Optional[int] = 100

In [9]:
import random
from functools import partial
from pathlib import Path
from typing import Dict, Iterator, List, Optional, cast

import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer



class DatasetWrapper:
    def __init__(
        self,
        path_to_files: str,
    ) -> None:
        self.root = path_to_files
        self.find_all_files()

    def find_all_files(self) -> None:
        self.list_files: List[Path] = list(Path(self.root).rglob("*.txt"))
        assert len(self.list_files), "No documents found"

    def _read_file_in_chunks(
        self,
        file_path: Path,
        tokenizer: AutoTokenizer,
        chunk_size: int,
        use_random_chunk_size: bool = False,
    ) -> Iterator[Dict[str, str]]:
        with file_path.open(encoding="utf-8") as f:
            content = f.read()

        tokens = tokenizer(content, truncation=False, return_tensors="pt")["input_ids"].squeeze(0)

        assert chunk_size is not None
        index = 0
        while index < len(tokens):
            chunk_size = chunk_size if not use_random_chunk_size else random.randint(1, chunk_size)
            chunk_tokens = tokens[index : index + chunk_size]
            chunk_text = tokenizer.decode(
                chunk_tokens.squeeze(0).tolist(), skip_special_tokens=True
            )
            index += chunk_size

            yield {
                "id": f"{str(file_path).replace('/', '_')}_{index}_{chunk_size}",
                "content": chunk_text,
            }

    def get_iterator(
        self,
        tokenizer: AutoTokenizer,
        chunk_size: int = 2048,
        use_random_chunk_size: bool = False,
    ) -> Iterator[Dict[str, str]]:
        for file_path in self.list_files:
            yield from self._read_file_in_chunks(
                file_path, tokenizer, chunk_size, use_random_chunk_size
            )

    def get_dataloader(
        self,
        batch_size: int = 10,
        tokenizer_path: BertHFPath = BertHFPath.modern_bert_base_embed,
        tokenize: bool = False,
        chunk_size: int = 2048,
        use_random_chunk_size: bool = False,
    ) -> DataLoader[Dict[str, torch.Tensor]]:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path.value)

        dataset = Dataset.from_generator(
            partial(
                self.get_iterator,
                tokenizer=tokenizer,
                chunk_size=chunk_size,
                use_random_chunk_size=use_random_chunk_size,
            )
        )
        if tokenize:
            dataset = dataset.map(partial(self.tokenization, tokenizer=tokenizer), batched=True)
            columns_of_interest = {"input_ids", "attention_mask", "token_type_ids"}.intersection(
                dataset.column_names
            )

            dataset.set_format("pt", columns=columns_of_interest, output_all_columns=False)

        dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True, pin_memory=True)
        return dataloader

    def tokenization(
        self,
        example: Dict[str, torch.Tensor],
        tokenizer: BertHFPath = BertHFPath.modern_bert_base_embed,
    ) -> Dict[str, torch.Tensor]:

        return cast(
            Dict[str, torch.Tensor],
            tokenizer(example["content"], return_tensors="pt", padding=True),
        )

In [None]:
import numpy as np 

corpus = DatasetWrapper(
        "corpus/",
    )
dataloader = corpus.get_dataloader(
        batch_size=1, tokenizer_path=BertHFPath.modern_bert_base, chunk_size=8192, tokenize=True
    )
sum([k['input_ids'].shape[2] for k in dataloader])
# np.mean([k['input_ids'].shape[2] for k in dataloader])


In [None]:
corpus = DatasetWrapper(
        "corpus/",
    )
dataloader = corpus.get_dataloader(
        batch_size=1, tokenizer_path=BertHFPath.gte_base, chunk_size=512, tokenize=True
    )
sum([k['input_ids'].shape[2] for k in dataloader])


In [11]:
import json
import logging
import time
from typing import Any, Dict

import numpy as np
import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel


logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("filelock").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)

# Configure logging to only write to a file
logging.basicConfig(
    filename="query_logs.log",  # Log file path
    level=logging.INFO,  # Set log level
    format="%(asctime)s - %(levelname)s - %(message)s",  # Log format
    filemode="w",  # Overwrite the file on each run (use "a" to append)
)

device = "cuda" if torch.cuda.is_available() else "mps"


def benchmark_on_corpus(
    encoder_path: BertHFPath,
    batch_size: int = 10,
    chunk_size: int = 4096,
    num_pass: int = 1,
    use_random_chunk_size: bool = False,
) -> Dict[str, Any]:
    torch.cuda.empty_cache()
    model = AutoModel.from_pretrained(encoder_path.value, trust_remote_code=True)
    model.to(device, torch.float16)
    # model = torch.compile(model)
    
    corpus = DatasetWrapper(
        "corpus/",
    )
    dataloader = corpus.get_dataloader(
        batch_size=batch_size, tokenizer_path=encoder_path, chunk_size=chunk_size, tokenize=True
    )

    evaluation_times = []
    with torch.no_grad():
        with torch.autocast(device_type=device, dtype=torch.float16):
            for docs in dataloader:
                for i in range(2):
                    inputs = {
                        k: v.to(device) for k, v in docs.items() if k not in ["content", "id"]
                    }
                    model(**inputs)
                break

    num_tokken = 0
    token_counts = []
    device_times = []
    start_time = time.perf_counter()
    with torch.no_grad():
        with torch.autocast(device_type=device, dtype=torch.float16):
            for i in tqdm(range(num_pass)):
                for docs in tqdm(dataloader):
                    num_tokken += inputs["input_ids"].numel()
                    token_counts.append(inputs["input_ids"].numel())
                    
                    start_event = torch.cuda.Event(enable_timing=True)
                    end_event = torch.cuda.Event(enable_timing=True)

                    start_event.record()
                    inputs = {
                        k: v.to(device, non_blocking=True)  # Use non-blocking transfers
                        for k, v in docs.items() if k not in ["content", "id"]
                    }
                    end_event.record()
                    torch.cuda.synchronize()  # Ensure all events are completed
                    device_times.append(start_event.elapsed_time(end_event) / 1000)  # Convert ms to sec

                    # Model forward pass timing
                    start_event.record()
                    _ = model(**inputs)
                    end_event.record()
                    torch.cuda.synchronize()
                    evaluation_times.append(start_event.elapsed_time(end_event) / 1000)

    total_time = time.perf_counter() - start_time

    mean_time = np.mean(evaluation_times)
    var_time = np.var(evaluation_times)

    tokens_per_sec = np.sum(np.array(token_counts)) / total_time

    results = {
        "exp": {
            "batch_size": batch_size,
            "chunk_size": chunk_size,
            "encoder": encoder_path.value,
            "use_random_chunk_size": use_random_chunk_size,
        },
        "total_time": str(total_time),
        "mean_time": str(sum(evaluation_times) / len(evaluation_times)),
        "num_tokken": num_tokken,
        "var": var_time,
        "mean_time/tokken": mean_time / np.mean(token_counts), 
        "total_time/tokken": total_time / np.sum(token_counts),
        "token/mean_time": np.mean(token_counts)/ mean_time ,
        "token/total_time":  np.sum(token_counts)/ total_time ,
        "tokken/sec": tokens_per_sec,
        "device_time": np.mean(device_times)
    }

    logging.info(json.dumps(results))

    return results

In [None]:
benchmark_on_corpus(BertHFPath.modern_bert_base, batch_size=1600, chunk_size=512, num_pass=2, use_random_chunk_size=False)

In [None]:
benchmark_on_corpus(BertHFPath.gte_base, batch_size=100, chunk_size=8192, num_pass=2, use_random_chunk_size=False)