In [None]:
import torch
import triton
import triton.language as tl

In [None]:
@triton.jit
def kmeans_kernel(
    X_ptr, centroids_ptr, new_centroids_ptr, cluster_counts_ptr, assignments_ptr, 
    B, H, L, C, D, 
    stride_x_b, stride_x_h, stride_x_l, stride_x_d,
    stride_c_b, stride_c_h, stride_c_c, stride_c_d,
    stride_nc_b, stride_nc_h, stride_nc_c, stride_nc_d,
    stride_cc_b, stride_cc_h, stride_cc_c,
    stride_a_b, stride_a_h, stride_a_l,
    BLOCK_SIZE_L: tl.constexpr,
    BLOCK_SIZE_C: tl.constexpr
):
    """
    Kernel to perform K-Means clustering:
    1. Assign each point (B, H, L, D) to the closest centroid (B, H, C, D).
    2. Compute new centroids using atomic updates.
    """
    b_idx = tl.program_id(0)
    h_idx = tl.program_id(1)
    l_idx = tl.program_id(2) * BLOCK_SIZE_L + tl.arange(0, BLOCK_SIZE_L)

    x_offset = b_idx * stride_x_b + h_idx * stride_x_h + l_idx * stride_x_l
    c_offset = b_idx * stride_c_b + h_idx * stride_c_h

    # Load input points
    x = tl.load(X_ptr + x_offset[:, None] + tl.arange(0, D) * stride_x_d, mask=l_idx[:, None] < L, other=0.0)

    # Track closest cluster
    min_dist = tl.full([BLOCK_SIZE_L], float("inf"), dtype=tl.float32)
    best_cluster = tl.zeros([BLOCK_SIZE_L], dtype=tl.int32)

    for c in range(0, C, BLOCK_SIZE_C):
        c_idx = c + tl.arange(0, BLOCK_SIZE_C)

        # Load centroids
        centroids = tl.load(
            centroids_ptr + c_offset + c_idx[:, None] * stride_c_c + tl.arange(0, D) * stride_c_d,
            mask=c_idx[:, None] < C,
            other=0.0
        )

        # Compute squared Euclidean distance
        dists = tl.sum((x[:, None] - centroids) ** 2, axis=-1)

        # Find the closest cluster
        closer = dists < min_dist[:, None]
        min_dist = tl.where(closer, dists, min_dist)
        best_cluster = tl.where(closer, c_idx, best_cluster)

    # Store assignments
    tl.store(
        assignments_ptr + b_idx * stride_a_b + h_idx * stride_a_h + l_idx * stride_a_l, best_cluster, 
        mask=l_idx < L
    )

    # Update centroids using atomic operations
    for d in range(D):
        value = tl.load(X_ptr + x_offset + d * stride_x_d, mask=l_idx < L)
        for i in range(BLOCK_SIZE_L):
            cluster = best_cluster[i]
            tl.atomic_add(
                new_centroids_ptr + b_idx * stride_nc_b + h_idx * stride_nc_h + cluster * stride_nc_c + d * stride_nc_d,
                value[i]
            )
            tl.atomic_add(
                cluster_counts_ptr + b_idx * stride_cc_b + h_idx * stride_cc_h + cluster * stride_cc_c,
                1.0
            )

def kmeans_triton(X, num_clusters, num_iters=10):
    """
    X: Tensor of shape (B, H, L, D)
    num_clusters: Number of clusters C
    Returns: centroids of shape (B, H, C, D)
    """
    B, H, L, D = X.shape
    C = num_clusters
    device = X.device
    
    # Generate shuffled indices
    shuffled_indices = torch.rand(B, H, L, device=device).argsort(dim=-1)  # (B, H, L)
    
    # Select the first C elements
    indices = shuffled_indices[:, :, :C]  # (B, H, C)
    
    # Gather the centroids
    centroids = torch.gather(X, 2, indices.unsqueeze(-1).expand(-1, -1, -1, D)).clone()

    # Allocate storage for assignments, new centroids, and counts
    assignments = torch.zeros((B, H, L), dtype=torch.int32, device=device)
    new_centroids = torch.zeros((B, H, C, D), device=device)
    cluster_counts = torch.zeros((B, H, C), device=device)

    BLOCK_SIZE_L = min(64, L)  # Adapt to dataset size
    grid = (B, H, triton.cdiv(L, BLOCK_SIZE_L))

    for _ in range(num_iters):
        # Reset new centroids and cluster counts
        new_centroids.zero_()
        cluster_counts.zero_()

        # Assign points and compute new centroids
        kmeans_kernel[grid](
            X, centroids, new_centroids, cluster_counts, assignments,
            B, H, L, C, D,
            X.stride(0), X.stride(1), X.stride(2), X.stride(3),
            centroids.stride(0), centroids.stride(1), centroids.stride(2), centroids.stride(3),
            new_centroids.stride(0), new_centroids.stride(1), new_centroids.stride(2), new_centroids.stride(3),
            cluster_counts.stride(0), cluster_counts.stride(1), cluster_counts.stride(2),
            assignments.stride(0), assignments.stride(1), assignments.stride(2),
            BLOCK_SIZE_L=64, BLOCK_SIZE_C=32
        )
        
        # Update centroids
        centroids = new_centroids / (cluster_counts + 1e-6)

    return centroids

