In [19]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Any, Optional
import math
import traceback

# -------------------------
# Config (adjust as needed)
# -------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-6
MAX_LENGTH = 512
GRAD_CLIP_NORM = 1.0
BATCH_LOG_EVERY = 10

# -------------------------
# Safe get_batch_logps
# -------------------------
def get_batch_logps(model, input_ids: torch.Tensor, attention_mask: torch.Tensor, response_start_indices: List[int]) -> torch.Tensor:
    """
    Returns sum of log probabilities over response tokens for each batch element.
    input_ids: LongTensor [B, L]
    attention_mask: LongTensor [B, L]
    response_start_indices: list[int] length B; index (0-based) of first response token in unshifted input_ids
    """
    model_output = model(input_ids=input_ids, attention_mask=attention_mask)
    # logits: [B, L, V]
    logits = model_output.logits

    # Shift logits and targets for causal LM: predict token t from logits at t-1
    logits = logits[:, :-1, :].float()  # [B, L-1, V] as float32 for numerical stability
    targets = input_ids[:, 1:]           # [B, L-1]
    attn_shifted = attention_mask[:, 1:] # [B, L-1]

    # log-softmax over vocab
    log_probs = F.log_softmax(logits, dim=-1)  # [B, L-1, V]

    # gather token log probs
    token_log_probs = torch.gather(log_probs, dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)  # [B, L-1]

    # Build response mask in the full (unshifted) space, then shift it to align with token_log_probs
    B, L = input_ids.shape
    full_indices = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)  # [B, L]
    start_tensor = torch.tensor(response_start_indices, device=input_ids.device).unsqueeze(1)  # [B, 1]

    response_mask_full = (full_indices >= start_tensor).float()  # 1 for response tokens, 0 otherwise; shape [B, L]

    # Shift mask to align with token_log_probs/targets (since they are input_ids[:,1:] etc.)
    response_mask_shifted = response_mask_full[:, 1:]  # [B, L-1]

    # Combine with attention mask (only keep non-padding tokens)
    final_mask = (response_mask_shifted * attn_shifted).to(token_log_probs.dtype)  # float

    # Zero-out non-response positions (keeping negative log_probs for true tokens)
    masked_token_log_probs = token_log_probs * final_mask  # [B, L-1]

    # Sum across sequence length to produce a scalar log-prob per example
    batch_logps = masked_token_log_probs.sum(dim=-1)  # [B]

    return batch_logps  # dtype: float32 on device

# -------------------------
# Example Dataset & collate_fn
# -------------------------
class PairwisePreferenceDataset(Dataset):
    """
    Expects list of dicts with keys: 'prompt', 'chosen', 'rejected'
    """
    def __init__(self, file_path):
        self.data = []
        with open(file_path, "r") as f:
            for line in f:
                if line.strip():                     # skip empty lines
                    self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch: List[Dict[str, str]]) -> Dict[str, Any]:
    """
    Returns a dict with lists; trainer will tokenize using the tokenizer so we keep raw strings here.
    """
    prompts = [item['prompt'] for item in batch]
    chosen = [p + "\n" + item['chosen'] for p, item in zip(prompts, batch)]
    rejected = [p + "\n" + item['rejected'] for p, item in zip(prompts, batch)]
    return {"prompt": prompts, "chosen": chosen, "rejected": rejected}

