# Profiling Different Attention Implementations on Intel Gaudi HPU

## Overview

This notebook demonstrates how to profile and compare different attention mechanism implementations on Intel Gaudi AI accelerators. Attention is the core component of Transformer models, and its performance is critical for LLM inference.

We compare three attention implementations:

1. **Fused SDPA (Scaled Dot-Product Attention)**: Uses Habana's optimized `FusedSDPA` kernel for maximum performance
2. **Initial PagedAttention**: Loop-based attention with fetching cache using index select
3. **Flat PagedAttention**: paged attention using flattened layout removing internal fragmentation

## Key Concepts

### PagedAttention
PagedAttention is a memory-efficient attention mechanism that stores KV cache in non-contiguous memory blocks, similar to virtual memory paging. This enables:
- Efficient memory utilization for variable-length sequences
- Better batching of requests with different context lengths
- Reduced memory fragmentation


## Setup: Import Libraries and Define Helper Functions

This section sets up:
- **Profiler configuration**: Wait, warmup, and active step counts
- **Attention parameters**: Number of heads, head dimension
- **Helper functions**: Synchronization, profiler setup, benchmarking utilities
- **Input Data Preparation**: Functions for preparing inputs for each attentions


In [None]:
from collections import namedtuple
import os
import time

os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['HABANA_PROFILE'] = 'profile_api'

import habana_frameworks.torch.hpu as ht
import habana_frameworks.torch as htorch
from habana_frameworks.torch.hpex.kernels import FusedSDPA
import torch

PROF_WAIT = 0
PROF_WARMUP = 1
PROF_ACTIVE = 3
STEPS = PROF_WARMUP + PROF_ACTIVE
DEVICE = os.environ.get('DEVICE', 'hpu')

MAX_BLOCKS = 1024

Q_NUM_HEADS = 32
K_NUM_HEADS = 32

HEAD_DIM = 128
BLOCK_SIZE = 128
PA_SPLIT_VALUE = 1

INDEX_DTYPE = torch.int32 if DEVICE == 'hpu' else torch.int64

Scenario = namedtuple("Scenario", ["name", "fn", "data_fn", "params"])

def sync():
    if DEVICE == 'hpu':
        htorch.core.mark_step()
        htorch.hpu.synchronize()

def pt_profiler(schedule, attn_name):
    activities = [torch.profiler.ProfilerActivity.CPU]
    activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == 'hpu' else [])

    profiler = torch.profiler.profile(
        schedule=schedule,
        activities=activities,
        on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./{attn_name}', use_gzip=True),
        record_shapes=False,
        with_stack=True)
    return profiler

def setup_profiler(method, attn_name):
    schedule = torch.profiler.schedule(wait=PROF_WAIT, warmup=PROF_WARMUP, active=PROF_ACTIVE, repeat=1)
    return method(schedule, attn_name)

def params_str(params):
    sep = ' '
    return sep.join(f'{k}:{v}' for k, v in params.items() if k != 'context_lens')

def define(*args, **params):
    return Scenario(*args, params=params)

def run(step, scenario, data):
    name = f'{scenario.name} | {params_str(scenario.params)}'
    args, kwargs = data
    start_ts = time.perf_counter()
    res = scenario.fn(*args, **kwargs)
    sync()
    diff_ts = 1000 * (time.perf_counter() - start_ts)
    print(f'{step} | {name} | {diff_ts:.3f}ms')
    return diff_ts

def benchmark(scenarios):
    data = [s.data_fn(**s.params) for s in scenarios]
    sync()
    for s, d in zip(scenarios, data):
        prof = setup_profiler(pt_profiler, s.name)
        prof.start()
        for i in range(STEPS):
            sync()
            run(i, s, d)
            prof.step()
        prof.stop()

def round_up(x, k):
    return (x + k - 1) // k * k

def fetch_from_cache(cache, blocks, permutations):
    if permutations is not None:
        return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]
    else:
        return [cache.index_select(0, blocks[:, i]) for i in range(blocks.size(1))]

def flat_pa_data(batch_size, num_blocks):
    cache_shape = (MAX_BLOCKS, BLOCK_SIZE, K_NUM_HEADS, HEAD_DIM)
    params = {
        'query': torch.empty(batch_size, Q_NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE),
        'key_cache': torch.empty(*cache_shape, dtype=torch.bfloat16, device=DEVICE),
        'value_cache': torch.empty(*cache_shape, dtype=torch.bfloat16, device=DEVICE),
        'scale': 1.0,
        'block_list': torch.zeros(num_blocks, dtype=INDEX_DTYPE, device=DEVICE),
        'block_mapping': torch.zeros(num_blocks, batch_size, dtype=torch.bfloat16, device=DEVICE),
        'block_bias': torch.zeros(num_blocks, BLOCK_SIZE, dtype=torch.bool, device=DEVICE),
    }
    return (), params

