In [1]:
import torch
from torch.nn import functional as F

class ChunkedDPO(torch.autograd.Function):
    
    @staticmethod
    @torch.no_grad()
    def forward(
        ctx,
        inputs_chosen, 
        inputs_rejected, 
        refs_chosen, 
        refs_rejected, 
        targets_chosen, 
        targets_rejected, 
        inputs_weight, 
        refs_weight,
        compiled=True, 
        chunk_size=128, 
        ignore_index=-100,
    ):

        @torch.no_grad()
        def compute_loss(input_chunk, weight, target_chunk, ignore_index=100):
            logits = input_chunk @ weight.T
            loss_mask = target_chunk != ignore_index
            label_chunk = torch.where(loss_mask, target_chunk, 0)
        
            logits_y = logits.gather(1, label_chunk.unsqueeze(1)).squeeze(1)
            lse = torch.logsumexp(logits, dim=1)
            per_token_logps = (logits_y - lse) * loss_mask
            return per_token_logps.sum()
        
        @torch.no_grad()
        def get_sum_logprob(inputs, targets, weight, chunk_size=512, ignore_index=-100):
            grad_inputs = []
            grad_weight = torch.zeros_like(weight)
            sum_log_prob = torch.zeros((), device=inputs.device)
            BT, H = inputs.shape
        
            for start_idx in range(0, BT, chunk_size):
                end_idx = min(start_idx + chunk_size, BT)
                _inputs_chunk = inputs[start_idx:end_idx]          
                _targets_chunk = targets[start_idx:end_idx] 
        
                (chunk_grad_input, chunk_grad_weight), per_token_logps_sum = torch.func.grad_and_value(
                    compute_loss, argnums=(0,1))(_inputs_chunk, weight, _targets_chunk)
                grad_weight.add_(chunk_grad_weight)
                sum_log_prob.add_(per_token_logps_sum)
                grad_inputs.append(chunk_grad_input)

            grad_inputs = torch.cat(grad_inputs, dim=0)
            return grad_inputs, grad_weight, sum_log_prob
        
        @torch.no_grad()
        def get_sum_logprob_ref(inputs, targets, weight, chunk_size=512, ignore_index=-100):
            sum_log_prob = torch.zeros((), device=inputs.device)
            BT, H = inputs.shape
        
            for start_idx in range(0, BT, chunk_size):
                end_idx = min(start_idx + chunk_size, BT)
                _inputs_chunk = inputs[start_idx:end_idx]          
                _targets_chunk = targets[start_idx:end_idx] 
        
                per_token_logps_sum = compute_loss(_inputs_chunk, weight, _targets_chunk,)
                sum_log_prob.add_(per_token_logps_sum)
        
            return sum_log_prob
    
        torch._dynamo.maybe_mark_dynamic(inputs_chosen, 0)
        torch._dynamo.maybe_mark_dynamic(inputs_rejected, 0)
        torch._dynamo.maybe_mark_dynamic(refs_chosen, 0)
        torch._dynamo.maybe_mark_dynamic(refs_rejected, 0)
        torch._dynamo.maybe_mark_dynamic(targets_chosen, 0)
        torch._dynamo.maybe_mark_dynamic(targets_rejected, 0)

        get_sum_logprob = torch.compile(get_sum_logprob)
        get_sum_logprob_ref = torch.compile(get_sum_logprob_ref)

        grad_inputs_chosen, grad_weight_chosen, logprob_inputs_chosen = get_sum_logprob(
            inputs_chosen, 
            targets_chosen, 
            inputs_weight,
            chunk_size=chunk_size,
            ignore_index=ignore_index,
        )
        grad_inputs_rejected, grad_weight_rejected, logprob_inputs_rejected = get_sum_logprob(
            inputs_rejected, 
            targets_rejected, 
            inputs_weight,
            chunk_size=chunk_size,
            ignore_index=ignore_index,
        )
    
        logprob_refs_chosen = get_sum_logprob_ref(refs_chosen, targets_chosen, refs_weight)
        logprob_refs_rejected = get_sum_logprob_ref(refs_rejected, targets_rejected, refs_weight)

        ctx.save_for_backward(
            grad_inputs_chosen, 
            grad_inputs_rejected,
            grad_weight_chosen,
            grad_weight_rejected,
        )
        
        return logprob_inputs_chosen, logprob_inputs_rejected, logprob_refs_chosen, logprob_refs_rejected

    @staticmethod
    @torch.no_grad()
    def backward(ctx, grad_logprob_chosen, grad_logprob_rejected, grad_logprob_ref_chosen, grad_logprob_ref_rejected):
        (
            grad_inputs_chosen,
            grad_inputs_rejected,
            grad_weight_chosen,
            grad_weight_rejected,
        ) = ctx.saved_tensors

        grad_inputs_chosen_out = grad_inputs_chosen * grad_logprob_chosen
        grad_inputs_rejected_out = grad_inputs_rejected * grad_logprob_rejected
        
        # Gradients for inputs_weight (accumulate from both chosen and rejected)
        grad_inputs_weight = (
            grad_weight_chosen * grad_logprob_chosen + 
            grad_weight_rejected * grad_logprob_rejected
        )
        
        # Return gradients in same order as forward inputs
        # (inputs_chosen, inputs_rejected, refs_chosen, refs_rejected, 
        #  targets_chosen, targets_rejected, inputs_weight, refs_weight,
        #  compiled, chunk_size, ignore_index)
        return (
            grad_inputs_chosen_out,   # inputs_chosen
            grad_inputs_rejected_out, # inputs_rejected
            None,                      # refs_chosen (no grad needed - frozen)
            None,                      # refs_rejected (no grad needed - frozen)
            None,                      # targets_chosen (integer indices)
            None,                      # targets_rejected (integer indices)
            grad_inputs_weight,        # inputs_weight
            None,                      # refs_weight (no grad needed - frozen)
            None,                      # compiled
            None,                      # chunk_size
            None,                      # ignore_index
        )

