In [None]:
# put this near the top of your file
def set_all_seeds(seed: int):
    import os, random, numpy as np
    from transformers.trainer_utils import set_seed as hf_set_seed
    hf_set_seed(seed)                     # covers python, numpy, torch (and sets cudnn flags)
    try:
        from accelerate.utils import set_seed as accel_set_seed
        accel_set_seed(seed)
    except Exception:
        pass
    # extra determinism (optional; may error on some ops)
    try:
        torch.use_deterministic_algorithms(False)  # True if you want strict determinism
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass


In [None]:
import json, os, time

class _JsonlWriter:
    def __init__(self, path):
        self.path = path
        os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
        self.f = open(path, "a", encoding="utf-8")
    def write(self, record: dict):
        self.f.write(json.dumps(record, ensure_ascii=False) + "\n"); self.f.flush()
    def close(self):
        try: self.f.close()
        except: pass


In [None]:
import copy, torch
import torch.nn as nn
import torch.nn.functional as F
from trl import DPOTrainer
from tqdm import tqdm
import math
import torch
from contextlib import nullcontext
from tqdm import tqdm

def _fmt_logs(d):
    parts = []
    for k, v in d.items():
        if isinstance(v, float):
            parts.append(f"{k}={v:.6f}")
        else:
            parts.append(f"{k}={v}")
    return " | ".join(parts)


# --------- utils: concat + continuation-only mask ----------
def _concat_ids(a, b): return torch.cat([a, b], dim=1)

def _cont_mask(prompt_mask, resp_mask):
    # 0 over prompt, 1 over response; then shift by 1 for next-token prediction
    zeros = torch.zeros_like(prompt_mask)
    return torch.cat([zeros, resp_mask], dim=1)[:, 1:]  # [B, T-1]

# --------- core: response-only log-prob SUM per sample ----------
def _response_logprob_sum_from_ids(model, full_ids, full_attn, cont_mask):
    """
    Sum of log p(y_t | x, y_<t>) over RESPONSE tokens only (masked by cont_mask).
    cont_mask: [B, T-1] with 1 over response positions and 0 elsewhere.
    """
    m = getattr(model, "module", model)
    emb = m.get_input_embeddings()
    e_dev = emb.weight.device

    # Start on the embedding device
    full_ids  = full_ids.to(e_dev, non_blocking=True)
    full_attn = full_attn.to(e_dev, dtype=torch.long, non_blocking=True)

    # Build embeddings explicitly on the correct device
    inputs_embeds = emb(full_ids)  # [B, T, H] on e_dev

    # Forward: keep everything starting on e_dev so the dispatcher can shard correctly
    outputs = m(inputs_embeds=inputs_embeds, attention_mask=full_attn, use_cache=False)
    logits  = outputs.logits                          # [B, T, V] (may live on another GPU)
    logp    = F.log_softmax(logits[:, :-1, :], dim=-1)

    # Move gather inputs to the logits' device before use
    dev     = logits.device
    targets = full_ids[:, 1:].to(dev, non_blocking=True)       # [B, T-1]
    msk     = cont_mask.to(dev, dtype=logp.dtype, non_blocking=True)

    tok_lp  = torch.gather(logp, 2, targets.unsqueeze(-1)).squeeze(-1)  # [B, T-1]
    return (tok_lp * msk).sum(dim=1)                                    # [B]