In [None]:
import torch
import triton
import triton.language as tl

@triton.jit
def kmeans_assign_kernel(
    X_ptr, centroids_ptr, assignments_ptr,
    B, H, L, C, D,
    stride_x_b, stride_x_h, stride_x_l, stride_x_d,
    stride_c_b, stride_c_h, stride_c_c, stride_c_d,
    stride_a_b, stride_a_h, stride_a_l,
    BLOCK_SIZE_L: tl.constexpr,
    BLOCK_SIZE_C: tl.constexpr
):
    """
    Kernel to assign each point to the closest centroid.
    """
    b_idx = tl.program_id(0)
    h_idx = tl.program_id(1)
    l_idx = tl.program_id(2) * BLOCK_SIZE_L + tl.arange(0, BLOCK_SIZE_L)

    x_offset = b_idx * stride_x_b + h_idx * stride_x_h + l_idx * stride_x_l
    c_offset = b_idx * stride_c_b + h_idx * stride_c_h

    # Load input points
    x = tl.load(X_ptr + x_offset[:, None] + tl.arange(0, D) * stride_x_d, mask=l_idx[:, None] < L, other=0.0)

    # Track closest cluster
    min_dist = tl.full([BLOCK_SIZE_L], 1e30, dtype=tl.float32)
    best_cluster = tl.zeros([BLOCK_SIZE_L], dtype=tl.int32)

    for c in range(0, C, BLOCK_SIZE_C):
        c_idx = c + tl.arange(0, BLOCK_SIZE_C)

        # Load centroids
        centroids = tl.load(
            centroids_ptr + c_offset + c_idx[:, None] * stride_c_c + tl.arange(0, D) * stride_c_d,
            mask=c_idx[:, None] < C,
            other=0.0
        )

        # Compute squared Euclidean distance
        dists = tl.sum((x[:, None] - centroids) ** 2, axis=-1)

        # Find the closest cluster
        closer = dists < min_dist[:, None]
        min_dist = tl.where(closer, dists, min_dist)
        best_cluster = tl.where(closer, c_idx, best_cluster)

    # Store assignments
    tl.store(
        assignments_ptr + b_idx * stride_a_b + h_idx * stride_a_h + l_idx * stride_a_l, best_cluster, 
        mask=l_idx < L
    )

@triton.jit
def kmeans_update_kernel(
    X_ptr, new_centroids_ptr, cluster_counts_ptr, assignments_ptr,
    B, H, L, C, D,
    stride_x_b, stride_x_h, stride_x_l, stride_x_d,
    stride_nc_b, stride_nc_h, stride_nc_c, stride_nc_d,
    stride_cc_b, stride_cc_h, stride_cc_c,
    stride_a_b, stride_a_h, stride_a_l,
    BLOCK_SIZE_L: tl.constexpr,
    BLOCK_SIZE_C: tl.constexpr
):
    """
    Kernel to compute new centroids based on the assignments, handling multiple centroids at once.
    """
    b_idx = tl.program_id(0)
    h_idx = tl.program_id(1)
    c_start = tl.program_id(2) * BLOCK_SIZE_C
    c_idx = c_start + tl.arange(0, BLOCK_SIZE_C)

    mask_c = c_idx < C  # Mask for valid centroids
    sum_x = tl.zeros([BLOCK_SIZE_C, D], dtype=tl.float32)
    count = tl.zeros([BLOCK_SIZE_C], dtype=tl.float32)

    for l in range(0, L, BLOCK_SIZE_L):
        l_idx = l + tl.arange(0, BLOCK_SIZE_L)
        mask_l = l_idx < L

        # Load assignments
        cluster = tl.load(
            assignments_ptr + b_idx * stride_a_b + h_idx * stride_a_h + l_idx * stride_a_l, 
            mask=mask_l, other=-1
        )

        # Load points
        x = tl.load(
            X_ptr + b_idx * stride_x_b + h_idx * stride_x_h + l_idx[:, None] * stride_x_l + tl.arange(0, D) * stride_x_d,
            mask=mask_l[:, None], 
            other=0.0
        )

        # Compute cluster membership for all centroids in the current block
        belongs_to_cluster = cluster[:, None] == c_idx[None, :]
        sum_x += tl.sum(x[:, None, :] * belongs_to_cluster[:, :, None], axis=0)
        count += tl.sum(belongs_to_cluster, axis=0)

    # Store new centroids
    tl.store(
        new_centroids_ptr + b_idx * stride_nc_b + h_idx * stride_nc_h + c_idx[:, None] * stride_nc_c + tl.arange(0, D) * stride_nc_d,
        sum_x, 
        mask=mask_c[:, None]
    )
    tl.store(
        cluster_counts_ptr + b_idx * stride_cc_b + h_idx * stride_cc_h + c_idx * stride_cc_c,
        count, 
        mask=mask_c
    )