In [2]:
torch.cuda.memory._record_memory_history(
   max_entries=100000
)

# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L11
inputs_chosen = torch.randn(10000, 5120, dtype=torch.bfloat16).cuda()
inputs_rejected = torch.randn(5000, 5120, dtype=torch.bfloat16).cuda()
refs_chosen = torch.randn(10000, 5120, dtype=torch.bfloat16).cuda()
refs_rejected = torch.randn(5000, 5120, dtype=torch.bfloat16).cuda()

# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L29
targets_chosen = torch.randint(low=0, high=151936, size=(10000,)).cuda()
targets_rejected = torch.randint(low=0, high=151936, size=(5000,)).cuda()

# assumed packing 1 sequences
# liger divide by batch size // 2
num_seqs = 1

inputs_weight = torch.nn.Linear(5120, 151936).cuda()
refs_weight = torch.nn.Linear(5120, 151936).cuda()

out = ChunkedDPO.apply(
    inputs_chosen.float(),
    inputs_rejected.float(),
    refs_chosen.float(),
    refs_rejected.float(),
    targets_chosen,
    targets_rejected,
    inputs_weight.weight,
    refs_weight.weight,
)
logpprob_inputs_chosen, logpprob_inputs_rejected, logpprob_refs_chosen, logpprob_refs_rejected = out

beta = 0.1
chosen_logratios = logpprob_inputs_chosen - logpprob_refs_chosen
rejected_logratios = logpprob_inputs_rejected - logpprob_refs_rejected

chosen_rewards = beta * chosen_logratios
rejected_rewards = beta * rejected_logratios
logits_diff = beta * (chosen_logratios - rejected_logratios)
loss = -F.logsigmoid(logits_diff) / num_seqs
loss.backward()

torch.cuda.memory._dump_snapshot("chunk-pytorch-v2.pickle")

In [None]:
first_grad = inputs_weight.weight.grad.clone()
inputs_weight.weight.grad = None

In [None]:
first_grad

In [None]:
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
from torch.nn.utils.rnn import pad_sequence

def pad_dim1(tensors, padding_value=0):
    tensors = [t.unsqueeze(0) if t.dim() == 2 else t for t in tensors]

    max_len = max(t.shape[1] for t in tensors)

    padded = []
    for t in tensors:
        pad_len = max_len - t.shape[1]
        if pad_len > 0:
            t = F.pad(t, (0, 0, 0, pad_len), value=padding_value)
        padded.append(t)

    return torch.cat(padded, dim=0)

In [None]:
targets = pad_sequence([targets_chosen, targets_rejected], batch_first=True, padding_value=-100).cuda()
inputs_padded = pad_dim1([inputs_chosen, inputs_rejected])
refs_padded = pad_dim1([refs_chosen, refs_rejected])

In [None]:
liger_loss = LigerFusedLinearDPOLoss()

In [None]:
out = liger_loss(
    inputs_weight.weight,
    inputs_padded.to(torch.float32),
    targets,
    ref_input=refs_padded.to(torch.float32),
    ref_weight=refs_weight.weight
)
print(out)

In [None]:
out[0].backward()

In [None]:
assert torch.allclose(loss, out[0], atol=0.125, rtol=0)

In [None]:
assert torch.allclose(first_grad, inputs_weight.weight.grad, atol=0.125, rtol=0)

In [None]:
inputs_weight.weight.grad