# Matrix Multiplication

If you're using a Wormhole card (N150/N300), you will need to set the full Tensix available to be able to continue with this tutorial

In [1]:
import os
os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"
os.environ["GS_ARCH_YAML"] = "grayskull_120_arch.yaml"

In [2]:
import torch
import ttnn

torch.manual_seed(0)

device_id = 0
device = ttnn.open_device(device_id=device_id)

2024-08-21 04:59:11.248 | DEBUG    | ttnn:<module>:82 - Initial ttnn.CONFIG:
Config{cache_path=/home/thienluu/.cache/ttnn,model_cache_path=/home/thienluu/.cache/ttnn/models,tmp_dir=/tmp/ttnn,enable_model_cache=false,enable_fast_runtime_mode=true,throw_exception_on_fallback=false,enable_logging=false,enable_graph_report=false,enable_detailed_buffer_report=false,enable_detailed_tensor_report=false,enable_comparison_mode=false,comparison_mode_pcc=0.9999,root_report_path=generated/ttnn/reports,report_name=std::nullopt,std::nullopt}


[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver
[32m2024-08-21 04:59:11.379[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 04:59:11.438[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 04:59:11.453[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 04:59:11.466[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 04:59:11.478[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1, 2, 3, 4, 5, 6, 7]
[32m2024-08-21 04:59:11.490[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 8 PCI devices : [0, 1,

## Enable program cache

Enabling the program cache will speed up the execution of operations that run repeatedly

In [3]:
ttnn.enable_program_cache(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Enabling program cache on device 0


# Configuration

In [4]:
m = 1024
k = 1024
n = 1024

## Initialize tensors a and b with random values using torch

In [5]:
torch_a = torch.randn((m, k), dtype=torch.bfloat16)
torch_b = torch.randn((k, n), dtype=torch.bfloat16)

In [6]:
a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device)
b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device)

cmd_wait
 DISPATCH WAIT 1a3b0 count 0
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096


## Matrix multiply tensor a and b
The operation will run longer the first time because the kernels need to get compiled

In [7]:
output = a @ b

cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096


cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 1
process_write_paged - pages: 63 page_size: 204

Re-running the operation shows significant speed up by utilizing program caching

In [8]:
output = a @ b

241696 1 102240 
cmd_write_packed
dispatch_write_packed: 100 112 245792 1 102240 
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 0
cmd_write_packed
dispatch_write_packed: 4 16 253984 1 107280 
cmd_write_packed
dispatch_write_packed: 4 16 258080 1 107264 
cmd_write_packed
dispatch_write_packed: 4 16 262176 1 107248 


Function 'ttnn.matmul' executed in 0.0002 seconds
cmd_write_packed
dispatch_write_packed: 4 16 266272 1 107232 
cmd_write_packed
dispatch_write_packed: 400 400 270368 1 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 319536 4 32 
write offset: 0 102352 0
cmd_write_packed
dispatch_write_packed: 52 64 327712 1 102352 
cmd_write_packed


## Inspect the layout of matrix multiplication output

In [9]:
print(output.layout)

dispatch_write_packed: 72 80 331808 1 102352 
cmd_write_packed
dispatch_write_packed: 80 80 335904 1 102352 
cmd_write_packed
dispatch_write_packed: 100 112 340000 1 102352 
cmd_wait
 DISPATCH WAIT 1a3b0 count 4


Layout.TILE


As can be seen, matrix multiplication produces outputs in a tile layout. That is because it's much more efficient to use this layout for computing matrix multiplications on TensTorrent accelerators compared to a row-major layout.

And this is aslo why the logs show 2 tilize operations, as the inputs get automatically convered to the tile layout if they are in a row-major layout.

Learn more about tile layout here: TODO

## Inspect the result of the matrix multiplication

To inspect the results we will first convert to row-major layout.

In [10]:
output = ttnn.to_layout(output, ttnn.ROW_MAJOR_LAYOUT)

print("Printing ttnn tensor")
print(f"shape: {output.shape}")
print(f"chunk of a tensor:\n{output[:1, :32]}")

Printing ttnn tensorcmd_write_packed
dispatch_write_packed: 4 16 348192 1 107280 
cmd_write_packed
dispatch_write_packed: 4 16 352288 1 107264 
cmd_write_packed
dispatch_write_packed: 4 16 356384 1 107248 
cmd_write_packed
dispatch_write_packed: 4 16 360480 1 107232 
cmd_write_packed
dispatch_write_packed: 400 400 364576 1 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 413744 4 32 

shape: ttnn.Shape([1024, 1024])
cmd_wait
 DISPATCH WAIT 1a3b0 count 8
cmd_write_paged is_dram: 1
process_write_paged - pages: 7 page_size: 2048 dispatch_cb_page_size: 4096
write offset: 0 102464 0
cmd_write_packed
dispatch_write_packed: 48 48 442512 32 102464 
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 8
cmd_write_packed
dispatch_write_packed: 272 272 450592 2 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 475168 2 32 
chunk of a tensor:
ttnn.Tensor([[33.50000,  9.00000,  ..., -38.750

## Matrix multiply tensor a and b by using more performant config
By default, matrix multiplication might not be as effecient as it could be. To speed it up further, the user can specify how many cores they want matrix multiplication to use. This can speed up the operation significantly.

In [None]:
a = ttnn.from_torch(torch_a)
b = ttnn.from_torch(torch_b)

a = ttnn.to_device(a, device, memory_config=ttnn.L1_MEMORY_CONFIG)
b = ttnn.to_device(b, device, memory_config=ttnn.L1_MEMORY_CONFIG)

a = ttnn.to_layout(a, ttnn.TILE_LAYOUT)
b = ttnn.to_layout(b, ttnn.TILE_LAYOUT)

cmd_wait
 DISPATCH WAIT 1a3b0 count 148
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
process_write_paged - pages: 63 page_size: 2048 dispatch_cb_page_size: 4096
cmd_write_paged is_dram: 0
proces

write offset: 0 102640 0
cmd_write_packed
dispatch_write_packed: 48 48 143504 32 102640 
cmd_wait
 DISPATCH BARRIER
 DISPATCH WAIT 1a3b0 count 148
cmd_write_packed
dispatch_write_packed: 272 272 151584 2 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 176160 2 32 
write offset: 0 102688 0
cmd_write_packed
dispatch_write_packed: 48 48 184464 32 102688 
cmd_wait
 DISPATCH WAIT 1a3b0 count 180
cmd_write_packed
dispatch_write_packed: 272 272 192544 2 106592 
cmd_write_packed_large
cmd_wait
 DISPATCH BARRIER
cmd_write_packed
dispatch_write_packed: 36 48 217120 2 32 


Run once to compile the kernels

In [13]:
print(f"Input shape: {a.shape}, {b.shape}")
print(f"Input layout: {a.layout}, {b.layout}")
output = ttnn.matmul(a, b, memory_config=ttnn.L1_MEMORY_CONFIG, core_grid=ttnn.CoreGrid(y=8, x=8))
print("Output shape: ", output.shape)
print("Output layout: ", output.layout)

Input shape: ttnn.Shape([1024, 1024]), ttnn.Shape([1024, 1024])
Input layout: Layout.TILE, Layout.TILE
Output shape:  ttnn.Shape([1024, 1024])
Output layout:  Layout.TILE


Enjoy a massive speed up on the subsequent runs

In [14]:
output = ttnn.matmul(a, b, memory_config=ttnn.L1_MEMORY_CONFIG, core_grid=ttnn.CoreGrid(y=8, x=8))

## Close the device

In [14]:
ttnn.close_device(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Disabling and clearing program cache on device 0