def kmeans_triton(X, num_clusters, num_iters=10):
    """
    X: Tensor of shape (B, H, L, D)
    num_clusters: Number of clusters C
    Returns: centroids of shape (B, H, C, D)
    """
    B, H, L, D = X.shape
    C = num_clusters
    device = X.device
    
    X = X.contiguous()
    
    # Initialize centroids
    shuffled_indices = torch.rand(B, H, L, device=device).argsort(dim=-1)
    indices = shuffled_indices[:, :, :C]
    centroids = torch.gather(X, 2, indices.unsqueeze(-1).expand(-1, -1, -1, D)).clone()

    assignments = torch.zeros((B, H, L), dtype=torch.int32, device=device)
    new_centroids = torch.zeros((B, H, C, D), device=device)
    cluster_counts = torch.zeros((B, H, C), device=device)

    BLOCK_SIZE_L = min(64, L)
    grid_assign = (B, H, triton.cdiv(L, BLOCK_SIZE_L))
    BLOCK_SIZE_C = min(32, C)  # Number of centroids processed in parallel
    grid_update = (B, H, triton.cdiv(C, BLOCK_SIZE_C))

    for _ in range(num_iters):
        new_centroids.zero_()
        cluster_counts.zero_()

        # Assign points
        kmeans_assign_kernel[grid_assign](
            X, centroids, assignments,
            B, H, L, C, D,
            X.stride(0), X.stride(1), X.stride(2), X.stride(3),
            centroids.stride(0), centroids.stride(1), centroids.stride(2), centroids.stride(3),
            assignments.stride(0), assignments.stride(1), assignments.stride(2),
            BLOCK_SIZE_L=64, BLOCK_SIZE_C=32
        )
        
        # Update centroids
        kmeans_update_kernel[grid_update](
            X, new_centroids, cluster_counts, assignments,
            B, H, L, C, D,
            X.stride(0), X.stride(1), X.stride(2), X.stride(3),
            new_centroids.stride(0), new_centroids.stride(1), new_centroids.stride(2), new_centroids.stride(3),
            cluster_counts.stride(0), cluster_counts.stride(1), cluster_counts.stride(2),
            assignments.stride(0), assignments.stride(1), assignments.stride(2),
            BLOCK_SIZE_L=64, BLOCK_SIZE_C=32
        )
        
        valid_clusters = cluster_counts > 0
        centroids[valid_clusters] = new_centroids[valid_clusters] / cluster_counts[valid_clusters]
        
        empty_clusters = ~valid_clusters 
        if empty_clusters.any():
            start = torch.randint(0, L - C, (1,), device=device).item()
            reinit_indices = shuffled_indices[:, :, start:start + C]
            new_centroids = torch.gather(X, 2, reinit_indices.unsqueeze(-1).expand(-1, -1, -1, D))
            centroids[empty_clusters] = new_centroids[empty_clusters]
    
    return centroids, cluster_counts


In [None]:
def kmeans_pytorch(X, num_clusters, num_iters=10):
    """
    Pure PyTorch implementation of k-means clustering.
    
    X: Tensor of shape (B, H, L, D)
    num_clusters: Number of clusters C
    Returns: centroids of shape (B, H, C, D)
    """
    B, H, L, D = X.shape
    C = num_clusters
    device = X.device
    
    X = X.contiguous()
    
    # Initialize centroids
    shuffled_indices = torch.rand(B, H, L, device=device).argsort(dim=-1)
    indices = shuffled_indices[:, :, :C]
    centroids = torch.gather(X, 2, indices.unsqueeze(-1).expand(-1, -1, -1, D)).clone()
    
    new_centroids = torch.zeros_like(centroids)
    cluster_counts = torch.zeros((B, H, C), dtype=torch.float32, device=device)
    
    for _ in range(num_iters):
        # Compute distances: (B, H, L, C)
        dists = torch.cdist(X.view(B * H, L, D), centroids.view(B * H, C, D)).view(B, H, L, C)
        
        # Assign each point to the closest centroid
        assignments = dists.argmin(dim=-1)
        
        # Compute new centroids
        new_centroids.zero_()
        cluster_counts.zero_()
        
        for c in range(C):
            mask = assignments == c  # (B, H, L)
            count = mask.sum(dim=-1, keepdim=True).float()  # (B, H, 1)
            
            # Sum selected points
            sum_x = torch.where(mask.unsqueeze(-1), X, torch.zeros_like(X)).sum(dim=2)
            
            # Avoid division by zero
            valid_clusters = count.squeeze(-1) > 0
            new_centroids[valid_clusters, c] = sum_x[valid_clusters] / count[valid_clusters]
        
        # Handle empty clusters by reinitializing
        empty_clusters = cluster_counts == 0
        if empty_clusters.any():
            start = torch.randint(0, L - C, (1,), device=device)
            reinit_indices = shuffled_indices[:, :, start:start + C]
            new_centroids[empty_clusters] = torch.gather(X, 2, reinit_indices.unsqueeze(-1).expand(-1, -1, -1, D))[empty_clusters]
        
        centroids = new_centroids.clone()
    
    return centroids, cluster_counts