# imports


In [1]:
import wandb

wandb.init(mode="disabled")

In [2]:
import gc
import os
import time
import warnings
import sys

import lightning as L
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import transformers
from datasets import Dataset
from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint
from rich.table import Table
from scipy.special import softmax
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from torch import nn
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    LinearLR,
    SequentialLR,
)
from torch.utils import data
from torchinfo import summary
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC
from transformers import (
    AutoModel,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    pipeline,
)

from utils.management import clean_mem, create_logger, free_vars
from utils.preprocess import (
    add_classification_preds_rule_subreddit,
    create_master_dataset,
    sanitize_comment,
)
from utils.reoberta import get_custom_roberta

In [3]:
os.environ["WANDB_DISABLED"] = "true"
warnings.simplefilter("ignore")

In [4]:
logger = create_logger()
training_logger = create_logger(name="model", log_file="training.log")

2025-08-02 16:44:15 | reddit_moderation | INFO | Logger 'reddit_moderation' created successfully
2025-08-02 16:44:15 | reddit_moderation | INFO | Log level: INFO
2025-08-02 16:44:15 | reddit_moderation | INFO | Log file: logs/reddit_moderation_20250802_164415.log
2025-08-02 16:44:15 | model | INFO | Logger 'model' created successfully
2025-08-02 16:44:15 | model | INFO | Log level: INFO
2025-08-02 16:44:15 | model | INFO | Log file: logs/training.log


# load dataset


In [5]:
# INPUT_PATH = os.path.join("/", "kaggle", "input")
INPUT_PATH = os.path.join("..", "data")
train = pd.read_csv(
    os.path.join(INPUT_PATH, "jigsaw-agile-community-rules", "train.csv")
)
test = pd.read_csv(os.path.join(INPUT_PATH, "jigsaw-agile-community-rules", "test.csv"))
submission = pd.read_csv(
    os.path.join(INPUT_PATH, "jigsaw-agile-community-rules", "sample_submission.csv")
)
features = pd.read_csv(os.path.join(INPUT_PATH, "jigsaw", "features.csv"))[
    "features"
].tolist()
subreddits = pd.read_csv(os.path.join(INPUT_PATH, "jigsaw", "subreddits.csv"))[
    "subreddit"
].tolist()

# load models


first model to load will be a basic roberta model, and hopefully we can do some stuff with it


In [6]:
# _MODEL_DIR = os.path.join("/", "kaggle", "input")
# MODEL_PATH = {
#     "classifier": os.path.join(
#         _MODEL_DIR, "facebookai-roberta-large-mnli", "transformers", "default", "1"
#     ),
#     "nli": os.path.join(
#         _MODEL_DIR,
#         "moritzlaurerdeberta-v3-base-mnli-fever-anli",
#         "transformers",
#         "default",
#         "1",
#     ),
# }
_MODEL_VERSION_PATH = os.path.join(
    "transformers",
    "default",
    "1",
)
_MODEL_DIR = os.path.join("..", "model")
MODEL_PATH = {
    "classifier-OLD": os.path.join(_MODEL_DIR, "facebookai-roberta-large-mnli"),
    "nli": os.path.join(
        _MODEL_DIR,
        "nli-deberta-v3-small",
    ),
    "classifier": os.path.join(_MODEL_DIR, "roberta-base"),
}
logger.info(MODEL_PATH)

2025-08-02 16:44:15 | reddit_moderation | INFO | {'classifier-OLD': '../model/facebookai-roberta-large-mnli', 'nli': '../model/nli-deberta-v3-small', 'classifier': '../model/roberta-base'}


# training dataset prep


In [7]:
for c in [
    "body",
    "rule",
    "subreddit",
    "positive_example_1",
    "positive_example_2",
    "negative_example_1",
    "negative_example_2",
]:
    logger.info(f"Cleaning {c = }")
    train[c] = train[c].apply(sanitize_comment)
    test[c] = test[c].apply(sanitize_comment)
