In [None]:
import copy
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from dotenv import load_dotenv
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from tqdm.auto import tqdm

from luminar.document.data import (
    FeatureDataset,
    PaddingDataloader,
    n_way_split,
)
from luminar.document.model import (
    CNNDocumentClassficationModel,
    ConvolutionalLayerSpec,
)
from luminar.features import FeatureExtractor, OneDimFeatures, Slicer, TwoDimFeatures
from luminar.mongo import PrismaiDataset

load_dotenv("../env")

warnings.filterwarnings("ignore", ".*does not have many workers.*")

In [None]:
domains = {
    "Blog Authorship": {"domain": "blog_authorship_corpus"},
    "Student Essays": {"domain": "student_essays"},
    "CNN News": {"domain": "cnn_news"},
    "Euro Court Cases": {"domain": "euro_court_cases"},
    "House of Commons": {"domain": "house_of_commons"},
    "ArXiv Papers": {"domain": "arxiv_papers"},
    "Gutenberg": {"domain": "gutenberg", "lang": "en-EN"},
    "Bundestag [DE]": {"domain": "bundestag"},
    "Spiegel [DE]": {"domain": "spiegel_articles"},
    # "Gutenberg [DE]": {"domain": "gutenberg", "lang": "de-DE"},
}

In [None]:
dmap = {
    "Blog Authorship": "Blog Authorship",
    "Student Essays": "Student Essays",
    "CNN News": "CNN News",
    "Euro Court Cases": "Euro Court Cases",
    "House of Commons": "House of Commons",
    "ArXiv Papers": "ArXiv Papers",
    "Gutenberg [EN]": "Gutenberg",
    "Bundestag": "Bundestag [DE]",
    "Spiegel": "Spiegel [DE]",
    # "Gutenberg [DE]": 
}

In [None]:
config = {
    "seed": 42,
    "eval_split": 0.1,
    "test_split": 0.2,
    "feature_model": "gpt2",
    # "feature_model": "meta-llama/Llama-3.2-1B",
    "synth_agent": "gpt-4o-mini",
    # "synth_agent": "gemma2:9b"
}

In [None]:
datasets = {
    domain: PrismaiDataset(
        mongo_db_connection="mongodb://prismai:prismai@isengart.hucompute.org:27123/?retryWrites=true&loadBalanced=false&serverSelectionTimeoutMS=5000&connectTimeoutMS=10000&authSource=admin&authMechanism=SCRAM-SHA-256",
        database="prismai",
        collection="features_prismai",
        feature_model=config["feature_model"],
        synth_agent=config["synth_agent"],
        # synth_agent={"$exists": True},
        **kwargs,
        # update_cache=True
    ).load()
    for domain, kwargs in domains.items()
}

