In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [2]:
import json
from pathlib import Path
from typing import Callable

import torch
from datasets import Dataset, DatasetDict, load_dataset
from numpy.typing import NDArray
from torch import Tensor, nn
from tqdm import tqdm
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments

from luminar.classifier import ConvolutionalLayerSpec, LuminarCNN, LuminarCNN2D
from luminar.utils import (
    PaddingDataCollator,
    compute_metrics,
    get_matched_datasets,
    get_pad_to_fixed_length_fn,
    save_model,
)

HF_TOKEN = (Path.home() / ".hf_token").read_text().strip()

### Encoder

In [3]:
# from luminar.encoder import LuminarEncoder


# encoder = LuminarEncoder()
# encoder.device = "cuda:0"

### Classifier

In [4]:
feature_len = 256
seed = 42

pad_to_fixed_length: Callable[[NDArray], NDArray] = get_pad_to_fixed_length_fn(
    feature_len
)

In [5]:
agent = "gpt_4o_mini"
other_agents = "gemma2_9b"
domain = "blog_authorship_corpus"

In [6]:
datset_config_name = f"{domain}-fulltext"
dataset_split_name = f"human+{agent}+{other_agents}"
dataset: Dataset = (
    load_dataset(
        "liberi-luminaris/PrismAI-encoded-gpt2",
        datset_config_name,
        split=dataset_split_name,
        token=HF_TOKEN,
    )  # type: ignore
    .rename_column("label", "labels")
    .filter(
        lambda features: len(features) > 0,
        input_columns=["features"],
        num_proc=8,
    )
    .with_format("numpy", columns=["features"])
    .map(
        lambda features: {"features": pad_to_fixed_length(features)},
        input_columns=["features"],
        desc="Trimming & Padding Features",
        num_proc=8,
    )
)
dataset

Generating o3_mini split:   0%|          | 0/275 [00:00<?, ? examples/s]

Generating human split:   0%|          | 0/18614 [00:00<?, ? examples/s]

Generating nemotron split:   0%|          | 0/1265 [00:00<?, ? examples/s]

Generating gemma2_9b split:   0%|          | 0/14674 [00:00<?, ? examples/s]

Generating gpt_4o_mini split:   0%|          | 0/4576 [00:00<?, ? examples/s]

Generating deepseek_r1_1.5b split:   0%|          | 0/182 [00:00<?, ? examples/s]

Generating deepseek_r1_32b split:   0%|          | 0/71 [00:00<?, ? examples/s]

Generating phi3_3.8b split:   0%|          | 0/3268 [00:00<?, ? examples/s]

Filter (num_proc=8):   0%|          | 0/37864 [00:00<?, ? examples/s]

Trimming & Padding Features (num_proc=8):   0%|          | 0/37864 [00:00<?, ? examples/s]

Dataset({
    features: ['agent', 'id_sample', 'id_source', 'labels', 'length', 'features'],
    num_rows: 37864
})

In [7]:
datasets_matched, dataset_unmatched = get_matched_datasets(dataset, agent)
datasets_matched.set_format("torch", columns=["labels", "features"])
dataset_unmatched.set_format("torch", columns=["labels", "features"])

Filter:   0%|          | 0/37864 [00:00<?, ? examples/s]

Filter:   0%|          | 0/37864 [00:00<?, ? examples/s]

Filter:   0%|          | 0/37864 [00:00<?, ? examples/s]

Filter:   0%|          | 0/37864 [00:00<?, ? examples/s]

Filter:   0%|          | 0/37864 [00:00<?, ? examples/s]

Filter:   0%|          | 0/37864 [00:00<?, ? examples/s]

In [8]:
datasets_matched

DatasetDict({
    train: Dataset({
        features: ['agent', 'id_sample', 'id_source', 'labels', 'length', 'features'],
        num_rows: 6406
    })
    eval: Dataset({
        features: ['agent', 'id_sample', 'id_source', 'labels', 'length', 'features'],
        num_rows: 914
    })
    test: Dataset({
        features: ['agent', 'id_sample', 'id_source', 'labels', 'length', 'features'],
        num_rows: 1832
    })
})