master_dataset = create_master_dataset(train, test, logger)
dataset = Dataset.from_pandas(master_dataset)
free_vars([train, test, master_dataset], logger=logger)

2025-08-02 16:44:15 | reddit_moderation | INFO | Cleaning c = 'body'
2025-08-02 16:44:15 | reddit_moderation | INFO | Cleaning c = 'rule'
2025-08-02 16:44:15 | reddit_moderation | INFO | Cleaning c = 'subreddit'
2025-08-02 16:44:16 | reddit_moderation | INFO | Cleaning c = 'positive_example_1'
2025-08-02 16:44:16 | reddit_moderation | INFO | Cleaning c = 'positive_example_2'
2025-08-02 16:44:17 | reddit_moderation | INFO | Cleaning c = 'negative_example_1'
2025-08-02 16:44:17 | reddit_moderation | INFO | Cleaning c = 'negative_example_2'
2025-08-02 16:44:18 | reddit_moderation | INFO | Starting master dataset creation
2025-08-02 16:44:18 | reddit_moderation | INFO | Input - Train: 2029 rows, Test: 10 rows
2025-08-02 16:44:18 | reddit_moderation | INFO | Concatenating all dataset parts
2025-08-02 16:44:18 | reddit_moderation | INFO | Master dataset created successfully: 10185 total records
2025-08-02 16:44:18 | reddit_moderation | INFO | Violation distribution: {1: 5109, 0: 5076}
2025-0

# extract features and tags


In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# rule_classifier = pipeline(
#     "zero-shot-classification", model=MODEL_PATH["nli"], device=device
# )
subreddit_classifier = pipeline(
    "zero-shot-classification", model=MODEL_PATH["nli"], device=device
)
rule_vectorizer = TfidfVectorizer().fit(features)
rule_vecs = rule_vectorizer.transform(features)

# subreddit_vectorizer = TfidfVectorizer().fit(subreddits)
# subreddit_vecs = subreddit_vectorizer.transform(subreddits)

Device set to use cuda


In [9]:
dataset_with_features, rule_lookup, subreddit_lookup = (
    add_classification_preds_rule_subreddit(
        dataset,
        rule_vectorizer,
        rule_vecs,
        features,
        subreddit_classifier,
        subreddits,
        logger=logger,
    )
)

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

2025-08-02 16:44:19 | reddit_moderation | INFO | Starting prediction for column 'rule'
2025-08-02 16:44:19 | reddit_moderation | INFO | Processing 2 texts from batch


2025-08-02 16:44:19 | reddit_moderation | INFO | Sample 0: top scores ['ban advertising:0.765', 'ban spam:0.383', 'ban illegal content:0.137'], above threshold: 2, chosen: 2
2025-08-02 16:44:19 | reddit_moderation | INFO | Sample 1: top scores ['enforce respectful conduct:0.000', 'enforce on topic content:0.000', 'ban low effort content:0.000'], above threshold: 0, chosen: 1
2025-08-02 16:44:19 | reddit_moderation | INFO | Prediction completed for 2 samples on column 'rule'
2025-08-02 16:44:19 | reddit_moderation | INFO | Sample predictions: ['ban advertising, ban spam', 'enforce respectful conduct']


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

2025-08-02 16:44:20 | reddit_moderation | INFO | Starting NLI prediction for column 'subreddit'
2025-08-02 16:44:20 | reddit_moderation | INFO | Processing 100 texts with NLI classifier
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 7: using best label news with score 0.788
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 10: using best label question answers with score 0.512
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 15: using best label food with score 0.755
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 30: using best label niche with score 0.544
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 33: using best label relationship with score 0.657
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 41: using best label nsfw with score 0.678
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 45: using best label regional with score 0.757
2025-08-02 16:44:31 | reddit_moderation | INFO | Sample 46: using best label relationship with sc

Attach predictions:   0%|          | 0/10185 [00:00<?, ? examples/s]

