In [None]:
import copy
import os
import warnings

import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from tqdm.auto import tqdm

load_dotenv("../env")

warnings.filterwarnings("ignore", ".*does not have many workers.*")

In [None]:
from luminar.document.data import (
    FeatureDataset,
    PaddingDataloader,
    n_way_split,
)
from luminar.document.model import CNNDocumentClassficationModel, ConvolutionalLayerSpec
from luminar.features import FeatureExtractor, OneDimFeatures, Slicer, TwoDimFeatures
from luminar.mongo import PrismaiDataset

In [None]:
domains = {
    "Blog Authorship": {"domain": "blog_authorship_corpus"},
    "Student Essays": {"domain": "student_essays"},
    "CNN News": {"domain": "cnn_news"},
    "Euro Court Cases": {"domain": "euro_court_cases"},
    "House of Commons": {"domain": "house_of_commons"},
    "ArXiv Papers": {"domain": "arxiv_papers"},
    "Gutenberg": {"domain": "gutenberg", "lang": "en-EN"},
    "Bundestag [DE]": {"domain": "bundestag"},
    "Spiegel [DE]": {"domain": "spiegel_articles"},
}

In [None]:
datasets_with_config = {}

In [None]:
from hashlib import sha256


def load_datasets(_config: dict) -> dict[str, dict[str, FeatureDataset]]:
    seed = _config.get("seed", 42)
    effective_config = [f"seed={seed}"]

    feature_model = _config["feature_model"]
    effective_config.append(f"feature_model={feature_model}")

    synth_agent = _config["synth_agent"]
    effective_config.append(f"synth_agent={synth_agent}")

    feature_dim = _config["feature_dim"]
    effective_config.append(repr(feature_dim))

    first_dim = feature_dim[0]
    match _config["slicer"].lower():
        case "f" | "first":
            slicer = Slicer.First(first_dim)
        case "r" | "random":
            slicer = Slicer.Random(first_dim)
        case "multiple" | "randommultiple":
            multiple = _config["multiple"]
            stride = _config.get("stride", 1)
            sort = _config.get("sort", False)
            slicer = Slicer.RandomMultiple(
                first_dim // multiple,
                multiple=multiple,
                stride=stride,
                sort=sort,
            )
        case invalid:
            raise ValueError(f"Invalid slicer: {invalid}")
    effective_config.append(repr(slicer))

    match _config["featurizer"].lower():
        case "l" | "likelihood":
            featurizer = FeatureExtractor.Likelihood()
        case "il" | "intermediate" | "intermediatelikelihood":
            last_n = _config.get(
                "last_n", 13 if "gpt2" in _config["feature_model"] else 17
            )
            featurizer = FeatureExtractor.IntermediateLikelihood(last_n)
        case "llr" | "lllrr" | "loglikelihoodlogrankratio":
            featurizer = FeatureExtractor.LogLikelihoodLogRankRatio()
        case "tkllr" | "topklikelihoodlikelihoodratio":
            top_k = _config.get("top_k", 13)
            featurizer = FeatureExtractor.TopkLikelihoodLikelihoodRatio(top_k)
        case "ltklr" | "likelihoodtopklikelihoodratio":
            top_k = _config.get("top_k", 13)
            featurizer = FeatureExtractor.LikelihoodTopkLikelihoodRatio(top_k)
        case invalid:
            raise ValueError(f"Invalid featurizer: {invalid}")
    effective_config.append(repr(featurizer))

    num_samples = _config.get("num_samples", None)
    if num_samples:
        effective_config.append(f"num_samples={num_samples}")

    eval_split = _config.get("eval_split", 0.1)
    effective_config.append(f"eval_split={eval_split}")

    test_split = _config.get("test_split", 0.2)
    effective_config.append(f"test_split={test_split}")

    ################################################################################

    effective_config_hash = sha256(",".join(effective_config).encode()).hexdigest()
    print(
        f"Effective Config: {effective_config_hash}\n - {'\n - '.join(effective_config)}"
    )

    if effective_config_hash in datasets_with_config:
        return datasets_with_config[effective_config_hash]

    datasets = {}
    for domain, domain_config in tqdm(domains.items(), position=0):
        seed_everything(seed)
        datasets[domain] = dict(
            zip(
                ("train", "eval", "test"),
                (
                    FeatureDataset(
                        tqdm(dataset, position=1, leave=False),
                        slicer,
                        featurizer,
                        num_samples=num_samples,
                    )
                    for dataset in n_way_split(
                        PrismaiDataset(
                            os.getenv("MONGO_DB_CONNECTION"),
                            database="prismai",
                            collection="features_prismai",
                            feature_model=feature_model,
                            synth_agent=synth_agent,
                            **domain_config,
                            update_cache=True,  # TODO: Remove
                        ),
                        eval_split,
                        test_split,
                        infer_first=True,
                    )
                ),
            )
        )

    datasets_with_config[effective_config_hash] = datasets

    return datasets