In [9]:
config = {
    # first 256 features & 13 layers for gpt2
    "feature_dim": (feature_len, 13),
    "feature_type": "intermediate_likelihoods",
    "feature_selection": "first",
    "conv_layer_shapes": (
        ConvolutionalLayerSpec(32, 5),
        ConvolutionalLayerSpec(64, 5),
        ConvolutionalLayerSpec(32, 3),
    ),
    "projection_dim": (1024, 32),
    "learning_rate": 5e-4,
    "max_epochs": 25,
    "gradient_clip_val": 1.0,
    "train_batch_size": 32,
    "eval_batch_size": 1024,
    "warmup_ratio": 1.0,
    "seed": seed,
    "agent": agent,
    "domain": domain,
}

In [10]:
training_args = TrainingArguments(
    output_dir="../logs/hf/" + hex(hash(json.dumps(config)))[2:],
    per_device_train_batch_size=config["train_batch_size"],
    per_device_eval_batch_size=config["eval_batch_size"],
    learning_rate=config["learning_rate"],
    num_train_epochs=config["max_epochs"],
    warmup_ratio=config["warmup_ratio"],
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    eval_strategy="steps",
    eval_steps=50,
    eval_delay=100,
    save_strategy="steps",
    save_steps=50,
    torch_compile=True,
    torch_compile_mode="reduce-overhead",
)

The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.


In [11]:
classifier = LuminarCNN(**config)

print(classifier)
print(
    "Parameters:\n  conv_layers:",
    sum(
        param.numel()
        for param in classifier.conv_layers.parameters()
        if param.requires_grad
    ),
    "\n  projection:",
    sum(
        param.numel()
        for param in classifier.projection.parameters()
        if param.requires_grad
    ),
    "\n  classifier:",
    sum(
        param.numel()
        for param in classifier.classifier.parameters()
        if param.requires_grad
    ),
    "\n  total:",
    sum(param.numel() for param in classifier.parameters() if param.requires_grad),
)

LuminarCNN(
  (conv_layers): Sequential(
    (0): Conv1d(13, 32, kernel_size=(5,), stride=(1,), padding=(2,))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=(2,))
    (3): LeakyReLU(negative_slope=0.01)
    (4): Conv1d(64, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (5): LeakyReLU(negative_slope=0.01)
  )
  (projection): Sequential(
    (0): Linear(in_features=32, out_features=1024, bias=True)
    (1): SiLU()
    (2): Linear(in_features=1024, out_features=32, bias=True)
    (3): SiLU()
    (4): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Linear(in_features=8192, out_features=1, bias=True)
  (criterion): BCEWithLogitsLoss()
)
Parameters:
  conv_layers: 18592 
  projection: 66592 
  classifier: 8193 
  total: 93377


In [None]:
# conv_layer_spec = (
#     ConvolutionalLayerSpec(8, (7, 1)),
#     ConvolutionalLayerSpec(16, (7, 1)),
#     ConvolutionalLayerSpec(32, 7),
#     ConvolutionalLayerSpec(64, 5),
#     ConvolutionalLayerSpec(32, 3),
#     ConvolutionalLayerSpec(3, 3),
# )

# classifier = LuminarCNN2D(conv_layer_shapes=conv_layer_spec, **config)
# print(classifier)
# print("num. parameters:", sum(1 for p in classifier.parameters() if p.requires_grad))

In [12]:
trainer = Trainer(
    model=classifier,
    args=training_args,
    train_dataset=datasets_matched["train"],
    eval_dataset=datasets_matched["eval"],
    # data_collator=PaddingDataCollator(config["feature_dim"]),
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(5)],
)

W0527 12:46:45.950000 15692 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Step,Training Loss,Validation Loss,F1 Human,F1 Ai,F1 Weighted,Accuracy,Roc Auc,F1 Human Median,F1 Ai Median,F1 Weighted Median,Accuracy Median,Roc Auc Median,Threshold Median,Ground Truth Human,Ground Truth Ai
100,0.6931,0.693188,0.0,0.666667,0.333333,0.5,0.5,0.515847,0.514786,0.515317,0.515317,0.515317,0.504259,457,457
150,0.6933,0.693159,0.0,0.666667,0.333333,0.5,0.5,0.518033,0.516977,0.517505,0.517505,0.517505,0.502709,457,457
200,0.6931,0.693216,0.666667,0.0,0.333333,0.5,0.5,0.522976,0.522976,0.522976,0.522976,0.522976,0.493509,457,457
250,0.6932,0.693218,0.666667,0.0,0.333333,0.5,0.5,0.538293,0.538293,0.538293,0.538293,0.538293,0.492708,457,457
300,0.6927,0.693485,0.666667,0.0,0.333333,0.5,0.5,0.555799,0.555799,0.555799,0.555799,0.555799,0.485933,457,457
350,0.6938,0.693044,0.0,0.666667,0.333333,0.5,0.5,0.56674,0.56674,0.56674,0.56674,0.56674,0.501967,457,457
400,0.6929,0.692967,0.0,0.666667,0.333333,0.5,0.5,0.560175,0.560175,0.560175,0.560175,0.560175,0.505519,457,457
450,0.6932,0.692609,0.157895,0.654321,0.406108,0.509847,0.509847,0.571116,0.571116,0.571116,0.571116,0.571116,0.502338,457,457
500,0.6924,0.691161,0.394326,0.619768,0.507047,0.532823,0.532823,0.562363,0.562363,0.562363,0.562363,0.562363,0.505605,457,457
550,0.6901,0.685972,0.342065,0.637829,0.489947,0.532823,0.532823,0.564551,0.564551,0.564551,0.564551,0.564551,0.530294,457,457


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled
early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


