In [1]:
import torch
import triton
import triton.language as tl
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss

In [2]:
torch.cuda.memory._record_memory_history(
   max_entries=100000
)

# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L11
inputs = torch.randn(2, 10000, 5120, dtype=torch.bfloat16).cuda()
refs = torch.randn(2, 10000, 5120, dtype=torch.bfloat16).cuda()

# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json#L29
targets_chosen = torch.randint(low=0, high=151936, size=(10000,))
targets_rejected = torch.randint(low=0, high=151936, size=(5000,))

targets = pad_sequence([targets_chosen, targets_rejected], batch_first=True, padding_value=-100).cuda()

inputs_weight = torch.nn.Linear(5120, 151936).cuda()
refs_weight = torch.nn.Linear(5120, 151936).cuda()

loss = LigerFusedLinearDPOLoss()
out = loss.forward(
    inputs_weight.weight,
    inputs.to(torch.float32),
    targets,
    ref_input=refs.to(torch.float32),
    ref_weight=refs_weight.weight
)
out[0].backward()

torch.cuda.memory._dump_snapshot("liger.pickle")

In [3]:
out

(tensor(0.5823, device='cuda:0', grad_fn=<LigerFusedLinearDPOFunctionBackward>),
 (tensor([-120998.0391], device='cuda:0'),
  tensor([-60436.0430], device='cuda:0'),
  tensor(7.5361e-06, device='cuda:0'),
  tensor(4.2012e-07, device='cuda:0'),
  tensor(0., device='cuda:0'),
  tensor([1.0633], device='cuda:0'),
  tensor([0.8277], device='cuda:0')))

In [4]:
!nvidia-smi

Sun Nov 30 14:43:15 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          Off |   00000000:09:00.0 Off |                    0 |
| N/A   42C    P0            150W /  700W |   46363MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          Off |   00

In [5]:
!ls -lha liger.pickle

-rw-r--r-- 1 root root 121K Nov 30 14:42 liger.pickle