## In-Domain Training & Evaluation

In [None]:
class ShiftUnitIntervalDataloader(PaddingDataloader):
    """
    Shift features from the unit interval [0, 1] to the interval [-1, 1].
    """

    def _collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]:
        batch = super()._collate_fn(batch)
        batch["features"] = batch["features"] * 2 - 1
        return batch
    
class FilterNanDataloader(PaddingDataloader):
    """
    Filter out NaN values from the features.
    """

    def _collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]:
        batch = super()._collate_fn(batch)
        batch["features"] = torch.nan_to_num(batch["features"], nan=0.0, posinf=1.0, neginf=-1.0)
        return batch

In [None]:
def train_in_domain(_config: dict, _config_name: str):
    datasets = load_datasets(_config)

    results_in_domain = []
    for domain, dataset in datasets.items():
        print(f"Training in domain: {domain}")
        seed_everything(_config["seed"])

        DataLoaderCls: type[PaddingDataloader] = (
            ShiftUnitIntervalDataloader
            if _config.get("shift_unit_interval", False)
            else FilterNanDataloader
            # else PaddingDataloader
        )

        train_dataloader = DataLoaderCls(
            dataset["train"],
            feature_dim=_config["feature_dim"],
            batch_size=_config.get("batch_size", 32),
            shuffle=True,
        )
        eval_dataloader = DataLoaderCls(
            dataset["eval"],
            feature_dim=_config["feature_dim"],
            batch_size=_config.get("batch_size", 32),
        )
        test_dataloader = DataLoaderCls(
            dataset["test"],
            feature_dim=_config["feature_dim"],
            batch_size=_config.get("batch_size", 32),
        )

        model = CNNDocumentClassficationModel(**_config)
        trainer = Trainer(
            max_epochs=_config.get("max_epochs", 25),
            logger=pl_loggers.TensorBoardLogger(
                save_dir=f"logs/ablation_in_domain/{_config_name}",
                name=domain,
            ),
            gradient_clip_val=_config["gradient_clip_val"],
            callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
            deterministic=True,
        )

        trainer.fit(
            model,
            train_dataloaders=train_dataloader,
            val_dataloaders=eval_dataloader,
        )

        # trainer.progress_bar_callback.disable()
        (metrics,) = trainer.test(model, test_dataloader, verbose=False)
        # trainer.progress_bar_callback.enable()

        results_in_domain.append(
            {
                "name": _config_name,
                "domain": domain,
                **metrics,
                "config": str({k: str(v) for k, v in _config.items()}),
            }
        )
    return results_in_domain

In [None]:
base_config = {
    "seed": 1337,
    "eval_split": 0.1,
    "test_split": 0.2,
    "feature_model": "gpt2",
    # "feature_model": "meta-llama/Llama-3.2-1B",
    "synth_agent": "gpt-4o-mini",
    # "synth_agent": "gemma2:9b"
}

In [None]:
default_config = base_config | {
    # Feature Parameter
    "feature_dim": TwoDimFeatures(256, 13),
    "slicer": "First",
    # "slicer": "RandomMultiple",
    "multiple": 4,
    "stride": 16,
    "featurizer": "IL",
    "last_n": 13,
    # Model Parameter
    "projection_dim": 32,
    "second_dim_as_channels": True,
    # SeqXGPT Layer Configuration
    "conv_layer_shapes": [
        ConvolutionalLayerSpec(64, 5),
        *[ConvolutionalLayerSpec(128, 3)] * 3,
        ConvolutionalLayerSpec(64, 3),
    ],
    # Hyper-Parameter
    "learning_rate": 0.0001,
    "warmup_steps": 66,
    "max_epochs": 25,
    "gradient_clip_val": 1.0,
    "batch_size": 32,
}

In [None]:
results_ablation = []

### Default Configuration

In [None]:
results_ablation.extend(train_in_domain(default_config, "default"))