class AuxDPOTrainer(DPOTrainer):
    """
    Batchwise AuxDPO:
      - Uses response-only log-probs given prompt (same masking logic as your snippet).
      - DPO margin augmented by (delta_chosen - delta_rejected).
      - Null penalty: ||A_batch^T delta_batch||^2, where rows of A are ref-model grads of log p(y|x).
      - Anti-collapse: small incentive for ||delta_batch|| to not collapse to 0 (tanh-bounded).
    """
    def __init__(self,
                 *args,
                 lambda_null: float = 1.0,     # weight on ||A^T δ||^2
                 lambda_amp: float = 0.01,     # weight on -||δ||^2 (anti-collapse)
                 delta_cap: float = 1.0,       # bound |δ_i| ≤ delta_cap via tanh
                 aux_lr: float = 5e-3,         # LR for δ
                 seed: int = 42,
                 **kwargs):
        super().__init__(*args, **kwargs)
        set_all_seeds(int(seed))
        self.experiment_seed = int(seed)

        # 1) Frozen reference (if not provided)
        if self.ref_model is None:
            self.ref_model = copy.deepcopy(self.model)
            for p in self.ref_model.parameters(): p.requires_grad = False
            self.ref_model.eval()
        self._ref_on_device = False

        # 2) Global delta (size 2N); live on the main device so it’s trainable
        assert self.train_dataset is not None, "train_dataset required to size δ"
        self.N = len(self.train_dataset)
        # --- REGISTER δ ON THE MODEL (so DDP/ZeRO/Accelerate sees it) ---
        dev = self.accelerator.device
        self.model.register_parameter(
            "aux_delta_raw",
            nn.Parameter(torch.zeros(2 * self.N, dtype=torch.float32, device=dev))
        )
        self.delta_raw = self.model.aux_delta_raw
        # small noise so grad at step 0 isn’t exactly zero
        torch.nn.init.normal_(self.delta_raw, mean=0.0, std=1e-3)
        self.delta_cap = float(delta_cap)

        # 3) Choose ref params to differentiate (match trainable names; fallback to 'lora_')
        self._grad_params = []
        ref_named = dict(self.ref_model.named_parameters())
        for n, p in self.model.named_parameters():
            if p.requires_grad and (n in ref_named):
                self._grad_params.append(ref_named[n])
        if not self._grad_params:
            self._grad_params = [p for n, p in self.ref_model.named_parameters() if "lora_" in n]

        self.lambda_null = float(lambda_null)
        self.lambda_amp  = float(lambda_amp)
        self.delta_cap   = float(delta_cap)
        self.aux_lr      = float(aux_lr)

    def _ensure_ref_on_device(self):
        if self._ref_on_device:
            return
        dtype = next(self.model.parameters()).dtype
        # If ref is already sharded, leave placement alone
        if getattr(self.ref_model, "hf_device_map", None):
            self.ref_model.eval()
            for p in self.ref_model.parameters():
                p.requires_grad = False
        else:
            dev = self.accelerator.device
            self.ref_model.to(dev, dtype=dtype)
            self.ref_model.eval()
            for p in self.ref_model.parameters():
                p.requires_grad = False
        self._ref_on_device = True


    @staticmethod
    def _grad_norm(model) -> float:
        sq = []
        for p in model.parameters():
            if p.grad is not None:
                g = p.grad.detach().float()
                sq.append((g.norm(2) ** 2))
        return float(torch.sqrt(torch.stack(sq).sum()).item()) if sq else 0.0

    # ensure δ gets its own param group
    @staticmethod
    def _optimizer_has_param(optimizer, param) -> bool:
        for g in optimizer.param_groups:
            for p in g.get("params", []):
                if p is param:            # identity, not equality
                    return True
        return False

    def create_optimizer(self):
        super().create_optimizer()
        assert hasattr(self, "delta_raw")
        if not any(p is self.delta_raw
                   for g in self.optimizer.param_groups
                   for p in g.get("params", [])):
            self.optimizer.add_param_group({
                "params": [self.delta_raw],
                "lr": self.aux_lr,
                "weight_decay": 0.0
            })


    @staticmethod
    def _input_device_for(model):
        m = getattr(model, "module", model)
        try:
            return m.get_input_embeddings().weight.device
        except Exception:
            return next(m.parameters()).device  # last resort
    
    
    # --- response-only scores for chosen/rejected under a model ---
    def _pair_scores_ids(self, model, batch, need_grad: bool):
        dev = self._input_device_for(model)
        p_ids = batch["prompt_input_ids"].to(dev, non_blocking=True)
        p_msk = batch["prompt_attention_mask"].to(dev, non_blocking=True)
        c_ids = batch["chosen_input_ids"].to(dev, non_blocking=True)
        c_msk = batch["chosen_attention_mask"].to(dev, non_blocking=True)
        r_ids = batch["rejected_input_ids"].to(dev, non_blocking=True)
        r_msk = batch["rejected_attention_mask"].to(dev, non_blocking=True)
    
        ch_full, ch_mask = _concat_ids(p_ids, c_ids), _concat_ids(p_msk, c_msk)
        rj_full, rj_mask = _concat_ids(p_ids, r_ids), _concat_ids(p_msk, r_msk)
        ch_cont = _cont_mask(p_msk, c_msk)   # [B, T_ch-1], 1 over response tokens
        rj_cont = _cont_mask(p_msk, r_msk)   # [B, T_rj-1]
    
        ctx = torch.enable_grad() if need_grad else torch.no_grad()
        with ctx:
            ch = _response_logprob_sum_from_ids(model, ch_full, ch_mask, ch_cont)
            rj = _response_logprob_sum_from_ids(model, rj_full, rj_mask, rj_cont)
        return ch, rj



    # --- compute A_batch^T δ_batch via one weighted backward on ref-model ---
    def _AT_delta(self, batch, delta_chosen, delta_rejected):
        """
        Return ||A_batch^T delta_batch||^2 with grads ONLY w.r.t. delta.
        We:
          1) compute per-sample ∇_θ log p_ref with create_graph=False,
          2) immediately .detach() those grads,
          3) accumulate Σ_i δ_i * ∇_θ log p_i into a param-shaped vector,
          4) L2^2 of that vector.
        """
        if not self._grad_params:
            return torch.tensor(0.0, device=self.accelerator.device)
    
        # Temporarily ensure we can take grads w.r.t. ref params for autograd.grad
        toggled = []
        for p in self._grad_params:
            if not p.requires_grad:
                p.requires_grad_(True); toggled.append(p)
    
        # Need per-sample scalars with graph on (to take grads), but we won't keep
        # a higher-order graph: create_graph=False and we detach the results.
        ch_ref, rj_ref = self._pair_scores_ids(self.ref_model, batch, need_grad=True)  # [B], [B]
    
        # Accumulator (same shapes as params)
        # Accumulator (match dtype + device of each param)
        acc = [torch.zeros_like(p, device=p.device, dtype=p.dtype) for p in self._grad_params]
        
        B = ch_ref.shape[0]
        for j in range(B):
            g_ch = torch.autograd.grad(
                ch_ref[j], self._grad_params,
                retain_graph=True, create_graph=False, allow_unused=True
            )
            g_rj = torch.autograd.grad(
                rj_ref[j], self._grad_params,
                retain_graph=True, create_graph=False, allow_unused=True
            )
        
            for k, p_k in enumerate(self._grad_params):
                dev_k = p_k.device
        
                if g_ch[k] is not None:
                    djc = delta_chosen[j].to(dev_k)  # keep δ as a tensor (preserve grads)
                    acc[k] = acc[k] + djc * g_ch[k].detach()
        
                if g_rj[k] is not None:
                    djr = delta_rejected[j].to(dev_k)  # keep δ as a tensor (preserve grads)
                    acc[k] = acc[k] + djr * g_rj[k].detach()
        
        # ||A^T δ||^2 (use a stable dtype if params are bf16/fp16)
        # Initialize pen on a single device (e.g., accelerator device)
        pen = torch.zeros((), device=self.accelerator.device, dtype=torch.float32)
        
        for a in acc:
            # local sum on the tensor's device (keeps graph w.r.t. deltas)
            term = a.float().pow(2).sum()
            # move the scalar to pen's device before accumulating
            pen = pen + term.to(pen.device)

    
        # Turn ref params back off
        for p in toggled:
            p.requires_grad_(False)
    
        return pen


    # --- main loss ---
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        beta = getattr(self, "beta", 0.1)
        dev  = self.accelerator.device

        # current model margin WITH grad
        ch, rj = self._pair_scores_ids(model, inputs, need_grad=True)
        ch, rj = ch.to(dev), rj.to(dev)
        
        # reference margin NO grad for the main loss
        with torch.no_grad():
            ch_ref, rj_ref = self._pair_scores_ids(self.ref_model, inputs, need_grad=False)
            ch_ref, rj_ref = ch_ref.to(dev), rj_ref.to(dev)
        
        # δ selection
        idx    = inputs["idx"].to(dev).long()
        idx_ch = (2 * idx).to(dev)
        idx_rj = (2 * idx + 1).to(dev)

        # Ensure delta params are on correct device
        self.delta_raw = self.delta_raw.to(dev)
        delta_vec = (self.delta_cap * torch.tanh(self.delta_raw)).to(dev)

        delta_chosen   = delta_vec.index_select(0, idx_ch)
        delta_rejected = delta_vec.index_select(0, idx_rj)
        
        # DPO + δ
        margin   = (ch - ch_ref) - (rj - rj_ref) + (delta_chosen - delta_rejected)
        dpo_loss = -F.logsigmoid(self.beta * margin).mean()
        
        # Null penalty depends ONLY on δ
        null_pen = self._AT_delta(inputs, delta_chosen, delta_rejected)
        
        # Anti-collapse (optional)
        amp_pen = -(delta_chosen.pow(2).mean() + delta_rejected.pow(2).mean()) * 0.5
        
        loss = dpo_loss + self.lambda_null * null_pen + self.lambda_amp * amp_pen

        if return_outputs:
            logs = {
                "aux/dpo": dpo_loss.detach(),
                "aux/null": (self.lambda_null * null_pen).detach(),
                "aux/amp": (self.lambda_amp * amp_pen).detach(),
                "aux/delta_mean_abs": torch.cat([delta_chosen, delta_rejected]).abs().mean().detach(),
                "aux/delta_max_abs": torch.cat([delta_chosen, delta_rejected]).abs().max().detach(),
            }
            # --- δ grad debug ---
            if self.delta_raw.grad is not None:
                logs["aux/delta_grad_norm"] = float(self.delta_raw.grad.detach().norm().cpu())
            else:
                logs["aux/delta_grad_norm"] = 0.0
            return loss, logs
        return loss

    def training_step(self, model, inputs, num_items_in_batch=None):
        model.train()
        loss = self.compute_loss(model, inputs)
        self.accelerator.backward(loss)

        step_idx = self.state.global_step + 1
        do_log = (step_idx % self.args.logging_steps == 0)
        do_step = (step_idx % self.args.gradient_accumulation_steps == 0) or (self.args.gradient_accumulation_steps == 1)

        if do_step and do_log:
            self.log({"grad_norm": self._grad_norm(model)})

        return loss.detach() / self.args.gradient_accumulation_steps


    @torch.no_grad()
    def _maybe_eval_and_log(self, global_step):
        if self.args.eval_strategy == "steps" and (global_step % self.args.eval_steps == 0):
            self.model.eval()
            metrics = self.evaluate()
            self.model.train()
            if metrics:
                # mirror Trainer.log
                self.log(metrics)

    def manual_train(self):
        """
        Minimal, NCCL-safe manual loop:
          - uses self.compute_loss (current model needs grad, ref is no-grad except inside _AT_delta),
          - supports gradient accumulation, grad clipping, LR scheduler,
          - optional eval/logging on steps.
        Call this instead of Trainer.train().
        """
        accelerator = self.accelerator
        model = self.model
        args = self.args
    
        # dataloader / optimizer / scheduler
        train_loader = self.get_train_dataloader()
        if self.optimizer is None:
            self.create_optimizer()
        if self.lr_scheduler is None:
            # same schedule HF would build
            num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
            max_steps = args.max_steps if args.max_steps > 0 else num_update_steps_per_epoch * args.num_train_epochs
            self.create_scheduler(num_training_steps=max_steps)
    
        optimizer = self.optimizer
        scheduler = self.lr_scheduler
    
        # put models in proper modes
        model.train()
        if self.ref_model is not None:
            self.ref_model.eval()   # we never optimize ref
    
        # writers (main process only)
        train_writer = _JsonlWriter(os.path.join(args.output_dir, "train_metrics.jsonl")) \
                       if accelerator.is_main_process else None
        eval_writer  = _JsonlWriter(os.path.join(args.output_dir, "eval_metrics.jsonl")) \
                       if accelerator.is_main_process else None
    
        # state counters
        global_step = 0
        completed_steps = 0
        num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps)
        total_train_steps = args.max_steps if args.max_steps > 0 else int(args.num_train_epochs * num_update_steps_per_epoch)
    
        # progress bar
        pbar = tqdm(
            total=total_train_steps,
            disable=not accelerator.is_local_main_process,
            desc="Manual training"
        )
    
        running_loss = 0.0
    
        for epoch in range(int(args.num_train_epochs)):
            # IMPORTANT: do not re-create sampler here; HuggingFace handles DDP sampler reseed internally.
            for step, inputs in enumerate(train_loader):
                # move to device like Trainer does
                inputs = self._prepare_inputs(inputs)
    
                # gradient accumulation context
                with accelerator.accumulate(model):
                    # compute loss (our compute_loss can return (loss, logs) if return_outputs=True)
                    out = self.compute_loss(model, inputs, return_outputs=True)
                    if isinstance(out, tuple):
                        loss, logs = out
                    else:
                        loss, logs = out, {}
    
                    # scale by accumulation (Trainer does this in return)
                    loss_to_backprop = loss / args.gradient_accumulation_steps
    
                    accelerator.backward(loss_to_backprop)
    
                    # collect δ grad AFTER backward
                    delta_gn = 0.0
                    try:
                        if getattr(self, "delta_raw", None) is not None and self.delta_raw.grad is not None:
                            delta_gn = float(self.delta_raw.grad.detach().norm().item())
                    except Exception:
                        pass
    
                    # step only when gradients are synced across processes
                    if accelerator.sync_gradients:
                        if args.max_grad_norm is not None and args.max_grad_norm > 0:
                            accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                        optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad(set_to_none=True)
                        completed_steps += 1
                        global_step += 1
                        pbar.update(1)
    
                        # ---- TRAIN LOGGING ----
                        if (args.logging_strategy == "steps") and (global_step % args.logging_steps == 0):
                            log_dict = {
                                "step": global_step,
                                "epoch": epoch + (step + 1) / max(1, len(train_loader)),
                                "lr": scheduler.get_last_lr()[0] if scheduler is not None else None,
                                "loss": float(loss.detach().cpu()),
                                "time": time.time(),
                                **{k: (float(v) if not isinstance(v, (int, float)) else float(v)) for k, v in logs.items()},
                                "aux/delta_grad_norm": delta_gn,
                            }
                            # to HF Trainer history (and TB/W&B if enabled)
                            self.log({k: v for k, v in log_dict.items() if k not in ("time",)})
                            # to JSONL
                            if train_writer: train_writer.write(log_dict)
                            # console
                            if accelerator.is_main_process:
                                pbar.write(f"[train] {_fmt_logs({k:v for k,v in log_dict.items() if k!='time'})}")
                                pbar.refresh()
                            # persist trainer state snapshot
                            try:
                                self.state.save_to_json(os.path.join(args.output_dir, "trainer_state.json"))
                            except Exception:
                                pass
    
                        # ---- EVALUATION ----
                        if self.args.eval_strategy == "steps" and (global_step % self.args.eval_steps == 0):
                            # per-example rows go here if you want them:
                            eval_rows_path = os.path.join(args.output_dir, f"eval_step{global_step}.jsonl")
                            metrics = self.manual_evaluate(save_jsonl=eval_rows_path)
                            if metrics:
                                self.log(metrics)  # to HF history
                                if eval_writer:
                                    eval_writer.write({"step": global_step, "time": time.time(), **metrics})
    
                        if args.max_steps > 0 and global_step >= args.max_steps:
                            break
                    # end if sync_gradients
    
                    running_loss += loss.detach().float().item()
    
                # end accumulate
    
            if args.max_steps > 0 and global_step >= args.max_steps:
                break
    
        pbar.close()
    
        # final eval if strategy="epoch"
        if args.eval_strategy == "epoch":
            metrics = self.manual_evaluate(save_jsonl=os.path.join(args.output_dir, "eval_final.jsonl"))
            if metrics:
                self.log(metrics)
                if eval_writer:
                    eval_writer.write({"step": global_step, "time": time.time(), **metrics})
    
        # close writers
        if train_writer: train_writer.close()
        if eval_writer:  eval_writer.close()
    
        # mimic Trainer return value minimalism
        return {"global_step": global_step, "train_loss": running_loss / max(1, global_step)}


    @torch.no_grad()
    def manual_evaluate(self, eval_dataset=None, max_batches=None, save_jsonl=None):
        """
        Pure inference eval:
          - Computes response-only log-probs for model and ref
          - margin = (ch - ch_ref) - (rj - rj_ref)
          - acc = sigmoid(margin) > 0.5
          - Uses accelerator.gather_for_metrics to aggregate across processes
        Never builds a graph, never touches ref_model grads.
        """
        self._ensure_ref_on_device()
        model        = self.model
        ref          = self.ref_model
        accelerator  = self.accelerator
        dev          = accelerator.device
    
        model.eval()
        ref.eval()
    
        ds = eval_dataset if eval_dataset is not None else self.eval_dataset
        if ds is None:
            return {}
    
        dl = self.get_eval_dataloader(ds)
    
        all_acc     = []
        all_margin  = []
        all_ch      = []
        all_rj      = []
        all_ch_ref  = []
        all_rj_ref  = []
    
        # optional writer
        writer = None
        if save_jsonl and accelerator.is_main_process:
            import json, os
            os.makedirs(os.path.dirname(save_jsonl) or ".", exist_ok=True)
            writer = open(save_jsonl, "w", encoding="utf-8")
    
        pbar = tqdm(
            disable=not accelerator.is_local_main_process,
            total=len(dl),
            desc="Eval"
        )
    
        for step, batch in enumerate(dl):
            if max_batches is not None and step >= max_batches:
                break
    
            # move to device the same way Trainer does
            batch = self._prepare_inputs(batch)
    
            # response-only scores (no grad)
            ch_m, rj_m = self._pair_scores_ids(model, batch, need_grad=False)   # [B], [B]
            ch_r, rj_r = self._pair_scores_ids(ref,   batch, need_grad=False)   # [B], [B]
    
            # ---- CRITICAL: force everything to the same device ----
            ch_m  = ch_m.to(dev)
            rj_m  = rj_m.to(dev)
            ch_r  = ch_r.to(dev)
            rj_r  = rj_r.to(dev)
    
            # margin & accuracy
            margin = (ch_m - ch_r) - (rj_m - rj_r)                # [B]
            acc    = (torch.sigmoid(margin) > 0.50).to(torch.float32)  # [B]
    
            # gather across processes safely
            g_acc    = accelerator.gather_for_metrics(acc)
            g_margin = accelerator.gather_for_metrics(margin)
            g_ch     = accelerator.gather_for_metrics(ch_m)
            g_rj     = accelerator.gather_for_metrics(rj_m)
            g_chref  = accelerator.gather_for_metrics(ch_r)
            g_rjref  = accelerator.gather_for_metrics(rj_r)
    
            all_acc.append(g_acc.cpu())
            all_margin.append(g_margin.cpu())
            all_ch.append(g_ch.cpu())
            all_rj.append(g_rj.cpu())
            all_ch_ref.append(g_chref.cpu())
            all_rj_ref.append(g_rjref.cpu())
    
            # optional per-example dumps (only on main process)
            if writer is not None and accelerator.is_main_process:
                import json
                for i in range(g_acc.numel()):
                    row = {
                        "acc":         float(g_acc[i].item()),
                        "margin":      float(g_margin[i].item()),
                        "chosen":      float(g_ch[i].item()),
                        "reject":      float(g_rj[i].item()),
                        "chosen_ref":  float(g_chref[i].item()),
                        "reject_ref":  float(g_rjref[i].item()),
                    }
                    writer.write(json.dumps(row, ensure_ascii=False) + "\n")
    
            pbar.update(1)
    
        pbar.close()
        if writer is not None:
            writer.close()
    
        # concat and compute metrics
        if not all_acc:
            return {}
    
        acc     = torch.cat(all_acc)
        margin  = torch.cat(all_margin)
        ch      = torch.cat(all_ch)
        rj      = torch.cat(all_rj)
        ch_ref  = torch.cat(all_ch_ref)
        rj_ref  = torch.cat(all_rj_ref)
    
        metrics = {
            "eval/accuracy":         float(acc.mean().item()),
            "eval/margin_mean":      float(margin.mean().item()),
            "eval/margin_std":       float(margin.std(unbiased=False).item()),
            "eval/chosen_mean":      float(ch.mean().item()),
            "eval/reject_mean":      float(rj.mean().item()),
            "eval/chosen_ref_mean":  float(ch_ref.mean().item()),
            "eval/reject_ref_mean":  float(rj_ref.mean().item()),
        }
    
        # mirror Trainer behavior
        if accelerator.is_main_process:
            self.log(metrics)
            msg = " | ".join(f"{k}={v:.6f}" for k, v in metrics.items())
            print(f"[eval] {msg}")
    
        return metrics


    # Override HF's evaluate to call our safe path
    def evaluate(self, eval_dataset=None, **kwargs):
        save_jsonl = kwargs.pop("save_jsonl", None)
        max_batches = kwargs.pop("max_batches", None)
        return self.manual_evaluate(eval_dataset=eval_dataset, max_batches=max_batches, save_jsonl=save_jsonl)