In [10]:
free_vars(
    vars_to_delete=[
        "dataset",
        "rule_vectorizer",
        "rule_vecs",
        "features",
        "subreddits",
        "subreddit_classifier",
    ],
    namespace=globals(),
    logger=logger,
)

2025-08-02 16:44:31 | reddit_moderation | INFO | Deleted variable 'dataset'
2025-08-02 16:44:31 | reddit_moderation | INFO | Deleted variable 'rule_vectorizer'
2025-08-02 16:44:31 | reddit_moderation | INFO | Deleted variable 'rule_vecs'
2025-08-02 16:44:31 | reddit_moderation | INFO | Deleted variable 'features'
2025-08-02 16:44:31 | reddit_moderation | INFO | Deleted variable 'subreddits'
2025-08-02 16:44:31 | reddit_moderation | INFO | Deleted variable 'subreddit_classifier'
2025-08-02 16:44:31 | reddit_moderation | INFO | GPU memory freed: 0.00 MB
2025-08-02 16:44:32 | reddit_moderation | INFO | RAM memory freed: 0.00 MB
RAM freed: 0.00 MB (1888.94 -> 1888.94)
GPU allocated freed: 0.00 MB (550.63 -> 550.63)
GPU reserved freed: 0.00 MB (578.00 -> 578.00)


# modelling


In [27]:
# basemodel = get_custom_roberta(MODEL_PATH["classifier"])
basemodel = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH["classifier"])
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH["classifier"])
print(f"{tokenizer.pad_token = } | {tokenizer.eos_token = }")

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at ../model/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer.pad_token = '<pad>' | tokenizer.eos_token = '</s>'


In [12]:
basemodel

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
         

## freeze and unfreeze stuff