Evaluating on eval set
{
    "eval_loss": 0.4723757207393646,
    "eval_f1_human": 0.7631296891747053,
    "eval_f1_ai": 0.753072625698324,
    "eval_f1_weighted": 0.7581011574365147,
    "eval_accuracy": 0.7582056892778993,
    "eval_roc_auc": 0.7582056892778993,
    "eval_f1_human_median": 0.7658643326039387,
    "eval_f1_ai_median": 0.7658643326039387,
    "eval_f1_weighted_median": 0.7658643326039387,
    "eval_accuracy_median": 0.7658643326039387,
    "eval_roc_auc_median": 0.7658643326039387,
    "eval_threshold_median": 0.4669828712940216,
    "eval_ground_truth_human": 457,
    "eval_ground_truth_ai": 457,
    "eval_runtime": 3.0184,
    "eval_samples_per_second": 302.81,
    "eval_steps_per_second": 0.331,
    "epoch": 13.18407960199005
}
Evaluating on test set


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


{
    "test_loss": 0.48548758029937744,
    "test_f1_human": 0.7638085218306154,
    "test_f1_ai": 0.7453204764605785,
    "test_f1_weighted": 0.754564499145597,
    "test_accuracy": 0.7549126637554585,
    "test_roc_auc": 0.7549126637554585,
    "test_f1_human_median": 0.7554585152838428,
    "test_f1_ai_median": 0.7554585152838428,
    "test_f1_weighted_median": 0.7554585152838428,
    "test_accuracy_median": 0.7554585152838428,
    "test_roc_auc_median": 0.7554585152838428,
    "test_threshold_median": 0.4562075734138489,
    "test_ground_truth_human": 916,
    "test_ground_truth_ai": 916,
    "test_runtime": 5.9088,
    "test_samples_per_second": 310.045,
    "test_steps_per_second": 0.338,
    "epoch": 13.18407960199005
}
Evaluating on unmatched set


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


{
    "unmatched_loss": 0.711101770401001,
    "unmatched_f1_human": 0.7261893381767199,
    "unmatched_f1_ai": 0.4997380217733015,
    "unmatched_f1_weighted": 0.6306656815620655,
    "unmatched_accuracy": 0.6460873146622734,
    "unmatched_roc_auc": 0.6153920472896133,
    "unmatched_f1_human_mean": 0.6853739055547582,
    "unmatched_f1_ai_mean": 0.5762613570462014,
    "unmatched_f1_weighted_mean": 0.6393471048206297,
    "unmatched_accuracy_mean": 0.6388797364085668,
    "unmatched_roc_auc_mean": 0.6312046035450574,
    "unmatched_threshold_mean": 0.35619881749153137,
    "unmatched_ground_truth_human": 14038,
    "unmatched_ground_truth_ai": 10242,
    "unmatched_runtime": 78.9497,
    "unmatched_samples_per_second": 307.537,
    "unmatched_steps_per_second": 0.304,
    "epoch": 13.18407960199005
}


In [13]:
trainer.train()
trainer._load_best_model()
classifier = trainer.model

metrics_eval = trainer.evaluate()
metrics_test = trainer.evaluate(datasets_matched["test"], metric_key_prefix="test")
metrics_unmatched = trainer.evaluate(dataset_unmatched, metric_key_prefix="unmatched")

path = save_model(trainer, config)

print("Evaluating on eval set")
metrics_eval = trainer.evaluate()
print(json.dumps(metrics_eval, indent=4))

