# Rotary Positional Embedding(RoPE) 

In [107]:
import math
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

In [115]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16

n_head = 6
n_embed = 48
block_size = 8
batch_size = 4

rope_n_elem = int(0.25 * (n_embed // n_head))
base = 10000

x = torch.tensor(
    np.random.rand(batch_size, n_head, block_size, n_embed // n_head),
    dtype=dtype,
    device=device,
)

## LL1

In [116]:
def build_rope_cache_ll1(
    block_size: int,
    n_elem: int,
    base: int = 10000,
    condense_ratio: int = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    theta = 1.0 / (
        base ** (torch.arange(0, n_elem, 2, device=device, dtype=dtype) / n_elem)
    )
    seq_idx = torch.arange(block_size, device=device) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
    return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope_ll1(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
    hs = x.size(-1)
    x1 = x[..., : hs // 2]  # B, nh, T, hs/2
    x2 = x[..., hs // 2 :]  # B, nh, T, hs/2
    rotated = torch.cat((-x2, x1), dim=-1)  # B, nh, T, hs
    roped = (x * cos) + (rotated * sin)
    return roped.type_as(x)


cos, sin = build_rope_cache_ll1(
    block_size=block_size, n_elem=rope_n_elem, base=base, device=device, dtype=dtype
)

apply_rope_ll1(x[..., :rope_n_elem], cos, sin).shape

torch.Size([4, 6, 8, 2])

## LL2

In [117]:
def build_rope_cache_ll2(
    block_size: int,
    rope_n_elem: int,
    base: int = 10000,
    condense_ratio: int = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    theta = 1 / base ** (
        torch.arange(0, rope_n_elem, 2, device=device, dtype=dtype) / rope_n_elem
    )
    seq_idx = torch.arange(block_size, device=device, dtype=dtype) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
    return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope_ll2(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
    hs = x.size(-1)
    x1 = x[..., : hs // 2]
    x2 = x[..., hs // 2 :]
    rotated = torch.cat((-x2, x1), dim=-1)
    roped = (x * cos) + (rotated * sin)

    return roped.type_as(x)


cos, sin = build_rope_cache_ll2(
    block_size=block_size,
    rope_n_elem=rope_n_elem,
    base=base,
    device=device,
    dtype=dtype,
)

apply_rope_ll2(x[..., :rope_n_elem], cos, sin).shape

torch.Size([4, 6, 8, 2])

## LL25

In [118]:
def build_rope_cache_ll25(
    block_size: int,
    rope_n_elem: int,
    base: int = 10000,
    condense_ratio: int = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    theta = 1 / base ** (
        torch.arange(0, rope_n_elem, 2, device=device, dtype=dtype) / rope_n_elem
    )
    seq_idx = torch.arange(block_size, device=device, dtype=dtype) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)  # block_size, rope_n_elem

    return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope_ll25(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
    hs = x.size(-1)
    x1 = x[..., :hs]
    x2 = x[..., hs:]
    rotated = torch.cat((-x2, x1), dim=-1)
    roped = (x * cos) + (rotated * sin)
    return roped.type_as(x)


cos, sin = build_rope_cache_ll25(
    block_size=block_size,
    rope_n_elem=rope_n_elem,
    base=base,
    device=device,
    dtype=dtype,
)

apply_rope_ll25(x[..., :rope_n_elem], cos, sin).shape

torch.Size([4, 6, 8, 2])

## LL29

In [119]:
def build_rope_cache_ll29(
    block_size: int,
    rope_n_elem: int,
    condense_ratio: int = 1,
    base: int = 10000,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
):
    theta = 1 / base ** (
        torch.arange(0, rope_n_elem, 2, device=device, dtype=dtype) / rope_n_elem
    )
    seq_idx = torch.arange(block_size, device=device, dtype=dtype) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
    return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope_ll29(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
    hs = x.size(-1)
    x1 = x[..., : hs // 2]
    x2 = x[..., hs // 2 :]
    rotated = torch.cat((-x2, x1), dim=-1)
    roped = (x * cos) + (rotated * sin)

    return roped.type_as(x)


cos, sin = build_rope_cache_ll29(
    block_size=block_size,
    rope_n_elem=rope_n_elem,
    base=base,
    device=device,
    dtype=dtype,
)

apply_rope_ll29(x[..., :rope_n_elem], cos, sin).shape

torch.Size([4, 6, 8, 2])

## LL3

In [120]:
def build_rope_cache_ll3(
    block_size: int,
    rope_n_elem: int,
    base: Optional[int] = 10000,
    condense_ratio: Optional[int] = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    theta = 1 / base ** (
        torch.arange(0, rope_n_elem, 2, device=device, dtype=dtype) / rope_n_elem
    )
    seq_idx = torch.arange(block_size, device=device, dtype=dtype) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
    return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope_ll3(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
    hs = x.size(-1)
    x1 = x[..., : hs // 2]
    x2 = x[..., hs // 2 :]
    rotated = torch.cat((-x2, x1), dim=-1)
    roped = (x * cos) + (rotated * sin)
    return roped.type_as(x)


cos, sin = build_rope_cache_ll3(
    block_size=block_size,
    rope_n_elem=rope_n_elem,
    base=base,
    device=device,
    dtype=dtype,
)

apply_rope_ll3(x[..., :rope_n_elem], cos, sin).shape

torch.Size([4, 6, 8, 2])

# MultiHeadAttention With RoPE

In [129]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
torch.set_default_device(device)

dtype = torch.float16

n_head = 6
n_embed = 48
block_size = 8
batch_size = 4

rope_n_elem = int(0.25 * (n_embed // n_head))
base = 10000

x = torch.tensor(
    np.random.rand(batch_size, block_size, n_embed), dtype=dtype, device=device
)

## LL1

In [122]:
def build_rope_cache_mhall1(
    block_size: int,
    rope_n_elem: int,
    base: Optional[int] = 10000,
    condense_ratio: Optional[int] = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    theta = 1 / base ** (
        torch.arange(0, rope_n_elem, 2, device=device, dtype=dtype) / rope_n_elem
    )
    seq_idx = torch.arange(block_size, device=device, dtype=dtype) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
    return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope_mhall1(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
    hs = x.size(-1)
    x1 = x[..., : hs // 2]
    x2 = x[..., hs // 2 :]
    rotated = torch.cat((-x2, x1), dim=-1)
    roped = (x * cos) + (rotated * sin)
    return roped.type_as(x)


class MultiHeadAttentionLL1(nn.Module):
    def __init__(
        self,
        block_size: int,
        n_head: int,
        n_embed: int,
        dropout: float = 0.20,
        bias: bool = False,
    ):
        super().__init__()
        assert n_embed % n_head == 0

        self.block_size = block_size
        self.n_head = n_head
        self.n_embed = n_embed
        self.dropout = dropout
        self.bias = bias

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)

        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_residual = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(block_size, block_size)).view(
            1, 1, block_size, block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(
        self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ) -> torch.Tensor:
        B, T, C = x.size()
        assert C == self.n_embed

        qkv = self.c_attn(x)
        q, k, v = qkv.split(n_embed, dim=2)

        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        q_roped = apply_rope_mhall1(q[..., :rope_n_elem], cos, sin)
        k_roped = apply_rope_mhall1(k[..., :rope_n_elem], cos, sin)
        q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1)
        k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1)

        attn = q @ k.transpose(-2, -1) / math.sqrt(k.size(-1))
        attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout_attn(attn)

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout_residual(self.c_proj(y))

        return y

In [130]:
cos, sin = build_rope_cache_mhall1(
    block_size=block_size, rope_n_elem=rope_n_elem, device=device, dtype=dtype
)

y = MultiHeadAttentionLL1(block_size=block_size, n_embed=n_embed, n_head=n_head)(
    x, cos, sin
)
y.shape

torch.Size([4, 8, 48])

## LL3

In [134]:
def build_rope_cache_mhall3(
    block_size: int,
    rope_n_elem: int,
    base: Optional[int] = 10000,
    condense_ratio: Optional[int] = 1,
    device: Optional[torch.device] = None,
    dtype: Optional[dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    theta = 1 / base ** (
        torch.arange(0, rope_n_elem, 2, device=device, dtype=dtype) / rope_n_elem
    )
    seq_idx = torch.arange(block_size, device=device, dtype=dtype) / condense_ratio
    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
    return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope_mhall3(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
    hs = x.size(-1)
    x1 = x[..., : hs // 2]
    x2 = x[..., hs // 2 :]
    rotated = torch.cat((-x2, x1), dim=-1)
    roped = (x * cos) + (rotated * sin)
    return roped.type_as(x)


class MultiHeadAttentionLL3(nn.Module):
    def __init__(
        self, block_size: int, n_embed: int, n_head: int, dropout: float, bias: bool
    ):
        super().__init__()
        assert n_embed % n_head == 0
        self.n_embed = n_embed
        self.n_head = n_head
        self.block_size = block_size
        self.dropout = dropout
        self.bias = bias

        self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=bias)
        self.c_proj = nn.Linear(n_embed, n_embed, bias=bias)
        self.dropout_attn = nn.Dropout(dropout)
        self.dropout_residual = nn.Dropout(dropout)

        ltm = torch.tril(torch.ones(block_size, block_size)).view(
            1, 1, block_size, block_size
        )
        self.register_buffer("causal_mask", ltm)

    def forward(
        self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ) -> torch.Tensor:
        B, T, C = x.size()
        assert C == self.n_embed

        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embed, dim=2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        q_roped = apply_rope_mhall3(q[..., :rope_n_elem], cos, sin)
        k_roped = apply_rope_mhall3(k[..., :rope_n_elem], cos, sin)
        q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1)
        k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1)

        attn = q @ k.transpose(-2, -1) / math.sqrt(k.size(-1))
        attn = attn.masked_fill(self.causal_mask == 0, float("-inf"))
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout_attn(attn)

        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout_residual(self.c_proj(y))

        return y

In [137]:
cos, sin = build_rope_cache_mhall3(block_size=block_size, rope_n_elem=rope_n_elem)

y = MultiHeadAttentionLL3(
    block_size=block_size, n_embed=n_embed, n_head=n_head, dropout=0.2, bias=False
)(x, cos, sin)

y.shape

torch.Size([4, 8, 48])