In [None]:
from trl.trainer.dpo_trainer import DataCollatorForPreference
import torch

class AuxDataCollatorForPreference(DataCollatorForPreference):
    def torch_call(self, examples):
        batch = super().torch_call(examples)
        if "idx" in examples[0]:
            batch["idx"] = torch.tensor([ex["idx"] for ex in examples], dtype=torch.long)
        return batch


from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    use_fast=True,
    local_files_only=True
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

collator = AuxDataCollatorForPreference(pad_token_id=tokenizer.pad_token_id)

In [None]:
from datasets import Dataset, load_dataset
import numpy as np

def attach_idx_column(ds: Dataset) -> Dataset:
    # 0..N-1, contiguous, stable across shuffles
    return ds.map(lambda _, idx: {"idx": idx}, with_indices=True, batched=False)

# --- params ---
N = 10000
SEED = 2025

# raw_train = load_dataset("json", data_files = "hhrlhf_train.jsonl")["train"].select(range(10000))
# raw_eval = load_dataset("json", data_files = "hhrlhf_test.jsonl")["train"].select(range(5000))

# Load whole file, then downsample to N
raw_ds = load_dataset("json", data_files="mmlu_pro_prefs/test_pref_single.jsonl")["train"]
#raw_train = ds_all.shuffle(seed=SEED).select(range(min(N, ds_all.num_rows)))
#raw_eval = load_dataset("json", data_files="mmlu_pro_prefs/test_pref_single.jsonl")["train"]