[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/119e4613239c2a870617e80c50d91776c6682016415184e1039b60e5ef5d8116.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/4aa6c04025b8043dccbb076975ccd2bc2492c5d5172798eb9adeeaee382f8586.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/3fb225030a91434546e94cddd9b435edbce7c31b51b1f14829cea50f209d1bff.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/841e418fa927a8868452fa7c33fecf637691d908aa5e19593c101f12bb430fc7.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/693d7989e4d059cf2330448fba44db75c0ce07c23ab3f8b5f4974a017e4ddb14.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/e17b34be8696a18364a4ffaefa50e1b68b3fee34af4a6e75203c75fc0d555a1f.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/ae07a41994953a0d6f9844233d8ae8858135956e65e24c00250b6f90857eaf54.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/1f8a7aa0d5e70435378180ce907ae70bcd6369f516164db9658d06619e84da58.pkl


[PrismaiDataset] Loading Documents from MongoDB: 0it [00:00, ?it/s]

[PrismaiDataset] Writing Cache File /tmp/luminar/5ba2b5f045656e518aaba12da3883c015cc1639b065a45b63cc1f53943330ffe.pkl


In [None]:
train_splits = {}
eval_splits = {}
test_splits = {}
for domain, dataset in datasets.items():
    seed_everything(config["seed"])
    train_dataset, eval_dataset, test_dataset = n_way_split(
        dataset,
        config["eval_split"],
        config["test_split"],
        infer_first=True,
    )
    train_splits[domain] = train_dataset
    eval_splits[domain] = eval_dataset
    test_splits[domain] = test_dataset

Seed set to 42
Seed set to 42
Seed set to 42
Seed set to 42
Seed set to 42
Seed set to 42
Seed set to 42
Seed set to 42
Seed set to 42


## Features

In [None]:
# feature_dim = OneDimFeatures(256)
# featurizer = FeatureExtractor.Likelihood()
# featurizer = FeatureExtractor.LogLikelihoodLogRankRatio()
# config["second_dim_as_channels"] = False
feature_dim = TwoDimFeatures(256, 13)
# featurizer = FeatureExtractor.LikelihoodTopkLikelihoodRatio(16)
featurizer = FeatureExtractor.IntermediateLogits(13)
# config["second_dim_as_channels"] = False
config["second_dim_as_channels"] = True

# slicer = Slicer.Random(feature_dim[0])
slicer = Slicer.RandomMultiple(feature_dim[0] // 4, multiple=4, stride=16)
# slicer = Slicer.RandomMultiple(feature_dim[0] // 4, 4)

config["feature_dim"] = feature_dim
config["featurizer"] = repr(featurizer)
config["slicer"] = repr(slicer)

config["num_samples"] = None
config["num_samples_test"] = 32


def featurize(dataset, num_samples=None) -> FeatureDataset:
    return FeatureDataset(
        tqdm(dataset, position=1, leave=False),
        slicer,
        featurizer,
        num_samples=num_samples or config["num_samples"],
    )

In [None]:
train_datasets, eval_datasets, test_datasets = {}, {}, {}
for domain, dataset in tqdm(datasets.items()):
    train_datasets[domain] = featurize(train_splits[domain])
    eval_datasets[domain] = featurize(eval_splits[domain])
    test_datasets[domain] = featurize(test_splits[domain], num_samples=config["num_samples_test"])

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/1050 [00:00<?, ?it/s]

IndexError: index is out of bounds for dimension with size 0

## In-Domain Training & Evaluation

In [None]:
config |= {
    # "projection_dim": 32,
    "learning_rate": 0.0001,
    "warmup_steps": 69,
    "max_epochs": 50,
    "gradient_clip_val": 1.0,
    "batch_size": 32,
}

# SeqXGPT Layer Configuration
config["conv_layer_shapes"] = [
    ConvolutionalLayerSpec(64, 5),
    *[ConvolutionalLayerSpec(128, 3)] * 3,
    ConvolutionalLayerSpec(64, 3),
]

In [None]:
from torch.utils.data import ConcatDataset, DataLoader


class PaddingDataloader(DataLoader):
    def __init__(self, *args, feature_dim: tuple[int, ...], **kwargs):
        kwargs["collate_fn"] = self._collate_fn
        super().__init__(*args, **kwargs)
        self.feature_dim = feature_dim

    def _collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]:
        features = torch.nn.utils.rnn.pad_sequence(
            [x["features"] for x in batch], batch_first=True
        )

        # In case we get a batch of sequences, that are all too short,
        # we need to pad them to the correct length as given by the feature_dim.
        # - First dimension is the batch size.
        # - Second dimension is the sequence length.
        # - Third dimension is the feature dimension, if 2D features are used.
        match features.shape, self.feature_dim:
            case (_, s1), (d1,) if s1 < d1:
                p2d = (0, 0, 0, d1 - s1)
                features = torch.nn.functional.pad(features, p2d, "constant", 0.0)
            case (_, s1, _), (d1, _) if s1 < d1:
                p2d = (0, 0, 0, d1 - s1, 0, 0)
                features = torch.nn.functional.pad(features, p2d, "constant", 0.0)
        labels = torch.tensor([x["labels"] for x in batch])

        return {"features": features, "labels": labels}


# capturing config from "closure"
def get_dataloader(*dataset, **kwargs) -> PaddingDataloader:
    if len(dataset) == 1:
        dataset = dataset[0]
    else:
        dataset = ConcatDataset(dataset)
    return PaddingDataloader(
        dataset,
        feature_dim=config["feature_dim"],
        batch_size=config["batch_size"],
        **kwargs,
    )

In [None]:
models_in_domain = {}
for domain in tqdm(domains, position=0, leave=False):
    seed_everything(config["seed"])

    train_dataloader = get_dataloader(train_datasets[domain], shuffle=True)
    eval_dataloader = get_dataloader(eval_datasets[domain])

    model = CNNDocumentClassficationModel(**config)
    trainer = Trainer(
        max_epochs=config["max_epochs"],
        logger=pl_loggers.TensorBoardLogger(
            save_dir=f"logs/in_domain/{type(featurizer).__name__}",
            name=domain,
        ),
        gradient_clip_val=config["gradient_clip_val"],
        callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
        deterministic=True,
    )

    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        val_dataloaders=eval_dataloader,
    )
    models_in_domain[domain] = model