print("Evaluating on test set")
metrics_test = trainer.evaluate(
    datasets_matched["test"],  # type: ignore
    metric_key_prefix="test",
)
print(json.dumps(metrics_test, indent=4))

print("Evaluating on unmatched set")
metrics_unmatched = trainer.evaluate(
    dataset_unmatched,  # type: ignore
    metric_key_prefix="unmatched",
)
print(json.dumps(metrics_unmatched, indent=4))

Step,Training Loss,Validation Loss,Model Preparation Time,F1 Human,F1 Ai,F1 Weighted,Accuracy,Roc Auc,F1 Human Median,F1 Ai Median,F1 Weighted Median,Accuracy Median,Roc Auc Median,Threshold Median,Ground Truth Human,Ground Truth Ai
100,0.4349,0.464132,0.0028,0.76516,0.76873,0.766945,0.766958,0.766958,0.768053,0.768053,0.768053,0.768053,0.768053,0.509632,457,457
150,0.4304,0.464665,0.0028,0.771214,0.762542,0.766878,0.766958,0.766958,0.768053,0.768053,0.768053,0.768053,0.768053,0.469459,457,457
200,0.4292,0.467792,0.0028,0.758465,0.772824,0.765644,0.765864,0.765864,0.765864,0.765864,0.765864,0.765864,0.765864,0.541059,457,457
250,0.418,0.464086,0.0028,0.767721,0.766191,0.766956,0.766958,0.766958,0.765864,0.765864,0.765864,0.765864,0.765864,0.494606,457,457
300,0.4521,0.46294,0.0028,0.764904,0.777423,0.771164,0.771335,0.771335,0.772429,0.772429,0.772429,0.772429,0.772429,0.526848,457,457
350,0.4085,0.465042,0.0028,0.775293,0.762655,0.768974,0.769147,0.769147,0.772429,0.772429,0.772429,0.772429,0.772429,0.462333,457,457
400,0.4641,0.46456,0.0028,0.762014,0.781971,0.771992,0.772429,0.772429,0.768053,0.768053,0.768053,0.768053,0.768053,0.564801,457,457
450,0.4422,0.459317,0.0028,0.776931,0.774477,0.775704,0.775711,0.775711,0.774617,0.774617,0.774617,0.774617,0.774617,0.492813,457,457
500,0.4199,0.461984,0.0028,0.763636,0.780591,0.772114,0.772429,0.772429,0.768053,0.768053,0.768053,0.768053,0.768053,0.546638,457,457
550,0.4552,0.460499,0.0028,0.771186,0.755656,0.763421,0.763676,0.763676,0.776805,0.776805,0.776805,0.776805,0.776805,0.446511,457,457


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled
early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


Evaluating on eval set
{
    "eval_loss": 0.4396685063838959,
    "eval_model_preparation_time": 0.0028,
    "eval_f1_human": 0.778021978021978,
    "eval_f1_ai": 0.7799564270152506,
    "eval_f1_weighted": 0.7789892025186143,
    "eval_accuracy": 0.7789934354485777,
    "eval_roc_auc": 0.7789934354485777,
    "eval_f1_human_median": 0.7789934354485777,
    "eval_f1_ai_median": 0.7789934354485777,
    "eval_f1_weighted_median": 0.7789934354485777,
    "eval_accuracy_median": 0.7789934354485777,
    "eval_roc_auc_median": 0.7789934354485777,
    "eval_threshold_median": 0.5049842596054077,
    "eval_ground_truth_human": 457,
    "eval_ground_truth_ai": 457,
    "eval_runtime": 3.1784,
    "eval_samples_per_second": 287.563,
    "eval_steps_per_second": 0.315,
    "epoch": 8.208955223880597
}
Evaluating on test set


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


{
    "test_loss": 0.4434773027896881,
    "test_model_preparation_time": 0.0028,
    "test_f1_human": 0.7862881628280665,
    "test_f1_ai": 0.7779632721202003,
    "test_f1_weighted": 0.7821257174741334,
    "test_accuracy": 0.7822052401746725,
    "test_roc_auc": 0.7822052401746725,
    "test_f1_human_median": 0.7805676855895196,
    "test_f1_ai_median": 0.7805676855895196,
    "test_f1_weighted_median": 0.7805676855895196,
    "test_accuracy_median": 0.7805676855895196,
    "test_roc_auc_median": 0.7805676855895196,
    "test_threshold_median": 0.4700171649456024,
    "test_ground_truth_human": 916,
    "test_ground_truth_ai": 916,
    "test_runtime": 5.9346,
    "test_samples_per_second": 308.7,
    "test_steps_per_second": 0.337,
    "epoch": 8.208955223880597
}
Evaluating on unmatched set


