# MallowsPO — Unified notebook

**Purpose**: This notebook is an implementation of https://openreview.net/forum?id=d8cnezVcaW. It is written to be practical and production-minded while still readable for experimentation.

**What you'll find**:

- Configuration (dataclass) and environment checks
- Utilities: tokenization helpers, log-prob computation, entropy-based dispersion estimator (as in the paper)
- Losses: DPO, MallowsPO-θ and MallowsPO-ϕ implementations
- A Trainer class with options for standard PyTorch training, Hugging Face Accelerate, and optional DeepSpeed (ZeRO)
- Example dataset loader for pairwise preference data (JSONL) and a runnable training snippet

---


## 0 — Install dependencies (run if packages missing)

Uncomment and run the cell below in your environment to install required packages. In many production setups, you
already have `torch`, `transformers`, `accelerate`, and `datasets` available. `trl` (for DPO helpers) and `deepspeed`
are optional.


In [None]:

# !pip install --upgrade pip
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118  # or choose CPU variant
# !pip install transformers accelerate datasets evaluate
# !pip install trl  # optional helpers for RLHF/DPO
# !pip install deepspeed  # optional; only if you want ZeRO stage 3 via DeepSpeed
# !pip install safetensors  # recommended for models
print('Installation cell — uncomment to install packages (if needed).')

## 1 — Imports & basic environment checks

In [None]:

import os, math, json, time, random, logging
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Tuple, Callable, Any
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedTokenizerBase
from transformers import get_linear_schedule_with_warmup, AdamW
from datasets import load_dataset

# optional
try:
    import deepspeed
    _HAS_DEEPSPEED = True
except Exception:
    _HAS_DEEPSPEED = False

print('torch:', torch.__version__)
print('deepspeed available:', _HAS_DEEPSPEED)


## 2 — Configuration dataclass (merged)

In [None]:

@dataclass
class MallowsPOConfig:
    # Model / tokenizer
    model_name: str = 'gpt2'  # change to your SFT / base model (e.g., 'pythia-2.8b' or a local path)
    ref_model_name: Optional[str] = None  # reference model (if None, uses the same model as `model_name`)
    tokenizer_name: Optional[str] = None  # if None, will use model_name

    # Training
    lr: float = 1e-6
    weight_decay: float = 0.0
    batch_size: int = 8
    micro_batch_size: Optional[int] = None  # for gradient accumulation
    epochs: int = 1
    max_length: int = 512

    # DPO / MallowsPO specifics
    beta: float = 0.1  # temperature-like KL weight in DPO
    mallows_variant: str = 'theta'  # 'theta' or 'phi' or 'dpo' for vanilla DPO
    phi_scale: float = 1.0  # ϕ* scaling used in the paper for the dispersion estimator (ϕ*)
    use_ref_logits: bool = True  # whether to use a reference model for KL term

    # Optimization / accelerator
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    use_deepspeed: bool = False  # set True to enable DeepSpeed engine (user-supplied config needed)
    deepspeed_config: Optional[dict] = None

    # Logging and saving
    out_dir: str = './mallowspo_checkpoints'
    save_every_steps: int = 5000
    seed: int = 42

    # Misc
    clip_grad_norm: float = 1.0

    def __post_init__(self):
        if self.tokenizer_name is None:
            self.tokenizer_name = self.model_name
        if self.ref_model_name is None:
            self.ref_model_name = self.model_name


## 3 — Utilities: tokenization, log-probs, entropy & dispersion estimator

In [None]:

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def build_tokenizer(model_name: str, config: MallowsPOConfig, use_fast=True) -> PreTrainedTokenizerBase:
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast)
    # Ensure tokenizer has pad token for batching
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    return tokenizer

