In [None]:
import torch

import os
os.environ['TRITON_INTERPRET'] = '1'
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

In [2]:
def tile_tensor(tensor, tile_size):
    """
    Разбивает тензор на тайлы и собирает их в новый тензор.
    
    Args:
        tensor: torch.Tensor размером [M, N]
        tile_size: размер тайла K
        
    Returns:
        torch.Tensor размером [M//K, N//K, K**2]
    """
    M, N = tensor.shape
    # Проверяем, что размеры делятся нацело
    assert M % tile_size == 0 and N % tile_size == 0, \
        f"Размеры {M}x{N} должны делиться на {tile_size} нацело"
    
    # Разбиваем на тайлы
    tiles = tensor.unfold(0, tile_size, tile_size).unfold(1, tile_size, tile_size)
    # tiles.shape = [M//K, N//K, K, K]
    
    # Преобразуем в [M//K, N//K, K**2]
    tiles_flat = tiles.contiguous().view(M // tile_size, N // tile_size, tile_size ** 2)
    
    return tiles_flat


# Альтернативная версия обратной функции с использованием reshape и permute
def untile_tensor_fast(tiled_tensor, original_shape, tile_size):
    """
    Быстрая версия восстановления тензора из тайлов.
    
    Args:
        tiled_tensor: torch.Tensor размером [M//K, N//K, K**2]
        original_shape: кортеж (M, N)
        tile_size: размер тайла K
        
    Returns:
        torch.Tensor размером [M, N]
    """
    M, N = original_shape
    num_tiles_h = M // tile_size
    num_tiles_w = N // tile_size
    
    # Преобразуем в [M//K, N//K, K, K]
    tiles = tiled_tensor.view(num_tiles_h, num_tiles_w, tile_size, tile_size)
    
    # Переупорядочиваем в [M, N]
    reconstructed = tiles.permute(0, 2, 1, 3).contiguous().view(M, N)
    
    return reconstructed


# M, N, K = 6, 8, 2
# original_tensor = torch.arange(M * N).float().view(M, N)
# print("Исходный тензор:")
# print(original_tensor)
# print(f"Форма: {original_tensor.shape}")

# # Прямое преобразование
# tiled = tile_tensor(original_tensor, K)
# print(f"\nТайлы:")
# print(tiled)
# print(f"Форма тайлов: {tiled.shape}")

# # Обратное преобразование
# reconstructed = untile_tensor_fast(tiled, (M, N), K)
# print(f"\nВосстановленный тензор:")
# print(reconstructed)
# print(f"Форма: {reconstructed.shape}")

# # Проверка корректности
# print(f"\nВосстановление корректно: {torch.allclose(original_tensor, reconstructed)}")

In [3]:
_INT16_BITS = 16

M = 16 # 512
N = 16 # 512

M_TILE_SIZE = 16
N_TILE_SIZE = 16

L = 16
BIT_SHIFT = 8

BITS_PER_VALUE = L // BIT_SHIFT
TILE_BITS = M_TILE_SIZE * N_TILE_SIZE * BITS_PER_VALUE

print("BITS_PER_VALUE:", BITS_PER_VALUE)
print("TILE_BITS:", TILE_BITS)

compessed = torch.randint(
    0, 
    2**15 - 1, 
    (M // M_TILE_SIZE, N // N_TILE_SIZE, TILE_BITS // _INT16_BITS), 
    dtype=torch.uint16
)

B = 16
x = torch.randn(B, N)
print("x.shape:", x.shape)

compessed.shape

BITS_PER_VALUE: 2
TILE_BITS: 512
x.shape: torch.Size([16, 16])


torch.Size([1, 1, 32])

In [4]:
compessed.shape

torch.Size([1, 1, 32])

In [14]:
def decode_fp4_python(bits):
    val0 = (bits >> 12) & 0xF
    val1 = (bits >> 8) & 0xF
    val2 = (bits >> 4) & 0xF
    val3 = bits & 0xF
    
    sign0 = (val0 & 0x8) != 0
    sign1 = (val1 & 0x8) != 0
    sign2 = (val2 & 0x8) != 0
    sign3 = (val3 & 0x8) != 0
    
    mag0 = (val0 & 0x7).to(torch.float16)
    mag1 = (val1 & 0x7).to(torch.float16)
    mag2 = (val2 & 0x7).to(torch.float16)
    mag3 = (val3 & 0x7).to(torch.float16)
    
    w0 = torch.where(sign0, -mag0, mag0)
    w1 = torch.where(sign1, -mag1, mag1)
    w2 = torch.where(sign2, -mag2, mag2)
    w3 = torch.where(sign3, -mag3, mag3)

    return torch.stack([w0, w1, w2, w3], dim=-1)


def matmul_python(bits, activations):
    N = bits.shape[0]
    fp4_matrix = decode_fp4_python(bits).reshape(N, 16, 16)
    result = torch.matmul(activations, fp4_matrix)
    return result

@triton.jit
def decode_fp4_triton(bits):
    """Декодирует 16-битное значение в 4 значения FP4"""
    val0 = (bits >> 12) & 0xF
    val1 = (bits >> 8) & 0xF
    val2 = (bits >> 4) & 0xF
    val3 = bits & 0xF
    
    sign0 = (val0 & 0x8) != 0
    sign1 = (val1 & 0x8) != 0
    sign2 = (val2 & 0x8) != 0
    sign3 = (val3 & 0x8) != 0
    
    mag0 = (val0 & 0x7)
    mag1 = (val1 & 0x7)
    mag2 = (val2 & 0x7)
    mag3 = (val3 & 0x7)
    
    w0 = tl.where(sign0, -mag0.to(tl.float16), mag0.to(tl.float16))
    w1 = tl.where(sign1, -mag1.to(tl.float16), mag1.to(tl.float16))
    w2 = tl.where(sign2, -mag2.to(tl.float16), mag2.to(tl.float16))
    w3 = tl.where(sign3, -mag3.to(tl.float16), mag3.to(tl.float16))
    
    return w0, w1, w2, w3


@triton.jit
def matmul_kernel(
    bits_ptr,
    activations_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    
    mask = offsets < n_elements

    pass


def matmul_triton(bits, activations):
    B, N, M = activations.shape
    output = torch.empty_like(activations)
    grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
    matmul_kernel[grid](bits, activations, output, N, BLOCK_SIZE=1024)
    return output



torch.manual_seed(0)
N = 256
bits = torch.randint(
    0, 
    2**15 - 1, 
    (N, 64,), 
    dtype=torch.int16
).cuda()
activations = torch.randn(N, 16, 16, dtype=torch.float16).cuda()

res = matmul_python(bits, activations)

In [11]:
@triton.jit
def decode_fp4_triton(bits):
    """Decodes 16-bit value into 4 FP4 values"""
    val0 = (bits >> 12) & 0xF
    val1 = (bits >> 8) & 0xF
    val2 = (bits >> 4) & 0xF
    val3 = bits & 0xF
    
    sign0 = (val0 & 0x8) != 0
    sign1 = (val1 & 0x8) != 0
    sign2 = (val2 & 0x8) != 0
    sign3 = (val3 & 0x8) != 0
    
    mag0 = (val0 & 0x7)
    mag1 = (val1 & 0x7)
    mag2 = (val2 & 0x7)
    mag3 = (val3 & 0x7)
    
    w0 = tl.where(sign0, -mag0, mag0)
    w1 = tl.where(sign1, -mag1, mag1)
    w2 = tl.where(sign2, -mag2, mag2)
    w3 = tl.where(sign3, -mag3, mag3)
    
    return w0, w1, w2, w3

@triton.jit
def matmul_kernel(
    bits_ptr,
    activations_ptr,
    output_ptr,
    N,  # Number of blocks in weight matrix
    M,  # Activation dimension
    B,  # Batch size
    BLOCK_SIZE: tl.constexpr,
):
    # 2D grid: batch_id x block_id
    batch_id = tl.program_id(axis=0)
    block_id = tl.program_id(axis=1)
    
    # Each block processes 16x16 submatrix
    block_start = block_id * 16
    row_offsets = tl.arange(0, 16)
    col_offsets = tl.arange(0, 16)
    
    # Load the bits for this block
    bits_offset = batch_id * N + block_id
    bits = tl.load(bits_ptr + bits_offset)
    
    # Decode FP4 values
    w0, w1, w2, w3 = decode_fp4_triton(bits)
    
    # Create 16x16 weight matrix from decoded values
    # We need to properly arrange the 4 values into a 4x4 pattern repeated 4 times
    # First, create a 4x4 block from the 4 values
    weight_block = tl.zeros((16, 16), dtype=tl.float16)
    
    # Fill the diagonal blocks with the decoded values
    # This creates a block diagonal matrix where each 4x4 block has the same value
    for i in range(4):
        for j in range(4):
            row = i * 4 + j
            col = i * 4 + j
            if i == 0:
                weight_block = tl.where((row_offsets[:, None] == row) & (col_offsets[None, :] == col), 
                                      w0, weight_block)
            elif i == 1:
                weight_block = tl.where((row_offsets[:, None] == row) & (col_offsets[None, :] == col), 
                                      w1, weight_block)
            elif i == 2:
                weight_block = tl.where((row_offsets[:, None] == row) & (col_offsets[None, :] == col), 
                                      w2, weight_block)
            else:
                weight_block = tl.where((row_offsets[:, None] == row) & (col_offsets[None, :] == col), 
                                      w3, weight_block)
    
    # Load activations for this block (16 elements)
    act_offsets = batch_id * M + block_start + row_offsets
    activations = tl.load(activations_ptr + act_offsets, mask=row_offsets < M)
    
    # Perform the matrix multiplication for this block
    # activations: [16], weight_block: [16, 16] -> result: [16]
    result = tl.dot(activations, weight_block)
    
    # Store the result
    output_offsets = batch_id * M + block_start + col_offsets
    tl.store(output_ptr + output_offsets, result, mask=col_offsets < M)

def matmul_triton(bits, activations):
    B, N, M = activations.shape
    # Reshape bits to match the expected format: [B, N//16]
    bits_reshaped = bits.reshape(B, -1)
    
    output = torch.zeros_like(activations)
    
    # Grid: batch_size x number_of_blocks
    grid = lambda meta: (B, triton.cdiv(M, 16))
    
    matmul_kernel[grid](
        bits_reshaped, 
        activations, 
        output, 
        bits_reshaped.shape[1],  # N_blocks
        M,  # M dimension
        B,  # Batch size
        BLOCK_SIZE=16
    )
    return output

In [12]:
res_triton = matmul_triton(bits, activations)

InterpreterError: AssertionError("Both inputs must be either 2D or 3D; (lhs: ['constexpr[16]'] vs rhs: ['constexpr[16]', 'constexpr[16]'])")