In [13]:
def _freeze_module(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = False


def _unfreeze_module(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = True


def _get_layer(model, layer_idx: int):
    """
    Returns one transformer block (RobertaLayer) by index.
    layer_idx = 0 .. 23 for roberta-large
    """
    return model.layer[layer_idx]


def _cast_module(module: nn.Module, dtype: torch.dtype):
    """recursively casts module parameters & buffers"""
    for p in module.parameters(recurse=False):
        if p.dtype == torch.float32:
            p.data = p.data.to(dtype=dtype)
    for b in module.buffers(recurse=False):
        if b.dtype == torch.float32:
            b.data = b.data.to(dtype=dtype)
    for child in module.children():
        _cast_module(child, dtype)

## create the dataset


In [14]:
dataset = dataset_with_features.to_pandas()
free_vars(["dataset_with_features"], namespace=globals(), logger=logger)

2025-08-02 16:44:32 | reddit_moderation | INFO | Deleted variable 'dataset_with_features'
2025-08-02 16:44:32 | reddit_moderation | INFO | GPU memory freed: 0.00 MB
2025-08-02 16:44:32 | reddit_moderation | INFO | RAM memory freed: 0.00 MB
RAM freed: 0.00 MB (1898.63 -> 1898.63)
GPU allocated freed: 0.00 MB (8.12 -> 8.12)
GPU reserved freed: 0.00 MB (20.00 -> 20.00)


In [15]:
def _make_prompt(row):
    prompt = f"""Rule: {row['rule']}
Subreddit: {row['subreddit']} ({row['predicted_subreddit_feature']})
Content restrictions: {row['predicted_rule_feature']}

Comment: "{row['comment']}"

Question: Does this comment violate the rule?
Answer:"""
    return prompt


dataset["prompt"] = dataset.apply(_make_prompt, axis=1)

In [16]:
train, val = train_test_split(
    dataset,
    test_size=0.3,
    random_state=42,
    shuffle=True,
    stratify=dataset["violation"].astype(str) + "-" + dataset["rule"],
)

In [17]:
if tokenizer.eos_token is None:
    print("yes")

In [18]:
class CommentDataset(data.Dataset):
    def __init__(self, prompts, labels, tokenizer):
        super().__init__()
        self.prompts = prompts
        self.labels = labels
        self.tokenizer = tokenizer
        self.tokenizer.padding_side = "left"
        if self.tokenizer.pad_token is None:
            if not self.tokenizer.eos_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.pad_token = "[PAD]"

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        encodings = self.tokenizer(
            self.prompts[index], padding="max_length", max_length=512, truncation=True
        )
        attention_mask = encodings["attention_mask"]
        input_ids = encodings["input_ids"]
        label = self.labels[index]
        return (
            torch.tensor(input_ids),
            torch.tensor(attention_mask),
            torch.tensor(label),
        )

In [19]:
train_dataset = CommentDataset(
    prompts=train["prompt"].tolist(),
    labels=train["violation"].tolist(),
    tokenizer=tokenizer,
)
val_dataset = CommentDataset(
    prompts=val["prompt"].tolist(),
    labels=val["violation"].tolist(),
    tokenizer=tokenizer,
)

train_dataloader = data.DataLoader(
    train_dataset, batch_size=32, num_workers=4, pin_memory=False
)
val_dataloader = data.DataLoader(
    val_dataset, batch_size=32, num_workers=4, pin_memory=False
)

In [20]:
attn, in_id, label = next(iter(train_dataloader))
print(f"{attn.shape = } | {in_id.shape = } | {label.shape = }")

attn.shape = torch.Size([32, 512]) | in_id.shape = torch.Size([32, 512]) | label.shape = torch.Size([32])


In [21]:
clean_mem()

RAM freed: 0.00 MB (1900.61 -> 1900.61)
GPU allocated freed: 0.00 MB (8.12 -> 8.12)
GPU reserved freed: 0.00 MB (20.00 -> 20.00)


In [22]:
basemodel

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
         

In [28]:
basemodel(in_id, attn)

: 

## define metrics


In [20]:
def column_averaged_auc(logits, labels):
    """
    Compute AUC score for binary classification.

    Parameters
    ----------
    logits : torch.Tensor
        Raw model outputs of shape (batch_size, 1)
    labels : torch.Tensor
        Ground truth labels of shape (batch_size, 1)

    Returns
    -------
    float
        AUC score
    """
    # convert to probabilities using sigmoid for binary classification
    probs = F.sigmoid(logits).squeeze().detach().cpu().numpy()
    labels_np = labels.squeeze().detach().cpu().numpy()

    # handle edge case where all labels are same class
    if len(set(labels_np)) == 1:
        return 0.5

    return roc_auc_score(y_true=labels_np, y_score=probs)

## finetune model


first freeze the entire model


In [21]:
DEVICE = torch.device("cuda")

now make the custom model by adding some more layers


In [22]:
class ProgressLogger(Callback):  # Callback already imported
    """
    Hugging-Face-style progress table that prints *once* per
    `log_every_n_steps` and once after every validation epoch.

    • Overwrites the previous table instead of appending new prints.
    • Adds columns lazily when the metric first appears
      → no more NaN spam.
    """

    def __init__(self, log_every_n_steps: int = 50):
        super().__init__()
        self.n = log_every_n_steps
        self._rows: list[dict] = []  # accumulated rows
        self._cols: list[str] = ["step", "epoch"]
        self._last_step = -1  # to suppress duplicates
        self._last_len = 0  # last #chars printed

    # ----------------- helpers ----------------- #
    @staticmethod
    def _scalar(x):
        if hasattr(x, "detach"):
            x = x.detach().cpu()
        return float(x)

    def _collect(self, step: int, epoch: int, metrics: dict):
        row = {"step": step, "epoch": epoch}
        for k, v in metrics.items():
            if k in {"step", "epoch"}:
                continue
            # add a new column the first time we see this metric
            if k not in self._cols:
                self._cols.append(k)
            try:
                row[k] = self._scalar(v)
            except Exception:
                pass  # skip non-scalars
        self._rows.append(row)

    def _print_table(self):
        df = pd.DataFrame(self._rows, columns=self._cols)
        with pd.option_context("display.float_format", "{:.5f}".format):
            table = df.to_string(index=False)

        # wipe previous printout
        sys.stdout.write("\r" + " " * self._last_len + "\r")
        sys.stdout.write(table)
        sys.stdout.flush()
        self._last_len = len(table)

    # ---------------- Lightning hooks ---------------- #
    def on_train_batch_end(self, trainer, pl_module, *_):
        step = trainer.global_step
        if step and step % self.n == 0 and step != self._last_step:
            self._collect(step, trainer.current_epoch, trainer.callback_metrics)
            self._print_table()
            self._last_step = step

    def on_validation_epoch_end(self, trainer, pl_module):
        step = trainer.global_step
        if step != self._last_step:  # skip if already printed for this step
            self._collect(step, trainer.current_epoch, trainer.callback_metrics)
            self._print_table()
            sys.stdout.write("\n")  # newline so checkpoint msgs start clean
            self._last_step = step

In [23]:
class MlpHead(nn.Module):
    def __init__(
        self,
        input_dim: int = 1024,
        num_hidden_layers: int = 3,
        num_output_classes: int = 1,
        hidden_dim: int = 1024,
        dropout: float = 0.2,
        *args,
        **kwargs
    ):
        super().__init__()
        hidden_layers = [
            nn.Linear(in_features=input_dim, out_features=hidden_dim, bias=True),
            nn.Dropout(p=dropout),
            nn.ReLU(),
        ]
        for i in range(num_hidden_layers - 1):
            out_features = (
                (hidden_dim // 2) if (i == (num_hidden_layers - 2)) else hidden_dim
            )
            layer = [
                nn.Linear(in_features=hidden_dim, out_features=out_features, bias=True),
                nn.Dropout(p=dropout),
                nn.ReLU(),
            ]
            hidden_layers.extend(layer)
        self.hidden_layers = nn.Sequential(*hidden_layers)
        self.classifier_head = nn.Linear(
            in_features=out_features, out_features=num_output_classes
        )

    def forward(self, x):
        hidden_layer_output = self.hidden_layers(x)
        out = self.classifier_head(hidden_layer_output)
        return out.flatten()

In [24]:
class BasedRedditMod(L.LightningModule):
    def __init__(
        self,
        basemodel,
        num_hidden_layers: int = 3,
        num_output_classes: int = 1,
        hidden_dim: int = 1024,
        lr: float = 3e-4,
        dropout: float = 0.2,
        model_save_path: str = "",
        # scheduler parameters
        scheduler_type: str = "cosine_warmup",
        warmup_epochs: int = 2,
        max_epochs: int = 20,
        # step-based logging and validation
        log_every_n_steps: int = 100,
        val_check_interval: int = 500,
        save_every_n_steps: int = 1000,
        # early stopping parameters
        early_stopping_patience: int = 3,
        early_stopping_min_delta: float = 0.001,
        early_stopping_monitor: str = "val_auroc",
        *args,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.basemodel = basemodel
        self.lr = lr
        self.model_save_path = model_save_path
        self.scheduler_type = scheduler_type
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs

        # step-based configuration
        self.log_every_n_steps = log_every_n_steps
        self.val_check_interval = val_check_interval
        self.save_every_n_steps = save_every_n_steps

        # freeze base model weights
        for param in self.basemodel.parameters():
            param.requires_grad = True

        # early stopping config
        self.early_stopping_config = {
            "patience": early_stopping_patience,
            "min_delta": early_stopping_min_delta,
            "monitor": early_stopping_monitor,
        }

        self.mlphead = MlpHead(
            input_dim=basemodel.config.hidden_size,
            num_hidden_layers=num_hidden_layers,
            num_output_classes=num_output_classes,
            hidden_dim=hidden_dim,
            dropout=dropout,
        )
        self.loss_fn = nn.BCEWithLogitsLoss()

        # epoch-based metrics
        self.train_acc = BinaryAccuracy()
        self.train_auroc = BinaryAUROC()
        self.val_acc = BinaryAccuracy()
        self.val_auroc = BinaryAUROC()

        # step-based metrics for intermediate logging
        self.train_acc_steps = BinaryAccuracy()
        self.train_auroc_steps = BinaryAUROC()
        self.val_acc_steps = BinaryAccuracy()
        self.val_auroc_steps = BinaryAUROC()

        # freeze, unfreeze and cast
        _freeze_module(self.basemodel.embeddings)
        # _cast_module(reddit_mod.basemodel.embeddings, torch.bfloat16)
        for idx in range(0, 12 // 2):
            _freeze_module(_get_layer(self.basemodel.encoder, idx))
        # _cast_module(_get_layer(reddit_mod.basemodel.encoder, idx), torch.bfloat16)
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        print(
            f"Trainable params: {trainable/1e6:.1f} M  /  Total params: {total/1e6:.1f} M"
        )
        # _cast_module(self, torch.bfloat16)

    def get_callbacks(self):
        """
        Create all necessary callbacks for training.

        Returns
        -------
        list
            List of configured PyTorch Lightning callbacks
        """
        callbacks = []

        # early stopping callback
        early_stop = EarlyStopping(
            monitor=self.early_stopping_config["monitor"],
            patience=self.early_stopping_config["patience"],
            min_delta=self.early_stopping_config["min_delta"],
            mode="max",
            verbose=True,
            strict=True,
        )
        callbacks.append(early_stop)

        # step-based checkpoint callback
        step_checkpoint = ModelCheckpoint(
            dirpath=f"{self.model_save_path}/step_checkpoints/",
            filename="model-step-{step:06d}-{val_auroc:.4f}",
            every_n_train_steps=self.save_every_n_steps,
            save_top_k=3,
            monitor=self.early_stopping_config["monitor"],
            mode="max",
            save_last=True,
            verbose=True,
        )
        callbacks.append(step_checkpoint)

        # epoch-based checkpoint callback (best model)
        epoch_checkpoint = ModelCheckpoint(
            dirpath=f"{self.model_save_path}/epoch_checkpoints/",
            filename="best-model-{epoch:02d}-{val_auroc:.4f}",
            monitor=self.early_stopping_config["monitor"],
            mode="max",
            save_top_k=1,
            save_last=False,
            verbose=True,
        )
        callbacks.append(epoch_checkpoint)

        return callbacks

    def load_best_model(self, checkpoint_dir: str = None):
        """
        Load the best model checkpoint based on validation metrics.

        Parameters
        ----------
        checkpoint_dir : str, optional
            Directory containing checkpoints. If None, uses model_save_path
        """
        import glob
        import os

        if checkpoint_dir is None:
            checkpoint_dir = f"{self.model_save_path}/epoch_checkpoints/"

        # find best checkpoint file
        checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "best-model-*.ckpt"))

        if not checkpoint_files:
            print("No best model checkpoint found")
            return

        # load the best checkpoint (should be only one due to save_top_k=1)
        best_checkpoint = checkpoint_files[0]
        print(f"Loading best model from: {best_checkpoint}")

        # load state dict
        checkpoint = torch.load(best_checkpoint, map_location=self.device)
        self.load_state_dict(checkpoint["state_dict"])

    def forward(self, input_ids, attention_mask):
        # base model forward pass (frozen)
        with torch.no_grad():
            pooled = self.basemodel(
                input_ids=input_ids, attention_mask=attention_mask
            ).pooler_output
        return self.mlphead(pooled).unsqueeze(1)

    def training_step(self, batch, batch_idx, *args, **kwargs):
        input_ids, attention_mask, labels = batch
        logits = self(input_ids, attention_mask)

        # ensure correct dtypes and shapes for bce loss
        labels_float = labels.float().unsqueeze(1)
        loss = self.loss_fn(logits, labels_float)

        # compute probabilities for metrics
        probs = torch.sigmoid(logits)
        labels_int = labels_float.int()

        # update epoch-based metrics
        self.train_acc.update(probs, labels_int)
        self.train_auroc.update(probs, labels_int)

        # update step-based metrics
        self.train_acc_steps.update(probs, labels_int)
        self.train_auroc_steps.update(probs, labels_int)

        # always log loss
        self.log("train_loss", loss, on_step=True, prog_bar=True)

        # step-based metric logging
        if self.global_step % self.log_every_n_steps == 0:
            self._log_step_metrics(is_train=True)

        return loss

    def validation_step(self, batch, batch_idx, *args, **kwargs):
        input_ids, attention_mask, labels = batch
        logits = self(input_ids, attention_mask)

        # ensure correct dtypes and shapes for bce loss
        labels_float = labels.float().unsqueeze(1)
        loss = self.loss_fn(logits, labels_float)

        # compute probabilities for metrics
        probs = torch.sigmoid(logits)
        labels_int = labels_float.int()

        # update epoch-based metrics
        self.val_acc.update(probs, labels_int)
        self.val_auroc.update(probs, labels_int)

        # update step-based metrics
        self.val_acc_steps.update(probs, labels_int)
        self.val_auroc_steps.update(probs, labels_int)

        # log validation loss
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def _log_step_metrics(self, is_train=True):
        """Log metrics at specified step intervals."""
        prefix = "train" if is_train else "val"
        acc_metric = self.train_acc_steps if is_train else self.val_acc_steps
        auroc_metric = self.train_auroc_steps if is_train else self.val_auroc_steps

        # compute step-based metrics
        acc_score = acc_metric.compute()
        auroc_score = auroc_metric.compute()

        # log step metrics
        self.log(f"{prefix}_acc_steps", acc_score, on_step=True, prog_bar=True)
        self.log(f"{prefix}_auroc_steps", auroc_score, on_step=True, prog_bar=True)
        self.log("global_step", float(self.global_step), on_step=True)

        # reset step metrics for next window
        acc_metric.reset()
        auroc_metric.reset()

    def on_validation_epoch_end(self):
        """Compute and log epoch-aggregated validation metrics."""
        # compute epoch metrics
        val_acc_epoch = self.val_acc.compute()
        val_auroc_epoch = self.val_auroc.compute()

        # log epoch metrics
        self.log("val_acc", val_acc_epoch, prog_bar=True)
        self.log("val_auroc", val_auroc_epoch, prog_bar=True)

        # reset metrics for next epoch
        self.val_acc.reset()
        self.val_auroc.reset()

    def on_train_epoch_end(self):
        """Compute and log epoch-aggregated training metrics."""
        # compute epoch metrics
        train_acc_epoch = self.train_acc.compute()
        train_auroc_epoch = self.train_auroc.compute()

        # log epoch metrics
        self.log("train_acc", train_acc_epoch, prog_bar=True)
        self.log("train_auroc", train_auroc_epoch, prog_bar=True)

        # reset metrics for next epoch
        self.train_acc.reset()
        self.train_auroc.reset()

    def configure_optimizers(self):
        """Configure optimizer and learning rate scheduler."""
        # optimizer setup - only optimize mlp head parameters
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            betas=(0.9, 0.98),
            eps=1e-6,
            weight_decay=0.1,
        )

        if self.scheduler_type == "cosine_warmup":
            return self._configure_cosine_warmup_scheduler(optimizer)
        elif self.scheduler_type == "linear_warmup":
            return self._configure_linear_warmup_scheduler(optimizer)
        else:
            return optimizer

    def _configure_cosine_warmup_scheduler(self, optimizer):
        """Configure cosine annealing scheduler with warmup."""
        warmup_steps = self.warmup_epochs
        total_steps = self.max_epochs

        # warmup scheduler
        warmup_scheduler = LinearLR(
            optimizer, start_factor=0.01, total_iters=warmup_steps
        )

        # cosine annealing scheduler
        cosine_scheduler = CosineAnnealingLR(
            optimizer, T_max=total_steps - warmup_steps, eta_min=self.lr * 0.01
        )

        # combine warmup + cosine
        scheduler = SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, cosine_scheduler],
            milestones=[warmup_steps],
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
                "monitor": "val_auroc",
            },
        }

    def _configure_linear_warmup_scheduler(self, optimizer):
        """Configure linear decay scheduler with warmup."""
        warmup_steps = self.warmup_epochs
        total_steps = self.max_epochs

        # warmup scheduler
        warmup_scheduler = LinearLR(
            optimizer, start_factor=0.01, total_iters=warmup_steps
        )

        # linear decay scheduler
        decay_scheduler = LinearLR(
            optimizer,
            start_factor=1.0,
            end_factor=0.01,
            total_iters=total_steps - warmup_steps,
        )

        # combine warmup + decay
        scheduler = SequentialLR(
            optimizer,
            schedulers=[warmup_scheduler, decay_scheduler],
            milestones=[warmup_steps],
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
                "monitor": "val_auroc",
            },
        }

