In [1]:
import os
import warnings

import pandas as pd
from dotenv import load_dotenv
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from tqdm.auto import tqdm, trange

from luminar.data import (
    FeatureDataset,
    PaddingDataloader,
    n_way_split,
)
from luminar.model import CNNDocumentClassficationModel, ConvolutionalLayerSpec
from luminar.features import FeatureExtractor, OneDimFeatures, Slicer, TwoDimFeatures
from luminar.mongo import PrismaiDataset

load_dotenv("../.env")

warnings.filterwarnings("ignore", ".*does not have many workers.*")

In [2]:
domains = {
    "Blog Authorship": {"domain": "blog_authorship_corpus"},
    "Student Essays": {"domain": "student_essays"},
    "CNN News": {"domain": "cnn_news"},
    "Euro Court Cases": {"domain": "euro_court_cases"},
    "House of Commons": {"domain": "house_of_commons"},
    "ArXiv Papers": {"domain": "arxiv_papers"},
    "Gutenberg": {"domain": "gutenberg", "lang": "en-EN"},
    "Bundestag": {"domain": "bundestag"},
    "Spiegel": {"domain": "spiegel_articles"},
}

In [3]:
config = {
    "eval_split": 0.1,
    "test_split": 0.2,
    "feature_model": "gpt2",
    "synth_agent": "gpt-4o-mini",
    "document_type": "fulltext",
}

## Features

In [None]:
first_dim = 256
k = 13

feature_dim = TwoDimFeatures(first_dim, k)
featurizer = FeatureExtractor.IntermediateLikelihood(k)
config["second_dim_as_channels"] = True

