In [1]:
import os
import pickle

from dotenv import load_dotenv
from lightning.pytorch import seed_everything
from sklearn.metrics import auc, roc_curve

# from luminar.baselines import log_likelihood_log_rank_ratio
from luminar.mongo import MongoDBAdapter
from simple_dataset.dataset import Dataset
from transition_scores.data import TransitionScores

load_dotenv("../env")
seed_everything(42)

Seed set to 42


42

In [2]:
import torch


def log_likelihood_log_rank_ratio(transition_scores: TransitionScores) -> float:
    target_probs = torch.tensor(transition_scores.target_probs)
    target_ranks = torch.tensor(transition_scores.target_ranks)
    mask = target_probs.ne(0.0)
    return (
        -(target_probs[mask] + 1e-8).log().sum()
        / (target_ranks[mask] + 1e-8).log1p().sum()
    )


In [3]:
for domain in (
    # "arxiv_papers",
    "blog_authorship_corpus",
    "bundestag",
    "cnn_news",
    "euro_court_cases",
    # "gutenberg",
    # "house_of_commons",
    # "spiegel_articles",
    "student_essays",
):
    db = MongoDBAdapter(
        os.environ.get("MONGO_DB_CONNECTION"),
        "prismai",
        "collected_items",
        "synthesized_texts",
        "transition_scores",
        domain=domain,
        source_collection_limit=1500,
    )
    with db.get_cache_file("pkl").open("rb") as f:
        dataset = Dataset(pickle.load(f))

    dataset = (
        dataset.filter(
            lambda doc: len(doc["features"]) == 2
            and len({ts["document"]["type"] for ts in doc["features"]}) == 2
        )
        .flat_map(lambda doc: doc["features"])
        .apply(TransitionScores.merge, "transition_scores")
        .map(
            lambda doc: {
                "llr": log_likelihood_log_rank_ratio(doc["transition_scores"]),
                "labels": int(doc["document"]["type"] != "source"),
            }
        )
    )
    llr, labels = zip(*[(doc["llr"], doc["labels"]) for doc in dataset])

    fpr, tpr, _ = roc_curve(labels, llr)
    roc_auc = auc(fpr, tpr)
    print(f"{domain}: {roc_auc}")

blog_authorship_corpus: 0.48696861155622584
bundestag: 0.8565655823166545
cnn_news: 0.9434142222222223
euro_court_cases: 0.9052147067706178
student_essays: 0.84698