def compute_logprobs(model, tokenizer, prompts: List[str], responses: List[str], device='cuda', max_length=512):
    # compute token-level and sequence log-probs of responses given prompts via causal LM
    model.eval()
    all_logprobs = []
    with torch.no_grad():
        for prompt, response in zip(prompts, responses):
            # concatenate prompt + response (ensuring separation)
            text = prompt + response
            inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length).to(device)
            # get labels such that only response tokens are scored
            # find where response starts in tokenized input
            prompt_input = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=max_length).to(device)
            prompt_len = prompt_input.input_ids.shape[1]
            outputs = model(**inputs, output_hidden_states=False, return_dict=True)
            logits = outputs.logits  # (1, seq_len, vocab)
            # compute logprobs for tokens that belong to response
            response_ids = inputs.input_ids[0, prompt_len:]
            if response_ids.nelement() == 0:
                all_logprobs.append(torch.tensor(0.0, device=device))
                continue
            logits_resp = logits[0, prompt_len-1:-1, :]  # model predicts token t from previous tokens
            # gather log probs
            log_probs = F.log_softmax(logits_resp, dim=-1)
            token_logps = log_probs[range(len(response_ids)), response_ids].sum()
            all_logprobs.append(token_logps)
    return torch.stack(all_logprobs)  # shape (N,)

def sequence_entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
    # logits: (seq_len, vocab) or (1, seq_len, vocab)
    if logits.dim() == 3:
        logits = logits[0]
    probs = F.softmax(logits, dim=-1)
    logp = F.log_softmax(logits, dim=-1)
    ent = -(probs * logp).sum(dim=-1)  # entropy per token
    # return mean entropy across tokens
    return ent.mean().item()

def estimate_dispersion_via_entropy(model, tokenizer, prompts: List[str], responses_w: List[str], responses_l: List[str],
                                    config: MallowsPOConfig, n_samples_per_pair:int=1, max_length:int=512) -> torch.Tensor:
    """Estimate negative-log dispersion proxy as in the paper (Eq (13)-(14)).
    We'll estimate predictive entropy of next-token predictions averaged across pairs.
    Returns a positive scalar per prompt (neg-log dispersion estimate)."""
    model.eval()
    device = config.device
    entropies = []
    with torch.no_grad():
        for prompt, yw, yl in zip(prompts, responses_w, responses_l):
            # token-level entropy approximation: compute next-token predictive entropy across the sequence using model logits
            # We follow the paper's simplified estimator: average next-token entropies for tokens from both sequences.
            text_w = prompt + yw
            text_l = prompt + yl
            inputs_w = tokenizer(text_w, return_tensors='pt', truncation=True, max_length=max_length).to(device)
            inputs_l = tokenizer(text_l, return_tensors='pt', truncation=True, max_length=max_length).to(device)
            out_w = model(**inputs_w, return_dict=True)
            out_l = model(**inputs_l, return_dict=True)
            ent_w = sequence_entropy_from_logits(out_w.logits)
            ent_l = sequence_entropy_from_logits(out_l.logits)
            entropies.append(0.5 * (ent_w + ent_l))
    entropies = torch.tensor(entropies, dtype=torch.float32)
    # Normalize by log(k^N) approximation; here k ~ vocab_size, and we scale using phi_scale to match the paper's
    vocab_size = model.config.vocab_size if hasattr(model, 'config') else tokenizer.vocab_size
    # number of tokens considered per example — approximate by average sequence length
    # Use a small epsilon to avoid log(0)
    avg_len = max(1.0, entropies.size(0))
    denom = math.log(vocab_size + 1e-12)
    # per-prompt neg-log dispersion estimator: -phi_scale * log(H / log k)
    H = entropies.clamp(min=1e-12)
    neg_log_disp = -config.phi_scale * torch.log(H / (denom + 1e-12))
    return neg_log_disp  # tensor shape (N,)


## 4 — Loss functions: DPO, MallowsPO-θ, MallowsPO-ϕ

In [None]:

def dpo_pair_loss(logp_w: torch.Tensor, logp_l: torch.Tensor, beta: float = 0.1) -> torch.Tensor:
    # logp_* are log-probabilities (sequence-level) under the policy model compared to a reference
    # The DPO objective uses: -log sigma(beta * (logp_w - logp_l))
    diff = beta * (logp_w - logp_l)
    loss = -F.logsigmoid(diff)
    return loss.mean()

def mallows_theta_pair_loss(logp_w: torch.Tensor, logp_l: torch.Tensor, neg_log_disp: torch.Tensor, beta: float = 0.1):
    # neg_log_disp: per-example scaling factor c(x) = -2 log phi(x) (paper uses this scaling)
    # Here we accept neg_log_disp as already computed for each example
    diff = beta * (logp_w - logp_l)
    scaled = (-neg_log_disp * diff)  # note the sign: paper's Eq multiplies by -2 log phi (positive)
    loss = -F.logsigmoid(scaled)
    return loss.mean()

def g_phi(s: torch.Tensor, t: float):
    # Implements the g_{ϕ, t}(s) function (paper's Eq (11) variant) for Mallows-phi.
    # Here s can be a tensor of logits; t is phi(x) in (0, 1]. The function is piecewise.
    # We'll implement a numerically stable differentiable form as in the paper's definition.
    # For numeric stability we'll clamp t into (1e-6, 1-1e-6).
    eps = 1e-6
    t = max(eps, min(1 - eps, float(t)))
    # For vectorized usage, implement elementwise:
    s_pos = torch.clamp(s, min=0.0)
    s_neg = torch.clamp(s, max=0.0)
    # when s > 0:
    # g = (s+1)/(1 - t*s + 1) - s/(1 - t*s)  => simplifying for numeric stability
    # we'll compute using safe arithmetic:
    # Note: to avoid divide-by-zero if 1 - t*|s| is zero, add tiny eps.
    denom1 = (1 - t * (s_pos + 1)).clamp(min=1e-8)
    denom2 = (1 - t * s_pos).clamp(min=1e-8)
    part_pos = (s_pos + 1) / denom1 - s_pos / denom2
    # when s < 0 (use -s):
    s_abs = torch.abs(s_neg)
    denom1n = (1 - t * (s_abs + 1)).clamp(min=1e-8)
    denom2n = (1 - t * s_abs).clamp(min=1e-8)
    part_neg = 1 - ( (s_abs + 1) / denom1n - s_abs / denom2n )
    return torch.where(s >= 0, part_pos, part_neg)

def mallows_phi_pair_loss(logp_w: torch.Tensor, logp_l: torch.Tensor, phi_vals: torch.Tensor, beta: float = 0.1):
    # phi_vals: per-example phi(x) in (0,1]; we convert to floats and compute link g_phi for each example
    diff = beta * (logp_w - logp_l)
    losses = []
    for d, t in zip(diff, phi_vals):
        # g_phi expects scalar t and returns probability in (0,1); we minimize -log g_phi(d)
        g = g_phi(d.unsqueeze(0), float(t.item() if isinstance(t, torch.Tensor) else t))
        # Avoid log(0)
        g = g.clamp(min=1e-12, max=1-1e-12)
        losses.append(-torch.log(g))
    return torch.stack(losses).mean()


## 5 — Dataset & dataloader (pairwise preferences JSONL)

In [None]:

class PreferencePairsDataset(Dataset):
    """Expect lines (JSONL) with keys: { 'prompt': str, 'winner': str, 'loser': str }."""
    def __init__(self, path_or_list, tokenizer: PreTrainedTokenizerBase, max_length:int=512):
        # path_or_list: either path to jsonl or an in-memory list of dicts
        if isinstance(path_or_list, str):
            assert os.path.exists(path_or_list), f"File not found: {path_or_list}"
            with open(path_or_list, 'r', encoding='utf-8') as f:
                self.data = [json.loads(line) for line in f]
        else:
            self.data = list(path_or_list)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = item['prompt']
        return {
            'prompt': prompt,
            'winner': item['winner'],
            'loser': item['loser']
        }

