# Imports & open Device

In [1]:
import time
import torch
import ttnn

torch.manual_seed(0)
device_id = 0
device = ttnn.open(device_id)

[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237mINFO    [0m | Initializing device 0
[38;2;000;128;000m                 Device[0m | [1m[38;2;100;149;237mINFO    [0m | Opening user mode device driver
[32m2024-01-03 20:44:25.631[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Detected 1 PCI device : {0}
[32m2024-01-03 20:44:25.641[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Using 1 Hugepages/NumHostMemChannels for TTDevice (logical_device_id: 0 pci_interface_id: 0 device_id: 0xfaca revision: 0)
[0;33m---- ttSiliconDevice::init_hugepage: bind_area_to_memory_nodeset() failed (physical_device_id: 0 ch: 0). Hugepage allocation is not on NumaNode matching TT Device. Side-Effect is decreased Device->Host perf (Issue #893).
[0m[32m2024-01-03 20:44:25.734[0m | [1m[38;2;100;149;237mINFO    [0m | [36mSiliconDriver  [0m - Disable PCIE DMA
[38;2;000;128;000m                  Metal[0m | [1m[38;2;100;149;237m

# Enable program cache

In [2]:
ttnn.enable_program_cache()

[38;2;000;128;000m                     Op[0m | [1m[38;2;100;149;237mINFO    [0m | Program Cache: enabled.


# Matrix Multiplications 

# Constants

In [3]:
b = 8
n = 12
s = 384
h = 1024

In [4]:
A = torch.randn((b, s, h), dtype=torch.bfloat16)
A = ttnn.from_torch(A)
# tilize before matmul
A = ttnn.to_layout(A, ttnn.TILE_LAYOUT)
# put on L1 moemory 
A = ttnn.to_device(A, device, memory_config=ttnn.L1_MEMORY_CONFIG)

In [5]:
B = torch.randn((h, h), dtype=torch.bfloat16)
B = ttnn.from_torch(B)
B = ttnn.to_layout(B, ttnn.TILE_LAYOUT)
B = ttnn.to_device(B, device, memory_config=ttnn.L1_MEMORY_CONFIG)

In [6]:
C = torch.randn((h, s), dtype=torch.bfloat16)
C = ttnn.from_torch(C)
C = ttnn.to_layout(C, ttnn.TILE_LAYOUT)
C = ttnn.to_device(C, device, memory_config=ttnn.L1_MEMORY_CONFIG)

In [7]:
D = torch.randn((s, s), dtype=torch.bfloat16)
D = ttnn.from_torch(D)
D = ttnn.to_layout(D, ttnn.TILE_LAYOUT)
D = ttnn.to_device(D, device, memory_config=ttnn.L1_MEMORY_CONFIG)

# Matmul1

In [8]:
M1 = ttnn.matmul(
        A,
        B,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b, # use float8 datatype
        core_grid=(b, n),
    )

# Matmul2

In [9]:
M2 = ttnn.matmul(
        M1,
        C,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b, # use float8 data type
        core_grid=(b, n), # specify grid cores to run matmul on
    )

# Matmul3

In [10]:
M3 = ttnn.matmul(
        M2,
        D,
        memory_config=ttnn.L1_MEMORY_CONFIG,
        dtype=ttnn.bfloat8_b, # use float8 data type
        core_grid=(b, n), # specify grid cores to run matmul on
    )

### # Visualize results

In [None]:
# print parts of M1 and M2; for slicing we need to change the layout to row_major
M1 = ttnn.to_layout(M1, ttnn.ROW_MAJOR_LAYOUT)
print("M1: ", M1[1, :4,:4])

M2 = ttnn.to_layout(M2, ttnn.ROW_MAJOR_LAYOUT)
print("M2: ", M2[1, :4,:4])

M3 = ttnn.to_layout(M3, ttnn.ROW_MAJOR_LAYOUT)
print("M3: ", M3[1, :4,:4])

# Matmul optim

# Defragment L1 memory Space

In [None]:
ttnn.close(device)