In [24]:
import numpy as np 
import ast

In [40]:
def load_matrix(path):
    with open(path) as f:
        lines = f.readlines()

    # keep only lines that start with "[" or " "
    array_lines = [ln for ln in lines if "[" in ln or "]" in ln]

    txt = "".join(array_lines)
    return np.array(ast.literal_eval(txt)).astype(int)



def cute_swizzle_byte(x: int, b: int, m: int, s: int) -> int:
  """
  Exact CUTE Swizzle<BBits=b, MBase=m, SShift=s>

  x' = x XOR ((x & yyy_mask) >> s)   if s>0
  x' = x XOR ((x & yyy_mask) << -s)  if s<0

  where:
    yyy_mask = ((1<<b)-1) << (m + max(0,s))
  """

  assert abs(s) >= b

  bitmask = (1 << b) - 1
  yyy_mask = bitmask << (m + max(0, s))

  if s >= 0:
      return x ^ ((x & yyy_mask) >> s)
  else:
      return x ^ ((x & yyy_mask) << (-s))

      


def make_byte_tensor  (shape,
                      order,
                      n_bytes_per_elem,
                      n_elems_per_vector):
    """
    Tensor where each logical entry is a VECTOR of length n_elems_per_vector.

    Each entry occupies:
        entry_bytes = n_elems_per_vector * n_bytes_per_elem

    Returns
    -------
    byte_tensor : (M, N, entry_bytes)
        Each (i,j) entry is its own byte chunk.
    flat_bytes  : (M*N*entry_bytes,)
        Flat byte backing storage.
    """

    M, N = shape

    entry_bytes = n_elems_per_vector * n_bytes_per_elem
    total_entries = M * N
    total_bytes = total_entries * entry_bytes

    # Flat byte storage = unique byte IDs
    flat_bytes = np.arange(total_bytes, dtype=np.uint32)

    # Reshape into tensor-of-vectors
    byte_tensor = flat_bytes.reshape(M, N, entry_bytes, order=order)

    return byte_tensor, flat_bytes


def make_swizzled(shape, order,
                  n_bytes_per_elem,
                  n_elems_per_vec
                  ,b_bits,m_base,s_shift):
  assert len(shape) == 2
  byte_tensor, flat_bytes = make_byte_tensor(shape,order,n_bytes_per_elem,n_elems_per_vec)
  swizzled_bytes = cute_swizzle_byte(flat_bytes,b_bits,m_base,s_shift)
  reshaped_swizzled_bytes = swizzled_bytes.reshape(byte_tensor.shape, order=order)
  reshaped_swizzled_elements = np.zeros(shape,order = order).astype(int)
  reshaped_swizzled_elements = reshaped_swizzled_bytes[:,:,0]
  return reshaped_swizzled_elements//(n_bytes_per_elem*n_elems_per_vec)




In [44]:
A_bf16_32B_16_16 = load_matrix("bf16_32B_16_16.txt")
for b_bits in range(10):
  for m_base in range(10):
    for s_shift in range(10): 
      if abs(s_shift) >= b_bits:
        swiz = make_swizzled((16,16), "C", 2,1,b_bits,m_base,s_shift)
        if np.all(swiz == A_bf16_32B_16_16):
          print(b_bits,m_base,s_shift)
          break 

1 4 3


In [46]:
A_bf16_64B_4_32 = load_matrix("bf16_64B_4_32.txt")
for b_bits in range(10):
  for m_base in range(10):
    for s_shift in range(10): 
      if abs(s_shift) >= b_bits:
        swiz = make_swizzled((4,32), "C", 2,1,b_bits,m_base,s_shift)
        if np.all(swiz == A_bf16_64B_4_32):
          print(b_bits,m_base,s_shift)
          break 

1 4 3
2 4 3
3 4 3


In [47]:
A_bf16_128B_2_64 = load_matrix("bf16_128B_2_64.txt")
for b_bits in range(10):
  for m_base in range(10):
    for s_shift in range(10): 
      if abs(s_shift) >= b_bits:
        swiz = make_swizzled((2,64), "C", 2,1,b_bits,m_base,s_shift)
        if np.all(swiz == A_bf16_128B_2_64):
          print(b_bits,m_base,s_shift)
          break 

1 4 3
2 4 3
3 4 3


In [49]:
A_f32_128B_128_32 = load_matrix("f32_128B_128_32.txt")
for b_bits in range(10):
  for m_base in range(10):
    for s_shift in range(10): 
      if abs(s_shift) >= b_bits:
        swiz = make_swizzled((128,32), "C", 4,1,b_bits,m_base,s_shift)
        if np.all(swiz == A_f32_128B_128_32):
          print(b_bits,m_base,s_shift)
          break 

3 4 3


In [50]:
A_f32_64B_128_16 = load_matrix("f32_64B_128_16.txt")
for b_bits in range(10):
  for m_base in range(10):
    for s_shift in range(10): 
      if abs(s_shift) >= b_bits:
        swiz = make_swizzled((128,16), "C", 4,1,b_bits,m_base,s_shift)
        if np.all(swiz == A_f32_64B_128_16):
          print(b_bits,m_base,s_shift)
          break 

2 4 3


In [51]:
A_f32_32B_128_8 = load_matrix("f32_32B_128_8.txt")
for b_bits in range(10):
  for m_base in range(10):
    for s_shift in range(10): 
      if abs(s_shift) >= b_bits:
        swiz = make_swizzled((128,8), "C", 4,1,b_bits,m_base,s_shift)
        if np.all(swiz == A_f32_32B_128_8):
          print(b_bits,m_base,s_shift)
          break 

1 4 3
