# Using CuTe Layout Algebra With Python DSL

Referencing the [01_layout.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md) and [02_layout_algebra.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) documentation from CuTe C++, we summarize:

In CuTe, a `Layout`:
- is defined by a pair of `Shape` and `Stride`,
- maps coordinates space(s) to an index space,
- supports both static (compile-time) and dynamic (runtime) values.

CuTe also provides a powerful set of operations—the *Layout Algebra*—for combining and manipulating layouts, including:
- Layout composition: Functional composition of layouts,
- Layout "divide": Splitting a layout into two component layouts,
- Layout "product": Reproducing a layout according to another layout.

In this notebook, we will demonstrate:
1. How to use CuTe’s key layout algebra operations with the Python DSL.
2. How static and dynamic layouts behave when printed or manipulated within the Python DSL.

We use examples from [02_layout_algebra.md](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md) which we recommend to the reader for additional details.

In [1]:
import cutlass
import cutlass.cute as cute

## Layout Algebra Operations

These operations form the foundation of CuTe's layout manipulation capabilities, enabling:
- Efficient data tiling and partitioning,
- Separation of thread and data layouts with a canonical type to represent both,
- Native description and manipulation of hierarchical tensors of threads and data crucial for tensor core programs,
- Mixed static/dynamic layout transformations,
- Seamless integration of layout algebra with tensor operations,
- Expression of complex MMA and copies as canonical loops.

### 1. Coalesce

The `coalesce` operation simplifies a layout by flattening and combining modes when possible, without changing its size or behavior as a function on the integers.

It ensures the post-conditions:
- Preserve size: cute.size(layout) == cute.size(result),
- Flattened: depth(result) <= 1,
- Preserve functional: For all i, 0 <= i < cute.size(layout), layout(i) == result(i).

#### Examples

- Basic Coalesce Example :

In [26]:
@cute.jit
def coalesce_example():
    """
    Demonstrates coalesce operation flattening and combining modes
    """
    layout = cute.make_layout((2, (1, 6)), stride=(1, (cutlass.Int32(6), 2))) # Dynamic stride
    result = cute.coalesce(layout)

    print(">>> Original:", layout)
    cute.printf(">?? Original: {}", layout)
    print(">>> Coalesced:", result)
    cute.printf(">?? Coalesced: {}", result)

coalesce_example()

>>> Original: (2,(1,6)):(1,(?,2))
>>> Coalesced: 12:1
>?? Original: (2,(1,6)):(1,(6,2))
>?? Coalesced: 12:1


In [None]:
@cute.jit
def coalesce_post_conditions():
    """
    Demonstrates coalesce operation's 3 post-conditions:
    1. size(@a result) == size(@a layout)
    2. depth(@a result) <= 1
    3. for all i, 0 <= i < size(@a layout), @a result(i) == @a layout(i)
    """
    layout = cute.make_layout(
        ((2, (3, 4)), (3, 2), 1),
        stride=((4, (8, 24)), (2, 6), 12)
    )
    result = cute.coalesce(layout)

    print(">>> Original:", layout)
    print(">>> Coalesced:", result)

    print(">>> Checking post-conditions:")
    print(">>> 1. Checking size remains the same after the coalesce operation:")
    original_size = cute.size(layout)
    coalesced_size = cute.size(result)
    print(f"Original size: {original_size}, Coalesced size: {coalesced_size}")
    assert coalesced_size == original_size, \
            f"Size mismatch: original {original_size}, coalesced {coalesced_size}"
    
    print(">>> 2. Checking depth of coalesced layout <= 1:")
    depth = cute.depth(result)
    print(f"Depth of coalesced layout: {depth}")
    assert depth <= 1, f"Depth of coalesced layout should be <= 1, got {depth}"

    print(">>> 3. Checking layout functionality remains the same after the coalesce operation:")
    for i in range(original_size):
        original_value = layout(i)
        coalesced_value = result(i)
        print(f"Index {i}: original {original_value}, coalesced {coalesced_value}")
        assert coalesced_value == original_value, \
            f"Value mismatch at index {i}: original {original_value}, coalesced {coalesced_value}"

coalesce_post_conditions()