In [None]:
config = default_config | {"num_samples": 8}
results_ablation.extend(train_in_domain(config, "default_num_samples=8"))

In [None]:
config = default_config | {"num_samples": 16}
results_ablation.extend(train_in_domain(config, "default_num_samples=16"))

### Fewer Intermediate Likelihoods

In [None]:
config = default_config | {"last_n": 11, "feature_dim": TwoDimFeatures(256, 11)}
results_ablation.extend(train_in_domain(config, "il_last_n=11"))

In [None]:
config = default_config | {"last_n": 9, "feature_dim": TwoDimFeatures(256, 9)}
results_ablation.extend(train_in_domain(config, "il_last_n=9"))

In [None]:
config = default_config | {"last_n": 7, "feature_dim": TwoDimFeatures(256, 7)}
results_ablation.extend(train_in_domain(config, "il_last_n=7"))

In [None]:
config = default_config | {"last_n": 5, "feature_dim": TwoDimFeatures(256, 5)}
results_ablation.extend(train_in_domain(config, "il_last_n=5"))

In [None]:
config = default_config | {"last_n": 3, "feature_dim": TwoDimFeatures(256, 3)}
results_ablation.extend(train_in_domain(config, "il_last_n=3"))

In [None]:
config = default_config | {"last_n": 2, "feature_dim": TwoDimFeatures(256, 2)}
results_ablation.extend(train_in_domain(config, "il_last_n=2"))

### No Convolution

In [None]:
config = default_config | {"conv_layer_shapes": []}
results_ablation.extend(train_in_domain(config, "no_convolution"))

### 2D-Convolution

In [None]:
config = default_config | {"second_dim_as_channels": False}
results_ablation.extend(train_in_domain(config, "2d_convolution"))

### No Projection

In [None]:
config = default_config | {"projection_dim": None}
results_ablation.extend(train_in_domain(config, "no_projection"))

### Shift Unit Interval to [-1, 1]

In [None]:
config = default_config | {"shift_unit_interval": True}
results_ablation.extend(train_in_domain(config, "shift_unit_interval"))

### Different Slicing Methods

In [None]:
config = default_config | {"slicer": "Random"}
results_ablation.extend(train_in_domain(config, "slice_random"))

In [None]:
config = default_config | {"slicer": "RandomMultiple", "multiple": 2, "stride": 16}
results_ablation.extend(train_in_domain(config, "slice_random_multiple_2"))

In [None]:
config = default_config | {"slicer": "RandomMultiple", "multiple": 4, "stride": 16}
results_ablation.extend(train_in_domain(config, "slice_random_multiple_4"))

In [None]:
config = default_config | {"slicer": "RandomMultiple", "multiple": 4, "stride": 16, "sort": True}
results_ablation.extend(train_in_domain(config, "slice_random_multiple_4_sorted"))

In [None]:
config = default_config | {"slicer": "RandomMultiple", "multiple": 4, "stride": 64}
results_ablation.extend(train_in_domain(config, "slice_random_multiple=4_stride=64"))

In [None]:
config = default_config | {"slicer": "RandomMultiple", "multiple": 4, "stride": 64, "sort": True}
results_ablation.extend(train_in_domain(config, "slice_random_multiple=4_stride=64_sorted"))

In [None]:
config = default_config | {"slicer": "RandomMultiple", "multiple": 8, "stride": 16}
results_ablation.extend(train_in_domain(config, "slice_random_multiple_8"))

### Different Feature Sizes

In [None]:
config = default_config | {"feature_dim": TwoDimFeatures(64, 13)}
results_ablation.extend(train_in_domain(config, "slice_64"))

In [None]:
config = default_config | {"feature_dim": TwoDimFeatures(128, 13)}
results_ablation.extend(train_in_domain(config, "slice_128"))

In [None]:
config = default_config | {"feature_dim": TwoDimFeatures(512, 13)}
results_ablation.extend(train_in_domain(config, "slice_512"))

### Other Features

In [None]:
config = default_config | {"featurizer": "tkllr", "top_k": 13}
results_ablation.extend(train_in_domain(config, "featurizer_tkllr"))

In [None]:
config = default_config | {"featurizer": "ltklr", "top_k": 13}
results_ablation.extend(train_in_domain(config, "featurizer_ltklr"))

In [None]:
config = default_config | {
    "featurizer": "lllrr",
    "feature_dim": OneDimFeatures(256),
    "second_dim_as_channels": False,
}
results_ablation.extend(train_in_domain(config, "featurizer_lllrr"))

