In [1]:
import torch
import triton
import triton.language as tl
from torch.nn import functional as F

MAX_FUSED_SIZE = 65536 // 2

In [2]:
# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L11
inputs_chosen = torch.randn(10000, 5120, dtype=torch.bfloat16).cuda()
inputs_rejected = torch.randn(5000, 5120, dtype=torch.bfloat16).cuda()
refs_chosen = torch.randn(10000, 5120, dtype=torch.bfloat16).cuda()
refs_rejected = torch.randn(5000, 5120, dtype=torch.bfloat16).cuda()

# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L29
targets_chosen = torch.randint(low=0, high=151936, size=(10000,)).cuda()
targets_rejected = torch.randint(low=0, high=151936, size=(5000,)).cuda()

# assumed packing 5 sequences
num_seqs = 5

In [3]:
inputs_weight = torch.nn.Linear(5120, 151936).cuda()
refs_weight = torch.nn.Linear(5120, 151936).cuda()

In [4]:
def get_sum_logprob(inputs, targets, weight, chunk_size=512, ignore_index=-100):
    sum_log_prob = 0.0

    BT, H = inputs.shape
    # V = weight_inputs.weight.shape[0]
    # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
    
    # inc_factor = triton.cdiv(V, H)  # (V + H - 1) // H
    # chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor))
    for start_idx in range(0, BT, chunk_size):
        end_idx = min(start_idx + chunk_size, BT)
        _inputs_chunk = inputs[start_idx:end_idx]
        _targets_chunk = targets[start_idx:end_idx]
    
        logits = _inputs_chunk.to(torch.float32) @ weight.T
        log_probs_chunk = F.log_softmax(logits.float(), dim=-1)
        
        loss_mask = _targets_chunk != ignore_index
        label_chunk = torch.where(loss_mask, _targets_chunk, 0)
        per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
        log_prob = (per_token_logps * loss_mask).sum(-1)
        sum_log_prob += log_prob
    
    return sum_log_prob / (targets != ignore_index).sum()

In [5]:
def get_logprobs(
    inputs_chosen, 
    inputs_rejected, 
    refs_chosen, 
    refs_rejected, 
    targets_chosen, 
    targets_rejected, 
    inputs_weight, 
    refs_weight,
):
    get_sum_log_prob = torch.compile(get_sum_logprob)

    torch._dynamo.mark_dynamic(inputs_chosen, 0)
    torch._dynamo.mark_dynamic(inputs_rejected, 0)
    torch._dynamo.mark_dynamic(refs_chosen, 0)
    torch._dynamo.mark_dynamic(refs_rejected, 0)

    logpprob_inputs_chosen = get_sum_logprob(inputs_chosen, targets_chosen, inputs_weight)
    logpprob_inputs_rejected = get_sum_logprob(inputs_rejected, targets_rejected, inputs_weight)

    logpprob_refs_chosen = get_sum_logprob(refs_chosen, targets_chosen, refs_weight)
    logpprob_refs_rejected = get_sum_logprob(refs_rejected, targets_rejected, refs_weight)

    return logpprob_inputs_chosen, logpprob_inputs_rejected, logpprob_refs_chosen, logpprob_refs_rejected

In [6]:
out = get_logprobs(
    inputs_chosen=inputs_chosen,
    inputs_rejected=inputs_rejected,
    refs_chosen=refs_chosen,
    refs_rejected=refs_rejected,
    targets_chosen=targets_chosen,
    targets_rejected=targets_rejected,
    inputs_weight=inputs_weight.weight,
    refs_weight=refs_weight.weight,
)

In [7]:
logpprob_inputs_chosen, logpprob_inputs_rejected, logpprob_refs_chosen, logpprob_refs_rejected = out

In [8]:
beta = 0.1
chosen_logratios = logpprob_inputs_chosen - logpprob_refs_chosen
rejected_logratios = logpprob_inputs_rejected - logpprob_refs_rejected

chosen_rewards = beta * chosen_logratios
rejected_rewards = beta * rejected_logratios
logits_diff = beta * (chosen_logratios - rejected_logratios)
loss = -F.logsigmoid(logits_diff) / num_seqs

In [9]:
logpprob_inputs_chosen, loss

(tensor(-12.0930, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.1386, device='cuda:0', grad_fn=<DivBackward0>))

In [10]:
loss.backward()

In [11]:
!nvidia-smi

Sun Nov 30 08:12:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          Off |   00000000:09:00.0 Off |                    0 |
| N/A   55C    P0            645W /  700W |   37893MiB /  81559MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          Off |   00