In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
class RoPE(nn.Module):
    def __init__(self, base, max_len, head_dim):
        super(RoPE, self).__init__()
        assert head_dim % 2 == 0
        self.thetas = 1 / (torch.pow(base, torch.arange(0, head_dim, 2) / head_dim)).float()
        self.seqs = torch.arange(max_len)
        self.matrix = torch.outer(self.seqs, self.thetas)
        self.complex = torch.polar(torch.ones_like(self.matrix), self.matrix).unsqueeze(0).unsqueeze(2)
        self.register_buffer('freqs_complex', self.complex)
    def forward(self, x):
        B, T, n_heads, head_dim = x.shape
        complex_x = torch.view_as_complex(x.float().reshape(B, T, n_heads, head_dim // 2, 2))
        rotated_x = complex_x * self.freqs_complex[:, :T, :, :]
        rotated_x = torch.view_as_real(rotated_x).reshape(*x.shape).type_as(x)
        return rotated_x

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, n_embed, eps=1e-5):
        super(RMSNorm, self).__init__()
        self.gammas = nn.Parameter(torch.ones(n_embed))
        self.eps = eps
    def forward(self, x):
        rms = torch.pow(torch.mean(torch.pow(x, 2), dim=2), 1/2).unsqueeze(-1)
        x = x / (rms + self.eps)
        return x * self.gammas

In [5]:
class GeGlu(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(GeGlu, self).__init__()
        self.up_proj = nn.Linear(in_dim, out_dim)
        self.down_proj = nn.Linear(out_dim // 2, in_dim)
        nn.init.xavier_uniform_(self.up_proj.weight)
        nn.init.xavier_uniform_(self.down_proj.weight)
    def forward(self, x):
        x, gate = self.up_proj(x).chunk(2, dim=-1)
        return self.down_proj(x * F.gelu(gate))

In [6]:
class GlobalAttention(nn.Module):
    def __init__(self, embed_dim : int, n_heads : int, dropout : float, rope):
        super(GlobalAttention, self).__init__()
        self.rope = rope
        self.head_size = embed_dim // n_heads
        self.n_heads = n_heads
        self.up_proj = nn.Linear(embed_dim, self.head_size * n_heads * 3)
        self.down_proj = nn.Linear(self.head_size * n_heads, embed_dim)
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.up_proj.weight)
        nn.init.xavier_uniform_(self.down_proj.weight)
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.up_proj(x).split(self.n_heads * self.head_size, dim=-1)
        q = self.rope(q.view(B, T, self.n_heads, self.head_size))
        k = self.rope(k.view(B, T, self.n_heads, self.head_size))
        v = v.view(B, T, self.n_heads, self.head_size)
        attn = F.scaled_dot_product_attention(q, k, v)
        attn = attn.reshape(B, T, self.n_heads * self.head_size)
        return self.dropout(self.down_proj(attn))

In [7]:
def cumsum(lst):
    cum_sum = [0]
    total = 0
    for num in lst:
        total += num
        cum_sum.append(total)
    return cum_sum

def unpadding(inputs, attention_masks, device):
    assert inputs.dim() == 2
    inputs = inputs.flatten()
    indices = attention_masks.bool()
    cu_lens = [mask.sum().item() for mask in indices]
    cu_lens = cumsum(cu_lens)
    indices = indices.flatten()
    return inputs[indices].to(device), cu_lens
    
def padding(inputs, attention_mask, batch_size, seq_len, device):
    assert inputs.dim() == 1
    outputs = torch.zeros(batch_size * seq_len, dtype=inputs.dtype, device=inputs.device)
    outputs[attention_mask.bool().flatten()] = inputs
    outputs = outputs.reshape(batch_size, seq_len)
    return outputs

In [8]:
x = torch.randint(high=100, size=(4, 32)).to(device)
a = torch.randint(high=2, size=(4, 32)).to(device)
out, lens = unpadding(x, a, device)
padded = padding(out, a, 4, 32, device)

In [11]:
q1, q2, q3 = torch.rand(6, 32), torch.rand(8, 32), torch.rand(5, 32)
k1, k2, k3 = torch.rand(6, 32), torch.rand(8, 32), torch.rand(5, 32)
v1, v2, v3 = torch.rand(6, 32), torch.rand(8, 32), torch.rand(5, 32)

In [13]:
q = torch.nested.nested_tensor([q1, q2, q3], layout=torch.jagged).to(device)
k = torch.nested.nested_tensor([k1, k2, k3], layout=torch.jagged).to(device)
v = torch.nested.nested_tensor([v1, v2, v3], layout=torch.jagged).to(device)

In [15]:
from torch.nn.attention import SDPBackend, sdpa_kernel

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

W0505 03:36:00.548000 29704 site-packages\torch\nested\_internal\sdpa.py:294] Memory efficient kernel not used because:
W0505 03:36:00.550000 29704 site-packages\torch\nested\_internal\sdpa.py:122] Fused kernels do not support ragged num_head_dims, query has a ragged num_heads.
W0505 03:36:00.552000 29704 site-packages\torch\nested\_internal\sdpa.py:297] Flash attention kernel not used because:
W0505 03:36:00.553000 29704 site-packages\torch\nested\_internal\sdpa.py:122] Fused kernels do not support ragged num_head_dims, query has a ragged num_heads.
W0505 03:36:00.554000 29704 site-packages\torch\nested\_internal\sdpa.py:300] Math attention kernel not used because:
W0505 03:36:00.558000 29704 site-packages\torch\nested\_internal\sdpa.py:252] If inputs are nested tensors they must be contiguous after transposing.
  out = F.scaled_dot_product_attention(q, k, v)
  out = F.scaled_dot_product_attention(q, k, v)


RuntimeError: No viable backend for scaled_dot_product_attention was found.