In [None]:
lm_eval --model hf \
--model_args pretrained=openai-community/gpt2-medium,trust_remote_code=True,dtype="bfloat16" \
--tasks hellaswag,wikitext \
--device cuda:0 \
--batch_size auto

In [None]:
class Router(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.top_k = config.n_activated_experts
        self.gate = nn.Linear(config.d_embed, config.n_experts)

    def forward(self, x):
        logits = self.gate(x)  # [batch_size, seq_len, n_experts]
        scores = F.softmax(logits, dim=-1)
        scores, indices = torch.topk(scores, self.top_k, dim=-1)  # [batch_size, seq_len, top_k]
        return scores, indices


class MoE(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.top_k = config.n_activated_experts
        if config.n_experts is None or config.n_activated_experts is None:
            raise ValueError("n_experts and n_activated_experts must be specified for MoE")
        if config.n_experts < config.n_activated_experts:
            raise ValueError("n_experts must be greater than or equal to n_activated_experts")
        self.router = Router(config)
        self.experts = nn.ModuleList([FeedForward(config) for _ in range(config.n_experts)])
        if config.n_shared_experts is not None:
            self.shared_experts = nn.ModuleList([FeedForward(config) for _ in range(config.n_shared_experts)])
        else:
            self.shared_experts = None

    def forward(self, x):
        batch_size, seq_len, d_embed = x.size()
        scores, indices = self.router(x)

        x_flat = x.view(-1, d_embed)  # [batch * seq, d_embed]
        scores_flat = scores.view(-1, self.top_k)  # [batch * seq, top_k]
        indices_flat = indices.view(-1, self.top_k)  # [batch * seq, top_k]

        y_flat = torch.zeros_like(x_flat)

        for expert_idx, expert in enumerate(self.experts):
            mask = (indices_flat == expert_idx)
            selected_pos = mask.any(dim=-1)  # [batch * seq]

            if selected_pos.any():
                expert_input = x_flat[selected_pos]  # [top_k, d_embed]
                expert_output = expert(expert_input)

                selected_mask_idx, selected_topk_idx = torch.where(mask)
                gating_scores = scores_flat[selected_mask_idx, selected_topk_idx].unsqueeze(1)  # [top_k, 1]

                weighted_output = expert_output * gating_scores  # [top_k, d_embed]
                y_flat[selected_pos] += weighted_output  # [batch * seq, d_embed]

        if self.shared_experts is not None:
            shared_output = sum([expert(x_flat) for expert in self.shared_experts]) / len(self.shared_experts)
            y_flat += shared_output  # [batch * seq, d_embed]

        y = y_flat.view(batch_size, seq_len, d_embed)
        return y


class RouterFreeMoE(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.experts = nn.ModuleList([FeedForward(config) for _ in range(config.n_experts)])

    def forward(self, x):
        batch_size, seq_len, d_embed = x.size()
        x_flat = x.view(-1, d_embed)  # [batch * seq, d_embed]

        # Step 1: Expert selection
        deltas = []
        with torch.no_grad():
            # if Device compute capabilities is 8.9 or higher, use fp8_autocast
            if torch.cuda.get_device_capability()[0] >= 8 and torch.cuda.get_device_capability()[1] >= 9:
                #ctx = fp8_autocast(enabled=True)
                ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
            else:
                ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
            with ctx:
                for expert in self.experts:
                    expert_output = expert(x_flat)  # [batch * seq, d_embed]
                    delta = (expert_output - x_flat).norm(dim=-1)  # [batch * seq]
                    deltas.append(delta)
            deltas = torch.stack(deltas, dim=0)  # [n_experts, batch * seq]
            best_expert_idx = deltas.argmax(dim=0)  # [batch * seq]

        # Step 2: Apply the best expert
        y_flat = torch.zeros_like(x_flat)
        for expert_idx, expert in enumerate(self.experts):
            mask = (best_expert_idx == expert_idx)  # [batch * seq]
            if mask.any():
                selected_input = x_flat[mask]  # [num_selected, d_embed]
                selected_output = expert(selected_input)  # [num_selected, d_embed]
                y_flat[mask] = selected_output  # [batch * seq, d_embed]
        y = y_flat.view(batch_size, seq_len, d_embed)
        return y


In [None]:
    def get_flops(self, x):
        B, T = x.size()
        D = self.config.d_embed

        # ---------- Accelerator Intensity in Nvidia -------------------------------------------------------------------
        # FLOPs = 52.22 TFLOPs
        # Memory Bandwidth = 736.6GB/s
        # TensorCore Accelerator Intensity = 70.65

        # ---------- Matrix multiplication FLOPs and Memory read/writes ------------------------------------------------
        # x = [B, T, D], W = [D, F]
        # 1. Read x to SRAM
        # bytes = 2 x B X T X D = 2BTD
        # 2. Read W to SRAM
        # bytes = 2 x D x F = 2DF
        # 3. Compute Y = x @ W
        # FLOPS = 2 x B x T x D x F = 2BTDF
        # 4. Write Y to HBM
        # bytes = 2BTF
        # Arithmetic Intensity
        # AI = 2BTDF / (2BTD + 2DF + 2BTF) = 4BTD / (5BT + 4D)
        #    = BT

        # ---------- Attention FLOPs and Memory read/writes ------------------------------------------------------------
        # 1. Read x from HBM
        # bytes =
        # 2. Read Wq Wk Wv from HBM
        # bytes =
        # 3. Compute
        #
        #

        ########## Multi Head Attention ################################################################################
        if self.config.rank is None:
            ## Prefill
            ######### Flash Attention ##########
            # 1. Read Q, K, V from HBM
            # bytes = 3 x (2 x B x T x D) = 6BTD
            # 2. Compute Q @ K
            # FLOPS = 2 x B x T x T x D = 2BT^2D
            # 3. Compute A @ V
            # FLOPS = 2 × B × T × T × D = 2BT^2D
            # 4. Write attn_out
            # bytes = 2 x B x T x D = 2BTD
            ####################################
            # attn bytes = 6BTD + 2BTD = 8BTD
            # attn FLOPS = 2BT^2D + 2BT^2D = 4BT^2D
            # attn AI = 4BT^2D / 8BTD = T/2
            #         = T
            attn_prefill_flops = 4 * B * T * T * D
            attn_prefill_ai = T / 2

            ## Decoding
            ######### Flash Attention ##########
            # 1. Read Q, K, V from HBM
            # bytes = 2 x B x 1 x D + 2 x (2 x B x S x D) = 2BD + 4BSD
            # 2. Compute Q @ K
            # FLOPS = 2 x B x 1 x S x D = 2BSD
            # 3. Compute A @ V
            # FLOPS = 2 × B × S × 1 × D = 2BSD
            # 4. Write attn_out
            # bytes = 2 x B x 1 x D = 2BD
            ####################################
            # attn bytes = 2BD + 4BSD + 2BD = 4BD(1 + S)
            # attn FLOPS = 2BSD + 2BSD = 4BSD
            # attn AI = 4BSD / 4BD(1 + S) = S / (1 + S) (ignore 1)
            #         = 1
            attn_decoding_flops = 4 * B * T * D
            attn_decoding_ai = T / (1 + T)
            # Why so low ai?
            # FLOPS decreased by T, bytes remain the same

        ########## Multi Head Latent Attention #########################################################################
        else:
            R = self.config.rank
            ## Prefill
            # 1. Read KV_latent from HBM
            # bytes = 2 x (2 x B x T x R) = 4BTR
            # 2. Read Wkv_up from HBM
            # bytes = 2 x R x 2D = 4RD
            # 3. Compute kv_latent @ Wkv_up
            # FLOPS = 2 x B x T x R x 2D = 4BTRD
            # 4. Write k, v to HBM
            # bytes = 2 x (2 x B x T x D) = 4BTD
            ######### Flash Attention ##########
            # 1. Read Q, K, V from HBM
            # bytes = 3 x (2 x B x T x D) = 6BTD
            # 2. Compute Q @ K
            # FLOPS = 2 x B x T x T x D = 2BT^2D
            # 3. Compute A @ V
            # FLOPS = 2 × B × T × T × D = 2BT^2D
            # 4. Write y
            # bytes = 2 x B x T x D = 2BTD
            ####################################
            # attn bytes = 4BTR + 4RD + 4BTD + 6BTD + 2BTD = 4(RD + BTR + 3BTD)
            # attn FLOPS = 4BTRD + 2BT^2D + 2BT^2D = 4BTD(R + T)
            # attn AI = 4BTD(R + T) / 4(RD + BTR + 3BTD) = BTD(R + T) / RD + BTR + 3BTD
            #         =
            attn_prefill_flops = 4 * B * T * D * (R + T)
            attn_prefill_ai = T / 2

            ## Decoding
            # 1. KV_latent from HBM
            # bytes = 2 x (2 x B x S x R) = 4BSR
            # 2. Read Wkv_up from HBM
            # bytes = 2 x R x 2D = 4RD
            # 3. Compute kv_latent @ Wkv_up
            # FLOPS = 2 x B x S x R x 2D = 4BSRD
            # 4. Write k, v to HBM
            # bytes = 2 x (2 x B x S x D) = 4BSD
            ######### Flash Attention ##########
            # 1. Read Q_latent, KV_latent from HBM
            # bytes = 2 x B x S x R x D / d + 2 x B x S x R = 2BSR(D/d + 1)
            # 2. Compute Q_latent @ K_latent
            # FLOPS = 2 x B x 1 x S x R = 2BSR
            # 3. Compute A @ KV_latent
            # FLOPS = 2 x B x S x 1 x R = 2BSR
            # 4. Write y
            # bytes = 2 x B x 1 X R = 2BR
            ####################################
            # 5. Compute
            # attn bytes =
            # attn FLOPS =
            # attn AI = 2S / (1 + S(D/d + 1))
            #         =
            attn_decoding_flops = 4 * B * T * D
            attn_decoding_ai = T / (1 + T)
        ################################################################################################################

        ## KV cache
        ##### MultiHeadAttention ###########
        # size = 2 x (2 x B x S x D) = 4BSD
        # AI = 1
        ##### GroupedQueryAttention ########
        # size = 2 x (2 x B x S x D / n_groups) = 4BSD / n_groups
        # AI = G (group_size)
        ##### MultiQueryAttention ##########
        # size = 2 x (2 x B X S x d) = 4BSD / n_heads
        # AI = n_heads
        ##### MultiHeadLatentAttention #####
        # size = 2 x (B x S x R) = 2BSR = 4BSD / (2D/R)
        # AI = 2D/R

        # ---------- FeedForward FLOPs and Memory read/writes ----------------------------------------------------------
        # 1. Read x from HBM
        # bytes = 2 x B x T x D = 2BTD
        # 2. Read Wup, Wdown from HBM
        # bytes = 2 x (2 x D x 4D) = 16D^2
        # 3. Compute x @ Wup
        # FLOPS = 2 x B x T x D x 4D = 8BTD^2
        # 4. Compute x @ Wdown
        # FLOPS = 2 x B x T x 4D x D = 8BTD^2
        # 5. Write x to HBM
        # bytes = 2 x B x T x D = 2BTD
        ####################################
        # FF bytes = 2BTD + 16D^2 + 2BTD = 4D(BT + 4D)
        # FF FLOPS = 8BTD^2 + 8BTD^2 = 16BTD^2
        # FF AI = 16BTD^2 / 4D(BT + 4D) = 4BTD / (BT + 4D) (ignore 2BT)
        #       = BT
        feedforward_flops = 16 * B * T * D * D
        feedforward_ai = 4 * B * T * D / (B * T + 4 * D)

        flops = {
            'attn_prefill_flops': attn_prefill_flops, 'attn_prefill_ai': attn_prefill_ai,
            'attn_decoding_flops': attn_decoding_flops, 'attn_decoding_ai': attn_decoding_ai,
            'feedforward_flops': feedforward_flops, 'feedforward_ai': feedforward_ai
        }

        return flops

In [1]:
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import math
import time


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_N': 16, 'BLOCK_D': 64}, num_warps=2, num_stages=2),
        triton.Config({'BLOCK_N': 32, 'BLOCK_D': 64}, num_warps=2, num_stages=2),
        triton.Config({'BLOCK_N': 64, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_N': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_N': 256, 'BLOCK_D': 64}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 512, 'BLOCK_D': 64}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 32, 'BLOCK_D': 128}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_N': 64, 'BLOCK_D': 128}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_N': 128, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 64, 'BLOCK_D': 256}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 128, 'BLOCK_D': 256}, num_warps=8, num_stages=2),
    ],
    key=['N', 'D'],  # 시퀀스 길이와 d_head에 따라 자동 최적화
)
@triton.jit
def autotuned_flash_decode_kernel(
        Q, K, V, Out,
        stride_qb, stride_qh, stride_qd,
        stride_kb, stride_kh, stride_kn, stride_kd,
        stride_vb, stride_vh, stride_vn, stride_vd,
        stride_ob, stride_oh, stride_od,
        B, H, N, D,
        scale,
        BLOCK_N: tl.constexpr,
        BLOCK_D: tl.constexpr,
):
    """
    Triton Autotuned Flash Decode Kernel
    - 자동으로 최적 블록 크기 선택
    - 다양한 하드웨어에서 최적 성능
    """
    # 작업 ID
    pid = tl.program_id(0)
    batch_id = pid // H
    head_id = pid % H

    # Q 벡터 로드
    q_offset = batch_id * stride_qb + head_id * stride_qh
    d_range = tl.arange(0, BLOCK_D)
    q = tl.load(Q + q_offset + d_range, mask=d_range < D)

    # 누적 변수
    acc = tl.zeros([BLOCK_D], dtype=tl.float32)
    max_score = -float('inf')
    sum_exp = 0.0

    # KV 블록 처리
    for start_n in tl.range(0, N, BLOCK_N):
        end_n = tl.minimum(start_n + BLOCK_N, N)

        # K 블록 로드
        k_offset = batch_id * stride_kb + head_id * stride_kh + start_n * stride_kn
        n_range = tl.arange(0, BLOCK_N)
        k_ptrs = K + k_offset + n_range[:, None] * stride_kn + d_range[None, :]
        k_mask = (n_range[:, None] < (end_n - start_n)) & (d_range[None, :] < D)
        k_vals = tl.load(k_ptrs, mask=k_mask, other=0.0)

        # Attention scores
        scores = tl.sum(q[None, :] * k_vals, axis=1) * scale
        scores = tl.where(n_range < (end_n - start_n), scores, -float('inf'))

        # Online softmax
        block_max = tl.max(scores)
        new_max = tl.maximum(max_score, block_max)

        if max_score > -float('inf'):
            exp_diff = tl.exp(max_score - new_max)
            acc = acc * exp_diff
            sum_exp = sum_exp * exp_diff

        # Softmax weights
        weights = tl.exp(scores - new_max)
        weights = tl.where(n_range < (end_n - start_n), weights, 0.0)
        block_sum = tl.sum(weights)

        # V 블록 로드 및 누적
        v_offset = batch_id * stride_vb + head_id * stride_vh + start_n * stride_vn
        v_ptrs = V + v_offset + n_range[:, None] * stride_vn + d_range[None, :]
        v_mask = (n_range[:, None] < (end_n - start_n)) & (d_range[None, :] < D)
        v_vals = tl.load(v_ptrs, mask=v_mask, other=0.0)

        weighted_v = tl.sum(weights[:, None] * v_vals, axis=0)
        acc = acc + weighted_v
        sum_exp = sum_exp + block_sum
        max_score = new_max

    # 출력
    result = acc / tl.maximum(sum_exp, 1e-8)
    out_offset = batch_id * stride_ob + head_id * stride_oh
    tl.store(Out + out_offset + d_range, result, mask=d_range < D)