# -------------------------
# DPO Trainer
# -------------------------
class CustomDPOTrainer:
    def __init__(
        self,
        policy_model: AutoModelForCausalLM,
        ref_model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        lr: float = LEARNING_RATE,
        beta: float = 1.0,
        device: torch.device = DEVICE,
        max_length: int = MAX_LENGTH,
        optimizer: Optional[torch.optim.Optimizer] = None,
    ):
        self.policy_model = policy_model.to(device)
        self.ref_model = ref_model.to(device)
        self.tokenizer = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.device = device
        self.max_length = max_length
        self.beta = beta

        # freeze ref model
        self.ref_model.eval()
        for p in self.ref_model.parameters():
            p.requires_grad = False

        # optimizer
        self.optimizer = optimizer if optimizer is not None else optim.AdamW(self.policy_model.parameters(), lr=lr)

    def train_epoch(self, dataloader: DataLoader, clip_norm: float = GRAD_CLIP_NORM, scheduler=None):
        self.policy_model.train()
        total_loss = 0.0
        total_batches = 0

        for step, batch in enumerate(dataloader):
            try:
                prompts = batch["prompt"]
                chosen_texts = batch["chosen"]
                rejected_texts = batch["rejected"]

                # Tokenize chosen/rejected using same settings so prompt length computation is consistent
                chosen_enc = self.tokenizer(chosen_texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device)
                rejected_enc = self.tokenizer(rejected_texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device)
                prompt_enc = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(self.device)

                # Compute response start indices from prompt_enc attention masks (unshifted)
                # This is the index (0-based) of the first response token within each sequence.
                response_start_indices = (prompt_enc.attention_mask.sum(dim=1)).tolist()  # list of ints

                # Get policy log probs
                policy_chosen_logps = get_batch_logps(self.policy_model, chosen_enc.input_ids, chosen_enc.attention_mask, response_start_indices)
                policy_rejected_logps = get_batch_logps(self.policy_model, rejected_enc.input_ids, rejected_enc.attention_mask, response_start_indices)

                # Get reference log probs (no grad)
                with torch.no_grad():
                    ref_chosen_logps = get_batch_logps(self.ref_model, chosen_enc.input_ids, chosen_enc.attention_mask, response_start_indices)
                    ref_rejected_logps = get_batch_logps(self.ref_model, rejected_enc.input_ids, rejected_enc.attention_mask, response_start_indices)

                # Compute log ratios
                chosen_log_ratio = policy_chosen_logps - ref_chosen_logps   # [B]
                rejected_log_ratio = policy_rejected_logps - ref_rejected_logps  # [B]

                logits = self.beta * (chosen_log_ratio - rejected_log_ratio)  # [B]
                # Numerically stable: convert to float32
                loss = -F.logsigmoid(logits.to(torch.float32)).mean()

                # Sanity checks
                if torch.isnan(loss) or torch.isinf(loss):
                    # print diagnostics
                    print("NaN/Inf loss detected. Diagnostics:")
                    print("policy_chosen_logps:", policy_chosen_logps)
                    print("policy_rejected_logps:", policy_rejected_logps)
                    print("ref_chosen_logps:", ref_chosen_logps)
                    print("ref_rejected_logps:", ref_rejected_logps)
                    raise ValueError("Loss is NaN or Inf")

                # Backprop
                self.optimizer.zero_grad()
                loss.backward()
                # optional grad clip
                if clip_norm is not None:
                    torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(), clip_norm)
                self.optimizer.step()
                if scheduler is not None:
                    scheduler.step()

                total_loss += loss.item()
                total_batches += 1

                if (step + 1) % BATCH_LOG_EVERY == 0:
                    avg = total_loss / max(1, total_batches)
                    print(f"Step {step+1} | avg loss: {avg:.6f}")

            except Exception as e:
                print("Exception during training step:", e)
                traceback.print_exc()
                # don't crash full training; skip this batch
                continue

        avg_loss = total_loss / max(1, total_batches)
        return avg_loss

# -------------------------
# Example usage
# -------------------------
if __name__ == "__main__":

    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")            # replace with your tokenizer
    policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")       # replace with your policy model
    ref = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")          # replace with your ref model (copy of policy)

    dataset = PairwisePreferenceDataset('financial_rewards.jsonl')
    loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

    trainer = CustomDPOTrainer(policy, ref, tokenizer)
    avg_loss = trainer.train_epoch(loader)
    print("Avg epoch loss:", avg_loss)


Step 10 | avg loss: 0.070633
Step 20 | avg loss: 0.035388
Step 30 | avg loss: 0.023592
Step 40 | avg loss: 0.017696
Step 50 | avg loss: 0.014157
Avg epoch loss: 0.013355556184959698
