Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 58 additions & 78 deletions tests/pytorch/test_fused_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,6 @@
)


def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]


def apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.

Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]

Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)


# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
Expand All @@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope(
dtype: torch.dtype,
seq_length: int,
Expand All @@ -85,6 +41,8 @@ def test_fused_rope(
transpose: Union[Tuple, None],
tensor_format: str,
loss_func: Callable,
cp_size: int,
interleaved: bool,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
Expand All @@ -99,50 +57,63 @@ def test_fused_rope(
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True

rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)

# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
fused=True,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(seq_length * cp_size)
assert emb.is_contiguous()

torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb(
t.float(),
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()


@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
transpose: Union[Tuple, None],
loss_func: Callable,
cp_size: int,
interleaved: bool,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
Expand Down Expand Up @@ -170,15 +141,23 @@ def test_fused_rope_thd(
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True

rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(cu_seqlens_padded[-1])
assert emb.is_contiguous()

for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank
output_unfused = apply_rotary_pos_emb(
t.float(),
emb,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
Expand All @@ -189,6 +168,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb(
t,
emb,
interleaved=interleaved,
fused=True,
tensor_format="thd",
cu_seqlens=cu_seqlens_padded,
Expand Down
Loading