In [1]:
import inspect
from functools import partial
from typing import Generator, Iterable, NamedTuple, TypedDict

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel

In [2]:
type _ModelOutput = dict[str, torch.Tensor | list[torch.Tensor]]


class Metrics(NamedTuple):
    llr: float
    fdg: float


class Sample(TypedDict):
    input_ids: list[int]
    attention_mask: list[int]


class MetricsCalculator:
    def __init__(
        self,
        model: str,
        batch_size: int = 128,
        device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.batch_size = batch_size
        self.device = torch.device(device)

        self.model = AutoModelForCausalLM.from_pretrained(model)

    @property
    def model(self):
        return self._model

    @model.setter
    def model(self, model):
        self.set_model(model)

    def set_model(self, model: PreTrainedModel):
        self._model = model
        self._requires_position_ids = "position_ids" in set(
            inspect.signature(self.model.forward).parameters.keys()
        )
        self.to(self.device)

    def to(self, device: str | torch.device):
        self.device = torch.device(device)
        self.model.to(self.device)
        return self

    def process(
        self,
        dataset: list[Sample],
        pad_token_id: int,
    ) -> list[Metrics]:
        """
        Calculate metrics for the given pre-processed dataset.

        Args:
            dataset (list[Sample]): A sequence of pre-processed documents to be processed.
            pad_token_id (int): The token ID to use for padding.

        Returns:
            list[Metrics]: A list of calculated metrics.
        """
        return list(self._generate_scores(dataset, pad_token_id))

    def _generate_scores(
        self, dataset: list[Sample], pad_token_id: int
    ) -> Generator[Metrics, None, None]:
        _collate_fn = partial(collate_fn, pad_token_id=pad_token_id)
        for input_ids, attention_mask in tqdm(
            DataLoader(
                dataset,
                shuffle=False,
                collate_fn=_collate_fn,
                batch_size=self.batch_size,
            ),
            position=2,
            leave=False,
            desc="Processing Sequences",
        ):
            yield from self._process_batch(input_ids, attention_mask, pad_token_id)

    def _process_batch(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pad_token_id: int,
    ) -> list[Metrics]:
        """
        Process the a batch of input sequences and calculate transition scores.
        Runs a forward pass on the model and extracts the top k probabilities.

        Args:
            input_ids (torch.Tensor): A list of input sequences, each represented as a list of token IDs.
            attention_mask (torch.Tensor): A list of attention masks for each input sequence.
            pad_token_id (int): The token ID that has been used for padding.

        Returns:
            list[TransitionScores]: A list output probability tuples.
        """
        (
            likelihoods,
            log_likelihoods,
        ) = self._forward(input_ids, attention_mask)

        results = []
        for (
            target_ids,
            likelihood,
            log_likelihood,
        ) in zip(
            input_ids.to(self.device),
            likelihoods,
            log_likelihoods,
        ):
            # Truncate the sequence to the last non-pad token
            labels = target_ids[1:].view(-1, 1)
            labels = labels[: labels.ne(pad_token_id).sum()]

            likelihood: torch.Tensor = likelihood[: labels.size(0)]
            log_likelihood: torch.Tensor = log_likelihood[: labels.size(0)]

            target_log_probs = log_likelihood.gather(-1, labels).squeeze(-1)

            # Get target likelihoods and ranks
            _, sorted_indices = torch.sort(likelihood, descending=True)
            _, target_ranks = torch.where(sorted_indices.eq(labels))

            # Calculate DetectLLM-LLR
            llr = self._calculate_log_likelihood_ratio(target_log_probs, target_ranks)

            # Calculate Fast-DetectGPT (analytic)
            fdg = self._calculate_fast_detect_gpt(
                likelihood, log_likelihood, target_log_probs
            )

            results.append(Metrics(llr, fdg))

        return results

    def _forward(
        self, input_ids: torch.Tensor, attention_mask: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, Iterable[tuple[torch.Tensor]]]:
        # Create `position_ids` on the fly, if required
        # Source: https://github.com/huggingface/transformers/blob/v4.48.1/src/transformers/generation/utils.py#L414
        position_ids = None
        if self._requires_position_ids:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)

        with torch.no_grad():
            outputs: _ModelOutput = self._model(
                input_ids=input_ids.to(self.device),
                attention_mask=attention_mask.to(self.device),
                position_ids=position_ids.to(self.device),
            )

            likelihoods: torch.Tensor = outputs.logits.softmax(-1)
            log_likelihoods: torch.Tensor = outputs.logits.log_softmax(-1)

            del outputs
        return (
            likelihoods,
            log_likelihoods,
        )

    def _calculate_log_likelihood_ratio(
        self,
        target_log_probs: torch.Tensor,
        target_ranks: torch.Tensor,
        device: torch.device = None,
    ) -> float:
        """Implements the DetectLLM-LLR analytic criterion.

        Args:
            target_log_probs (torch.Tensor): A tensor of log probabilities for each target token.
            target_ranks (torch.Tensor): A tensor of ranks for each target token.
            device (torch.device, optional): Device to run the calculations on. Defaults to None.

        Returns:
            float: The calculated log-likelihood ratio.

        Source:
            - Paper: https://aclanthology.org/2023.findings-emnlp.827.pdf
            - GitHub: https://github.com/mbzuai-nlp/DetectLLM
            - Implementation:
                - https://github.com/mbzuai-nlp/DetectLLM/blob/main/baselines/all_baselines.py#L35:L42
                - https://github.com/mbzuai-nlp/DetectLLM/blob/main/baselines/all_baselines.py#L94:L100
        """
        device = device or self.device
        return (
            -torch.div(
                target_log_probs.to(device).sum(),
                target_ranks.to(device).log1p().sum(),
            )
            .cpu()
            .item()
        )

    def _calculate_fast_detect_gpt(
        self,
        likelihood: torch.Tensor,
        log_likelihood: torch.Tensor,
        target_log_probs: torch.Tensor,
        device: torch.device = None,
    ) -> float:
        """Implements the Fast-DetectGPT analytic criterion.

        Source:
            - Paper: https://arxiv.org/abs/2310.05130
            - GitHub: https://github.com/baoguangsheng/fast-detect-gpt
            - Implementation: https://github.com/baoguangsheng/fast-detect-gpt/blob/main/scripts/fast_detect_gpt.py#L52:L70
        """
        device = device or self.device
        expectation = (likelihood.to(device) * log_likelihood.to(device)).sum(-1)
        variance = (likelihood.to(device) * log_likelihood.to(device).square()).sum(
            -1
        ) - expectation.square()

        fast_detect_gpt = (
            target_log_probs.to(device).sum(-1) - expectation.sum(-1)
        ) / variance.sum(-1).sqrt()

        return fast_detect_gpt.cpu().item()


def collate_fn(
    batch: list[Sample],
    pad_token_id: int,
) -> list[Sample]:
    input_ids, attention_mask = zip(
        *[(sample["input_ids"], sample["attention_mask"]) for sample in batch]
    )
    input_ids = pad_sequence(
        [torch.tensor(seq_ids) for seq_ids in input_ids],
        batch_first=True,
        padding_value=pad_token_id,
    ).long()
    attention_mask = pad_sequence(
        [torch.tensor(mask) for mask in attention_mask],
        batch_first=True,
        padding_value=0,
    ).long()
    return input_ids, attention_mask


In [3]:
model = MetricsCalculator("gpt2")

In [4]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
dataset = [
    tokenizer(sentence, add_special_tokens=True)
    for sentence in [
        "The quick brown fox jumps over the lazy dog.",
        "Once upon a time, in a land far far away...",
    ]
]

In [5]:
model.process(dataset, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id)


[A
[A
[A

[Metrics(llr=1.5427662134170532, fdg=0.526492178440094),
 Metrics(llr=2.1564900875091553, fdg=1.4833513498306274)]