In [None]:
results_in_domain = {}
for domain in tqdm(domains, position=0, leave=False):
    trainer.progress_bar_callback.disable()
    metrics = []
    for other in domains:
        trainer.validate(
            models_in_domain[domain],
            get_dataloader(
                # train_datasets[other],
                eval_datasets[other],
            ),
            verbose=False,
        )
        metrics.append(
            {"other": other}
            | trainer.test(
                models_in_domain[domain],
                get_dataloader(test_datasets[other]),
                verbose=False,
            )[0]
        )
    trainer.progress_bar_callback.enable()

    results_in_domain[domain] = metrics
    print(domain, metrics)

In [None]:
# import pandas as pd

# _metric = "test_roc_auc"


# results = []
# for domain in domains:
#     results.append(
#         [
#             results_in_domain[domain]["metrics"][i][_heatmap_metric]
#             for i in range(len(domains))
#         ]
#     )

# df = pd.DataFrame(results, columns=domains, index=domains)
# df

## All-Domain Training & Evaluation

In [None]:
seed_everything(config["seed"])
train_dataloader = get_dataloader(*train_datasets.values(), shuffle=True)
eval_dataloader = get_dataloader(*eval_datasets.values())
test_dataloader = get_dataloader(*test_datasets.values())

model = CNNDocumentClassficationModel(**config)
trainer = Trainer(
    max_epochs=config["max_epochs"],
    logger=pl_loggers.TensorBoardLogger(
        save_dir=f"logs/all_domains/{type(featurizer).__name__}",
    ),
    gradient_clip_val=config["gradient_clip_val"],
    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
    deterministic=True,
)
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=eval_dataloader,
)
models_in_domain["all"] = model

In [None]:
trainer.progress_bar_callback.disable()
metrics_all_domains = []
for other, dataset in test_datasets.items():
    trainer.validate(
        models_in_domain["all"],
        get_dataloader(
            # train_datasets[other],
            eval_datasets[other],
        ),
        verbose=False,
    )
    metrics_all_domains.append(
        {"other": other}
        | trainer.test(models_in_domain["all"], get_dataloader(dataset), verbose=False)[
            0
        ]
    )
trainer.validate(models_in_domain["all"], eval_dataloader, verbose=False)
metrics_all_domains += [
    {
        "other": "ALL",
        **trainer.test(models_in_domain["all"], test_dataloader, verbose=False)[0],
    }
]
trainer.progress_bar_callback.enable()

In [None]:
def df_to_latex_heatmap(_df: pd.DataFrame):
    print(
        "\\plotHeatmap{"
        + ",".join(_df.index)
        + "}{%\n    "
        + ",%\n    ".join(
            [
                "{"
                + ",".join(f"{val:.4f}/{round(val, 2):.2f}" for val in row[1:])
                + "}"
                for row in _df.reset_index().values
            ]
        )
        + "%\n}{"
        + ",".join(_df.columns)
        + "}"
    )


def plot_heatmap(_df: pd.DataFrame):
    ax = sns.heatmap(
        _df,
        annot=True,
        fmt=".2f",
        vmax=1.0,
        vmin=0.0,
        cmap=sns.cubehelix_palette(rot=-0.2, as_cmap=True),
        yticklabels=list(domains) + ["ALL"],
        xticklabels=list(domains) + ["AVG"],
        square=True,
        # reduce annotation font size
        annot_kws={"fontsize": 8},
        cbar=False,
    )

    # rotate x-axis labels by 45 degrees
    # anchored at the right edge of the axes
    for tick in ax.get_xticklabels():
        tick.set_rotation(45)
        tick.set_horizontalalignment("right")

    plt.tight_layout()
    plt.show()