>>> Original: ((2,(3,4)),(3,2),1):((4,(8,24)),(2,6),12)
>>> Coalesced: (24,6):(4,2)
>>> Checking post-conditions:
>>> 1. Checking size remains the same after the coalesce operation:
Original size: 144, Coalesced size: 144
>>> 2. Checking depth of coalesced layout <= 1:
Depth of coalesced layout: 1
>>> 3. Checking layout functionality remains the same after the coalesce operation:
Index 0: original 0, coalesced 0
Index 1: original 4, coalesced 4
Index 2: original 8, coalesced 8
Index 3: original 12, coalesced 12
Index 4: original 16, coalesced 16
Index 5: original 20, coalesced 20
Index 6: original 24, coalesced 24
Index 7: original 28, coalesced 28
Index 8: original 32, coalesced 32
Index 9: original 36, coalesced 36
Index 10: original 40, coalesced 40
Index 11: original 44, coalesced 44
Index 12: original 48, coalesced 48
Index 13: original 52, coalesced 52
Index 14: original 56, coalesced 56
Index 15: original 60, coalesced 60
Index 16: original 64, coalesced 64
Index 17: original 68

- By-mode Coalesce Example :

In [6]:
@cute.jit
def bymode_coalesce_example():
    """
    Demonstrates by-mode coalescing
    """
    layout = cute.make_layout((2, (1, 6)), stride=(1, (6, 2)))

    # Coalesce with mode-wise profile (1,1) = coalesce both modes
    result = cute.coalesce(layout, target_profile=(1, 1))
    
    # Print results
    print(">>> Original: ", layout)
    print(">>> Coalesced Result: ", result)

bymode_coalesce_example()

>>> Original:  (2,(1,6)):(1,(6,2))
>>> Coalesced Result:  (2,6):(1,2)


### 2. Composition

`Composition` of Layout `A` with Layout `B` creates a new layout `R = A ◦ B` where:
- The shape of `B` is compatible with the shape of `R` so that all coordinates of `B` can also be used as coordinates of `R`,
- `R(c) = A(B(c))` for all coordinates `c` in `B`'s domain.

Layout composition is very useful for reshaping and reordering layouts.

#### Examples

- Basic Composition Example :

In [7]:
@cute.jit
def composition_example():
    """
    Demonstrates basic layout composition R = A ◦ B
    """
    A = cute.make_layout((6, 2), stride=(cutlass.Int32(8), 2)) # Dynamic stride
    B = cute.make_layout((4, 3), stride=(3, 1))
    R = cute.composition(A, B)

    # Print static and dynamic information
    print(">>> Layout A:", A)
    cute.printf(">?? Layout A: {}", A)
    print(">>> Layout B:", B) 
    cute.printf(">?? Layout B: {}", B)
    print(">>> Composition R = A ◦ B:", R)
    cute.printf(">?? Composition R: {}", R)

composition_example()

>>> Layout A: (6,2):(?,2)
>>> Layout B: (4,3):(3,1)
>>> Composition R = A ◦ B: ((2,2),3):((?{div=3},2),?)
>?? Layout A: (6,2):(8,2)
>?? Layout B: (4,3):(3,1)
>?? Composition R: ((2,2),3):((24,2),8)


- Comparing Composition with static and dynamic layouts :

In this case, the results may look different but are mathematically the same. The 1s in the shape don't affect the layout as a mathematical function on the integers. In the dynamic case, CuTe can not coalesce the dynamic size-1 modes to "simplify" the layout because it is not valid to do so for all possible dynamic values that parameter could realize at runtime.

In [30]:
@cute.jit
def composition_static_vs_dynamic_layout():
    """
    Shows difference between static and dynamic composition results
    """
    # Static version - using compile-time values
    A_static = cute.make_layout(
        (10, 2), 
        stride=(16, 4)
    )
    B_static = cute.make_layout(
        (5, 4), 
        stride=(1, 5)
    )
    R_static = cute.composition(A_static, B_static)

    # Static print shows compile-time info
    print(">>> Static composition:")
    print(">>> A_static: ", A_static)
    print(">>> B_static: ", B_static)
    print(">>> R_static: ", R_static)

    # Dynamic version - using runtime Int32 values
    A_dynamic = cute.make_layout(
        (cutlass.Int32(10), cutlass.Int32(2)),
        stride=(cutlass.Int32(16), cutlass.Int32(4))
    )
    B_dynamic = cute.make_layout(
        (cutlass.Int32(5), cutlass.Int32(4)),
        stride=(cutlass.Int32(1), cutlass.Int32(5))
    )
    R_dynamic = cute.composition(A_dynamic, B_dynamic)
    
    # Dynamic printf shows runtime values
    cute.printf(">?? Dynamic composition:")
    cute.printf(">?? A_dynamic: {}", A_dynamic)
    cute.printf(">?? B_dynamic: {}", B_dynamic)
    cute.printf(">?? R_dynamic: {}", R_dynamic)