def collate_pairs(batch, tokenizer: PreTrainedTokenizerBase, max_length: int = 512):
    prompts = [b['prompt'] for b in batch]
    winners = [b['winner'] for b in batch]
    losers = [b['loser'] for b in batch]
    return prompts, winners, losers


## 6 — Trainer (supports vanilla torch / accelerate / optional DeepSpeed)

In [None]:

class MallowsPOTrainer:
    def __init__(self, config: MallowsPOConfig):
        self.cfg = config
        set_seed(config.seed)
        self.device = torch.device(config.device)
        self.tokenizer = build_tokenizer(config.tokenizer_name, config)
        # load model and reference model (if requested)
        self.model = AutoModelForCausalLM.from_pretrained(config.model_name, torch_dtype=torch.float16 if 'cuda' in config.device else None)
        if self.model.get_input_embeddings() is None:
            raise RuntimeError('Model missing embeddings?')
        # resize token embeddings if tokenizer added tokens
        self.model.resize_token_embeddings(len(self.tokenizer))

        if config.use_ref_logits and config.ref_model_name and config.ref_model_name != config.model_name:
            self.ref_model = AutoModelForCausalLM.from_pretrained(config.ref_model_name, torch_dtype=torch.float16 if 'cuda' in config.device else None)
            self.ref_model.resize_token_embeddings(len(self.tokenizer))
        else:
            self.ref_model = None

        self.model.to(self.device)
        if self.ref_model is not None:
            self.ref_model.to(self.device)
            self.ref_model.eval()

        # optimizer & scheduler
        no_decay = ['bias', 'LayerNorm.weight']
        params = [
            {'params': [p for n,p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.cfg.weight_decay},
            {'params': [p for n,p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        self.optimizer = AdamW(params, lr=self.cfg.lr)
        self.scheduler = None
        self.global_step = 0

    def _compute_pair_logprobs(self, prompts: List[str], winners: List[str], losers: List[str]):
        # compute sequence logprobs under current model and reference model (if any)
        # returns four tensors: logp_w, logp_l, ref_logp_w (or None), ref_logp_l (or None)
        logp_w = compute_logprobs(self.model, self.tokenizer, prompts, winners, device=self.device, max_length=self.cfg.max_length)
        logp_l = compute_logprobs(self.model, self.tokenizer, prompts, losers, device=self.device, max_length=self.cfg.max_length)
        if self.ref_model is not None:
            ref_w = compute_logprobs(self.ref_model, self.tokenizer, prompts, winners, device=self.device, max_length=self.cfg.max_length)
            ref_l = compute_logprobs(self.ref_model, self.tokenizer, prompts, losers, device=self.device, max_length=self.cfg.max_length)
        else:
            ref_w = None
            ref_l = None
        return logp_w, logp_l, ref_w, ref_l

    def train(self, train_dataset: PreferencePairsDataset):
        dataloader = DataLoader(train_dataset, batch_size=self.cfg.batch_size, shuffle=True, collate_fn=lambda b: collate_pairs(b, self.tokenizer, self.cfg.max_length))
        total_steps = len(dataloader) * self.cfg.epochs
        # setup scheduler
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=50, num_training_steps=total_steps)
        self.model.train()
        for epoch in range(self.cfg.epochs):
            for prompts, winners, losers in dataloader:
                # compute logs
                logp_w, logp_l, ref_w, ref_l = self._compute_pair_logprobs(prompts, winners, losers)
                # optional KL / ref adjustment: DPO uses log(pi/pi_ref) term; here we compute sequence log-prob difference
                # For simplicity, treat ref logits absent case as using model reference (which cancels out)
                if self.cfg.mallows_variant.lower() == 'dpo':
                    loss = dpo_pair_loss(logp_w, logp_l, beta=self.cfg.beta)
                elif self.cfg.mallows_variant.lower() == 'theta':
                    # estimate dispersion per prompt
                    neg_log_disp = estimate_dispersion_via_entropy(self.model, self.tokenizer, prompts, winners, losers, self.cfg)
                    loss = mallows_theta_pair_loss(logp_w, logp_l, neg_log_disp, beta=self.cfg.beta)
                elif self.cfg.mallows_variant.lower() == 'phi':
                    phi_vals = estimate_dispersion_via_entropy(self.model, self.tokenizer, prompts, winners, losers, self.cfg)
                    # convert neg-log-disp back to phi approx: phi = exp(-neg_log_disp) clipped to (eps,1)
                    phi_vals = torch.exp(-phi_vals).clamp(min=1e-6, max=0.999999).to(self.device)
                    loss = mallows_phi_pair_loss(logp_w, logp_l, phi_vals, beta=self.cfg.beta)
                else:
                    raise ValueError('Unknown mallows variant: ' + str(self.cfg.mallows_variant))

                # backprop
                loss.backward()
                # gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.clip_grad_norm)
                self.optimizer.step()
                if self.scheduler is not None:
                    self.scheduler.step()
                self.optimizer.zero_grad()
                self.global_step += 1

                if self.global_step % 100 == 0:
                    print(f"step {self.global_step} epoch {epoch} loss {loss.item():.6f}")

                # optional save
                if self.global_step % self.cfg.save_every_steps == 0:
                    os.makedirs(self.cfg.out_dir, exist_ok=True)
                    save_path = os.path.join(self.cfg.out_dir, f'checkpoint-step-{self.global_step}')
                    self.model.save_pretrained(save_path)
                    self.tokenizer.save_pretrained(save_path)

        # final save
        os.makedirs(self.cfg.out_dir, exist_ok=True)
        final_path = os.path.join(self.cfg.out_dir, 'final')
        self.model.save_pretrained(final_path)
        self.tokenizer.save_pretrained(final_path)
        print('Training finished — model saved to', final_path)


## 7 — Example usage

Below is a minimal example showing how to prepare the dataset and run training. Adjust model names and paths for your environment.

In [None]:

# Example: prepare config and run
cfg = MallowsPOConfig(
    model_name='gpt2',  # replace with a larger model path if available
    ref_model_name=None,
    tokenizer_name=None,
    lr=1e-6,
    batch_size=2,
    epochs=1,
    beta=0.1,
    mallows_variant='theta',  # choose 'theta'|'phi'|'dpo'
    device='cuda' if torch.cuda.is_available() else 'cpu',
    out_dir='./mallowspo_demo_ckpt',
    save_every_steps=1000000  # turn off frequent saving in demo
)

# simple toy dataset (in-memory)
toy_pairs = [
    {'prompt': 'Translate to French: Hello', 'winner': 'Bonjour', 'loser': 'Salut test'},
    {'prompt': '2+2 = ?', 'winner': '4', 'loser': '3'}
]

trainer = MallowsPOTrainer(cfg)
dataset = PreferencePairsDataset(toy_pairs, trainer.tokenizer, max_length=128)
trainer.train(dataset)


## 8 — Notes, caveats and recommended next steps

- This notebook intentionally favors clarity and portability over micro-optimizations. For large-scale fine-tuning
  (billions of parameters), prefer a production code path with DeepSpeed/ZeRO stage 3 or Hugging Face Accelerate.
- The dispersion estimator used here follows the paper heuristics but is approximate. The authors compute a
  normalized average of predictive entropies across tokens from the output sequences; we implemented a robust
  approximation using model logits.
- For exact reproduction of the authors' experiments (Pythia 2.8B / Llama 3 8B), adapt the trainer to use
  `trl`'s DPO helpers and your cluster's DeepSpeed/Accelerate configuration. The original repo's scripts are succinct
  and were used as guidance; this notebook unifies them and explains the choices.

---

### Files produced
- `/mnt/data/MallowsPO_unified_notebook.ipynb` — this notebook file (download below).
