# About

Training of LuminarSequenceClassifier on the PrismAI dataset.

In [14]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [77]:
import sys
import torch

from typing import Final, Callable
from pathlib import Path

sys.path.insert(0, str(Path().resolve().parent.parent))

from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader
from IPython.display import display, Markdown
from datasets import load_dataset
from numpy._typing import NDArray
from torch.utils.data import Dataset
from datasets import DatasetDict
from src.luminar.utils.data import get_pad_to_fixed_length_fn, get_matched_datasets
from src.luminar.utils.training import ConvolutionalLayerSpec
from src.luminar.encoder import LuminarEncoder

import numpy as np
import glob

In [16]:
class Config:
    HF_TOKEN: Final[str] = (Path.home() / ".hf_token").read_text().strip()
    #DATASET_PATH: Final[str] = "liberi-luminaris/PrismAI-encoded-gpt2"
    DATASET_ROOT_PATH: Final[str] = "/storage/projects/stoeckel/prismai/encoded/fulltext/"
    #DATASET_ROOT_PATH: Final[str] = "/mnt/c/home/projects/prismAI/data/encoded/fulltext/"
    NUM_INTERMEDIATE_LIKELIHOODS: Final[int] = 13
    FEATURE_LEN = 512
    SEED = 42


## Loading & Preprocessing the datasets

Load the datasets for the training.

In [17]:
# We need that to sentence tokenize the text
import nltk

nltk.download("punkt")
from nltk.tokenize import sent_tokenize

[nltk_data] Downloading package punkt to
[nltk_data]     /home/staff_homes/kboenisc/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [18]:
def sentence_to_token_spans(text: str) -> list[tuple[int, int]]:
    """Return a list of (start_token_idx, end_token_idx) for each sentence in the text."""
    sentences = sent_tokenize(text)
    spans = []
    current_token_idx = 0

    for sent in sentences:
        tokens = luminar_encoder.tokenize(sent)["input_ids"]
        token_count = len(tokens)

        start = current_token_idx
        end = min(current_token_idx + token_count, Config.FEATURE_LEN)

        if start >= Config.FEATURE_LEN:
            break  # Stop if padding region or overflow
        spans.append((start, end))
        current_token_idx = end

    return spans


In [19]:
domains = ['student_essays']
agents = ['gpt_4o_mini_gemma2_9b']
feature_agents = ['gpt2_512']

luminar_encoder = LuminarEncoder(max_len=Config.FEATURE_LEN)

In [20]:
datasets = {}

pad_to_fixed_length: Callable[[NDArray], NDArray] = get_pad_to_fixed_length_fn(Config.FEATURE_LEN)

for domain in domains:
    for agent in agents:
        for feature_agent in feature_agents:
            dataset_path = Path(Config.DATASET_ROOT_PATH) / agent / feature_agent / domain

            if not dataset_path.exists():
                raise FileNotFoundError(f"Dataset path {dataset_path} does not exist.")

            data_files = {
                "train": sorted(str(f) for f in dataset_path.glob("train/*.arrow")),
                "test": sorted(str(f) for f in dataset_path.glob("test/*.arrow")),
                "eval": sorted(str(f) for f in dataset_path.glob("eval/*.arrow")),
            }
            data_files = {k: v for k, v in data_files.items() if v}

            # These datasets are already matched, so we can load them directly
            # We need the tokenized text since we label sequences based on sentences.
            dataset_dict = (
                load_dataset(
                    "arrow",
                    data_files=data_files,
                )
                .map(
                    lambda batch: {
                        "tokenized_text": [
                            pad_to_fixed_length(
                                np.array(luminar_encoder.tokenize(t)["input_ids"]).reshape(-1, 1)
                            )
                            for t in batch["text"]
                        ],
                        "sentence_token_spans": [
                            sentence_to_token_spans(t)
                            for t in batch["text"]
                        ]
                    },
                    batched=True,
                    desc="Tokenizing, padding, and aligning sentences"
                )
                .map(
                    lambda batch: {
                        "span_labels": [
                            [label] * len(spans)
                            for label, spans in zip(batch["labels"], batch["sentence_token_spans"])
                        ]
                    },
                    batched=True,
                    desc="Assigning labels to sentence spans"
                )
            )

            datasets.setdefault(domain, {}).setdefault(agent, {})[feature_agent] = dataset_dict
            print(f"Loaded dataset for domain '{domain}' with agent '{agent}' and feature agent '{feature_agent}'")

datasets

Tokenizing, padding, and aligning sentences:   0%|          | 0/50734 [00:00<?, ? examples/s]

Tokenizing, padding, and aligning sentences:   0%|          | 0/14496 [00:00<?, ? examples/s]

Tokenizing, padding, and aligning sentences:   0%|          | 0/7248 [00:00<?, ? examples/s]