# slicer = Slicer.First(first_dim)
multiple = 4
slicer = Slicer.RandomMultiple(first_dim // multiple, multiple=multiple, stride=16)

config["feature_dim"] = feature_dim
config["featurizer"] = repr(featurizer)
config["slicer"] = repr(slicer)

config["num_samples"] = None


def featurize(dataset) -> FeatureDataset:
    return FeatureDataset(
        tqdm(dataset, position=1, leave=False),
        slicer,
        featurizer,
        num_samples=config["num_samples"],
    )

In [6]:
config["seed"] = 42

sizes = [0.1] * 10
splits = {}
for domain, kwargs in tqdm(domains.items(), desc="Domains", position=0):
    seed_everything(config["seed"], verbose=False)
    splits[domain] = [
        featurize(subset)
        for subset in n_way_split(
            PrismaiDataset(
                mongo_db_connection=os.getenv("MONGO_DB_CONNECTION"),
                database="prismai",
                collection="features_prismai",
                feature_model=config["feature_model"],
                synth_agent=config["synth_agent"],
                document_type=config["document_type"],
                additional_match_conditions=config.get(
                    "additional_match_conditions", {}
                ),
                **kwargs,
                # update_cache=True,
            ).load(verbose=False),
            *sizes,
        )
    ]

Domains:   0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/163 [00:01<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/163 [00:00<?, ?it/s]

  0%|          | 0/162 [00:00<?, ?it/s]

  0%|          | 0/162 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/149 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:00<?, ?it/s]

In [7]:
config |= {
    "projection_dim": 32,
    "learning_rate": 0.0001,
    "warmup_steps": 66,
    "max_epochs": 25,
    "gradient_clip_val": 1.0,
    "batch_size": 32,
}

# SeqXGPT Layer Configuration
config["conv_layer_shapes"] = [
    ConvolutionalLayerSpec(64, 5),
    *[ConvolutionalLayerSpec(128, 3)] * 3,
    ConvolutionalLayerSpec(64, 3),
]

In [8]:
from torch.utils.data import ConcatDataset


# capturing config from "closure"
def get_dataloader(*dataset, **kwargs) -> PaddingDataloader:
    if len(dataset) == 1:
        dataset = dataset[0]
    else:
        dataset = ConcatDataset(dataset)
    return PaddingDataloader(
        dataset,
        feature_dim=config["feature_dim"],
        batch_size=config["batch_size"],
        **kwargs,
    )

## In-Domain Training & Evaluation

In [None]:
from collections import defaultdict

metrics_in_domain = defaultdict(list)
for domain, subsets in tqdm(splits.items()):
    for _ in trange(5, desc=domain, position=1):
        seed_everything(config["seed"], verbose=False)
        # cycle through splits for cross-validation
        eval_dataset = subsets.pop(0)
        test_dataloader = get_dataloader(*subsets[:2])
        train_dataloader = get_dataloader(*subsets[2:], shuffle=True)
        eval_dataloader = get_dataloader(eval_dataset)
        subsets.append(eval_dataset)

        model = CNNDocumentClassficationModel(**config)
        trainer = Trainer(
            max_epochs=config["max_epochs"],
            logger=pl_loggers.TensorBoardLogger(
                save_dir=f"logs/in_domain/{type(featurizer).__name__}",
                name=domain,
            ),
            gradient_clip_val=config["gradient_clip_val"],
            callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
            deterministic=True,
        )
        trainer.progress_bar_callback.disable()

        trainer.fit(
            model,
            train_dataloaders=train_dataloader,
            val_dataloaders=eval_dataloader,
        )
        (metrics,) = trainer.test(model, test_dataloader, verbose=False)
        metrics_in_domain[domain].append(metrics)

In [10]:
config

{'eval_split': 0.1,
 'test_split': 0.2,
 'feature_model': 'gpt2',
 'synth_agent': 'gpt-4o-mini',
 'document_type': 'fulltext',
 'second_dim_as_channels': True,
 'feature_dim': TwoDimFeatures(width=256, height=13),
 'featurizer': 'IntermediateLikelihood(last_n=13)',
 'slicer': 'SliceRandomMultiple(size=64, multiple=4, stride=16, sort=False)',
 'num_samples': None,
 'seed': 42,
 'projection_dim': 32,
 'learning_rate': 0.0001,
 'warmup_steps': 66,
 'max_epochs': 25,
 'gradient_clip_val': 1.0,
 'batch_size': 32,
 'conv_layer_shapes': [(64, 5, 1),
  (128, 3, 1),
  (128, 3, 1),
  (128, 3, 1),
  (64, 3, 1)]}

In [11]:
df = pd.DataFrame(
    [
        {
            "domain": domain,
            **{
                "test_auroc": metric["test_auroc"],
                "test_f1@0.5": metric["test_f1@0.5"],
            },
        }
        for domain in domains
        for metric in metrics_in_domain[domain]
    ]
)
df = (
    df.groupby("domain")
    .mean()
    .sort_index(key=lambda i: list(map(list(domains.keys()).index, i)))
)
print(
    df.to_latex(
        float_format="%.3f",
        index=False,
    )
)
df

\begin{tabular}{rr}
\toprule
test_auroc & test_f1@0.5 \\
\midrule
0.982 & 0.906 \\
0.976 & 0.912 \\
0.972 & 0.916 \\
0.990 & 0.937 \\
0.976 & 0.914 \\
0.987 & 0.925 \\
0.964 & 0.883 \\
0.943 & 0.865 \\
0.928 & 0.855 \\
\bottomrule
\end{tabular}



Unnamed: 0_level_0,test_auroc,test_f1@0.5
domain,Unnamed: 1_level_1,Unnamed: 2_level_1
Blog Authorship,0.996168,0.968689
Student Essays,0.979569,0.895166
CNN News,0.981229,0.912783
Euro Court Cases,0.991806,0.934931
House of Commons,0.973987,0.881006
ArXiv Papers,0.984814,0.933906
Gutenberg,0.971568,0.912373
Bundestag,0.963048,0.892373
Spiegel,0.930374,0.845966


## Out-of-Domain

In [None]:
from collections import defaultdict

metrics_out_of_domain = defaultdict(list)
for domain in tqdm(splits.keys()):
    for _ in trange(5, desc=domain, position=1):
        seed_everything(config["seed"], verbose=False)
        train_subsets = []
        eval_subsets = []
        for other, subsets in splits.items():
            if other == domain:
                subsets.append(subsets.pop(0))
                test_dataset = subsets[:2]
            else:
                eval_dataset = subsets.pop(0)
                eval_subsets.append(eval_dataset)
                train_subsets.extend(subsets[2:])
                subsets.append(eval_dataset)

        train_dataloader = get_dataloader(*train_subsets, shuffle=True)
        eval_dataloader = get_dataloader(*eval_subsets)
        test_dataloader = get_dataloader(*test_dataset)

        model = CNNDocumentClassficationModel(**config)
        trainer = Trainer(
            max_epochs=config["max_epochs"],
            logger=pl_loggers.TensorBoardLogger(
                save_dir=f"logs/in_domain/{type(featurizer).__name__}",
                name=domain,
            ),
            gradient_clip_val=config["gradient_clip_val"],
            callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
            deterministic=True,
        )
        trainer.progress_bar_callback.disable()

        trainer.fit(
            model,
            train_dataloaders=train_dataloader,
            val_dataloaders=eval_dataloader,
        )
        (metrics,) = trainer.test(model, test_dataloader, verbose=False)
        metrics_out_of_domain[domain].append(metrics)

        print(domain, metrics)

In [19]:
df = pd.DataFrame(
    [
        {
            "domain": domain,
            **{
                "test_auroc": metric["test_auroc"],
                "test_f1@0.5": metric["test_f1@0.5"],
            },
        }
        for domain in domains
        for metric in metrics_out_of_domain[domain]
    ]
)
df = (
    df.groupby("domain")
    .mean()
    .sort_index(key=lambda i: list(map(list(domains.keys()).index, i)))
)
print(
    df.to_latex(
        float_format="%.3f",
        index=False,
    )
)
df

\begin{tabular}{rr}
\toprule
test_roc_auc & test_f1@0.5 \\
\midrule
0.453 & 0.193 \\
0.381 & 0.318 \\
0.619 & 0.235 \\
0.694 & 0.394 \\
0.540 & 0.479 \\
0.551 & 0.156 \\
0.330 & 0.025 \\
0.501 & 0.257 \\
0.554 & 0.076 \\
\bottomrule
\end{tabular}



Unnamed: 0_level_0,test_roc_auc,test_f1@0.5
domain,Unnamed: 1_level_1,Unnamed: 2_level_1
Blog Authorship,0.446162,0.474457
Student Essays,0.763817,0.732818
CNN News,0.973505,0.883103
Euro Court Cases,0.940748,0.853289
House of Commons,0.959928,0.853791
ArXiv Papers,0.971661,0.857507
Gutenberg,0.971357,0.897811
Bundestag$_{de}$,0.795373,0.723981
Spiegel$_{de}$,0.78148,0.663948