def pa_v0_2_data(batch_size, seq_len, context_lens=None):
    num_blocks = round_up(seq_len, BLOCK_SIZE) // BLOCK_SIZE
    cache_shape = (MAX_BLOCKS, K_NUM_HEADS, HEAD_DIM, BLOCK_SIZE)
    params = {
        'query': torch.empty(batch_size, Q_NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE),
        'key_cache': torch.empty(*cache_shape, dtype=torch.bfloat16, device=DEVICE),
        'value_cache': torch.empty(*cache_shape, dtype=torch.bfloat16, device=DEVICE),
        'head_mapping': None,
        'scale': 1.0,
        'block_tables': torch.zeros(batch_size, num_blocks, dtype=INDEX_DTYPE, device=DEVICE),
        'context_lens': torch.ones(batch_size, dtype=INDEX_DTYPE, device=DEVICE) * seq_len if context_lens is None else context_lens,
        'block_size': BLOCK_SIZE,
        'alibi_slopes': None,
    }
    return (), params

def fetch_from_cache_pa_v0_3(cache, blocks):
    return [cache.index_select(0, blocks[:, i]) for i in range(blocks.size(1))]

---

## Attention Implementation 1: Fused SDPA

### Fused Scaled Dot-Product Attention

Habana's `FusedSDPA` is a highly optimized kernel that fuses multiple attention operations into a single efficient kernel:

```
Attention(Q, K, V) = softmax(Q @ K^T / ‚àöd_k) @ V
```

**Advantages:**
- Single fused kernel reduces memory bandwidth requirements
- Optimized for Intel Gaudi's Matrix Multiplication Engine (MME)
- Best performance for prefill

In [None]:
def attn_baseline(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: None,
) -> torch.Tensor:
    with ht.sdp_kernel(enable_recompute = False):
        sdpa_out = FusedSDPA.apply(query, key, value, None, 0.0, False)
    return sdpa_out

def attn_baseline_data(batch_size, seq_len):
    params = {
        'query': torch.ones(batch_size, Q_NUM_HEADS, 1, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE),
        'key': torch.ones(batch_size, K_NUM_HEADS, seq_len, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE),
        'value': torch.ones(batch_size, K_NUM_HEADS, seq_len, HEAD_DIM, dtype=torch.bfloat16, device=DEVICE),
        'scale': 1.0,
        'attn_mask': None
        
    }
    return (), params

---

## Attention Implementation 2: Inital PagedAttention

### Initial PagedAttention

**Algorithm:**
1. Create attention mask based on `context_lens` for variable-length sequences
2. Fetch Key blocks from cache using `block_tables`
3. Compute attention scores: `Q @ K^T` for each block
4. Apply masking and softmax across all blocks
5. Fetch Value blocks and compute weighted sum

**Key Features:**
- index_selct based cache fetching
- loop based attention computation


In [None]:
def pa_v0_3(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, attn_masks=None)  -> None:
    if alibi_slopes is not None:
        raise NotImplementedError
    if attn_masks is not None:
        raise NotImplementedError

    seq_len = block_tables.size(1)
    batch_size, query_heads, _ = query.shape
    _, kv_heads, _, _ = key_cache.shape
    min_inf = torch.finfo(query.dtype).min
    mask = (torch.arange(0, seq_len * block_size, dtype=torch.int32, device=key_cache.device)
            .view(1, -1)
            .expand(batch_size, -1)
            .ge(context_lens.view(-1, 1))
            .view(batch_size, 1, 1, -1))
    query = query.unsqueeze(-2)
    keys = fetch_from_cache_pa_v0_3(key_cache, block_tables)
    if query_heads != kv_heads:
        query = query.unflatten(1, (kv_heads, -1))
        keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
        mask = mask.unsqueeze(2)

    attn_weights = [torch.matmul(query, k) for k in keys]
    attn_weights = (torch.cat(attn_weights, dim=-1)
                    .mul_(scale)
                    .masked_fill(mask, min_inf)
                    .softmax(dim=-1))

    values = fetch_from_cache_pa_v0_3(value_cache, block_tables)
    if PA_SPLIT_VALUE:
        attn_weights = attn_weights.split(block_size, dim=-1)
    else:
        values = [torch.cat(values, dim=-1)]
        attn_weights = [attn_weights]
    if query_heads != kv_heads:
        values = [v.unflatten(1, (kv_heads, 1)) for v in values]
    attn_weights = [torch.matmul(a, v.transpose(-1, -2)).squeeze(-2) for a, v in zip(attn_weights, values)]
    if query_heads != kv_heads:
        attn_weights = [a.flatten(1, 2) for a in attn_weights]
    attn_weights = sum(attn_weights)

    return attn_weights

