# imports


In [1]:
# import wandb

# wandb.init(mode="disabled")

In [2]:
import os
import warnings
from typing import List, Literal, Union

import lightning as L
import numpy as np
import pandas as pd
import torch
from lightning.pytorch import callbacks as lcb
from scipy.special import softmax
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import (
    classification_report,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from torch import nn, optim, utils
from torch.nn import functional as F
from torch.optim import lr_scheduler
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score
from transformers import (
    AutoTokenizer,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaPreTrainedModel,
)

In [3]:
ON_KAGGLE = all(os.path.exists(p) for p in ["/kaggle/working", "/kaggle/input"])

if ON_KAGGLE:
    INPUT_PATH = os.path.join("/", "kaggle", "input")
    OUTPUT_PATH = os.path.join("/", "kaggle", "working")
    _MODEL_VERSION_PATH = os.path.join(
        "transformers",
        "default",
        "1",
    )
    _MODEL_DIR = os.path.join("/", "kaggle", "input")
    MODEL_PATH = {
        "classifier": os.path.join(_MODEL_DIR, "roberta-base", _MODEL_VERSION_PATH),
    }
else:
    INPUT_PATH = os.path.join("..", "data")
    OUTPUT_PATH = os.path.join(".", "rulewise-output")
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    _MODEL_DIR = os.path.join("..", "model")
    MODEL_PATH = {
        "classifier": "FacebookAI/roberta-base",
    }

TRY_PROBABILITY_CALIBRATION = False

In [4]:
os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
warnings.simplefilter("ignore")

# utilities


In [5]:
import gc
import inspect
import logging
import os
import sys
import time
import traceback
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Union

import psutil
import torch


def clean_mem():
    # import gc
    # import os
    # import sys
    # import time
    # import traceback

    # import psutil
    # import torch

    process = psutil.Process(os.getpid())

    # Measure RAM before cleanup
    ram_before = process.memory_info().rss / (1024**2)  # in MB

    # Measure GPU before cleanup
    if torch.cuda.is_available():
        gpu_alloc_before = torch.cuda.memory_allocated() / (1024**2)  # in MB
        gpu_reserved_before = torch.cuda.memory_reserved() / (1024**2)  # in MB
    else:
        gpu_alloc_before = gpu_reserved_before = 0

    # clean all traceback
    if hasattr(sys, "last_traceback"):
        traceback.clear_frames(sys.last_traceback)
        delattr(sys, "last_traceback")
    if hasattr(sys, "last_type"):
        delattr(sys, "last_type")
    if hasattr(sys, "last_value"):
        delattr(sys, "last_value")

    # clean all ipython history
    if "get_ipython" in globals():
        try:
            from IPython import get_ipython

            ip = get_ipython()
            user_ns = ip.user_ns
            ip.displayhook.flush()
            pc = ip.displayhook.prompt_count + 1
            for n in range(1, pc):
                user_ns.pop("_i" + repr(n), None)
            user_ns.update(dict(_i="", _ii="", _iii=""))
            hm = ip.history_manager
            hm.input_hist_parsed[:] = [""] * pc
            hm.input_hist_raw[:] = [""] * pc
            hm._i = hm._ii = hm._iii = hm._i00 = ""
        except Exception as e:
            print("ipython mem could not be cleared")

    # do a garbage collection and flush cuda cache
    gc.collect()
    torch.cuda.empty_cache()

    # Give system a small moment to settle (helps RAM measurement be more accurate)
    time.sleep(0.1)

    # Measure RAM after cleanup
    ram_after = process.memory_info().rss / (1024**2)  # in MB

    # Measure GPU after cleanup
    if torch.cuda.is_available():
        gpu_alloc_after = torch.cuda.memory_allocated() / (1024**2)  # in MB
        gpu_reserved_after = torch.cuda.memory_reserved() / (1024**2)  # in MB
    else:
        gpu_alloc_after = gpu_reserved_after = 0

    # Report freed memory
    print(
        f"RAM freed: {ram_before - ram_after:.2f} MB ({ram_before:.2f} -> {ram_after:.2f})"
    )
    if torch.cuda.is_available():
        print(
            f"GPU allocated freed: {gpu_alloc_before - gpu_alloc_after:.2f} MB ({gpu_alloc_before:.2f} -> {gpu_alloc_after:.2f})"
        )
        print(
            f"GPU reserved freed: {gpu_reserved_before - gpu_reserved_after:.2f} MB ({gpu_reserved_before:.2f} -> {gpu_reserved_after:.2f})"
        )
    else:
        print("No GPU detected.")


def create_logger(
    name: str = "reddit_moderation",
    log_level: str = "INFO",
    log_file: Optional[Union[str, Path]] = None,
    log_dir: Optional[Union[str, Path]] = "logs",
    console_output: bool = True,
    file_output: bool = True,
    format_string: Optional[str] = None,
    max_bytes: int = 10_000_000,  # 10MB
    backup_count: int = 5,
    include_timestamp_in_filename: bool = True,
) -> logging.Logger:
    """
    Create a fully featured logger for the Reddit comment moderation system.

    This logger is designed to handle all aspects of the multi-stage classification
    pipeline including zero-shot classification, fine-tuning, and evaluation.

    Parameters
    ----------
    name : str, optional
        Logger name, by default "reddit_moderation"
    log_level : str, optional
        Logging level ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
        by default "INFO"
    log_file : str or Path, optional
        Specific log file path. If None, auto-generates based on name and timestamp
    log_dir : str or Path, optional
        Directory for log files, by default "logs"
    console_output : bool, optional
        Whether to output logs to console, by default True
    file_output : bool, optional
        Whether to output logs to file, by default True
    format_string : str, optional
        Custom log format string, by default None (uses comprehensive format)
    max_bytes : int, optional
        Maximum log file size before rotation, by default 10MB
    backup_count : int, optional
        Number of backup log files to keep, by default 5
    include_timestamp_in_filename : bool, optional
        Whether to include timestamp in log filename, by default True

    Returns
    -------
    logging.Logger
        Configured logger instance ready for use

    Examples
    --------
    >>> # Basic usage
    >>> logger = create_logger()
    >>> logger.info("Starting Reddit comment classification pipeline")

    >>> # Advanced usage for training
    >>> training_logger = create_logger(
    ...     name="distilbert_training",
    ...     log_level="DEBUG",
    ...     log_file="training_session.log"
    ... )
    >>> training_logger.debug("Training batch processed")

    >>> # For evaluation only
    >>> eval_logger = create_logger(
    ...     name="model_evaluation",
    ...     console_output=False,
    ...     log_file="evaluation_results.log"
    ... )
    """

    # Create logger
    logger = logging.getLogger(name)
    logger.setLevel(getattr(logging, log_level.upper()))

    # Clear existing handlers to avoid duplication
    logger.handlers.clear()

    # Default comprehensive format for ML workflows
    if format_string is None:
        format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"

    formatter = logging.Formatter(format_string, datefmt="%Y-%m-%d %H:%M:%S")

    # Console handler
    if console_output:
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setLevel(getattr(logging, log_level.upper()))
        console_handler.setFormatter(formatter)
        logger.addHandler(console_handler)

    # File handler with rotation
    if file_output:
        # Create log directory
        if log_dir:
            log_dir = Path(log_dir)
            log_dir.mkdir(exist_ok=True)

        # Generate log filename
        if log_file is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            if include_timestamp_in_filename:
                log_filename = f"{name}_{timestamp}.log"
            else:
                log_filename = f"{name}.log"
            log_file = log_dir / log_filename if log_dir else Path(log_filename)
        else:
            log_file = Path(log_file)
            if log_dir and not log_file.is_absolute():
                log_file = Path(log_dir) / log_file

        # Create rotating file handler
        from logging.handlers import RotatingFileHandler

        file_handler = RotatingFileHandler(
            log_file, maxBytes=max_bytes, backupCount=backup_count, encoding="utf-8"
        )
        file_handler.setLevel(getattr(logging, log_level.upper()))
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    # Add some useful methods to the logger
    def log_dataset_info(dataset, dataset_name="Dataset"):
        """Log dataset information"""
        logger.info(f"{dataset_name} Info:")
        logger.info(f"  - Size: {len(dataset):,} samples")
        logger.info(f"  - Columns: {dataset.column_names}")
        if "labels" in dataset.column_names:
            import numpy as np

            labels = np.array(dataset["labels"])
            unique, counts = np.unique(labels, return_counts=True)
            logger.info(f"  - Label distribution: {dict(zip(unique, counts))}")

    def log_model_info(model, model_name="Model"):
        """Log model information"""
        logger.info(f"{model_name} Info:")
        if hasattr(model, "config"):
            logger.info(f"  - Model type: {model.config.model_type}")
            logger.info(f"  - Hidden size: {model.config.hidden_size}")
            if hasattr(model.config, "num_labels"):
                logger.info(f"  - Number of labels: {model.config.num_labels}")

        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logger.info(f"  - Total parameters: {total_params:,}")
        logger.info(f"  - Trainable parameters: {trainable_params:,}")

    def log_training_args(training_args):
        """Log training arguments"""
        logger.info("Training Configuration:")
        logger.info(f"  - Learning rate: {training_args.learning_rate}")
        logger.info(f"  - Batch size: {training_args.per_device_train_batch_size}")
        logger.info(
            f"  - Gradient accumulation: {training_args.gradient_accumulation_steps}"
        )
        logger.info(f"  - Epochs: {training_args.num_train_epochs}")
        logger.info(f"  - Weight decay: {training_args.weight_decay}")
        logger.info(f"  - LR scheduler: {training_args.lr_scheduler_type}")
        logger.info(f"  - Warmup ratio: {training_args.warmup_ratio}")

    def log_metrics(metrics, stage=""):
        """Log evaluation metrics"""
        stage_prefix = f"{stage} " if stage else ""
        logger.info(f"{stage_prefix}Metrics:")
        for metric, value in metrics.items():
            if isinstance(value, float):
                logger.info(f"  - {metric}: {value:.4f}")
            else:
                logger.info(f"  - {metric}: {value}")

    # Attach utility methods to logger
    logger.log_dataset_info = log_dataset_info
    logger.log_model_info = log_model_info
    logger.log_training_args = log_training_args
    logger.log_metrics = log_metrics

    # Log logger creation
    logger.info(f"Logger '{name}' created successfully")
    logger.info(f"Log level: {log_level}")
    if file_output:
        logger.info(f"Log file: {log_file}")

    return logger


# Convenience function for quick setup
def setup_project_logging(debug_mode: bool = False) -> logging.Logger:
    """
    Quick setup for the Reddit moderation project logging.

    Parameters
    ----------
    debug_mode : bool
        If True, sets log level to DEBUG and enables verbose logging

    Returns
    -------
    logging.Logger
        Configured project logger
    """
    log_level = "DEBUG" if debug_mode else "INFO"

    return create_logger(
        name="reddit_moderation_pipeline",
        log_level=log_level,
        log_dir="project_logs",
        include_timestamp_in_filename=True,
    )


def get_ram_usage():
    process = psutil.Process()
    return process.memory_info().rss  # bytes


def free_vars(
    vars_to_delete: List[Union[str, object]],
    namespace: Optional[dict] = None,
    try_gpu: bool = True,
    logger=None,
):
    """
    Deletes variables by name or reference, frees RAM and GPU (PyTorch) memory,
    logs actions via logger if provided.

    Args:
      vars_to_delete: list of variable names (str) or object refs
      namespace: dict to remove names from (defaults to caller's globals())
      try_gpu: clear GPU memory for torch objects
      logger: logging object or None (use print)
    Returns:
      (freed_ram_bytes, freed_gpu_bytes)
    """
    # Setup logger if not provided
    if logger is None:

        def logger(msg):
            print(msg)

    else:
        logger = logger.info

    # Automatic namespace resolution
    if namespace is None:
        # Get frame of the caller, locals then globals
        frame = inspect.currentframe().f_back
        namespace = frame.f_globals

    before_ram = get_ram_usage()
    try:
        import torch
    except ImportError:
        torch = None

    freed_gpu_bytes = 0
    torch_objs = []
    deleted = []

    for var in vars_to_delete:
        if isinstance(var, str):
            obj = namespace.get(var, None)
            if obj is not None:
                deleted.append(var)
                if torch and try_gpu:
                    torch_objs.append(obj)
                del namespace[var]
                logger(f"Deleted variable '{var}'")
            else:
                logger(f"Variable '{var}' not found in namespace")
        else:
            # Try to remove all names referencing the object
            names = [n for n, v in namespace.items() if v is var]
            for n in names:
                del namespace[n]
                deleted.append(n)
                logger(f"Deleted variable '{n}' (by reference)")
            if not names:
                logger(
                    f"Could not find a variable name for object {var!r}, may not be deleted"
                )
            if torch and try_gpu:
                torch_objs.append(var)

    if torch and try_gpu and torch_objs and torch.cuda.is_available():
        before_gpu = torch.cuda.memory_allocated()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        after_gpu = torch.cuda.memory_allocated()
        freed_gpu_bytes = after_gpu - before_gpu
        logger(f"GPU memory freed: {freed_gpu_bytes/(1024**2):.2f} MB")
    # Always run gc
    gc.collect()
    after_ram = get_ram_usage()
    freed_ram_bytes = after_ram - before_ram
    logger(f"RAM memory freed: {freed_ram_bytes/(1024**2):.2f} MB")
    clean_mem()
    # return freed_ram_bytes, freed_gpu_bytes

In [6]:
import markdown2
from bs4 import BeautifulSoup
import re
from unidecode import unidecode


def sanitize_comment(comment):
    # Convert markdown to HTML, then extract the text (HTML tags removed)
    html = markdown2.markdown(comment)
    text = BeautifulSoup(html, features="html.parser").get_text()

    text = re.sub(r"\[([^\]]+)\]\(([^)]+)\)", r"\1", comment)
    # Then re-run markdown2 and extract text again to clean up
    html = markdown2.markdown(text)
    text = BeautifulSoup(html, features="html.parser").get_text()

    url_pattern = re.compile(r"((?:http|https)://[^\s]+|www\.[^\s]+)", re.IGNORECASE)
    text = url_pattern.sub(lambda m: m.group(0), text)

    # Convert non-unicode characters to unicode (ASCII compatible)
    text = unidecode(text)

    # Normalize whitespace
    text = " ".join(text.split()).lower()

    return text

# read datasets from path


In [7]:
logger = create_logger(name="rulewise")

2025-08-17 18:09:00 | rulewise | INFO | Logger 'rulewise' created successfully
2025-08-17 18:09:00 | rulewise | INFO | Log level: INFO
2025-08-17 18:09:00 | rulewise | INFO | Log file: logs/rulewise_20250817_180900.log


In [8]:
logger.info(f"{ON_KAGGLE = }")
logger.info(f"{INPUT_PATH = }")
logger.info(f"{OUTPUT_PATH = }")
logger.info(f"{_MODEL_DIR = }")
logger.info(f"{MODEL_PATH = }")
logger.info(f"{TRY_PROBABILITY_CALIBRATION = }")

2025-08-17 18:09:00 | rulewise | INFO | ON_KAGGLE = False
2025-08-17 18:09:00 | rulewise | INFO | INPUT_PATH = '../data'
2025-08-17 18:09:00 | rulewise | INFO | OUTPUT_PATH = './rulewise-output'
2025-08-17 18:09:00 | rulewise | INFO | _MODEL_DIR = '../model'
2025-08-17 18:09:00 | rulewise | INFO | MODEL_PATH = {'classifier': 'FacebookAI/roberta-base'}
2025-08-17 18:09:00 | rulewise | INFO | TRY_PROBABILITY_CALIBRATION = False


In [9]:
train = pd.read_csv(
    os.path.join(INPUT_PATH, "jigsaw-agile-community-rules", "train.csv")
)
test = pd.read_csv(os.path.join(INPUT_PATH, "jigsaw-agile-community-rules", "test.csv"))
submission = pd.read_csv(
    os.path.join(INPUT_PATH, "jigsaw-agile-community-rules", "sample_submission.csv")
)

## clean the dataset


In [10]:
for c in [
    "body",
    "rule",
    "subreddit",
    "positive_example_1",
    "positive_example_2",
    "negative_example_1",
    "negative_example_2",
]:
    logger.info(f"Cleaning {c = }")
    train[c] = train[c].apply(sanitize_comment)
    test[c] = test[c].apply(sanitize_comment)

2025-08-17 18:09:00 | rulewise | INFO | Cleaning c = 'body'
2025-08-17 18:09:01 | rulewise | INFO | Cleaning c = 'rule'
2025-08-17 18:09:01 | rulewise | INFO | Cleaning c = 'subreddit'
2025-08-17 18:09:01 | rulewise | INFO | Cleaning c = 'positive_example_1'
2025-08-17 18:09:02 | rulewise | INFO | Cleaning c = 'positive_example_2'
2025-08-17 18:09:02 | rulewise | INFO | Cleaning c = 'negative_example_1'
2025-08-17 18:09:03 | rulewise | INFO | Cleaning c = 'negative_example_2'


## melt the dataset

there will be 3 parts

-   actual training
-   training examples
-   testing examples


In [11]:
train_main = (
    train[["row_id", "body", "rule", "rule_violation"]]
    .rename(columns={"rule_violation": "label"})
    .assign(split="train")
)
train_examples = pd.concat(
    [
        (
            train[["row_id", "rule", "positive_example_1", "positive_example_2"]]
            .melt(id_vars=["row_id", "rule"], value_name="body")
            .assign(label=1, split="train")
            .drop(columns=["variable"])
        ),
        (
            train[["row_id", "rule", "negative_example_1", "negative_example_2"]]
            .melt(id_vars=["row_id", "rule"], value_name="body")
            .assign(label=0, split="train")
            .drop(columns=["variable"])
        ),
    ]
)
test_examples = pd.concat(
    [
        (
            test[["row_id", "rule", "positive_example_1", "positive_example_2"]]
            .melt(id_vars=["row_id", "rule"], value_name="body")
            .assign(label=1, split="test")
            .drop(columns=["variable"])
        ),
        (
            test[["row_id", "rule", "negative_example_1", "negative_example_2"]]
            .melt(id_vars=["row_id", "rule"], value_name="body")
            .assign(label=0, split="test")
            .drop(columns=["variable"])
        ),
    ]
)

In [12]:
fulldata = pd.concat(
    [train_main, train_examples[train_main.columns], test_examples[train_main.columns]],
    ignore_index=True,
)

# since we are using the examples there will be duplicates
fulldata = fulldata.drop_duplicates(subset=["body", "rule", "label"])
fulldata.to_csv(os.path.join(OUTPUT_PATH, "fulldata.csv"), index=False)

In [13]:
free_vars(
    ["train", "train_main", "train_examples", "test_examples"],
    namespace=globals(),
    logger=logger,
)

2025-08-17 18:09:03 | rulewise | INFO | Deleted variable 'train'
2025-08-17 18:09:03 | rulewise | INFO | Deleted variable 'train_main'
2025-08-17 18:09:03 | rulewise | INFO | Deleted variable 'train_examples'
2025-08-17 18:09:03 | rulewise | INFO | Deleted variable 'test_examples'
2025-08-17 18:09:04 | rulewise | INFO | RAM memory freed: 16.25 MB
RAM freed: 0.00 MB (892.02 -> 892.02)
No GPU detected.


## create dataset and dataloader

when we are creating the dataloader, we should specify the rule


In [14]:
class CommentDataset(utils.data.Dataset):
    def __init__(self, dataset, tokenizer, *args, **kwargs):
        super().__init__()
        dataset = dataset
        self.row_id = dataset["row_id"].tolist()
        self.rule = dataset["rule"].tolist()
        self.body = dataset["body"].tolist()
        self.label = dataset["label"].tolist()

        rule_encodings = tokenizer(self.rule, truncation=True, padding="max_length")
        comment_encodings = tokenizer(self.body, truncation=True, padding="max_length")

        self.rule_input_ids = rule_encodings["input_ids"]
        self.rule_attention_mask = rule_encodings["attention_mask"]
        self.comment_input_ids = comment_encodings["input_ids"]
        self.comment_attention_mask = comment_encodings["attention_mask"]

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        return (
            torch.tensor(self.rule_input_ids[index]),
            torch.tensor(self.rule_attention_mask[index]),
            torch.tensor(self.comment_input_ids[index]),
            torch.tensor(self.comment_attention_mask[index]),
            torch.tensor(self.label[index]),
            torch.tensor(self.row_id[index]),
        )

In [15]:
def prepare_datasets(fulldata, test, test_size, rule):
    if isinstance(rule, str):
        rule = [rule]
    subdata = fulldata[fulldata["rule"].isin(rule)]
    subtest = test[test["rule"].isin(rule)]

    logger.info(f"{len(subdata)}/{len(fulldata)} rows remain in fulldata.")
    logger.info(f"{len(subtest)}/{len(test)} rows remain in test.")
    train, val = train_test_split(
        subdata,
        test_size=test_size,
        shuffle=True,
        random_state=42,
        stratify=subdata["label"],
    )
    subtest["label"] = 0
    columns = ["row_id", "body", "rule", "label"]
    return train[columns], val[columns], subtest[columns]

In [16]:
def prepare_dataloaders(train, val, test, tokenizer, batch_size=8):
    train_dataset = CommentDataset(train, tokenizer)
    train_dataloader = utils.data.DataLoader(
        train_dataset,
        drop_last=True,
        shuffle=True,
        # num_workers=4,
        batch_size=batch_size,
    )
    val_dataset = CommentDataset(val, tokenizer)
    val_dataloader = utils.data.DataLoader(
        val_dataset,
        drop_last=True,
        shuffle=True,
        # num_workers=4,
        batch_size=2 * batch_size,
    )
    test_dataset = CommentDataset(test, tokenizer)
    test_dataloader = utils.data.DataLoader(
        test_dataset,
        drop_last=False,
        shuffle=False,
        # num_workers=4,
        batch_size=4 * batch_size,
    )

    return train_dataloader, val_dataloader, test_dataloader

# modelling


In [17]:
logger.info(MODEL_PATH)

2025-08-17 18:09:06 | rulewise | INFO | {'classifier': 'FacebookAI/roberta-base'}


## define the model classes


### classifier head


In [18]:
class Classifier(nn.Module):
    def __init__(self, hidden_size: int, num_labels: int = 2, dropout: float = 0.2):
        super().__init__()

        # a layer norm to ensure that means and standard deviations are standardised
        self.layer_norm = nn.LayerNorm(hidden_size * 4)
        # now take a clf head
        # this should bring the compressed thingy
        # down to num_labels
        self.clf = nn.Sequential(
            nn.Linear(hidden_size * 4, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 4, num_labels),
        )

    def forward(self, rule_outputs, comment_outputs):
        # sentence similarity thingy from sbert did not work
        # back to just the difference
        diff = torch.abs(rule_outputs - comment_outputs)
        # actually back to more features
        # - rule output
        # - comment output
        # - diff
        # - elementwise product
        combined = torch.cat(
            [rule_outputs, comment_outputs, diff, rule_outputs * comment_outputs], dim=1
        )
        # apply layer norm
        combined = self.layer_norm(combined)
        # now do the classification
        logits = self.clf(combined)
        return logits

### difference roberta


In [19]:
class DifferenceRoberta(RobertaPreTrainedModel):
    config_class = RobertaConfig

    def __init__(self, config: RobertaConfig):
        super().__init__(config)
        _basemodel = RobertaForSequenceClassification(config)
        self.roberta = _basemodel.roberta
        self.classifier = Classifier(
            hidden_size=config.hidden_size, num_labels=config.num_labels, dropout=0.4
        )
        # self.classifier = _basemodel.classifier
        self.hidden_size = config.hidden_size
        free_vars(["_basemodel"], namespace=locals(), logger=logger)
        self.init_weights()
        logger.info(config)

    @staticmethod
    def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):
        """
        last_hidden_state: (B, T, H)
        attention_mask: (B, T) with 1 for real tokens, 0 for padding
        returns: (B, H) mean-pooled over token positions where attention_mask == 1
        """
        # convert mask to float on same device as last_hidden_state
        mask = attention_mask.unsqueeze(-1).to(
            dtype=last_hidden_state.dtype
        )  # (B, T, 1)
        # sum of hidden states where mask == 1
        masked_sum = (last_hidden_state * mask).sum(dim=1)  # (B, H)
        # number of real tokens per example
        lengths = mask.sum(dim=1).clamp(min=1.0)  # (B, 1) clamp avoid div-by-zero
        pooled = masked_sum / lengths  # broadcasting divides (B, H) / (B, 1)
        return pooled

    def forward(
        self,
        rule_input_ids=None,
        rule_attention_mask=None,
        comment_input_ids=None,
        comment_attention_mask=None,
        return_dict=True,
        **kwargs
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # get last_hidden_state for both inputs
        rule_lhs = self.roberta(
            input_ids=rule_input_ids,
            attention_mask=rule_attention_mask,
            return_dict=True,
        ).last_hidden_state  # (B, T_r, H)

        comment_lhs = self.roberta(
            input_ids=comment_input_ids,
            attention_mask=comment_attention_mask,
            return_dict=True,
        ).last_hidden_state  # (B, T_c, H)

        # replace CLS with mean pooling
        rule_outputs = self.mean_pool(rule_lhs, rule_attention_mask)  # (B, H)
        comment_outputs = self.mean_pool(comment_lhs, comment_attention_mask)  # (B, H)

        # base feature: difference (as before)
        logits = self.classifier(rule_outputs, comment_outputs)
        return logits

### tie everything together in lightning module


In [20]:
class BasedRedditMod(L.LightningModule):
    def __init__(
        self,
        diffroberta: DifferenceRoberta,
        # model_path: str,
        # num_labels: int,
        max_steps: int,
        weight: torch.Tensor = None,
        logger: logging.Logger = logger,
        *args,
        **kwargs,
    ):
        super().__init__()
        self.diffroberta = diffroberta
        self.loss_fn = nn.CrossEntropyLoss(weight=weight)
        self.max_steps = max_steps

        self.accuracy = BinaryAccuracy()
        self.auroc = BinaryAUROC()
        self.previous_auroc = 0

        self.mylogger = logger

    def training_step(self, batch, batch_idx):
        rid, ram, cid, cam, labels, row_ids = batch
        logits = self.diffroberta(rid, ram, cid, cam)
        loss = self.loss_fn(logits, labels)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx, *args, **kwargs):
        rid, ram, cid, cam, labels, row_ids = batch
        logits = self.diffroberta(rid, ram, cid, cam)
        loss = self.loss_fn(logits, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)

        # compute the metrics
        probs, preds = self._get_predictions_from_logits(logits, labels, row_ids)
        self.accuracy.update(preds=preds, target=labels)
        self.auroc.update(preds=probs, target=labels)

        self.log_dict(
            {
                "accuracy": self.accuracy.compute(),
                "auroc": self.auroc.compute(),
            },
            on_epoch=True,
            on_step=False,
            prog_bar=False,
        )
        return loss

    def on_validation_epoch_end(self):
        accuracy = self.accuracy.compute()
        auroc = self.auroc.compute()
        change = auroc - self.previous_auroc
        self.previous_auroc = auroc

        self.accuracy.reset()
        self.auroc.reset()

        lrs = [pg["lr"] for pg in self.trainer.optimizers[0].param_groups]

        # robust step lookup: prefer trainer.global_step, fall back to self.global_step or 0
        step = getattr(self.trainer, "global_step", None)
        if step is None:
            step = getattr(self, "global_step", 0)

        self.mylogger.info(
            f"#:{step}: acc = {accuracy*100.0:.2f} % | "
            f"auroc = {auroc:.2f} (D = {change:.3f}) | "
            f"lr = {lrs[0]:.2e}"
        )

    def predict_step(self, batch, batch_idx, *args, **kwargs):
        rid, ram, cid, cam, labels, row_ids = batch
        logits = self.diffroberta(rid, ram, cid, cam)
        probs, _ = self._get_predictions_from_logits(logits, labels, row_ids)
        return row_ids, probs, labels

    def _get_predictions_from_logits(self, logits, labels, row_ids):
        # get the probs and the preds
        probs_full = F.softmax(logits, dim=1)
        probs = probs_full[:, 1]
        preds = torch.argmax(probs_full, dim=1)

        assert preds.shape == probs.shape == labels.shape == row_ids.shape

        return probs, preds

    def configure_optimizers(self):
        lr = 2e-5
        wd = 0.01
        no_decay = ["bias", "LayerNorm.weight"]
        param_groups = [
            {
                "params": [
                    p
                    for n, p in self.diffroberta.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": wd,
            },
            {
                "params": [
                    p
                    for n, p in self.diffroberta.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimiser = optim.AdamW(param_groups, lr=lr)

        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer=optimiser, T_max=self.max_steps, eta_min=5e-6
        )

        try:
            self.mylogger.info(
                f"param_groups={len(optimiser.param_groups)} | max_steps={getattr(self,'max_steps',None)}"
            )
            for i, pg in enumerate(optimiser.param_groups):
                n = sum(p.numel() for p in pg.get("params", []) if p is not None)
                lr = pg.get("lr", None)
                wd = pg.get("weight_decay", None)
                self.mylogger.info(f"pg[{i}] params={n:,} lr={lr:.2e} wd={wd}")
        except Exception as e:
            self.mylogger.exception(f"_log_opt_info failed: {e}")

        return {
            "optimizer": optimiser,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

# training


### constants


In [21]:
def collate_predictions(submission_list):
    submission = pd.concat(submission_list, ignore_index=True)
    submission.to_csv(os.path.join(OUTPUT_PATH, "submission.csv"))
    return submission

In [22]:
BATCH_SIZE = 8
GRAD_ACCUMULATION_STEPS = 4
VALIDATION_PER_N_STEPS = 2 * BATCH_SIZE * GRAD_ACCUMULATION_STEPS
# EPOCHS = 5

logger.info(f"{BATCH_SIZE = }")
logger.info(f"{GRAD_ACCUMULATION_STEPS = }")
logger.info(f"{VALIDATION_PER_N_STEPS = }")

2025-08-17 18:09:06 | rulewise | INFO | BATCH_SIZE = 8
2025-08-17 18:09:06 | rulewise | INFO | GRAD_ACCUMULATION_STEPS = 4
2025-08-17 18:09:06 | rulewise | INFO | VALIDATION_PER_N_STEPS = 64


### get all rules


In [23]:
# get all the rules
all_rules = fulldata["rule"].unique().tolist()
id2rule = dict(enumerate(all_rules))
rule2id = {v: k for k, v in id2rule.items()}

logger.info(rule2id)

2025-08-17 18:09:06 | rulewise | INFO | {'no advertising: spam, referral links, unsolicited advertising, and promotional content are not allowed.': 0, 'no legal advice: do not offer or request legal advice.': 1}


## run loop to train


In [24]:
all_submissions = []
# ------------------------
# initialise the tokenizer
# ------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH["classifier"])
logger.info(f"{tokenizer.pad_token = } | {tokenizer.eos_token = }")

for i, rule in enumerate(all_rules):
    if i == 0:
        pass
    # -------------------------------
    # get the rule
    # make the output paths and stuff
    # -------------------------------
    rule_output_path = os.path.join(OUTPUT_PATH, f"rule-{rule2id[rule]}")
    os.makedirs(rule_output_path, exist_ok=True)
    logger.info(rule_output_path)

    # --------------------------
    # prepare the datasets
    # as well as the dataloaders
    # --------------------------
    train, val, subtest = prepare_datasets(fulldata.copy(), test.copy(), 0.25, rule)
    train_dataloader, val_dataloader, subtest_dataloader = prepare_dataloaders(
        train, val, subtest, tokenizer, batch_size=BATCH_SIZE
    )
    logger.info(
        f"train steps = {len(train_dataloader)} | val steps = {len(val_dataloader)}"
    )
    weight = torch.Tensor(
        train["label"].value_counts(normalize=True).sort_index().values
    )
    free_vars(["train", "val", "subtest"], namespace=globals(), logger=logger)

    # --------------------
    # initialise the model
    # --------------------
    diffroberta = DifferenceRoberta.from_pretrained(
        MODEL_PATH["classifier"], num_labels=2
    )
    model = BasedRedditMod(
        diffroberta=diffroberta,
        # model_path=MODEL_PATH["classifier"],
        # num_labels=2,
        max_steps=(4 * VALIDATION_PER_N_STEPS) // GRAD_ACCUMULATION_STEPS,
        weight=weight,
        logger=logger,
    )

    # ---------
    # callbacks
    # ---------
    # ideally i want 2 checkpoints to be saved
    # the best one
    # the latest one
    # i want the checkpoints to be saved immediately after validation check is performed
    checkpoint_callback = lcb.ModelCheckpoint(
        monitor="auroc",
        dirpath=rule_output_path,
        mode="max",
        save_top_k=1,
        save_last=True,
        save_on_train_epoch_end=False,
    )
    # have two early stopping callbacks
    # primary - stop if auroc does not improve much
    # secondary - stop if val loss does not improve much
    early_stopping_callback_auroc = lcb.EarlyStopping(
        monitor="auroc", min_delta=1e-4, patience=2, mode="max", verbose=True
    )
    early_stopping_callback_loss = lcb.EarlyStopping(
        monitor="val_loss", min_delta=1e-5, patience=4, mode="min", verbose=True
    )

    # ---------------------
    # customise the trainer
    # ---------------------
    trainer = L.Trainer(
        # limit_train_batches=2 * VALIDATION_PER_N_STEPS,  # this is only for rapid iteration
        max_steps=8 * VALIDATION_PER_N_STEPS,
        # max_epochs=1,
        accelerator="cuda",
        # devices=2,
        # train in mixed bf16 precision
        precision="bf16-mixed",
        # each training batch size is 8
        # accumulate gradients over 4 batches... so eff. batch size is 32
        accumulate_grad_batches=GRAD_ACCUMULATION_STEPS,
        # clip gradients' global norm to <=0.5 using gradient_clip_algorithm='norm' by default
        gradient_clip_val=0.5,
        # perform eval every 32 steps
        val_check_interval=VALIDATION_PER_N_STEPS,
        # model checkpointing
        default_root_dir=rule_output_path,
        # callbacks
        callbacks=[
            checkpoint_callback,
            early_stopping_callback_auroc,
            # early_stopping_callback_loss,
        ],
    )

    # -------------
    # fit the model
    # -------------
    model.train()
    trainer.fit(
        model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
    )
    free_vars(
        ["model", "train_dataloader"],
        namespace=globals(),
        logger=logger,
    )

    # --------------------------------------------
    # load the best model and use it for inference
    # --------------------------------------------
    best_ckpt = checkpoint_callback.best_model_path
    logger.info(f"Loading model for inference and next rule from {best_ckpt}")
    model = BasedRedditMod.load_from_checkpoint(
        best_ckpt,
        strict=False,
        diffroberta=diffroberta,
        # model_path=MODEL_PATH["classifier"],
        # num_labels=2,
        max_steps=(4 * VALIDATION_PER_N_STEPS) // GRAD_ACCUMULATION_STEPS,
        logger=logger,
    )

    # ----------------------------------------------
    # evaluate predictions on validation (TEMPORARY)
    # ----------------------------------------------
    model.eval()
    # this will return a list of length = len(dataloader)
    # each element will be a tuple of length 3
    # tuple = (row_ids, probs)
    outs = trainer.predict(model=model, dataloaders=val_dataloader)
    row_ids = torch.cat([o[0] for o in outs], dim=0).detach().cpu().numpy()
    probs = torch.cat([o[1] for o in outs], dim=0).detach().cpu().numpy()
    labels = torch.cat([o[2] for o in outs], dim=0).detach().cpu().numpy()
    preds = (probs > 0.5).astype(float)
    _evaluation_for_this_rule = pd.DataFrame(
        {"row_id": row_ids, "prob": probs, "label": labels, "pred": preds}
    )
    free_vars(
        ["outs", "row_ids", "probs", "labels", "preds"],
        namespace=globals(),
        logger=logger,
    )
    logger.info(
        "\n"
        + classification_report(
            _evaluation_for_this_rule["label"], _evaluation_for_this_rule["pred"]
        )
    )
    logger.info(
        f"auc = {roc_auc_score(_evaluation_for_this_rule['label'], _evaluation_for_this_rule['prob']):.2f}"
    )
    if i == 1:
        break

    # -----------------
    # write predictions
    # -----------------
    model.eval()
    # this will return a list of length = len(dataloader)
    # each element will be a tuple of length 3
    # tuple = (row_ids, probs)
    outs = trainer.predict(model=model, dataloaders=subtest_dataloader)
    row_ids = torch.cat([o[0] for o in outs], dim=0).detach().cpu().numpy()
    probs = torch.cat([o[1] for o in outs], dim=0).detach().cpu().numpy()
    free_vars(["outs"], namespace=globals(), logger=logger)

    # -------------------------------
    # perform probability calibration
    # -------------------------------
    # EXPERIMENTAL
    # perform prob calibration with isotonic regression
    # needs to be evaluated
    # try a submission without calibration and one with calibration
    if TRY_PROBABILITY_CALIBRATION and len(val_dataloader) > 1_000:
        reg = IsotonicRegression(y_min=0, y_max=1, out_of_bounds="clip")
        reg.fit(_evaluation_for_this_rule["prob"], _evaluation_for_this_rule["label"])
        calibrated_probs = reg.transform(probs)
    else:
        calibrated_probs = probs

    _submission_for_this_rule = pd.DataFrame(
        {"row_id": row_ids, "rule_violation": calibrated_probs}
    )
    all_submissions.append(_submission_for_this_rule)

    free_vars(
        [
            "model",
            "diffroberta",
            "val_dataloader",
            "subtest_dataloader",
            "row_ids",
            "probs",
            "trainer",
            "_evaluation_for_this_rule",
            "reg",
            "calibrated_probs",
        ],
        namespace=globals(),
        logger=logger,
    )

2025-08-17 18:09:06 | rulewise | INFO | tokenizer.pad_token = '<pad>' | tokenizer.eos_token = '</s>'
2025-08-17 18:09:06 | rulewise | INFO | ./rulewise-output/rule-0
2025-08-17 18:09:06 | rulewise | INFO | 860/1874 rows remain in fulldata.
2025-08-17 18:09:06 | rulewise | INFO | 9/10 rows remain in test.
2025-08-17 18:09:06 | rulewise | INFO | train steps = 80 | val steps = 13
2025-08-17 18:09:06 | rulewise | INFO | Deleted variable 'train'
2025-08-17 18:09:06 | rulewise | INFO | Deleted variable 'val'
2025-08-17 18:09:06 | rulewise | INFO | Deleted variable 'subtest'
2025-08-17 18:09:06 | rulewise | INFO | RAM memory freed: 0.00 MB
RAM freed: 0.00 MB (988.82 -> 988.82)
No GPU detected.
2025-08-17 18:09:08 | rulewise | INFO | Deleted variable '_basemodel'
2025-08-17 18:09:08 | rulewise | INFO | RAM memory freed: 0.00 MB
RAM freed: 0.00 MB (992.19 -> 992.19)
No GPU detected.
2025-08-17 18:09:09 | rulewise | INFO | RobertaConfig {
  "_attn_implementation_autoset": true,
  "architectures"

Some weights of DifferenceRoberta were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.clf.0.bias', 'classifier.clf.0.weight', 'classifier.clf.3.bias', 'classifier.clf.3.weight', 'classifier.clf.6.bias', 'classifier.clf.6.weight', 'classifier.layer_norm.bias', 'classifier.layer_norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MisconfigurationException: `CUDAAccelerator` can not run on your system since the accelerator is not available. The following accelerator(s) is available and can be passed into `accelerator` argument of `Trainer`: ['cpu'].

## save


In [None]:
collate_predictions(all_submissions)

Unnamed: 0,row_id,rule_violation
0,2029,0.046725
1,2031,0.880386
2,2032,0.886813
3,2033,0.882833
4,2034,0.040846
5,2035,0.869715
6,2036,0.042722
7,2037,0.040846
8,2038,0.880386
