In [2]:
import numpy as np
import timeit
import torch
import torch.nn.functional as F

from torch import nn

torch.manual_seed(1)
np.random.seed(1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class MultiHeadAttention(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout_p (float, optional): Dropout probability. Default: 0.0
    """
    def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,
                 nheads: int, dropout_p: float = 0.0):
        super().__init__()
        self.nheads = nheads
        self.dropout_p = dropout_p
        self.query_proj = nn.Linear(E_q, E_total)
        self.key_proj = nn.Linear(E_k, E_total)
        self.value_proj = nn.Linear(E_v, E_total)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (N, L_t, E_q)
            key (torch.Tensor): key of shape (N, L_s, E_k)
            value (torch.Tensor): value of shape (N, L_s, E_v)

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        # TODO: demonstrate packed projection
        query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=dropout_p, is_causal=True)
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

In [14]:
class MultiHeadAttention2(nn.Module):
    """
    Computes multi-head attention. Supports nested or padded tensors.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        nheads (int): Number of heads
        dropout_p (float, optional): Dropout probability. Default: 0.0
    """
    def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,
                 nheads: int, dropout_p: float = 0.0):
        super().__init__()
        self.nheads = nheads
        self.dropout_p = dropout_p
        self.query_proj = nn.Linear(E_q, E_total)
        self.key_proj = nn.Linear(E_k, E_total)
        self.value_proj = nn.Linear(E_v, E_total)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, start_inds: torch.Tensor, end_inds: torch.Tensor) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for SDPA
            3. Run SDPA
            4. Apply output projection

        Args:
            query (torch.Tensor): query of shape (N * ?, E_q)
            key (torch.Tensor): key of shape (N * ?, E_k)
            value (torch.Tensor): value of shape (N * ?, E_v)
            slice_inds (torch.Tensor): value of shape (N, 2)

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        # TODO: demonstrate packed projection
        q_f = lambda s, e: self.query_proj(query[s:e, :])
        q_f = torch.vmap(q_f)
        q_f(start_inds, end_inds)
        

        return
        
        # query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = F.scaled_dot_product_attention(
            query, key, value, dropout_p=dropout_p, is_causal=True)
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

In [5]:
N = 512
E_q, E_k, E_v, E_total = 512, 512, 512, 512
E_out = E_q
nheads = 8

In [6]:
dropout_p = 0.0

In [7]:
def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)

In [8]:
def gen_batch(N, E_q, E_k, E_v, device):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    query = torch.cat([
        torch.randn(l.item(), E_q, device=device)
        for l in sentence_lengths
    ])

    key = torch.cat([
        torch.randn(s.item(), E_k, device=device)
        for s in sentence_lengths
    ])

    value = torch.cat([
        torch.randn(s.item(), E_v, device=device)
        for s in sentence_lengths
    ])

    inds = [0]
    for s in sentence_lengths:
        inds.append(inds[-1] + s.item())

    slice_inds = torch.stack([torch.tensor(inds[:-1]), torch.tensor(inds[1:])], 1)
    
    return query, key, value, slice_inds, sentence_lengths

query, key, value, slice_inds, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)

In [11]:
def jagged_to_padded(jt, padding_val):
    # TODO: do jagged -> padded directly when this is supported
    return torch.nested.to_padded_tensor(
        torch.nested.nested_tensor(list(jt.unbind())),
        padding_val)

padded_query, padded_key, padded_value = (
    jagged_to_padded(t, 0.0) for t in (query, key, value)
)

In [15]:
mha = MultiHeadAttention2(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device)

In [16]:
mha(query, key, value, slice_inds[:, 0], slice_inds[:, 1])

RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. If error is occurring somewhere inside PyTorch internals, please file a bug report.

In [18]:
def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin)

output_nested, time_nested = benchmark(mha, query, key, value)
output_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value)

# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
    output_padded[i, entry_length:] = 0.0

print("=== without torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item())
print("nested tensor multi-head attention takes", time_nested, "seconds")
print("padded tensor multi-head attention takes", time_padded, "seconds")

# warm up compile first...
compiled_mha = torch.compile(mha)
compiled_mha(query, key, value)
# ...now benchmark
compiled_output_nested, compiled_time_nested = benchmark(
    compiled_mha, query, key, value)

# warm up compile first...
compiled_mha(padded_query, padded_key, padded_value)
# ...now benchmark
compiled_output_padded, compiled_time_padded = benchmark(
    compiled_mha, padded_query, padded_key, padded_value)

# padding-specific step: remove output projection bias from padded entries for fair comparison
for i, entry_length in enumerate(sentence_lengths):
    compiled_output_padded[i, entry_length:] = 0.0

print("=== with torch.compile ===")
print("nested and padded calculations differ by", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item())
print("nested tensor multi-head attention takes", compiled_time_nested, "seconds")
print("padded tensor multi-head attention takes", compiled_time_padded, "seconds")

torch.Size([512, j1, 512])
torch.Size([512, j1, 512])
torch.Size([512, 128, 512])
torch.Size([512, 128, 512])
=== without torch.compile ===
nested and padded calculations differ by 0.0
nested tensor multi-head attention takes 0.01481871772557497 seconds
padded tensor multi-head attention takes 0.005708473734557629 seconds
torch.Size([512, j1, 512])
torch.Size([512, j1, 512])
torch.Size([512, j1, 512])
torch.Size([512, j1, 512])
torch.Size([512, 128, 512])
torch.Size([512, 128, 512])
torch.Size([512, 128, 512])
torch.Size([512, 128, 512])
=== with torch.compile ===
nested and padded calculations differ by 0.0
nested tensor multi-head attention takes 0.002335646189749241 seconds
padded tensor multi-head attention takes 0.005724301561713219 seconds


In [22]:
batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
x, y = torch.randn(3, 5), torch.randn(3, 5)
batched_dot(x, y)

tensor([-1.5005,  3.0300, -5.8522])