Assigning labels to sentence spans:   0%|          | 0/50734 [00:00<?, ? examples/s]

Assigning labels to sentence spans:   0%|          | 0/14496 [00:00<?, ? examples/s]

Assigning labels to sentence spans:   0%|          | 0/7248 [00:00<?, ? examples/s]

Loaded dataset for domain 'student_essays' with agent 'gpt_4o_mini_gemma2_9b' and feature agent 'gpt2_512'


{'student_essays': {'gpt_4o_mini_gemma2_9b': {'gpt2_512': DatasetDict({
       train: Dataset({
           features: ['agent', 'id_sample', 'id_source', 'labels', 'text', 'features', 'tokenized_text', 'sentence_token_spans', 'span_labels'],
           num_rows: 50734
       })
       test: Dataset({
           features: ['agent', 'id_sample', 'id_source', 'labels', 'text', 'features', 'tokenized_text', 'sentence_token_spans', 'span_labels'],
           num_rows: 14496
       })
       eval: Dataset({
           features: ['agent', 'id_sample', 'id_source', 'labels', 'text', 'features', 'tokenized_text', 'sentence_token_spans', 'span_labels'],
           num_rows: 7248
       })
   })}}}

In [21]:
# Sanity check
idx = 0
for domain, agents_dict in datasets.items():
    for agent, feature_agents_dict in agents_dict.items():
        for feature_agent, dataset in feature_agents_dict.items():
            md = f"""
**Domain:** `{domain}`
**Agent:** `{agent}`
**Feature Agent:** `{feature_agent}`

**Train:** {len(dataset['train'])}
**Test:** {len(dataset['test'])}
**Eval:** {len(dataset['eval'])}

**Example-Features:**
`{dataset['train'][idx]['features'][:2]}...`

**Feature-Shape:**
`{np.asarray(dataset['train'][idx]['features']).shape}`

**Example text:**
`{dataset['train'][idx]['text']}`

**Example-Tokenized Text:**
`{dataset['train'][idx]['tokenized_text'][:10]}...`

**Tokenized Text Shape:**
`{np.asarray(dataset['train'][idx]['tokenized_text']).shape}`

**Sentence-Token-Spans:**
`{dataset['train'][idx]['sentence_token_spans']}`

**Example Sentence-Token Span decoded:**
`{luminar_encoder.tokenizer.decode(np.asarray(dataset['train'][idx]['tokenized_text'])[dataset['train'][idx]['sentence_token_spans'][0][0]:dataset['train'][idx]['sentence_token_spans'][0][1]].flatten().tolist())}`

**Example span labels:**
`{dataset['train'][idx]['span_labels']}`

**Example label:**
`{dataset['train'][idx]['labels']}`

---
"""
            display(Markdown(md))


**Domain:** `student_essays`
**Agent:** `gpt_4o_mini_gemma2_9b`
**Feature Agent:** `gpt2_512`

**Train:** 50734
**Test:** 14496
**Eval:** 7248

**Example-Features:**
`[[4.4132913899375126e-05, 4.203895392974451e-45, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00015184479707386345], [9.98157247522613e-06, 2.268475629563227e-09, 2.4706817147723825e-10, 2.598772308459729e-10, 4.608947598572222e-11, 2.177385494128714e-10, 7.711160811448015e-13, 6.366168985201537e-13, 2.281021511211531e-17, 9.533206545892253e-25, 7.987888483894126e-31, 0.0, 0.029878340661525726]]...`

**Feature-Shape:**
`(512, 13)`

**Example text:**
`[Your Name] [Your Address] [City, State, Zip Code] [Email Address] [Date] The Honorable [Senator's Name] [Senator's Office Address] [City, State, Zip Code] Dear Senator [Senator's Last Name], I hope this letter finds you well. I’m writing to you today as a concerned citizen who’s really worried about the way our presidential elections work. Honestly, I can’t wrap my head around the Electoral College anymore. It feels like an outdated system that just doesn’t represent us, the people. I mean, how is it fair that a candidate can win the presidency without winning the popular vote? It’s like saying my vote doesn’t count just because I live in a state that leans one way or another. That doesn’t seem right, does it? We should all have an equal say in who leads our country, and the popular vote is the only way to truly reflect the will of the people. The Electoral College creates a situation where some votes are more valuable than others, and that’s just plain wrong. It leads to candidates ignoring vast swathes of the country because they know they won’t win those states. Isn’t it time we changed that? We need a system that encourages candidates to engage with all of us, not just the folks in swing states. I urge you to consider advocating for a popular vote system. It’s time for our elections to truly represent the voice of the people. Thank you for your time, and I hope you’ll take my concerns to heart.`