composition_static_vs_dynamic_layout()

>>> Static composition:
>>> A_static:  (10,2):(16,4)
>>> B_static:  (5,4):(1,5)
>>> R_static:  (5,(2,2)):(16,(80,4))
>?? Dynamic composition:
>?? A_dynamic: (10,2):(16,4)
>?? B_dynamic: (5,4):(1,5)
>?? R_dynamic: ((5,1),(2,2)):((16,4),(80,4))


-  By-mode Composition Example :

By-mode composition allows us to apply composition operations to individual modes of a layout. This is particularly useful when you want to manipulate specific modes layout independently (e.g. rows and columns).

In the context of CuTe, by-mode composition is achieved by using a `Tiler`, which can be a layout or a tuple of layouts. The leaves of the `Tiler` tuple specify how the corresponding mode of the target layout should be composed, allowing for sublayouts to be treated independently.

In [None]:
@cute.jit
def bymode_composition_example():
    """
    Demonstrates by-mode composition using a tiler
    """
    # Define the original layout A
    A = cute.make_layout(
        (cutlass.Int32(12), (cutlass.Int32(4), cutlass.Int32(8))), 
        stride=(cutlass.Int32(59), (cutlass.Int32(13), cutlass.Int32(1)))
    )

    # Define the tiler for by-mode composition
    tiler = (3, 8) # Apply 3:1 to mode-0 and 8:1 to mode-1

    # Apply by-mode composition
    result = cute.composition(A, tiler)

    # Print static and dynamic information
    print(">>> Layout A:", A)
    cute.printf(">?? Layout A: {}", A)
    print(">>> Tiler:", tiler)
    cute.printf(">?? Tiler: {}", tiler)
    print(">>> By-mode Composition Result:", result)
    cute.printf(">?? By-mode Composition Result: {}", result)

bymode_composition_example()

>>> Layout A: (?,(?,?)):(?,(?,?))
>>> Tiler: (3, 8)
>>> By-mode Composition Result: (3,(?,?)):(?,(?,?))
>?? Layout A: (12,(4,8)):(59,(13,1))
>?? Tiler: (3,8)
>?? By-mode Composition Result: (3,(4,2)):(59,(13,1))


### 3. Division (Splitting into Tiles)

The Division operation in CuTe is used to split a layout into tiles, which is particularly useful for partitioning data across threads or memory hierarchies.

#### Examples :

- Logical divide :

When applied to two Layouts, `logical_divide` splits a layout into two modes -- the first mode contains the elements pointed to by the tiler, and the second mode contains the remaining elements.

In [10]:
@cute.jit
def logical_divide_1d_example():
    """
    Demonstrates 1D logical divide
    """
    # Define the original layout
    layout = cute.make_layout((4, 2, 3), stride=(2, 1, 8))  # (4,2,3):(2,1,8)
    
    # Define the tiler
    tiler = cute.make_layout(4, stride=2)  # Apply to layout 4:2
    
    # Apply logical divide
    result = cute.logical_divide(layout, tiler=tiler)
    
    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Logical Divide Result:", result)
    cute.printf(">?? Logical Divide Result: {}", result)

logical_divide_1d_example()

>>> Layout: (4,2,3):(2,1,8)
>>> Tiler : 4:2
>>> Logical Divide Result: ((2,2),(2,3)):((4,1),(2,8))
>?? Logical Divide Result: ((2,2),(2,3)):((4,1),(2,8))


When applied to a Layout and a `Tiler` tuple, `logical_divide` applies itself to the leaves of the `Tiler`and the corresponding mode of the target Layout. This means that the sublayouts are split independently according to the layouts within the `Tiler`.

In [11]:
@cute.jit
def logical_divide_2d_example():
    """
    Demonstrates 2D logical divide :
    Layout Shape : (M, N, L, ...)
    Tiler Shape  : <TileM, TileN>
    Result Shape : ((TileM,RestM), (TileN,RestN), L, ...)
    """
    # Define the original layout
    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))
    
    # Define the tiler
    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3
             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)
    
    # Apply logical divide
    result = cute.logical_divide(layout, tiler=tiler)
    
    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Logical Divide Result:", result)
    cute.printf(">?? Logical Divide Result: {}", result)