def get_df_from_metrics(
    _metrics_in_domain, _metrics_all_domains, _metric: str = "test_f1@best"
) -> pd.DataFrame:
    results = []
    for domain in domains:
        results.append(
            [
                _metrics_in_domain[domain][i][_metric]
                for i in range(len(domains))
            ]
        )
    results.append([m[_metric] for m in _metrics_all_domains[: len(domains)]])

    _df = pd.DataFrame(results, columns=list(domains), index=list(domains) + ["ALL"])
    _df["AVG"] = _df.mean(axis=1)

    # fix diagonale by replacing average of inter-domain metrics with actual value
    _df["AVG"][-1] = _metrics_all_domains[-1][_metric]
    return _df

In [None]:
config

In [None]:
_metric = "test_f1@0.5"
# _metric = "test_acc@0.5"

df = get_df_from_metrics(results_in_domain, metrics_all_domains, _metric)
plot_heatmap(df)
df_to_latex_heatmap(df)

In [None]:
_metric = "test_f1@best"
# _metric = "test_acc@best"

df = get_df_from_metrics(results_in_domain, metrics_all_domains, _metric)
plot_heatmap(df)
df_to_latex_heatmap(df)

In [None]:
_metric = "test_roc_auc"

df = get_df_from_metrics(results_in_domain, metrics_all_domains, _metric)
plot_heatmap(df)
df_to_latex_heatmap(df)

In [None]:
config

In [None]:
from sklearn.metrics import f1_score


def predict(model, dataloader):
    preds = []
    labels = []
    with torch.no_grad():
        for batch in dataloader:
            labels.extend(batch["labels"].tolist())
            preds.extend(model(batch["features"]).sigmoid().tolist())

    preds = np.array(preds)
    labels = np.array(labels)
    return labels, preds


def get_f1_threshold(labels: np.ndarray, preds: np.ndarray) -> float:
    thresholds = np.linspace(0.25, 1, 7501)
    preds_thresholded = (preds.T > thresholds.reshape(-1, 1)).astype(float)
    tp = np.sum(preds_thresholded[:, labels == 0] == 0, axis=1)
    # tn = np.sum(preds_thresholded[:, labels == 1] == 1, axis=1)
    fp = np.sum(preds_thresholded[:, labels == 0] == 1, axis=1)
    fn = np.sum(preds_thresholded[:, labels == 1] == 0, axis=1)

    f1_thresholded = 2 * tp / (2 * tp + fp + fn)
    f1_thresholded[np.isnan(f1_thresholded)] = 0
    f1_threshold = thresholds[np.argmax(f1_thresholded)]
    return f1_threshold


manual_evaluation = []
for domain in tqdm(list(domains.keys()) + ["all"]):
    model = models_in_domain[domain]
    manual_evaluation_row = []
    for other in domains:
        # labels, preds = predict(model, get_dataloader(eval_datasets[other]))
        labels, preds = predict(model, get_dataloader(test_datasets[other]))
        f1_threshold = get_f1_threshold(labels, preds)
        # labels, preds = predict(model, get_dataloader(test_datasets[other]))
        manual_evaluation_row.append(f1_score(labels, preds > 0.5, average="binary"))
    manual_evaluation.append(manual_evaluation_row)

In [None]:
df = pd.DataFrame(manual_evaluation, columns=domains, index=list(domains.keys()) + ["all"])
df

## Cross-Domain Training & Evaluation

