In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import torch
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

# Original function

- 깃헙에 있는 함수를 그대로 가져온 것입니다.
- 이 스터디 덕분에 rotary position embedding에 대해서도 공부하게 되었는데, 그게 어떻게 구현되었는지 살펴보고자 했습니다.
- 원본의 input, output 뿐만 아니라 중간 data shape와 같은 흐름 또한 확인하고자했습니다.

In [3]:
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    change sign so the last dimension becomes [-odd, +even]
    """
    x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb_pytorch(
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
    """
    Apply rotary positional embedding tensor to the input tensor.

    Parameters
    ----------
    t: torch.Tensor
        Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
        rotary positional embedding will be applied.
    freqs: torch.Tensor
        Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
        with `s2 >= s` and `d2 <= d`.
    fused: bool, default = False
        Whether to use a fused applying RoPE implementation.
    tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
        is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
        of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
    cu_seqlens: torch.Tensor, default = None.
        Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
        dtype torch.int32. Only valid when `tensor_format` is 'thd'.
    """
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )

    max_seq_len = freqs.shape[0]
    cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]

    # Only apply the rotary embeddings up to the sequence length of the running
    # input.
    assert cur_seq_len <= max_seq_len, (
        f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    )
    freqs = freqs[:cur_seq_len]
    if tensor_format == "bshd":
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    # cos/sin first then dtype conversion for better precision
    cos_ = torch.cos(freqs).to(t.dtype)
    sin_ = torch.sin(freqs).to(t.dtype)

    rot_dim = freqs.shape[-1]
    # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
    t, t_pass = t[..., :rot_dim], t[..., rot_dim:]

    # first part is cosine component
    # second part is sine component, need to change signs with _rotate_half method
    t = (t * cos_) + (_rotate_half(t) * sin_)
    return torch.cat((t, t_pass), dim=-1)

### Prepare input sample

In [4]:
s, b, h, d = 10, 2, 4, 16
t = torch.randn(s, b, h, d)

### hypothsis. s2 = s, d2 = d/2
s2, d2 = s, d//2
freqs = torch.randn(s2, 1, 1, d2)

tensor_format = "sbhd"

fused = False
cu_seqlens = None

# t, freqs, tensor_format, fused, cu_seqlens

t = t.to('cuda')
freqs = freqs.to('cuda')

### Result

In [5]:
original_result = apply_rotary_pos_emb_pytorch(
    t=t,
    freqs=freqs,
    tensor_format=tensor_format,
)

original_result.shape, original_result.mean(), original_result.std()

(torch.Size([10, 2, 4, 16]),
 tensor(-0.0223, device='cuda:0'),
 tensor(1.0054, device='cuda:0'))

# Migrate 'apply_rotary_pos_emb' function to triton
- _rotate_half_kernel 함수는 이전 노트북에서 구현했던 것을 그대로 가져왔습니다.
- 가급적이면 사용자 입장에서 원본 apply_rotary_pos_emb 함수를 호출하는 것과 동일하게 사용하는게 좋겠다 싶어서, 높은 추상화 레벨의 함수에서 apply_rotary_pos_emb_kernel을 호출하는 식으로 구현하고자 했습니다.
- 정말 아쉽지만 작성 도중에 제출 마감 시간이 되어 이 함수는 완성하지 못 하였습니다.

In [6]:
import triton
import triton.language as tl

In [7]:
@triton.jit
def _rotate_half_kernel(
    output_ptr, input_ptr,
    batch_size: tl.constexpr, seq_len: tl.constexpr, head_num: tl.constexpr, d_model: tl.constexpr,
    BLOCK_SIZE: tl.constexpr
):
    idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    total_elements = batch_size * seq_len * head_num * d_model
    mask = idx < total_elements

    dim_idx = idx % d_model
    half_dim = d_model // 2
    is_second_half = dim_idx >= half_dim

    swapped_idx = tl.where(is_second_half, idx - half_dim, idx + half_dim)
    data = tl.load(input_ptr + idx, mask=mask)
    data_swapped = tl.where(is_second_half, -data, data)

    tl.store(output_ptr + swapped_idx, data_swapped, mask=mask)

        
@triton.jit
def apply_rotary_pos_emb_kernel(
    t_ptr, freqs_ptr, output_ptr, 
    s: tl.constexpr, b: tl.constexpr, h: tl.constexpr, d: tl.constexpr, d2: tl.constexpr, 
    BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0)
    col_idx = tl.arange(0, BLOCK_SIZE)
    
    freq = tl.load(freqs_ptr + row_idx * d2 + col_idx, mask=col_idx < d2, other=0.0)
    
    cos_val, sin_val = tl.cos(freq), tl.sin(freq)
    
    input_val = tl.load(t_ptr + row_idx * h * d + col_idx)
    ### TODO: make rotated_input_val using __rotate_half_kernel
    rotated_input_val = input_val
#     __rotate_half_kernel[(grid_size,)](
#         rotated_input_val, input_val, 
#         batch_size=batch_size, seq_len=seq_len, head_num=head_num, d_model=d_model,
#         BLOCK_SIZE=BLOCK_SIZE
#     )

    rotary_emb = cos_val * input_val + sin_val * rotated_input_val

    tl.store(output_ptr + row_idx * h * d + col_idx, rotary_emb, mask=col_idx < d)


def apply_rotary_pos_emb_triton(
    t: torch.Tensor,
    freqs: torch.Tensor,
    tensor_format: str = "sbhd",
    fused: bool = False,
    cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
    if fused:
        assert (
            tensor_format != "thd" or cu_seqlens is not None
        ), "cu_seqlens must not be None when tensor_format is 'thd'."
        return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)

    assert tensor_format in ("sbhd", "bshd"), (
        "Only formats `sbhd` or `bshd` are supported for input tensor `t` "
        f"when fused is False, got {tensor_format}."
    )
    
    s, b, h, d = t.shape
    cur_seq_len = b if tensor_format == "bshd" else s
    max_seq_len, _, _, d2 = freqs.shape
    
    # Only apply the rotary embeddings up to the sequence length of the running
    assert cur_seq_len <= max_seq_len, (
        f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
    )
    
    freqs = freqs[:cur_seq_len]
    if tensor_format == "bshd":
        freqs = freqs.transpose(0, 1)  # [seq, 1, 1, dim] -> [1, seq, 1, dim]
    
    output = torch.empty_like(t)
    BLOCK_SIZE = d
    grid = (s * b * h, )
    apply_rotary_pos_emb_kernel[grid](
        output, t, freqs, 
        s, b, h, d, d2, 
        BLOCK_SIZE=BLOCK_SIZE
    )
    return output

In [8]:
triton_result = apply_rotary_pos_emb_triton(
    t=t,
    freqs=freqs,
    tensor_format=tensor_format,
)

triton_result.shape, triton_result.mean(), triton_result.std()

(torch.Size([10, 2, 4, 16]),
 tensor(0.0003, device='cuda:0'),
 tensor(0.9146, device='cuda:0'))

In [9]:
torch.testing.assert_close(original_result, triton_result)

AssertionError: Tensor-likes are not close!

Mismatched elements: 1280 / 1280 (100.0%)
Greatest absolute difference: 4.691768646240234 at index (8, 0, 1, 6) (up to 1e-05 allowed)
Greatest relative difference: inf at index (1, 0, 0, 0) (up to 1.3e-06 allowed)

- 스터디 모집 문제에 cuda kernel로 구현한 것까지 첨부해주신 것으로 보아, 해당 코드를 참고해 triton에서 low하게 구현하여 성능 향상을 하는 결과물을 제출하는 것이 목표인 듯 합니다.
- 사실 저에게 2~3일로는 이 과제가 쉽지 않을 것 같았지만, 정말 매력적인 분야라 꼭 공부하고 싶었고, 그냥 포기하면 후회할 것 같아서 부딫혀보게 되었습니다.
- 제 한계가 고작 이 정도인가 싶어서 자괴감도 들고, 최선을 다 했으니 한편으로는 후련한 복잡한 심정입니다...