early stopping required metric_for_best_model, but did not find eval_loss so early stopping is disabled


{
    "unmatched_loss": 0.6880214214324951,
    "unmatched_model_preparation_time": 0.0028,
    "unmatched_f1_human": 0.7347699552009418,
    "unmatched_f1_ai": 0.5488625618777463,
    "unmatched_f1_weighted": 0.6563488875561243,
    "unmatched_accuracy": 0.6659390444810543,
    "unmatched_roc_auc": 0.6410347646507168,
    "unmatched_f1_human_mean": 0.703363477141938,
    "unmatched_f1_ai_mean": 0.5968031870961473,
    "unmatched_f1_weighted_mean": 0.6584132921893437,
    "unmatched_accuracy_mean": 0.6581960461285008,
    "unmatched_roc_auc_mean": 0.6502854388675099,
    "unmatched_threshold_mean": 0.385267436504364,
    "unmatched_ground_truth_human": 14038,
    "unmatched_ground_truth_ai": 10242,
    "unmatched_runtime": 77.2966,
    "unmatched_samples_per_second": 314.115,
    "unmatched_steps_per_second": 0.31,
    "epoch": 8.208955223880597
}


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.image import AxesImage
from numpy.typing import NDArray

cubehelix = sns.cubehelix_palette(as_cmap=True)


def visualize_features(features: NDArray, cmap=cubehelix, size=4) -> AxesImage:
    a = features.shape[0] / features.shape[1]
    h, w = (size, size * a) if a > 1 else (size, size / a)
    fig, ax = plt.subplots(figsize=(h, w))
    fig = ax.imshow(
        features,
        cmap=cmap,
        vmin=min(0.0, features.min()),
        vmax=max(1.0, features.max()),
        aspect="equal",
    )
    fig.axes.set_axis_off()
    plt.tight_layout()
    return fig


In [None]:
sample = datasets_matched["test"][0]
print(sample["labels"])
features = sample["features"]
print(features.shape)
visualize_features(features.T)
plt.show()

conv = (
    classifier.conv_layers[:-1](features.T.unsqueeze(0).to("cuda")).cpu().detach()[0].T
)
print(conv.shape)
visualize_features(conv.clip(0, 1).numpy().T)
plt.show()

ff = classifier.projection[:-1](conv.unsqueeze(0).to("cuda")).cpu().detach()[0]
print(ff.shape)
visualize_features(ff.numpy().T)
plt.show()

In [None]:
sample = datasets_matched["test"][-1]
print(sample["labels"])
features = sample["features"]
print(features.shape)
visualize_features(features.T)
plt.show()

conv = (
    classifier.conv_layers[:-1](features.T.unsqueeze(0).to("cuda")).cpu().detach()[0].T
)
print(conv.shape)
visualize_features(conv.clip(0, 1).numpy().T)
plt.show()

ff = classifier.projection[:-1](conv.unsqueeze(0).to("cuda")).cpu().detach()[0]
print(ff.shape)
visualize_features(ff.numpy().T)
plt.show()

In [None]:
raise RuntimeError()

In [None]:
import json

print(json.dumps(scores, indent=4))
with open("../logs/luminar/gpt2_first_128-3_epochs.json", "w") as f:
    json.dump(scores, f, indent=4)

In [None]:
datasets = {}
for subset in [
    "blog_authorship_corpus",
    "student_essays",
    "cnn_news",
    "euro_court_cases",
    "house_of_commons",
    "arxiv_papers",
    "gutenberg_en",
    "en",
    "bundestag",
    "spiegel_articles",
    "gutenberg_de",
    "de",
]:
    datset_config_name = f"{subset}-fulltext"
    datasets[datset_config_name] = load_dataset(
        "liberi-luminaris/PrismAI-encoded-gpt2",
        datset_config_name,
        token=HF_TOKEN,
        split="human+gpt_4o_mini",
    )

