In [1]:
# Set stuff up
import os

os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml"
#os.environ["TTNN_CONFIG_OVERRIDES"] = "{\"enable_fast_runtime_mode\": true}"

import ttnn
import numpy
import time 
import torch
#from ttnn.tracer import trace, visualize


torch.manual_seed(0)

device_id = 0
device = ttnn.open_device(device_id=device_id)

ttnn.enable_program_cache(device)



[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver

[32m2024-08-28 15:07:47.493[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device : [4]
[32m2024-08-28 15:07:47.557[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Software version 6.0.0, Ethernet FW version 6.9.0 (Device 0)
[32m2024-08-28 15:07:47.557[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Software version 6.0.0, Ethernet FW version 6.9.0 (Device 1)
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0. Program cache is NOT enabled
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | AI CLK for device 0 is:   1000 MHz
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | MMIO Device 0 : Tunnel 0 : Device 0
[38;2;000;128;000m                  Metal[0m

In [2]:
# Initialize Tensors

m = 4096
k = 4096
n = 4096

torch.manual_seed(0)

torch_a = torch.rand((m, k))
torch_b = torch.rand((k, n))

a = ttnn.from_torch(torch_a, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)
b = ttnn.from_torch(torch_b, layout=ttnn.TILE_LAYOUT, device=device, dtype=ttnn.bfloat16)


In [3]:
# Run multiplication 500 times, compare means of runs 0:99 with 400:499

output = ttnn.matmul(a, b, core_grid=ttnn.CoreGrid(y=8, x=8)) # run once to cache mult in on chip program mem

runtimes = [] #save the matmul time for each run
num_iters = 1 
for run in range(0, num_iters): 
    start = time.time()
    output = ttnn.matmul(a, b, core_grid=ttnn.CoreGrid(y=8, x=8))
    end = time.time()
    runtimes.append(end-start)
    ttnn.to_torch(output)

runtimes = [x/1e-6 for x in runtimes] # normalize values to microseconds 
print(f"All runtime array, microseconds: {runtimes}")
print(f"Mean of first 100 runs: {numpy.mean(runtimes[:100])} microseconds")
print(f"Mean of last 100 runs: {numpy.mean(runtimes[100:])} microseconds")

n_chips = 2
mem_bw = 288e9 * n_chips
flops_fp8 = 262e12 * n_chips
size_in_bytes = 2 * (m*k + k*n)
expected_runtime_membw = size_in_bytes/mem_bw/1e-6
expected_runtime_flops = 2*m*k*n/(flops_fp8/4)/1e-6 # assumes Fp16 flops are ~1/4th fp8 flops by implementation
print(f"Expected runtime if membw limited (approximately): {expected_runtime_membw} microseconds")
print(f"Expected runtime if compute limited (approximately): {expected_runtime_flops} microseconds")
print(f"Expected runtime (approximately): {max(expected_runtime_membw, expected_runtime_flops)} microseconds")


All runtime array, microseconds: [157.35626220703125, 212.19253540039062, 176.90658569335938, 179.05235290527344, 189.30435180664062, 174.99923706054688, 182.86705017089844, 190.73486328125, 176.42974853515625, 182.15179443359375, 182.62863159179688, 184.05914306640625, 178.33709716796875, 189.06593322753906, 175.95291137695312, 178.5755157470703, 189.30435180664062, 172.3766326904297, 171.661376953125, 176.90658569335938, 170.23086547851562, 178.81393432617188, 172.13821411132812, 174.2839813232422, 185.25123596191406, 170.23086547851562, 164.27040100097656, 186.44332885742188, 170.70770263671875, 181.1981201171875, 175.47607421875, 167.36984252929688, 169.27719116210938, 177.62184143066406, 171.18453979492188, 170.23086547851562, 181.1981201171875, 169.7540283203125, 173.5687255859375, 176.42974853515625, 170.70770263671875, 166.89300537109375, 181.43653869628906, 172.13821411132812, 168.08509826660156, 179.05235290527344, 171.661376953125, 162.36305236816406, 172.3766326904297, 160.

In [11]:
ttnn.close_device(device)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Closing device 0
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Disabling and clearing program cache on device 0