**Example-Tokenized Text:**
`[[58], [7120], [6530], [60], [685], [7120], [17917], [60], [685], [14941]]...`

**Tokenized Text Shape:**
`(512, 1)`

**Sentence-Token-Spans:**
`[[0, 61], [61, 87], [87, 103], [103, 121], [121, 141], [141, 168], [168, 179], [179, 209], [209, 233], [233, 256], [256, 266], [266, 288], [288, 300], [300, 317], [317, 336]]`

**Example Sentence-Token Span decoded:**
`[Your Name] [Your Address] [City, State, Zip Code] [Email Address] [Date] The Honorable [Senator's Name] [Senator's Office Address] [City, State, Zip Code] Dear Senator [Senator's Last Name], I hope this letter finds you well.`

**Example span labels:**
`[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]`

**Example label:**
`1`

---


# Training

Train the LuminarSequenceClassifier on the datasets.

In [86]:
class LuminarSequenceDataset(Dataset):
    def __init__(self, dataset, feature_key="features"):
        self.samples = []
        for example in dataset:
            spans = example["sentence_token_spans"]
            if(len(spans)) < 10:
                continue
            features = torch.tensor(example[feature_key])  # (seq_len, feature_dim)
            labels = example["span_labels"]
            self.samples.append({
                "features": features,
                "sentence_spans": spans,
                "span_labels": labels
            })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    features = [item["features"] for item in batch]
    sentence_spans = [item["sentence_spans"] for item in batch]
    span_labels = [item["span_labels"] for item in batch]
    features = torch.stack(features)
    return {
        "features": features,
        "sentence_spans": sentence_spans,
        "span_labels": span_labels
    }

In [87]:
train_dataset = LuminarSequenceDataset(datasets["student_essays"]["gpt_4o_mini_gemma2_9b"]["gpt2_512"]["train"])
eval_dataset = LuminarSequenceDataset(datasets["student_essays"]["gpt_4o_mini_gemma2_9b"]["gpt2_512"]["eval"])
test_dataset = LuminarSequenceDataset(datasets["student_essays"]["gpt_4o_mini_gemma2_9b"]["gpt2_512"]["test"])
print(len(train_dataset))
print(train_dataset[0])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

47578
{'features': tensor([[4.4133e-05, 4.2039e-45, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         1.5184e-04],
        [9.9816e-06, 2.2685e-09, 2.4707e-10,  ..., 7.9879e-31, 0.0000e+00,
         2.9878e-02],
        [1.5927e-05, 2.7541e-12, 5.1212e-10,  ..., 7.7882e-28, 2.9427e-44,
         2.8827e-01],
        ...,
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]]), 'sentence_spans': [[0, 61], [61, 87], [87, 103], [103, 121], [121, 141], [141, 168], [168, 179], [179, 209], [209, 233], [233, 256], [256, 266], [266, 288], [288, 300], [300, 317], [317, 336]], 'span_labels': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [88]:
for batch in list(enumerate(train_loader))[:1]:
    print(batch)