In [None]:
dataset = load_dataset(
    "liberi-luminaris/PrismAI-fulltext", "cnn_news", split="human+gpt_4o_mini"
)
dataset_human = dataset.filter(lambda sample: sample["agent"] == "human")
source_ids = set(
    dataset_human.shuffle(seed=42).take(len(dataset_human) // 10 * 8)["id_source"]
)
dataset_train = dataset.filter(lambda sample: sample["id_source"] in source_ids)

In [None]:
datasets_truncated = {}
for datset_config_name, dataset in datasets.items():
    datasets_truncated[datset_config_name] = dataset.with_format(
        "numpy", columns=["features"], output_all_columns=True
    ).map(
        lambda batch: {"features": batch["features"][:, :256]},
        batched=True,
    )

In [None]:
datasets_considered = {
    key: value
    for key, value in datasets_truncated.items()
    if not key.startswith("de-") and not key.startswith("en-")
}

In [None]:
import pandas as pd

domains = [
    "Web Blogs",
    "Essays",
    "CNN",
    "ECHR",
    "HoC",
    "arXiv",
    "Gutenberg$_{en}$",
    "Bundestag$_{de}$",
    "Spiegel$_{de}$",
    "Gutenberg$_{de}$",
    "All$_{en}$",
    "All$_{de}$",
]
name_map = {
    "blog_authorship_corpus": "Web Blogs",
    "student_essays": "Essays",
    "cnn_news": "CNN",
    "euro_court_cases": "ECHR",
    "house_of_commons": "HoC",
    "arxiv_papers": "arXiv",
    "gutenberg_en": "Gutenberg$_{en}$",
    "bundestag": "Bundestag$_{de}$",
    "spiegel_articles": "Spiegel$_{de}$",
    "gutenberg_de": "Gutenberg$_{de}$",
    "en": "All$_{en}$",
    "de": "All$_{de}$",
}

results = [
    {"domain": name_map[key.split("-", 1)[0]]}
    | {
        "f1": value["f1"],
        "acc": value["accuracy"],
        "auroc": value["auroc"],
    }
    for key, value in scores.items()
]
metric_df = (
    pd.DataFrame(results)
    .set_index("domain")
    .sort_index(key=lambda x: list(map(domains.index, x)))
)
print(metric_df.to_latex(float_format="%.3f", index=True))
metric_df

In [None]:
# def run_detector(
#     detector: DetectorABC, datasets: dict[str, DatasetDict]
# ) -> dict[str, float]:
#     scores = {}
#     for config_name, ds in tqdm(datasets.items(), desc="Predicting on Datasets"):
#         dataset: Dataset = ds["test"].map(
#             detector.tokenize,
#             input_columns=["text"],
#             batched=True,
#             batch_size=1024,
#             desc="Tokenizing",
#         )
#         dataset = dataset.sort("length")
#         dataset = dataset.map(
#             detector.process,
#             batched=True,
#             batch_size=128,
#             desc="Predicting",
#         )

#         dataset_np = dataset.select_columns(["prediction", "label"]).with_format(
#             "numpy"
#         )

#         acc, f1, auroc = get_scores(dataset_np["label"], dataset_np["prediction"])
#         scores[config_name] = {"accuracy": acc, "f1": f1, "auroc": auroc}

#         acc, f1, auroc = get_scores(
#             dataset_np["label"],
#             dataset_np["prediction"],
#             calibrated=True,
#         )
#         scores[config_name] |= {
#             "accuracy_calibrated": acc,
#             "f1_calibrated": f1,
#             "auroc_calibrated": auroc,
#         }
#     return scores


In [None]:
# def evaluate(model: LuminarClassifier, datasets: dict[str, DatasetDict]) -> dict:
#     scores = {}
#     for config_name, dataset in tqdm(datasets.items(), desc="Evaluating", leave=False):
#         ds = (
#             dataset["test"]
#             .with_format("torch", ["features"])
#             .map(model.process, batched=True, batch_size=32, desc="Predicting")
#         )
#         dataset_np = ds.select_columns(["prediction", "label"]).with_format("numpy")

#         acc, f1, auroc = get_scores(dataset_np["label"], dataset_np["prediction"])
#         scores[config_name] = {
#             "accuracy": acc,
#             "f1": f1,
#             "auroc": auroc,
#         }

#         acc, f1, auroc = get_scores(
#             dataset_np["label"],
#             dataset_np["prediction"],
#             calibrated=True,
#         )
#         scores[config_name] |= {
#             "accuracy_calibrated": acc,
#             "f1_calibrated": f1,
#             "auroc_calibrated": auroc,
#         }

#     return scores