#.select(range(min(N, ds_all.num_rows)))
# train_ds = attach_idx_column(raw_train)
# eval_ds = attach_idx_column(raw_eval)
# Add idx
split = raw_ds.train_test_split(test_size=0.2, seed=SEED)
train_ds = attach_idx_column(split["train"])
eval_ds  = attach_idx_column(split["test"])


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
# from peft import PeftModel
import torch, os

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    device_map="auto",     # optional: auto place on GPU(s)
    dtype="auto",     # or torch.float16 / bfloat16
    local_files_only=True
)

# Identify total layers
layers = model.model.layers
num_layers = len(layers)
half = 0

print(f"Freezing first {half} layers out of {num_layers}")

# Freeze first half
for layer in layers[:half]:
    for param in layer.parameters():
        param.requires_grad = False




In [None]:
# def get_lora_target_modules(model, target_layer_ids, module_suffixes=["v_proj", "o_proj", "k_proj", "q_proj"]):
#         """
#         Extracts target module names for LoRA injection from selected GPT2 transformer layers.
#         """
#         targets = set()
#         for name, _ in model.named_modules():
#             for layer_id in target_layer_ids:
#                 if f"layers.{layer_id}." in name and any(name.endswith(suffix) for suffix in module_suffixes):
#                     targets.add(name)
#         return list(targets)