(0, {'features': tensor([[[8.2419e-07, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 5.7127e-05],
         [1.0031e-06, 5.1468e-02, 1.3811e-02,  ..., 1.0000e+00,
          1.0000e+00, 6.3995e-02],
         [3.0254e-04, 3.5541e-04, 1.1336e-02,  ..., 1.1198e-09,
          4.9380e-24, 5.7311e-02],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]],

        [[2.4651e-05, 2.8026e-45, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.7961e-04],
         [3.8299e-05, 6.5291e-14, 5.8796e-09,  ..., 3.2720e-42,
          0.0000e+00, 9.2997e-03],
         [8.6544e-06, 1.5348e-09, 1.7457e-06,  ..., 7.9098e-30,
          0.0000e+00, 8.0821e-01],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00

In [89]:
#from luminar.sequence_classifier import LuminarSequence
from src.luminar.sequence_classifier import LuminarSequence

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

config = {
    "feature_dim": (Config.FEATURE_LEN, Config.NUM_INTERMEDIATE_LIKELIHOODS),
    "feature_type": "intermediate_likelihoods",
    "feature_selection": "first",
    "conv_layer_shapes": (
        ConvolutionalLayerSpec(32, 5),
        ConvolutionalLayerSpec(64, 5),
        ConvolutionalLayerSpec(32, 3),
    ),
    "lstm_hidden_dim": 128,
    "lstm_layers": 1,
    "projection_dim": 32,
    "early_stopping_patience": 3,
    "learning_rate": 5e-4,
    "max_epochs": 40,
    "gradient_clip_val": 1.0,
    "train_batch_size": 32,
    "eval_batch_size": 1024,
    "warmup_ratio": 1.0,
    "seed": Config.SEED,
    "rescale_features": False
}

In [90]:
luminar = LuminarSequence(**config).to(device)

LuminarTrainingConfig(feature_dim=(512, 13), feature_type='intermediate_likelihoods', feature_selection='first', conv_layer_shapes=((32, 5, 1), (64, 5, 1), (32, 3, 1)), lstm_hidden_dim=128, lstm_layers=1, projection_dim=32, early_stopping_patience=3, learning_rate=0.0005, max_epochs=40, gradient_clip_val=1.0, train_batch_size=32, eval_batch_size=1024, warmup_ratio=1.0, seed=42, rescale_features=False)


In [91]:
optimizer = torch.optim.Adam(luminar.parameters(), lr=config["learning_rate"])

In [92]:
def evaluate_metrics(model, data_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in data_loader:
            features = batch["features"].to(device)
            sentence_spans = batch["sentence_spans"]
            span_labels = batch["span_labels"]

            output = model(features, sentence_spans)
            probs = torch.sigmoid(output.logits).view(-1).cpu().numpy()
            preds = (probs > 0.5).astype(int)

            labels = torch.cat([torch.tensor(lbl, dtype=torch.int) for lbl in span_labels]).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels)
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    return acc, f1

def evaluate(model, eval_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in eval_loader:
            features = batch["features"].to(device)
            sentence_spans = batch["sentence_spans"]
            span_labels = batch["span_labels"]

            output = model(features, sentence_spans, span_labels=span_labels)
            loss = output.loss
            total_loss += loss.item()
    avg_loss = total_loss / len(eval_loader)
    return avg_loss

def train_and_evaluate(model, train_loader, eval_loader, optimizer, device, epochs, patience=3):
    best_eval_loss = float("inf")
    best_train_loss = None
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        # Training
        model.train()
        total_train_loss = 0
        for batch in train_loader:
            features = batch["features"].to(device)
            sentence_spans = batch["sentence_spans"]
            span_labels = batch["span_labels"]

            optimizer.zero_grad()
            output = model(features, sentence_spans, span_labels=span_labels)
            loss = output.loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)

        # Evaluation
        avg_eval_loss = evaluate(model, eval_loader, device)
        print(f"Train Loss: {avg_train_loss:.4f} | Eval Loss: {avg_eval_loss:.4f}")

        # Every 5 epochs: evluate on test set
        if test_loader is not None and (epoch + 1) % 5 == 0:
            acc, f1 = evaluate_metrics(model, test_loader, device)
            print(f"Test Accuracy: {acc:.4f} | Test F1: {f1:.4f}")

        # Early Stopping & Checkpoint
        if avg_eval_loss < best_eval_loss:
            best_eval_loss = avg_eval_loss
            best_train_loss = avg_train_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict()
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping after {epoch+1} epochs.")
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                break

    print(f"\nBest Eval Loss: {best_eval_loss:.4f} | Best Train Loss: {best_train_loss:.4f}")

In [93]:
train_and_evaluate(luminar,
                   train_loader,
                   eval_loader,
                   optimizer,
                   device,
                   epochs=config["max_epochs"],
                   patience=config["early_stopping_patience"])


Epoch 1/40
Train Loss: 0.5412 | Eval Loss: 0.5080

Epoch 2/40
Train Loss: 0.4657 | Eval Loss: 0.4499

Epoch 3/40
Train Loss: 0.4460 | Eval Loss: 0.4335

Epoch 4/40
Train Loss: 0.4332 | Eval Loss: 0.4296

Epoch 5/40
Train Loss: 0.4195 | Eval Loss: 0.4254
Test Accuracy: 0.7941 | Test F1: 0.8017

Epoch 6/40
Train Loss: 0.4083 | Eval Loss: 0.4012

Epoch 7/40
Train Loss: 0.4013 | Eval Loss: 0.4153

Epoch 8/40
Train Loss: 0.3922 | Eval Loss: 0.4464

Epoch 9/40
Train Loss: 0.3870 | Eval Loss: 0.3980

Epoch 10/40
Train Loss: 0.3794 | Eval Loss: 0.4099
Test Accuracy: 0.8159 | Test F1: 0.8451

Epoch 11/40
Train Loss: 0.3746 | Eval Loss: 0.3814

Epoch 12/40
Train Loss: 0.3675 | Eval Loss: 0.3757

Epoch 13/40
Train Loss: 0.3625 | Eval Loss: 0.3759

Epoch 14/40
Train Loss: 0.3585 | Eval Loss: 0.3608

Epoch 15/40
Train Loss: 0.3540 | Eval Loss: 0.3733
Test Accuracy: 0.8249 | Test F1: 0.8304

Epoch 16/40
Train Loss: 0.3491 | Eval Loss: 0.3546

Epoch 17/40
Train Loss: 0.3448 | Eval Loss: 0.3642

Epoc