@triton.autotune(
    configs=[
        triton.Config({'BLOCK_N': 32, 'BLOCK_D': 64, 'BLOCK_BH': 1}, num_warps=2, num_stages=2),
        triton.Config({'BLOCK_N': 64, 'BLOCK_D': 64, 'BLOCK_BH': 1}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_N': 128, 'BLOCK_D': 64, 'BLOCK_BH': 2}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_N': 256, 'BLOCK_D': 64, 'BLOCK_BH': 2}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 512, 'BLOCK_D': 64, 'BLOCK_BH': 4}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 64, 'BLOCK_D': 128, 'BLOCK_BH': 1}, num_warps=4, num_stages=2),
        triton.Config({'BLOCK_N': 128, 'BLOCK_D': 128, 'BLOCK_BH': 2}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 256, 'BLOCK_D': 128, 'BLOCK_BH': 4}, num_warps=8, num_stages=2),
        triton.Config({'BLOCK_N': 128, 'BLOCK_D': 256, 'BLOCK_BH': 2}, num_warps=8, num_stages=2),
    ],
    key=['B', 'H', 'N', 'D'],  # 모든 차원을 고려한 최적화
)
@triton.jit
def mega_autotuned_flash_kernel(
        Q, K, V, Out,
        stride_qb, stride_qh, stride_qd,
        stride_kb, stride_kh, stride_kn, stride_kd,
        stride_vb, stride_vh, stride_vn, stride_vd,
        stride_ob, stride_oh, stride_od,
        B, H, N, D,
        scale,
        BLOCK_N: tl.constexpr,
        BLOCK_D: tl.constexpr,
        BLOCK_BH: tl.constexpr,
):
    """
    Mega Autotuned Flash Kernel
    - 배치/헤드 블로킹도 자동 최적화
    - 대규모 워크로드 특화
    """
    pid = tl.program_id(0)

    # 블록당 여러 배치×헤드 처리
    for local_idx in range(BLOCK_BH):
        bh_id = pid * BLOCK_BH + local_idx
        if bh_id >= B * H:
            break

        batch_id = bh_id // H
        head_id = bh_id % H

        # Q 로드
        q_offset = batch_id * stride_qb + head_id * stride_qh
        d_range = tl.arange(0, BLOCK_D)
        q = tl.load(Q + q_offset + d_range, mask=d_range < D)

        # Flash attention 계산
        acc = tl.zeros([BLOCK_D], dtype=tl.float32)
        max_score = -float('inf')
        sum_exp = 0.0

        for start_n in tl.range(0, N, BLOCK_N):
            end_n = tl.minimum(start_n + BLOCK_N, N)
            block_size = end_n - start_n

            # K 로드
            k_offset = batch_id * stride_kb + head_id * stride_kh + start_n * stride_kn
            n_range = tl.arange(0, BLOCK_N)
            k_ptrs = K + k_offset + n_range[:, None] * stride_kn + d_range[None, :]
            k_mask = (n_range[:, None] < block_size) & (d_range[None, :] < D)
            k_vals = tl.load(k_ptrs, mask=k_mask, other=0.0)

            # Scores
            scores = tl.sum(q[None, :] * k_vals, axis=1) * scale
            scores = tl.where(n_range < block_size, scores, -float('inf'))

            # Softmax
            block_max = tl.max(scores)
            new_max = tl.maximum(max_score, block_max)

            if max_score > -float('inf'):
                scale_factor = tl.exp(max_score - new_max)
                acc = acc * scale_factor
                sum_exp = sum_exp * scale_factor

            weights = tl.exp(scores - new_max)
            weights = tl.where(n_range < block_size, weights, 0.0)
            block_sum = tl.sum(weights)

            # V 로드 및 누적
            v_offset = batch_id * stride_vb + head_id * stride_vh + start_n * stride_vn
            v_ptrs = V + v_offset + n_range[:, None] * stride_vn + d_range[None, :]
            v_mask = (n_range[:, None] < block_size) & (d_range[None, :] < D)
            v_vals = tl.load(v_ptrs, mask=v_mask, other=0.0)

            weighted_v = tl.sum(weights[:, None] * v_vals, axis=0)
            acc = acc + weighted_v
            sum_exp = sum_exp + block_sum
            max_score = new_max

        # 출력 저장
        result = acc / tl.maximum(sum_exp, 1e-8)
        out_offset = batch_id * stride_ob + head_id * stride_oh
        tl.store(Out + out_offset + d_range, result, mask=d_range < D)