In [25]:
reddit_mod = BasedRedditMod(
    basemodel=basemodel,
    lr=2e-6,
    scheduler_type="cosine_warmup",
    warmup_epochs=1,
    max_epochs=5,
    model_save_path="./training_outputs",
    # step-based configuration
    log_every_n_steps=len(train_dataloader) // 8,
    val_check_interval=len(train_dataloader) // 4,
    save_every_n_steps=len(train_dataloader) // 4,
    # early stopping
    early_stopping_patience=5,
    early_stopping_monitor="val_auroc",
)

Trainable params: 45.5 M  /  Total params: 127.0 M


In [33]:
reddit_mod

BasedRedditMod(
  (basemodel): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm)

In [36]:
# create model with learning rate scheduling
# get existing callbacks


print(
    summary(
        reddit_mod,
        input_size=[(64, 514)] * 2,
        dtypes=[torch.long, torch.long],
        col_names=[
            # "input_size",
            "output_size",
            "num_params",
            "trainable",
        ],  # include trainable_params column
    )
)

/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [15,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [15,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [15,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [15,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [15,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [15,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: index

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [RobertaEmbeddings: 2, Embedding: 3, Embedding: 3, Embedding: 3, LayerNorm: 3, Dropout: 3]

In [37]:
clean_mem()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [28]:
callbacks = reddit_mod.get_callbacks()

# add HF-style progress bar
# hf_progress = HuggingFaceStyleProgressBar()
progress = ProgressLogger(log_every_n_steps=reddit_mod.log_every_n_steps)
callbacks.append(progress)
trainer = L.Trainer(
    callbacks=callbacks,
    accelerator="gpu",
    max_epochs=reddit_mod.max_epochs,
    # step-based logging configuration
    log_every_n_steps=reddit_mod.log_every_n_steps,
    # step-based validation configuration
    val_check_interval=reddit_mod.val_check_interval,
    # enable checkpointing
    enable_checkpointing=True,
    # enable progress bar for step tracking
    enable_progress_bar=True,
    precision=16,
    gradient_clip_val=1.0,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [29]:
trainer.fit(reddit_mod, train_dataloader, val_dataloader)

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name              | Type              | Params | Mode 
-----------------------------------------------------------------
0  | basemodel         | RobertaModel      | 124 M  | train
1  | mlphead           | MlpHead           | 2.4 M  | train
2  | loss_fn           | BCEWithLogitsLoss | 0      | train
3  | train_acc         | BinaryAccuracy    | 0      | train
4  | train_auroc       | BinaryAUROC       | 0      | train
5  | val_acc           | BinaryAccuracy    | 0      | train
6  | val_auroc         | BinaryAUROC       | 0      | train
7  | train_acc_steps   | BinaryAccur

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

RuntimeError: The expanded size of the tensor (2048) must match the existing size (514) at non-singleton dimension 1.  Target sizes: [64, 2048].  Tensor sizes: [1, 514]

# generate predictions