In [None]:
results_cross_domain = {}
for domain in tqdm(domains, position=0, leave=True):
    seed_everything(config["seed"])

    train_other = [train_datasets[other] for other in domains if other != domain]
    eval_other = [eval_datasets[other] for other in domains if other != domain]
    train_dataloader = get_dataloader(*train_other, shuffle=True)
    eval_dataloader = get_dataloader(*eval_other)

    model = CNNDocumentClassficationModel(**config)
    trainer = Trainer(
        max_epochs=config["max_epochs"],
        logger=pl_loggers.TensorBoardLogger(
            save_dir=f"logs/in_domain/{type(featurizer).__name__}",
            name=domain,
        ),
        gradient_clip_val=config["gradient_clip_val"],
        callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
        deterministic=True,
    )

    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        val_dataloaders=eval_dataloader,
    )

    trainer.progress_bar_callback.disable()
    metrics = []
    for other, dataset in test_datasets.items():
        trainer.validate(model, get_dataloader(eval_datasets[other]), verbose=False)
        metrics.append(
            {
                "other": other,
            }
            | trainer.test(
                model,
                get_dataloader(dataset),
                verbose=False,
            )[0]
        )
    trainer.progress_bar_callback.enable()

    results_cross_domain[domain] = {
        "domain": domain,
        "config": copy.deepcopy(config),
        "metrics": metrics,
    }
    print(domain, metrics)


In [None]:
_metric = "test_roc_auc"
# _metric = "test_acc@best"

results = []
for domain in domains:
    results.append(
        [
            results_cross_domain[domain]["metrics"][i][_metric]
            for i in range(len(domains))
        ]
    )

df = pd.DataFrame(results, columns=list(domains), index=list(domains))
df["AVG"] = df.mean(axis=1)
df

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

ax = sns.heatmap(
    df,
    annot=True,
    fmt=".2f",
    vmax=1.0,
    vmin=0.0,
    cmap=sns.cubehelix_palette(rot=-0.2, as_cmap=True),
    yticklabels=list(domains),
    xticklabels=list(domains) + ["AVG"],
    square=True,
    # reduce annotation font size
    annot_kws={"fontsize": 8},
    cbar=False,
)

# rotate x-axis labels by 45 degrees
# anchored at the right edge of the axes
for tick in ax.get_xticklabels():
    tick.set_rotation(45)
    tick.set_horizontalalignment("right")

plt.tight_layout()
# plt.savefig(
#     "../figures/evaluation-trained_in_domain-test_0.1-gpt2_256-rand_4-il_13_as_channels.pdf",
#     dpi=300,
# )
plt.show()

In [None]:
raise RuntimeError("Stop here.")

## Baselines

In [None]:
from luminar.mongo import MongoFindDataset

metric_datasets = {
    name: MongoFindDataset(
        {
            "model.name": config["feature_model"],
            "document.agent": {"$in": [None, "gpt-4o-mini"]},
            # "document.agent": {"$in": [None, "gemma2:9b"]},
            "document.domain": kwargs["domain"],
            "document.type": {"$in": ["source", "fulltext"]},
        },
        projection={"metrics": 1, "type": "$document.type"},
        mongo_db_connection=os.environ.get("MONGO_DB_CONNECTION"),
        database="prismai",
        collection="features_prismai",
        update_cache=True,
    ).load()
    for name, kwargs in domains.items()
}


In [None]:
from sklearn.metrics import auc, f1_score, roc_curve

results = {
    "domain": [],
    "llr_auroc": [],
    "llr_f1": [],
    "fdg_auroc": [],
    "fdg_f1": [],
}
for name, dataset in metric_datasets.items():
    metrics = [
        (
            x["metrics"][0]["llr"],
            x["metrics"][0]["fast_detect_gpt"],
            int(x["type"] != "source"),
        )
        for x in dataset
    ]
    llr, fdg, labels = zip(*metrics)

    fpr, tpr, _ = roc_curve(labels, llr)
    llr_auroc = auc(fpr, tpr)

    human_mean, ai_mean = np.mean(llr[labels==0]), np.mean(llr[labels==1])
    threshold = (human_mean + ai_mean) / 2
    llr_f1 = max(f1_score(labels, llr > threshold), f1_score(labels, llr < threshold))

    fpr, tpr, _ = roc_curve(labels, fdg)
    fdg_auroc = auc(fpr, tpr)

    human_mean, ai_mean = np.mean(fdg[labels==0]), np.mean(fdg[labels==1])
    threshold = (human_mean + ai_mean) / 2
    fdg_f1 = max(f1_score(labels, fdg > threshold), f1_score(labels, fdg < threshold))

    results["domain"].append(name)
    results["llr_auroc"].append(llr_auroc)
    results["llr_f1"].append(llr_f1)
    results["fdg_auroc"].append(fdg_auroc)
    results["fdg_f1"].append(fdg_f1)