def autotuned_flash_attn_decode(q, k, v, scale=None):
    """Autotuned Flash Attention Decode"""
    batch_size, n_heads, q_seq_len, d_head = q.shape
    _, _, kv_seq_len, _ = k.shape

    assert q_seq_len == 1, "Decode mode requires q_seq_len=1"

    if scale is None:
        scale = 1.0 / math.sqrt(d_head)

    out = torch.empty_like(q)
    grid = (batch_size * n_heads,)

    autotuned_flash_decode_kernel[grid](
        q, k, v, out,
        q.stride(0), q.stride(1), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        out.stride(0), out.stride(1), out.stride(3),
        batch_size, n_heads, kv_seq_len, d_head,
        scale,
    )

    return out


def mega_autotuned_flash_attn_decode(q, k, v, scale=None):
    """Mega Autotuned Flash Attention - 대규모 배치 특화"""
    batch_size, n_heads, q_seq_len, d_head = q.shape
    _, _, kv_seq_len, _ = k.shape

    if scale is None:
        scale = 1.0 / math.sqrt(d_head)

    out = torch.empty_like(q)

    # 자동 그리드 크기 계산
    total_bh = batch_size * n_heads
    # Autotune이 BLOCK_BH를 결정하므로 충분히 큰 그리드 설정
    grid_size = min(total_bh, 1024)  # 최대 1024 블록

    mega_autotuned_flash_kernel[(grid_size,)](
        q, k, v, out,
        q.stride(0), q.stride(1), q.stride(3),
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),
        out.stride(0), out.stride(1), out.stride(3),
        batch_size, n_heads, kv_seq_len, d_head,
        scale,
    )

    return out