In [None]:
config = default_config | {
    "featurizer": "likelihood",
    "feature_dim": OneDimFeatures(256),
    "second_dim_as_channels": False,
}
results_ablation.extend(train_in_domain(config, "featurizer_likelihood"))

### Smaller CNN

In [None]:
config = default_config | {"conv_layer_shapes": [(16, 5), (32, 3), (16, 3)]}
results_ablation.extend(train_in_domain(config, "conv_16_32_16"))

In [None]:
config = default_config | {"conv_layer_shapes": [(32, 5), (64, 3), (32, 3)]}
results_ablation.extend(train_in_domain(config, "conv_32_64_32"))

In [None]:
config = default_config | {"conv_layer_shapes": [(32, 5), (64, 3), (64, 3), (64, 3), (32, 3)]}
results_ablation.extend(train_in_domain(config, "conv_32_64_64_64_32"))

## Ablation Results

In [80]:
df = pd.DataFrame(results_ablation)
df["config"] = df["config"].apply(str)
_df = df.groupby("name").mean(numeric_only=True)
_df

Unnamed: 0_level_0,test_loss,test_f1@0.5,test_f1@best,test_f1_threshold,test_acc@0.5,test_acc@best,test_acc_threshold,test_roc_auc
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2d_convolution,0.235069,0.934351,0.937269,0.616556,0.933338,0.93757,0.478244,0.980989
conv_16_32_16,0.256629,0.895132,0.887979,0.405422,0.898558,0.892002,0.378144,0.952598
conv_32_64_32,0.204503,0.9212,0.916847,0.440656,0.922036,0.918671,0.379044,0.968638
conv_32_64_64_64_32,0.181412,0.930867,0.928081,0.420256,0.931564,0.92738,0.402944,0.976048
default,0.155585,0.939838,0.953069,0.429078,0.941939,0.952759,0.377489,0.984246
default_num_samples=16,0.155585,0.939838,0.953069,0.429078,0.941939,0.952759,0.377489,0.984246
default_num_samples=8,0.155585,0.939838,0.953069,0.429078,0.941939,0.952759,0.377489,0.984246
featurizer_likelihood,0.239589,0.896642,0.890439,0.541044,0.890156,0.897611,0.351656,0.950908
featurizer_lllrr,32.972841,0.780376,0.726168,0.209733,0.796837,0.803278,0.000111,0.820486
featurizer_ltklr,0.211521,0.916482,0.910855,0.401456,0.913582,0.915453,0.343567,0.96335


In [81]:
_diff = _df[["test_roc_auc", "test_f1@0.5"]] - _df.loc[_df.index == "default"][["test_roc_auc", "test_f1@0.5"]].values[0]
print(_diff.to_latex(float_format="%.3f"))

\begin{tabular}{lrr}
\toprule
 & test_roc_auc & test_f1@0.5 \\
name &  &  \\
\midrule
2d_convolution & -0.003 & -0.005 \\
conv_16_32_16 & -0.032 & -0.045 \\
conv_32_64_32 & -0.016 & -0.019 \\
conv_32_64_64_64_32 & -0.008 & -0.009 \\
default & 0.000 & 0.000 \\
default_num_samples=16 & 0.000 & 0.000 \\
default_num_samples=8 & 0.000 & 0.000 \\
featurizer_likelihood & -0.033 & -0.043 \\
featurizer_lllrr & -0.164 & -0.159 \\
featurizer_ltklr & -0.021 & -0.023 \\
featurizer_tkllr & -0.027 & -0.039 \\
il_last_n=11 & -0.000 & -0.009 \\
il_last_n=2 & -0.026 & -0.044 \\
il_last_n=3 & -0.024 & -0.039 \\
il_last_n=5 & -0.019 & -0.022 \\
il_last_n=7 & -0.010 & -0.022 \\
il_last_n=9 & -0.001 & -0.001 \\
no_convolution & -0.115 & -0.134 \\
no_projection & -0.000 & 0.010 \\
shift_unit_interval & 0.003 & 0.007 \\
slice_128 & -0.017 & -0.021 \\
slice_512 & 0.007 & 0.019 \\
slice_64 & -0.038 & -0.052 \\
slice_random & -0.003 & -0.012 \\
slice_random_multiple=4_stride=64 & -0.008 & -0.024 \\
slice_random_