---

## Attention Implementation 3: Flat PagedAttention

### Flat PagedAttention

This implementation uses a "flattened" approach where batch and block dimensions are merged, removing internal fragmentation within a batch.

**Key Operations:**
- `batch2block`: Transforms batch-indexed tensors to block-indexed using `block_mapping`
- `block2batch`: Inverse transformation from block to batch indexing
- `block_softmax`: Custom softmax that correctly normalizes across blocks belonging to the same batch item

**Algorithm:**
1. Transform query from batch space to block space
2. Fetch Key/Value from cache using `block_list`
3. Compute attention scores with block-level bias
4. Apply custom `block_softmax` for correct normalization
5. Compute weighted sum and transform back to batch space


In [None]:
def batch2block(tensor, block_mapping):
    shape = tuple(tensor.shape)
    return (block_mapping @ tensor.view(shape[0], -1)).view(-1, *shape[1:])

def block2batch(tensor, block_mapping):
    shape = tuple(tensor.shape)
    return (block_mapping.t() @ tensor.view(shape[0], -1)).view(-1, *shape[1:])

def block_softmax(batch_size, attn, block_mapping):
    attn = attn.exp_()
    sums = attn.sum(dim=-1).unsqueeze(-1)
    sums = block2batch(sums, block_mapping)
    sums = batch2block(sums, block_mapping)
    attn.div_(sums)
    return attn

def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias, scale):
    batch_size = query.size(0)
    q_heads = query.size(1)
    kv_heads = key_cache.size(2)

    query = batch2block(scale * query, block_mapping).unsqueeze(-2)
    key = torch.index_select(key_cache, 0, block_list).transpose(1, 2)
    value = torch.index_select(value_cache, 0, block_list).transpose(1, 2)
    block_bias = block_bias.view(key.size(0), 1, 1, -1)

    if kv_heads != q_heads:
        block_bias = block_bias.unsqueeze(1)
        query = query.unflatten(1, (kv_heads, -1))
        key = key.unflatten(1, (kv_heads, 1))
        value = value.unflatten(1, (kv_heads, 1))
        key = key.transpose(3, 4)
    else:
        key = key.transpose(2, 3)

    attn = (query @ key) + block_bias
    attn = block_softmax(batch_size, attn, block_mapping)
    attn = attn @ value
    attn = block2batch(attn, block_mapping)
    attn = attn.squeeze(-2)
    if kv_heads != q_heads:
        attn = attn.flatten(1, 2)
    return attn

---

## Run Profiling Benchmark

### Benchmark Configuration

The benchmark runs all three attention implementations with the following parameters:
- **Batch size**: 32 sequences
- **Sequence length**: 1024 tokens

Each implementation will be profiled separately, and traces will be saved to individual directories.

In [None]:
batch_size = 32
seq_len = 1024
torch.manual_seed(42)
context_lens = torch.randint(low=1, high=seq_len+1, size=(batch_size,), dtype=INDEX_DTYPE, device=DEVICE)
num_blocks = int(torch.sum(torch.ceil(context_lens/BLOCK_SIZE)).item())
scenarios = [
    define('attn_baseline', attn_baseline, attn_baseline_data, batch_size=batch_size, seq_len=seq_len),
    define('initial_pa', pa_v0_3, pa_v0_2_data, batch_size=batch_size, seq_len=seq_len, context_lens=context_lens),
    define('flat_pa', flat_pa, flat_pa_data, batch_size=batch_size, num_blocks=num_blocks),
]
benchmark(scenarios)


---

## üîç Viewing the Trace Results

### Generated Trace Files

After running the benchmark, trace files are saved in the following directories:
- `./attn_baseline/` - Fused SDPA traces
- `./initial_pa/` - Initial PagedAttention traces
- `./flat_pa/` - Flat PagedAttention traces

### Habana Perfetto Viewer

**Upload the file to https://perfetto.habana.ai and view the API calls and hardware trace events.**

Steps:
1. Download the generated `.json.gz` trace files to your local machine
2. Open https://perfetto.habana.ai in your browser
3. Click "Open trace file" and select your downloaded trace file
4. Explore the trace timeline to analyze:
   - CPU and HPU activity overlap
   - Kernel execution times (TPC and MME utilization)
   - Memory transfer operations
   - Attention kernel performance differences


### Performance Tips

- **Fused SDPA** typically has the lowest latency for standard attention patterns
- **PagedAttention variants** excel when memory efficiency is critical (e.g., long contexts)
- Compare the **MME utilization** in traces to understand compute efficiency