def naive_attention(q, k, v, scale=None):
    """PyTorch 기본 구현"""
    if scale is None:
        scale = 1.0 / math.sqrt(q.size(-1))

    scores = torch.matmul(q, k.transpose(-2, -1)) * scale
    attn_weights = F.softmax(scores, dim=-1)
    out = torch.matmul(attn_weights, v)
    return out


def comprehensive_scale_benchmark():
    """포괄적 스케일 벤치마크 - batch_size=1 포함"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if not torch.cuda.is_available():
        print("❌ CUDA 필요")
        return

    # 다양한 스케일 테스트 (batch_size=1 포함)
    test_configs = [
        # Single batch 시나리오 (추론 서버)
        {"batch_size": 1, "n_heads": 12, "kv_seq_len": 2048, "d_head": 64},
        {"batch_size": 1, "n_heads": 32, "kv_seq_len": 8192, "d_head": 128},
        {"batch_size": 1, "n_heads": 40, "kv_seq_len": 16384, "d_head": 128},
        {"batch_size": 1, "n_heads": 64, "kv_seq_len": 32768, "d_head": 128},

        # Small batch 시나리오
        {"batch_size": 4, "n_heads": 96, "kv_seq_len": 4096, "d_head": 128},
        {"batch_size": 8, "n_heads": 128, "kv_seq_len": 8192, "d_head": 128},

        # Medium batch 시나리오
        {"batch_size": 32, "n_heads": 64, "kv_seq_len": 4096, "d_head": 128},
        {"batch_size": 64, "n_heads": 80, "kv_seq_len": 8192, "d_head": 128},

        # Large batch 시나리오
        {"batch_size": 128, "n_heads": 96, "kv_seq_len": 4096, "d_head": 128},
        {"batch_size": 256, "n_heads": 64, "kv_seq_len": 2048, "d_head": 128},

        # 초장문 컨텍스트
        { "batch_size": 1, "n_heads": 32, "kv_seq_len": 128000, "d_head": 128},
        { "batch_size": 2, "n_heads": 32, "kv_seq_len": 128000, "d_head": 128},
        # 128K tokens
    ]

    print("🚀 포괄적 Autotuned Flash Attention 벤치마크")
    print("=" * 100)

    for config in test_configs:
        batch_size = config["batch_size"]
        n_heads = config["n_heads"]
        kv_seq_len = config["kv_seq_len"]
        d_head = config["d_head"]

        # 메모리 계산
        kv_memory_gb = (batch_size * n_heads * kv_seq_len * d_head * 2 * 4) / (1024 ** 3)

        print(f"📊 B={batch_size}, H={n_heads}, KV_len={kv_seq_len:,}, D={d_head}")
        print(f"💾 KV Cache: {kv_memory_gb:.2f} GB")

        # 메모리 제한 체크
        if kv_memory_gb > 24:  # 24GB 제한
            print("⚠️ 메모리 제한으로 스킵")
            continue

        try:
            # 데이터 준비
            dtype = torch.float16 if kv_memory_gb > 4 else torch.float32
            q = torch.randn(batch_size, n_heads, 1, d_head, device=device, dtype=dtype)
            k = torch.randn(batch_size, n_heads, kv_seq_len, d_head, device=device, dtype=dtype)
            v = torch.randn(batch_size, n_heads, kv_seq_len, d_head, device=device, dtype=dtype)

            scale = 1.0 / math.sqrt(d_head)

            methods = []
            warmup = 3
            runs = 10

            # 1. Naive PyTorch
            try:
                for _ in range(warmup):
                    _ = naive_attention(q, k, v, scale)
                torch.cuda.synchronize()

                start = time.time()
                for _ in range(runs):
                    out_naive = naive_attention(q, k, v, scale)
                torch.cuda.synchronize()

                naive_time = (time.time() - start) / runs
                methods.append(("Naive PyTorch", naive_time, out_naive))
            except Exception as e:
                print(f"❌ Naive 실패: {e}")

            # 2. PyTorch Flash Attention 2
            try:
                for _ in range(warmup):
                    _ = F.scaled_dot_product_attention(q, k, v, scale=scale)
                torch.cuda.synchronize()

                start = time.time()
                for _ in range(runs):
                    out_fa2 = F.scaled_dot_product_attention(q, k, v, scale=scale)
                torch.cuda.synchronize()

                fa2_time = (time.time() - start) / runs
                methods.append(("PyTorch FA2", fa2_time, out_fa2))
            except Exception as e:
                print(f"❌ PyTorch FA2 실패: {e}")

            # 3. Autotuned Flash Decode
            try:
                print("🔧 Triton autotune 진행 중...")
                for _ in range(warmup):
                    _ = autotuned_flash_attn_decode(q, k, v, scale)
                torch.cuda.synchronize()

                start = time.time()
                for _ in range(runs):
                    out_auto = autotuned_flash_attn_decode(q, k, v, scale)
                torch.cuda.synchronize()

                auto_time = (time.time() - start) / runs
                methods.append(("Autotuned Triton", auto_time, out_auto))
            except Exception as e:
                print(f"❌ Autotuned 실패: {e}")

            # 4. Mega Autotuned (배치 크기 >= 4만)
            if batch_size >= 4:
                try:
                    print("🔧 Mega autotune 진행 중...")
                    for _ in range(warmup):
                        _ = mega_autotuned_flash_attn_decode(q, k, v, scale)
                    torch.cuda.synchronize()

                    start = time.time()
                    for _ in range(runs):
                        out_mega = mega_autotuned_flash_attn_decode(q, k, v, scale)
                    torch.cuda.synchronize()

                    mega_time = (time.time() - start) / runs
                    methods.append(("Mega Autotuned", mega_time, out_mega))
                except Exception as e:
                    print(f"❌ Mega Autotuned 실패: {e}")

            # 결과 출력
            if methods:
                print(f"\n{'구현':<18} {'시간(ms)':<10} {'속도향상':<10} {'처리량(M tok/s)':<15} {'정확도'}")
                print("-" * 78)

                baseline_time = methods[0][1]
                baseline_output = methods[0][2]

                for name, exec_time, output in methods:
                    speedup = f"{baseline_time / exec_time:.1f}x"

                    # 토큰 처리량 (M tokens/sec)
                    total_tokens = batch_size * kv_seq_len
                    throughput_m = (total_tokens / exec_time) / 1e6

                    # 정확도 체크
                    try:
                        if torch.allclose(output, baseline_output, rtol=1e-2, atol=1e-2):
                            accuracy = "✅"
                        else:
                            diff = torch.max(torch.abs(output - baseline_output)).item()
                            accuracy = f"⚠️{diff:.1e}"
                    except:
                        accuracy = "❓"

                    print(f"{name:<18} {exec_time * 1000:<9.1f} {speedup:<10} {throughput_m:<14.1f} {accuracy}")

        except torch.cuda.OutOfMemoryError:
            print("❌ GPU 메모리 부족")
        except Exception as e:
            print(f"❌ 테스트 실패: {e}")


if __name__ == "__main__":
    comprehensive_scale_benchmark()

🚀 포괄적 Autotuned Flash Attention 벤치마크
📊 B=1, H=12, KV_len=2,048, D=64
💾 KV Cache: 0.01 GB
🔧 Triton autotune 진행 중...

구현                 시간(ms)     속도향상       처리량(M tok/s)    정확도
------------------------------------------------------------------------------
Naive PyTorch      0.2       1.0x       9.8            ✅
PyTorch FA2        0.3       0.8x       7.5            ✅
Autotuned Triton   0.1       3.6x       35.0           ✅
📊 B=1, H=32, KV_len=8,192, D=128
💾 KV Cache: 0.25 GB
🔧 Triton autotune 진행 중...

구현                 시간(ms)     속도향상       처리량(M tok/s)    정확도
------------------------------------------------------------------------------
Naive PyTorch      0.4       1.0x       19.4           ✅
PyTorch FA2        1.4       0.3x       6.0            ✅
Autotuned Triton   0.2       2.0x       38.5           ⚠️6.1e-02
📊 B=1, H=40, KV_len=16,384, D=128
💾 KV Cache: 0.62 GB
🔧 Triton autotune 진행 중...

구현                 시간(ms)     속도향상       처리량(M tok/s)    정확도
--------------------------------