logical_divide_2d_example()

>>> Layout: (9,(4,8)):(59,(13,1))
>>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc95a4ca7b0>, <cutlass.cute.core._Layout object at 0x7fc958160f50>)
>>> Logical Divide Result: ((3,3),((2,4),(2,2))):((177,59),((13,2),(26,1)))
>?? Logical Divide Result: ((3,3),((2,4),(2,2))):((177,59),((13,2),(26,1)))


Zipped, tiled, and flat divide are flavors of `logical_divide` that potentially rearrange modes into more convenient forms.

- Zipped Divide :

In [12]:
@cute.jit
def zipped_divide_example():
    """
    Demonstrates zipped divide :
    Layout Shape : (M, N, L, ...)
    Tiler Shape  : <TileM, TileN>
    Result Shape : ((TileM,TileN), (RestM,RestN,L,...))
    """
    # Define the original layout
    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))
    
    # Define the tiler
    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3
             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)
    
    # Apply zipped divide
    result = cute.zipped_divide(layout, tiler=tiler)
    
    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Zipped Divide Result:", result)
    cute.printf(">?? Zipped Divide Result: {}", result)

zipped_divide_example()

>>> Layout: (9,(4,8)):(59,(13,1))
>>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc95a4ca7b0>, <cutlass.cute.core._Layout object at 0x7fc9581611f0>)
>>> Zipped Divide Result: ((3,(2,4)),(3,(2,2))):((177,(13,2)),(59,(26,1)))
>?? Zipped Divide Result: ((3,(2,4)),(3,(2,2))):((177,(13,2)),(59,(26,1)))


- Tiled Divide :

In [13]:
@cute.jit
def tiled_divide_example():
    """
    Demonstrates tiled divide :
    Layout Shape : (M, N, L, ...)
    Tiler Shape  : <TileM, TileN>
    Result Shape : ((TileM,TileN), RestM, RestN, L, ...)
    """
    # Define the original layout
    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))
    
    # Define the tiler
    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3
             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)
    
    # Apply tiled divide
    result = cute.tiled_divide(layout, tiler=tiler)
    
    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Tiled Divide Result:", result)
    cute.printf(">?? Tiled Divide Result: {}", result)

tiled_divide_example()

>>> Layout: (9,(4,8)):(59,(13,1))
>>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc9581610d0>, <cutlass.cute.core._Layout object at 0x7fc958161070>)
>>> Tiled Divide Result: ((3,(2,4)),3,(2,2)):((177,(13,2)),59,(26,1))
>?? Tiled Divide Result: ((3,(2,4)),3,(2,2)):((177,(13,2)),59,(26,1))


- Flat Divide :

In [14]:
@cute.jit
def flat_divide_example():
    """
    Demonstrates flat divide :
    Layout Shape : (M, N, L, ...)
    Tiler Shape  : <TileM, TileN>
    Result Shape : (TileM, TileN, RestM, RestN, L, ...)
    """
    # Define the original layout
    layout = cute.make_layout((9, (4, 8)), stride=(59, (13, 1)))  # (9,(4,8)):(59,(13,1))
    
    # Define the tiler
    tiler = (cute.make_layout(3, stride=3),            # Apply to mode-0 layout 3:3
             cute.make_layout((2, 4), stride=(1, 8)))  # Apply to mode-1 layout (2,4):(1,8)
    
    # Apply flat divide
    result = cute.flat_divide(layout, tiler=tiler)
    
    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Flat Divide Result:", result)
    cute.printf(">?? Flat Divide Result: {}", result)

flat_divide_example()

>>> Layout: (9,(4,8)):(59,(13,1))
>>> Tiler : (<cutlass.cute.core._Layout object at 0x7fc958161430>, <cutlass.cute.core._Layout object at 0x7fc9581610d0>)
>>> Flat Divide Result: (3,(2,4),3,(2,2)):(177,(13,2),59,(26,1))
>?? Flat Divide Result: (3,(2,4),3,(2,2)):(177,(13,2),59,(26,1))


### 4. Product (Reproducing a Tile)

The Product operation in CuTe is used to reproduce one layout according to another layout. It creates a new layout where:
- The first mode is the original layout A.
- The second mode is a restrided layout B that points to the origin of a "unique replication" of A.

This is particularly useful for repeating layouts of threads across a tile of data for creating "repeat" patterns.

#### Examples

- Logical Product :