metric_df = pd.DataFrame.from_dict(results).T
print(metric_df.to_latex(float_format="\\np{%.3f}"))
metric_df

In [None]:
import seaborn as sns

sns.histplot(x=fdg, y=["Human" if x else "AI" for x in labels], hue=labels, kde=True)

### Ad-Hoc LLR


In [None]:
from luminar.baselines import llr_from_transition_scores
from simple_dataset import Dataset as SimpleDataset
from transition_scores.data import FeatureValues


def precompute_llr(split: list[dict]):
    lrr_dataset = (
        SimpleDataset(split)
        .flat_map(lambda doc: doc["features"])
        .map(
            lambda x: {
                "llr": llr_from_transition_scores(
                    FeatureValues(**x["transition_scores"])
                ),
                "labels": int(x["type"] != "source"),
            },
            in_place=False,
        )
    )

    llr = np.array(lrr_dataset["llr"])
    labels = np.array(lrr_dataset["labels"])

    return {"llr": llr, "labels": labels}


In [None]:
import pandas as pd
from sklearn.metrics import auc, roc_curve


def llr_metrics(llr: np.ndarray, labels: np.ndarray):
    mean_0 = float(np.mean(llr[labels == 0]))
    mean_1 = float(np.mean(llr[labels == 1]))
    threshold_simple = mean_0 + (mean_1 - mean_0) / 2
    acc_simple = np.mean((llr > threshold_simple) == labels)

    # thresholds = np.linspace(round(mean_0, 1) - 0.2, round(mean_1, 1) + 0.3, 1001)

    threshold_space = np.linspace(llr.min(), llr.max(), 10001)
    preds_thresholded: np.ndarray = llr > threshold_space.reshape(-1, 1)
    acc_thresholded = np.mean((preds_thresholded == labels), axis=1)
    idx = np.argmax(acc_thresholded)
    threshold_best = threshold_space[idx]
    acc_best = acc_thresholded[idx]

    fpr, tpr, _ = roc_curve(labels, llr)
    auroc = auc(fpr, tpr)

    return {
        "acc_simple": acc_simple,
        "threshold_simple": threshold_simple,
        "acc_best": acc_best,
        "threshold_best": threshold_best,
        "auroc": auroc,
    }


### LLR on Whole Datasets

In [None]:
llr_datasets = {}
for domain, dataset in tqdm(datasets.items()):
    llr_datasets[domain] = precompute_llr(dataset)

In [None]:
results_llr = []
for domain, split in llr_datasets.items():
    results_llr.append({"domain": domain} | llr_metrics(split["llr"], split["labels"]))

pd.DataFrame(results_llr)

### LLR on Test Splits

In [None]:
llr_test_datasets = {}
for domain, test_dataset in tqdm(test_splits.items()):
    llr_test_datasets[domain] = precompute_llr(test_dataset)

In [None]:
results_test_llr = []
for domain, split in llr_test_datasets.items():
    results_test_llr.append(
        {"domain": domain} | llr_metrics(split["llr"], split["labels"])
    )

pd.DataFrame(results_test_llr)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

ax = sns.heatmap(
    df,
    annot=True,
    fmt=".2f",
    vmax=1.0,
    vmin=0.0,
    cmap=sns.cubehelix_palette(rot=-0.2, as_cmap=True),
    yticklabels=list(domains) + ["ALL"],
    xticklabels=list(domains) + ["AVG"],
    square=True,
    # reduce annotation font size
    annot_kws={"fontsize": 8},
    cbar=False,
)

# rotate x-axis labels by 45 degrees
# anchored at the right edge of the axes
for tick in ax.get_xticklabels():
    tick.set_rotation(45)
    tick.set_horizontalalignment("right")

plt.tight_layout()
# plt.savefig(
#     "../figures/evaluation-trained_in_domain-test_0.1-gpt2_256-rand_4-il_13_as_channels.pdf",
#     dpi=300,
# )
plt.show()