Rotatory

In [None]:
import torch
import torch.nn.functional as F

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len).type_as(inv_freq)
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        self.register_buffer('cos', freqs.cos())
        self.register_buffer('sin', freqs.sin())

    def forward(self, q, k):
        return apply_rotary_pos_emb(q, k, self.cos, self.sin)


Group query attention

In [None]:
class GroupQueryAttention(nn.Module):
    def __init__(self, embed_size, num_heads, num_groups):
        super(GroupQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = embed_size // num_heads
        self.scale = self.head_dim ** -0.5

        self.query_projections = nn.ModuleList([
            nn.Linear(self.head_dim, self.head_dim) for _ in range(num_groups)
        ])
        self.key_projection = nn.Linear(embed_size, embed_size)
        self.value_projection = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        N, seq_length, _ = x.shape
        queries = x.reshape(N, seq_length, self.num_heads, self.head_dim)
        keys = self.key_projection(x).reshape(N, seq_length, self.num_heads, self.head_dim)
        values = self.value_projection(x).reshape(N, seq_length, self.num_heads, self.head_dim)

        # Process each group of queries separately
        attention_scores = 0
        for i, query_projection in enumerate(self.query_projections):
            group_queries = query_projection(queries[:, :, i::self.num_groups])
            attention_scores += (group_queries @ keys.transpose(-2, -1)) * self.scale

        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention_scores, dim=-1)
        out = attention @ values
        out = out.reshape(N, seq_length, -1)
        return self.fc_out(out)


Sliding window attention

In [None]:
class SlidingWindowAttention(nn.Module):
    def __init__(self, embed_size, num_heads, window_size):
        super(SlidingWindowAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = embed_size // num_heads

        self.query = nn.Linear(self.head_dim, self.head_dim)
        self.key = nn.Linear(self.head_dim, self.head_dim)
        self.value = nn.Linear(self.head_dim, self.head_dim)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        N, seq_length, _ = x.shape
        qkv = x.reshape(N, seq_length, self.num_heads, self.head_dim)
        queries, keys, values = map(lambda layer: layer(qkv), (self.query, self.key, self.value))

        # Apply sliding window
        attention_scores = torch.zeros((N, self.num_heads, seq_length, self.window_size), device=x.device)
        for i in range(seq_length):
            start = max(0, i - self.window_size + 1)
            end = i + 1
            attention_scores[:, :, i, :i-start+1] = (queries[:, i] @ keys[:, start:end].transpose(-2, -1))

        if mask is not None:
            extended_mask = mask[:, None, :, None]
            attention_scores = attention_scores.masked_fill(extended_mask == 0, float("-1e20"))

        attention = torch.softmax(attention_scores, dim=-1)
        out = attention @ values
        out = out.reshape(N, seq_length, -1)
        return self.fc_out(out)