In [15]:
@cute.jit
def logical_product_1d_example():
    """
    Demonstrates 1D logical product
    """
    # Define the original layout
    layout = cute.make_layout((2, 2), stride=(4, 1))  # (2,2):(4,1)
    
    # Define the tiler
    tiler = cute.make_layout(6, stride=1)  # Apply to layout 6:1
    
    # Apply logical product
    result = cute.logical_product(layout, tiler=tiler)
    
    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Logical Product Result:", result)
    cute.printf(">?? Logical Product Result: {}", result)

logical_product_1d_example()

>>> Layout: (2,2):(4,1)
>>> Tiler : 6:1
>>> Logical Product Result: ((2,2),(2,3)):((4,1),(2,8))
>?? Logical Product Result: ((2,2),(2,3)):((4,1),(2,8))


- Blocked and Raked Product :
  
  - Blocked Product: Combines the modes of A and B in a block-like fashion, preserving the semantic meaning of the modes by reassociating them after the product.
  - Raked Product: Combines the modes of A and B in an interleaved or "raked" fashion, creating a cyclic distribution of the tiles.

In [16]:
@cute.jit
def blocked_raked_product_example():
    """
    Demonstrates blocked and raked products
    """
    # Define the original layout
    layout = cute.make_layout((2, 5), stride=(5, 1))
    
    # Define the tiler
    tiler = cute.make_layout((3, 4), stride=(1, 3))
    
    # Apply blocked product
    blocked_result = cute.blocked_product(layout, tiler=tiler)

    # Apply raked product
    raked_result = cute.raked_product(layout, tiler=tiler)
    
    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Blocked Product Result:", blocked_result)
    print(">>> Raked Product Result:", raked_result)
    cute.printf(">?? Blocked Product Result: {}", blocked_result)
    cute.printf(">?? Raked Product Result: {}", raked_result)

blocked_raked_product_example()

>>> Layout: (2,5):(5,1)
>>> Tiler : (3,4):(1,3)
>>> Blocked Product Result: ((2,3),(5,4)):((5,10),(1,30))
>>> Raked Product Result: ((3,2),(4,5)):((10,5),(30,1))
>?? Blocked Product Result: ((2,3),(5,4)):((5,10),(1,30))
>?? Raked Product Result: ((3,2),(4,5)):((10,5),(30,1))


- Zipped, tiled, and flat product :
  
  - Similar to divide operations, zipped, tiled, and flat product are flavors of `logical_product` that potentially rearrange modes into more convenient forms.

In [17]:
@cute.jit
def zipped_tiled_flat_product_example():
    """
    Demonstrates zipped, tiled, and flat products
    Layout Shape : (M, N, L, ...)
    Tiler Shape  : <TileM, TileN>

    zipped_product  : ((M,N), (TileM,TileN,L,...))
    tiled_product   : ((M,N), TileM, TileN, L, ...)
    flat_product    : (M, N, TileM, TileN, L, ...)
    """
    # Define the original layout
    layout = cute.make_layout((2, 5), stride=(5, 1))
    
    # Define the tiler
    tiler = cute.make_layout((3, 4), stride=(1, 3))

    # Apply zipped product
    zipped_result = cute.zipped_product(layout, tiler=tiler)
    
    # Apply tiled product
    tiled_result = cute.tiled_product(layout, tiler=tiler)
    
    # Apply flat product
    flat_result = cute.flat_product(layout, tiler=tiler)

    # Print results
    print(">>> Layout:", layout)
    print(">>> Tiler :", tiler)
    print(">>> Zipped Product Result:", zipped_result)
    print(">>> Tiled Product Result:", tiled_result)
    print(">>> Flat Product Result:", flat_result)
    cute.printf(">?? Zipped Product Result: {}", zipped_result)
    cute.printf(">?? Tiled Product Result: {}", tiled_result)
    cute.printf(">?? Flat Product Result: {}", flat_result)

zipped_tiled_flat_product_example()

>>> Layout: (2,5):(5,1)
>>> Tiler : (3,4):(1,3)
>>> Zipped Product Result: ((2,5),(3,4)):((5,1),(10,30))
>>> Tiled Product Result: ((2,5),3,4):((5,1),10,30)
>>> Flat Product Result: (2,5,3,4):(5,1,10,30)
>?? Zipped Product Result: ((2,5),(3,4)):((5,1),(10,30))
>?? Tiled Product Result: ((2,5),3,4):((5,1),10,30)
>?? Flat Product Result: (2,5,3,4):(5,1,10,30)