In [None]:
# from peft import LoraConfig, get_peft_model

# layers_to_include = [15]
# target_modules = get_lora_target_modules(model, layers_to_include)
        
# lora_config = LoraConfig(
#                 r=4,
#                 lora_alpha=8,
#                 lora_dropout=0.1,
#                 target_modules=target_modules,
#                 bias="none",
#                 )       

# print("Applying LoRA adapters to the main model...")
# peft_model = get_peft_model(model, lora_config)

In [None]:
from trl import DPOConfig, DPOTrainer

aux_args = DPOConfig(
    output_dir="./llama1_hhrlhf/",
    per_device_train_batch_size=4,
    num_train_epochs=1,
    eval_strategy="steps",
    eval_steps=200,
    logging_strategy="steps",
    logging_steps=100,
    remove_unused_columns=False,
    report_to=[],
    fp16=False, 
    bf16=True,
    padding_value=128001
,
    seed = 42,
)

aux_trainer = AuxDPOTrainer(   # ← the class we built earlier
    model=model,
    ref_model=None,                 # will freeze a copy automatically
    args=aux_args,
    train_dataset=train_ds,  # helpful train (with idx)
    eval_dataset=eval_ds,
    processing_class=tokenizer,
    data_collator=collator,
    # AuxDPO knobs
    lambda_null=0.0001,      # weight on ||A^T δ||^2
    lambda_amp=1,      # small negative L2 on δ_batch (anti-collapse)
    delta_cap=1.0,        # |δ| bound via tanh
    aux_lr=5e-3,          # LR for δ parameter
)



In [None]:
# trainer = DPOTrainer(   # ← the class we built earlier
#     model=model,
#     ref_model=None,                 # will freeze a copy automatically
#     args=aux_args,
#     train_dataset=train_ds,  # helpful train (with idx)
#     eval_dataset=eval_ds,
#     processing_class=tokenizer,
# )

In [None]:
aux_trainer.manual_train()

In [None]:
tokenizer.pad_token_id

In [None]:
for entry in aux_trainer.state.log_history:
    print(entry)