In [1]:
import torch
import triton
import triton.language as tl
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss

MAX_FUSED_SIZE = 65536 // 2

In [2]:
# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L11
inputs = torch.randn(2, 10000, 5120, dtype=torch.bfloat16).cuda()
refs = torch.randn(2, 10000, 5120, dtype=torch.bfloat16).cuda()

# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L29
targets_chosen = torch.randint(low=0, high=151936, size=(10000,))
targets_rejected = torch.randint(low=0, high=151936, size=(5000,))

targets = pad_sequence([targets_chosen, targets_rejected], batch_first=True, padding_value=-100).cuda()

# assumed packing 5 sequences
num_seqs = 5

In [3]:
inputs_weight = torch.nn.Linear(5120, 151936).cuda()
refs_weight = torch.nn.Linear(5120, 151936).cuda()

In [4]:
loss = LigerFusedLinearDPOLoss()

In [5]:
out = loss.forward(
    inputs_weight.weight,
    inputs.to(torch.float32),
    targets,
    ref_input=refs.to(torch.float32),
    ref_weight=refs_weight.weight
)
print(out)

(tensor(0.1247, device='cuda:0', grad_fn=<LigerFusedLinearDPOFunctionBackward>), (tensor([-120938.7969], device='cuda:0'), tensor([-60440.4844], device='cuda:0'), tensor(-1.5490e-05, device='cuda:0'), tensor(-2.1861e-05, device='cuda:0'), tensor(0., device='cuda:0'), tensor([6.6695], device='cuda:0'), tensor([4.6508], device='cuda:0')))


In [6]:
out[0].backward()

In [7]:
!nvidia-smi

Sun Nov 30 08:30:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          Off |   00000000:09:00.0 Off |                    0 |
| N/A   45C    P0            407W /  700W |   46363MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          Off |   00