# Multi-Head Attention

Multi-Head Attention is an important part of all Transformer-based models.
This tutorial will show how to write it and how to then optimize it.

If you're using a Wormhole card (N150/N300), you will need to set the full Tensix available to be able to continue with this tutorial

In [1]:
import os
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"

In [2]:
import time
import torch
import ttnn

torch.manual_seed(0)

device_id = 0
device = ttnn.open_device(device_id=device_id, l1_small_size=8192)

2024-08-21 05:38:58.907 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/thienluu/.cache/ttnn,model_cache_path=/home/thienluu/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver
[32m2024-08-21 05:38:59.032[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 05:38:59.109[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 05:38:59.121[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 05:38:59.133[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 05:38:59.151[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 05:38:59.169[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1,

## Enable program cache

In [3]:
ttnn.enable_program_cache(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Enabling program cache on device 0


## Write Multi-Head Attention using ttnn

Multi-head can be implemented in `torch` using just 6 operations:

1. `torch.matmul`
2. `torch.add`
3. `torch.reshape`
4. `torch.permute`
5. `torch.mul`
6. `torch.softmax`

`ttnn` provides the exact same APIs to do that and therefore multi-head attention can be implemented in a very similar fashion. Except, when using `ttnn`, the user should be mindful of the tensor layout.

In [5]:
@ttnn.log_runtime
def multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    *,
    num_heads,
):
    fallback_reshape = ttnn.get_fallback_function(ttnn.reshape) 
       
    batch_size, sequence_size, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads

    query = hidden_states @ query_weight
    query = query + query_bias
    query = ttnn.to_layout(query, layout=ttnn.ROW_MAJOR_LAYOUT)
    query = fallback_reshape(query, (batch_size, sequence_size, num_heads, head_size))
    query = ttnn.to_layout(query, layout=ttnn.TILE_LAYOUT)
    query = ttnn.permute(query, (0, 2, 1, 3))

    key = hidden_states @ key_weight
    key = key + key_bias
    key = ttnn.to_layout(key, layout=ttnn.ROW_MAJOR_LAYOUT)
    key = fallback_reshape(key, (batch_size, sequence_size, num_heads, head_size))
    key = ttnn.to_layout(key, layout=ttnn.TILE_LAYOUT)
    key = ttnn.permute(key, (0, 2, 3, 1))

    value = hidden_states @ value_weight
    value = value + value_bias
    value = ttnn.to_layout(value, layout=ttnn.ROW_MAJOR_LAYOUT)
    value = fallback_reshape(value, (batch_size, sequence_size, num_heads, head_size))
    value = ttnn.to_layout(value, layout=ttnn.TILE_LAYOUT)
    value = ttnn.permute(value, (0, 2, 1, 3))

    attention_scores = query @ key
    attention_scores = attention_scores * (1 / (head_size**0.5))
    attention_scores += attention_mask
    attention_probs = ttnn.softmax(attention_scores, dim=-1)

    context_layer = attention_probs @ value
    context_layer = ttnn.permute(context_layer, (0, 2, 1, 3))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.ROW_MAJOR_LAYOUT)
    context_layer = fallback_reshape(context_layer, (batch_size, sequence_size, hidden_size))
    context_layer = ttnn.to_layout(context_layer, layout=ttnn.TILE_LAYOUT)

    self_output = context_layer @ output_weight
    self_output = self_output + output_bias

    return self_output

Now that the model is written, let's create input tensors to run it and test it

## Configuration

In [6]:
batch_size = 8
sequence_size = 384
num_heads = 16
head_size = 64
hidden_size = num_heads * head_size

## Initialize activations and weights using torch

In [7]:
torch_hidden_states = torch.randn((batch_size, sequence_size, hidden_size), dtype=torch.bfloat16)
torch_attention_mask = torch.randn((batch_size, 1, 1, sequence_size), dtype=torch.bfloat16)
torch_query_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_query_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_key_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_key_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_value_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_value_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)
torch_output_weight = torch.randn((hidden_size, hidden_size), dtype=torch.bfloat16)
torch_output_bias = torch.randn((hidden_size,), dtype=torch.bfloat16)

## Convert activations and weights to ttnn

In [8]:
hidden_states = ttnn.from_torch(torch_hidden_states, layout=ttnn.TILE_LAYOUT, device=device)
attention_mask = ttnn.from_torch(torch_attention_mask, layout=ttnn.TILE_LAYOUT, device=device)
query_weight = ttnn.from_torch(torch_query_weight, layout=ttnn.TILE_LAYOUT, device=device)
query_bias = ttnn.from_torch(torch_query_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
key_weight = ttnn.from_torch(torch_key_weight, layout=ttnn.TILE_LAYOUT, device=device)
key_bias = ttnn.from_torch(torch_key_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
value_weight = ttnn.from_torch(torch_value_weight, layout=ttnn.TILE_LAYOUT, device=device)
value_bias = ttnn.from_torch(torch_value_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.from_torch(torch_output_weight, layout=ttnn.TILE_LAYOUT, device=device)
output_bias = ttnn.from_torch(torch_output_bias, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG)

cmd_wait
 DISPATCH WAIT 1a3b0 count 0
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_

cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 204

## Run the first iteration of Multi-Head Attention

In [9]:
start = time.time()
multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

Function 'ttnn.matmul' executed in 0.4110 seconds
cmd_wait
 DISPATCH WAIT 1a3b0 count 0
cmd_write_paged is_dram: 1
process_write_paged - pages: 14 page_size: 2048 dispatch_cb_page_size: 4096
write offset: 0 102240 0
cmd_write_packed
dispatch_write_packed: 72 80 397456 31 102240 
cmd_write_packed
dispatch_write_packed: 92 96 401440 1 102240 
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 0
cmd_write_packed
dispatch_write_packed: 4 16 409632 2 107248 
cmd_write_packed
dispatch_write_packed: 4 16 413728 2 107232 
cmd_write_packed
dispatch_write_packed: 400 400 417824 2 106592 
cmd_write_packed_large
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 475184 4 32 
cmd_wait
 DISPATCH WAIT 1a3b0 count 32
cmd_write_paged is_dram: 1
process_write_paged - pages: 8 page_size: 2048 dispatch_cb_page_size: 4096
write offset: 0 102336 0
cmd_write_packed
dispatch_write_packed: 108 112 508352 108 102336 
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0

dispatch_write_packed: 272 272 122912 1 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 139296 1 32 
write offset: 0 105072 0
cmd_write_packed
dispatch_write_packed: 72 80 147600 31 105072 
cmd_write_packed
dispatch_write_packed: 92 96 151584 1 105072 
cmd_wait
 DISPATCH WAIT 1a3b0 count 3061
cmd_write_packed
dispatch_write_packed: 4 16 159776 2 107248 
cmd_write_packed
dispatch_write_packed: 4 16 163872 2 107232 
cmd_write_packed
dispatch_write_packed: 400 400 167968 2 106592 
cmd_write_packed_large
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 225328 4 32 
write offset: 0 105168 0
cmd_write_packed
dispatch_write_packed: 108 112 233920 108 105168 
cmd_wait
 DISPATCH WAIT 1a3b0 count 3093
cmd_write_packed
dispatch_write_packed: 272 272 253984 1 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 270368 1 32 


In [10]:
print(f"Multi-head attention ran in {duration} seconds for the first iteration")

Multi-head attention ran in 31.655242681503296 seconds for the first iteration


## Run a subsequent iteration of Multi-Head Attention

In [13]:
start = time.time()
output = multi_head_attention(
    hidden_states,
    attention_mask,
    query_weight,
    query_bias,
    key_weight,
    key_bias,
    value_weight,
    value_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

Function 'ttnn.matmul' executed in 0.0002 seconds
write offset: 0 104016 0
cmd_write_packed
dispatch_write_packed: 72 80 258192 31 104016 
cmd_write_packed
dispatch_write_packed: 92 96 262176 1 104016 
cmd_wait
 DISPATCH WAIT 1a3b0 count 6402
cmd_write_packed
dispatch_write_packed: 4 16 270368 2 107248 
cmd_write_packed
dispatch_write_packed: 4 16 274464 2 107232 
cmd_write_packed
dispatch_write_packed: 400 400 278560 2 106592 
cmd_write_packed_large
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 335920 4 32 
write offset: 0 104112 0
cmd_write_packed
dispatch_write_packed: 108 112 344512 108 104112 
cmd_wait
 DISPATCH WAIT 1a3b0 count 6434
cmd_write_packed
dispatch_write_packed: 272 272 364576 1 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 380960 1 32 
write offset: 0 104224 0
cmd_write_packed
dispatch_write_packed: 48 48 389520 96 104224 
cmd_wait
 DISPATCH WAIT 1a3b0 count 6542

In [None]:
print(f"Multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096


cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
Multi-head attention ran in 35.035075664520264 seconds for the subsequent iteration because of the program cache
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096


cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 204

## Write optimized version of Multi-Head Attention

Optimized version of the multi-head attention can be written by:

- Tilizing all of the tensors ahead of time
- Using more performant matmuls that fuse bias and specify the number of cores they execute on
- Putting every tensor into L1
- Using bfloat8_b data_type
- Using custom `ttnn.transformer` operations instead of `ttnn.permute` and `ttnn.reshape`

`ttnn.deallocate` calls are needed because otherwise, the cores on the device will run out of the L1 memory

In [15]:
def optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    fused_qkv_weight,
    fused_qkv_bias,
    self_output_weight,
    self_output_bias,
    *,
    num_heads,
    num_cores_x=12,
):
    batch_size, _, hidden_size = hidden_states.shape
    head_size = hidden_size // num_heads
    
    hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)

    fused_qkv_output = ttnn.linear(
        hidden_states,
        fused_qkv_weight,
        bias=fused_qkv_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )

    (
        query,
        key,
        value,
    ) = ttnn.transformer.split_query_key_value_and_split_heads(
        fused_qkv_output,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        num_heads=num_heads,
    )
    ttnn.deallocate(fused_qkv_output)

    attention_scores = ttnn.matmul(
        query,
        key,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(query)
    ttnn.deallocate(key)

    attention_probs = ttnn.transformer.attention_softmax_(attention_scores, attention_mask=attention_mask, head_size=head_size)

    context_layer = ttnn.matmul(
        attention_probs,
        value,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(attention_probs)

    context_layer_after_concatenate_heads = ttnn.transformer.concatenate_heads(
        context_layer,
        memory_config=ttnn.L1_MEMORY_CONFIG,
    )
    ttnn.deallocate(context_layer)

    self_output = ttnn.linear(
        context_layer_after_concatenate_heads,
        self_output_weight,
        bias=self_output_bias,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat16,
        core_grid=ttnn.CoreGrid(y=batch_size, x=num_cores_x),
    )
    ttnn.deallocate(context_layer_after_concatenate_heads)

    return self_output

## Pre-process the parameters of the optimized model

1. Fuse QKV weights and biases
2. Reshape and tilize for the optimized operations using preprocess_linear_weight and preprocess_linear_bias
3. Move to device

In [16]:
from ttnn.model_preprocessing import (
    preprocess_linear_bias,
    preprocess_linear_weight,
)

torch_qkv_weight = torch.cat([torch_query_weight, torch_key_weight, torch_value_weight], dim=-1)
torch_qkv_bias = torch.cat([torch_query_bias, torch_key_bias, torch_value_bias], dim=-1)

qkv_weight = preprocess_linear_weight(torch_qkv_weight.T, dtype=ttnn.bfloat16)
qkv_bias = preprocess_linear_bias(torch_qkv_bias, dtype=ttnn.bfloat16)
output_weight = preprocess_linear_weight(torch_output_weight.T, dtype=ttnn.bfloat16)
output_bias = preprocess_linear_bias(torch_output_bias, dtype=ttnn.bfloat16)

qkv_weight = ttnn.to_device(qkv_weight, device)
qkv_bias = ttnn.to_device(qkv_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)
output_weight = ttnn.to_device(output_weight, device)
output_bias = ttnn.to_device(output_bias, device, memory_config=ttnn.L1_MEMORY_CONFIG)

cmd_wait
 DISPATCH WAIT 1a3b0 count 9603
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096


cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 204

## Run the first iteration of the optimized Multi-Head Attention

In [17]:
start = time.time()
hidden_states = ttnn.to_layout(hidden_states, ttnn.TILE_LAYOUT)
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

cmd_wait
 DISPATCH WAIT 1a3b0 count 9603
cmd_write_paged is_dram: 1
process_write_paged - pages: 21 page_size: 2048 dispatch_cb_page_size: 4096
write offset: 0 102720 0
cmd_write_packed
dispatch_write_packed: 52 64 188608 42 102720 
cmd_write_packed
dispatch_write_packed: 72 80 192560 7 102720 
cmd_write_packed
dispatch_write_packed: 52 64 196768 35 102720 
cmd_write_packed
dispatch_write_packed: 80 80 200768 11 102720 
cmd_write_packed
dispatch_write_packed: 100 112 204832 1 102720 
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 9603
cmd_write_packed
dispatch_write_packed: 4 16 213024 1 107280 
cmd_write_packed
dispatch_write_packed: 4 16 217120 1 107264 
cmd_write_packed
dispatch_write_packed: 4 16 221216 1 107248 
cmd_write_packed
dispatch_write_packed: 4 16 225312 1 107232 
cmd_write_packed
dispatch_write_packed: 400 400 229408 1 106592 
cmd_write_packed_large
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 303168 5 32 
cmd_wai

dispatch_write_packed: 52 64 262304 35 103312 
cmd_write_packed
dispatch_write_packed: 72 80 266288 7 103312 
cmd_write_packed
dispatch_write_packed: 52 64 270496 35 103312 
cmd_write_packed
dispatch_write_packed: 80 80 274496 10 103312 
cmd_write_packed
dispatch_write_packed: 100 112 278560 1 103312 
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 10191
cmd_write_packed
dispatch_write_packed: 4 16 286752 1 107280 
cmd_write_packed
dispatch_write_packed: 4 16 290848 1 107264 
cmd_write_packed
dispatch_write_packed: 4 16 294944 1 107248 
cmd_write_packed
dispatch_write_packed: 4 16 299040 1 107232 
cmd_write_packed
dispatch_write_packed: 400 400 303136 1 106592 
cmd_write_packed_large
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 380992 5 32 


In [18]:
print(f"Optimized multi-head attention ran in {duration} seconds for the first iteration")

Optimized multi-head attention ran in 3.070328712463379 seconds for the first iteration


## Run a subsequent iteration of the optimized Multi-Head Attention

In [None]:
start = time.time()
optimized_output = optimized_multi_head_attention(
    hidden_states,
    attention_mask,
    qkv_weight,
    qkv_bias,
    output_weight,
    output_bias,
    num_heads=num_heads,
)
end = time.time()
duration = end - start

write offset: 0 103424 0
cmd_write_packed
dispatch_write_packed: 52 64 389312 42 103424 
cmd_write_packed
dispatch_write_packed: 72 80 393264 7 103424 
cmd_write_packed
dispatch_write_packed: 52 64 397472 35 103424 
cmd_write_packed
dispatch_write_packed: 80 80 401472 11 103424 
cmd_write_packed
dispatch_write_packed: 100 112 

405536 1 103424 
cmd_wait
 DISPATCH WAIT 1a3b0 count 10279
cmd_write_packed
dispatch_write_packed: 4 16 413728 1 107280 
cmd_write_packed
dispatch_write_packed: 4 16 417824 1 107264 
cmd_write_packed
dispatch_write_packed: 4 16 421920 1 107248 
cmd_write_packed
dispatch_write_packed: 4 16 426016 1 107232 
cmd_write_packed
dispatch_write_packed: 400 400 430112 1 106592 
cmd_write_packed_large
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 503872 5 32 
write offset: 0 103536 0
cmd_write_packed
dispatch_write_packed: 52 64 512400 96 103536 
cmd_wait
 DISPATCH WAIT 1a3b0 count 10375
cmd_write_packed
dispatch_write_packed: 272 272 524320 2 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 548896 2 32 
write offset: 0 103600 0
cmd_write_packed
dispatch_write_packed: 144 144 557328 64 103600 
cmd_write_packed
dispatch_write_packed: 144 144 569488 32 103600 
cmd_wait
 DISPATCH WAIT 1a3b0 coun

In [20]:
print(f"Optimized multi-head attention ran in {duration} seconds for the subsequent iteration because of the program cache")

Optimized multi-head attention ran in 0.001035928726196289 seconds for the subsequent iteration because of the program cache


Note that the optimized multi-head attention is 2 orders of magnitude faster than the initial version

## Check that the output of the optimized version matches the output of the original implementation

In [21]:
torch_output = ttnn.to_torch(output)
torch_optimized_output = ttnn.to_torch(optimized_output)

assert torch.allclose(torch_output, torch_optimized_output)

cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 10955
cmd_write_linear_h_host
process_write_host_h: 6291472
cmd_wait
 DISPATCH WAIT 1a3b0 count 10955
cmd_write_packed
dispatch_write_packed: 16 16 380960 1 107408 
cmd_write_linear_h_host
process_write_host_h: 32
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 10955
cmd_write_linear_h_host
process_write_host_h: 6291472
cmd_wait
 DISPATCH WAIT 1a3b0 count 10955
cmd_write_packed
dispatch_write_packed: 16 16 401440 1 107408 
cmd_write_linear_h_host
process_write_host_h: 32


## Close the device

In [22]:
ttnn.close_device(device)

cmd_wait
 DISPATCH WAIT 1a3b0 count 10955
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
cmd_write_packed
dispatch_write_packed: 16 16 413728 1 107408 
cmd_write_linear_h_host
process_write_host_h: 32
dispatch terminate
prefetcher_11: out
dispatch_11: out
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | DPRINT Server dettached device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Disabling and clearing program cache on device 0
