# Flash Attention 2 & Fused Kernel Benchmark

이 노트북은 동일한 가중치에서 eager(기본)와 fused(최적화) 커널을 일관되게 비교하고, 최종적으로 실 서비스에 가까운 사전 학습 언어 모델에서 속도와 정확도를 확인합니다. 모든 실험은 GPU에서 실행되며, 양쪽 모델 모두 `torch.compile`을 적용해 공정하게 측정합니다.


## 실험 개요
- Flash Attention 2 vs. Eager Multi-Head Attention
- RMSNorm / SwiGLU MLP / Linear+CrossEntropy (eager vs. fused)
- 전체 Transformer (무작위 가중치) 비교
- `meta-llama/Llama-3.2-1B-Instruct` 체크포인트를 활용한 엔드-투-엔드 추론 성능 비교


In [None]:
import math
import time
from functools import partial
from time import perf_counter

import torch
import transformers

if not torch.cuda.is_available():
    raise RuntimeError('이 노트북은 GPU(CUDA) 환경에서만 실행할 수 있습니다.')

device = torch.device('cuda')
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
torch.set_float32_matmul_precision('high')

print(f'Device: {device}')
print(f'Dtype: {dtype}')
print(f'PyTorch: {torch.__version__}')
print(f'Transformers: {transformers.__version__}')


## 설정 및 모듈 로드
기본(eager) 커널과 fused 커널을 동시에 가져와 비교합니다. 둘 다 동일한 `ExampleConfig`를 사용해 가중치를 공유합니다.


In [None]:
from models.default_config import TransformerConfig
from models.example_config import ExampleConfig
from models.default_layers import MultiHeadAttention, RMSNorm as DefaultRMSNorm, SwiGLUMLP as DefaultSwiGLUMLP, fused_linear_cross_entropy as default_linear_ce, RotaryEmbedding
from models.example_layers import FlashMultiHeadAttention, RMSNorm as FusedRMSNorm, SwiGLUMLP as FusedSwiGLUMLP, fused_linear_cross_entropy as fused_linear_ce
from models.default_model import TransformerForCausalLM
from models.example_model import ExampleTransformerForCausalLM

SEQ_LEN = 1024
BATCH_SIZE = 4
PAD_TO = SEQ_LEN - 128
COMPILE_MODE = 'reduce-overhead'

config = ExampleConfig(
    hidden_size=1024,
    intermediate_size=4096,
    num_hidden_layers=1,
    num_attention_heads=16,
    num_key_value_heads=16,
    max_position_embeddings=max(SEQ_LEN, 2048),
    attention_dropout=0.0,
)


## 공용 유틸리티
- `compile_module`: `torch.compile`을 적용(없을 경우 경고)
- `benchmark_callable`: GPU 싱크를 포함한 일관된 벤치마크
- 가중치 복제 및 텐서 샘플링 헬퍼


In [None]:
def compile_module(module, mode=COMPILE_MODE):
    if hasattr(torch, 'compile'):
        try:
            return torch.compile(module, mode=mode)
        except RuntimeError as err:
            print(f'torch.compile 실패: {err}. 원본 모듈을 반환합니다.')
            return module
    else:
        print('torch.compile이 지원되지 않아 원본 모듈을 사용합니다.')
        return module

def benchmark_callable(callable_obj, *call_args, warmup=10, iters=50, **call_kwargs):
    torch.cuda.synchronize()
    with torch.inference_mode():
        for _ in range(warmup):
            callable_obj(*call_args, **call_kwargs)
    torch.cuda.synchronize()
    start = perf_counter()
    with torch.inference_mode():
        for _ in range(iters):
            callable_obj(*call_args, **call_kwargs)
    torch.cuda.synchronize()
    end = perf_counter()
    return (end - start) / iters

def sample_attention_inputs(config, batch_size=BATCH_SIZE, seq_len=SEQ_LEN, pad_to=PAD_TO, seed=0):
    torch.manual_seed(seed)
    hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device=device, dtype=dtype)
    position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
    rotary = RotaryEmbedding(config).to(device=device)
    cos, sin = rotary(hidden_states, position_ids)
    cos = cos.to(dtype=dtype)
    sin = sin.to(dtype=dtype)
    attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.bool)
    if pad_to is not None and pad_to < seq_len:
        attention_mask[:, pad_to:] = 0
    return hidden_states, (cos, sin), attention_mask

def diff_stats(a, b):
    diff = (a - b).abs()
    return diff.max().item(), diff.mean().item()


## Flash Attention 2 vs Eager Multi-Head Attention
동일한 가중치를 공유하는 두 모듈을 `torch.compile`로 컴파일한 뒤, padding이 포함된 배치에서 속도와 수치 차이를 비교합니다.


In [None]:
hidden_states, position_embeddings, attention_mask = sample_attention_inputs(config)

baseline_attn = MultiHeadAttention(config, layer_idx=0, is_causal=True).to(device=device, dtype=dtype)
flash_attn = FlashMultiHeadAttention(config, layer_idx=0, is_causal=True).to(device=device, dtype=dtype)
flash_attn.load_state_dict(baseline_attn.state_dict(), strict=False)
baseline_attn.eval()
flash_attn.eval()

baseline_attn = compile_module(baseline_attn)
flash_attn = compile_module(flash_attn)

with torch.inference_mode():
    baseline_out = baseline_attn(hidden_states, position_embeddings, attention_mask=attention_mask)
    flash_out = flash_attn(hidden_states, position_embeddings, attention_mask=attention_mask)

max_diff, mean_diff = diff_stats(baseline_out, flash_out)
baseline_time = benchmark_callable(baseline_attn, hidden_states, position_embeddings, attention_mask=attention_mask)
flash_time = benchmark_callable(flash_attn, hidden_states, position_embeddings, attention_mask=attention_mask)
speedup = baseline_time / max(flash_time, 1e-12)

print(f'Eager attention:  {baseline_time * 1e3:.2f} ms')
print(f'Flash attention: {flash_time * 1e3:.2f} ms')
print(f'Speedup:         {speedup:.2f}x')
print(f'Max abs diff:    {max_diff:.3e}')
print(f'Mean abs diff:   {mean_diff:.3e}')


## RMSNorm / SwiGLU MLP / Linear+CE (Eager vs Fused)
각 커널을 동일한 가중치로 초기화하고 `torch.compile`을 적용한 뒤, 속도와 오차를 비교합니다.


In [None]:
rms_default = DefaultRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(device=device, dtype=dtype)
rms_fused = FusedRMSNorm(config.hidden_size, eps=config.rms_norm_eps).to(device=device, dtype=dtype)
rms_fused.load_state_dict(rms_default.state_dict(), strict=False)

mlp_default = DefaultSwiGLUMLP(config).to(device=device, dtype=dtype)
mlp_fused = FusedSwiGLUMLP(config).to(device=device, dtype=dtype)
mlp_fused.load_state_dict(mlp_default.state_dict(), strict=False)

rms_default = compile_module(rms_default)
rms_fused = compile_module(rms_fused)
mlp_default = compile_module(mlp_default)
mlp_fused = compile_module(mlp_fused)

norm_inputs = torch.randn(BATCH_SIZE, SEQ_LEN, config.hidden_size, device=device, dtype=dtype)
mlp_inputs = norm_inputs.clone()
ce_hidden = torch.randn(BATCH_SIZE * SEQ_LEN, config.hidden_size, device=device, dtype=dtype)
ce_labels = torch.randint(0, config.vocab_size, (BATCH_SIZE * SEQ_LEN,), device=device)
lm_head_weight = torch.randn(config.vocab_size, config.hidden_size, device=device, dtype=dtype)

with torch.inference_mode():
    rms_eager = rms_default(norm_inputs)
    rms_fused_out = rms_fused(norm_inputs)
    mlp_eager = mlp_default(mlp_inputs)
    mlp_fused_out = mlp_fused(mlp_inputs)
    ce_eager = default_linear_ce(ce_hidden, ce_labels, lm_head_weight)
    ce_fused = fused_linear_ce(ce_hidden, ce_labels, lm_head_weight)

rms_diff = diff_stats(rms_eager, rms_fused_out)
mlp_diff = diff_stats(mlp_eager, mlp_fused_out)
ce_gap = abs(ce_fused.item() - ce_eager.item())

rms_time_eager = benchmark_callable(rms_default, norm_inputs)
rms_time_fused = benchmark_callable(rms_fused, norm_inputs)
mlp_time_eager = benchmark_callable(mlp_default, mlp_inputs)
mlp_time_fused = benchmark_callable(mlp_fused, mlp_inputs)
ce_time_eager = benchmark_callable(default_linear_ce, ce_hidden, ce_labels, lm_head_weight)
ce_time_fused = benchmark_callable(fused_linear_ce, ce_hidden, ce_labels, lm_head_weight)

print('=== RMSNorm (Eager vs Fused) ===')
print(f'Eager: {rms_time_eager * 1e3:.2f} ms | Fused: {rms_time_fused * 1e3:.2f} ms | Speedup: {rms_time_eager / max(rms_time_fused, 1e-12):.2f}x')
print(f'Max diff: {rms_diff[0]:.3e} | Mean diff: {rms_diff[1]:.3e}')

print('
=== SwiGLU MLP (Eager vs Fused) ===')
print(f'Eager: {mlp_time_eager * 1e3:.2f} ms | Fused: {mlp_time_fused * 1e3:.2f} ms | Speedup: {mlp_time_eager / max(mlp_time_fused, 1e-12):.2f}x')
print(f'Max diff: {mlp_diff[0]:.3e} | Mean diff: {mlp_diff[1]:.3e}')

print('
=== Linear + CrossEntropy ===')
print(f'Eager: {ce_time_eager * 1e3:.2f} ms | Fused: {ce_time_fused * 1e3:.2f} ms | Speedup: {ce_time_eager / max(ce_time_fused, 1e-12):.2f}x')
print(f'Loss gap: {ce_gap:.3e}')


## 전체 Transformer (Synthetic) 비교
무작위 초기화된 동일한 가중치에서 전체 모델 forward와 loss 계산을 비교합니다.


In [None]:
default_model = TransformerForCausalLM(config).to(device=device, dtype=dtype)
optimized_model = ExampleTransformerForCausalLM(config).to(device=device, dtype=dtype)
optimized_model.load_state_dict(default_model.state_dict(), strict=False)

default_model.eval()
optimized_model.eval()
default_model = compile_module(default_model)
optimized_model = compile_module(optimized_model)

input_ids = torch.randint(0, config.vocab_size, (BATCH_SIZE, SEQ_LEN), device=device)
labels = input_ids.clone()

with torch.inference_mode():
    eager_out = default_model(input_ids=input_ids, labels=labels)
    fused_out = optimized_model(input_ids=input_ids, labels=labels)

logits_diff = diff_stats(eager_out.logits, fused_out.logits)
loss_gap = abs((fused_out.loss - eager_out.loss).item())

forward_time_eager = benchmark_callable(lambda: default_model(input_ids=input_ids, labels=labels))
forward_time_fused = benchmark_callable(lambda: optimized_model(input_ids=input_ids, labels=labels))
speedup = forward_time_eager / max(forward_time_fused, 1e-12)

print(f'Eager forward:   {forward_time_eager * 1e3:.2f} ms')
print(f'Fused forward:  {forward_time_fused * 1e3:.2f} ms')
print(f'Speedup:         {speedup:.2f}x')
print(f'Max logit diff:  {logits_diff[0]:.3e}')
print(f'Mean logit diff: {logits_diff[1]:.3e}')
print(f'Loss gap:        {loss_gap:.3e}')


## 실제 체크포인트 로드: meta-llama/Llama-3.2-1B-Instruct
다음 셀은 Hugging Face 허브에서 사전 학습 모델과 토크나이저를 로드한 뒤, eager Transformer와 fused Transformer에 동일한 가중치를 로드하고 성능을 측정합니다. (인터넷 접근이 필요합니다.)


In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct', use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

prompts = [
    'Explain the Flash Attention algorithm to a senior ML engineer in three bullet points.',
    'Summarize the benefits of fused MLP kernels in transformer inference.',
]
batch = tokenizer(prompts, return_tensors='pt', padding=True).to(device)
prefill_length = batch['input_ids'].shape[-1]
print(f'Prompt length: {prefill_length}')


In [None]:
default_pretrained = TransformerForCausalLM.from_pretrained(
    'meta-llama/Llama-3.2-1B-Instruct',
    torch_dtype=dtype,
    device_map={'': str(device)},
)
example_pretrained = ExampleTransformerForCausalLM.from_pretrained(
    'meta-llama/Llama-3.2-1B-Instruct',
    torch_dtype=dtype,
    device_map={'': str(device)},
)

default_pretrained.eval()
example_pretrained.eval()

default_pretrained = compile_module(default_pretrained)
example_pretrained = compile_module(example_pretrained)

def benchmark_generation(model, input_ids, attention_mask, max_new_tokens=64, warmup=2, iters=5):
    torch.cuda.synchronize()
    with torch.inference_mode():
        for _ in range(warmup):
            model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                use_cache=True,
            )
    torch.cuda.synchronize()
    start = perf_counter()
    with torch.inference_mode():
        for _ in range(iters):
            model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                use_cache=True,
            )
    torch.cuda.synchronize()
    end = perf_counter()
    return (end - start) / iters

default_time = benchmark_generation(default_pretrained, batch['input_ids'], batch['attention_mask'])
example_time = benchmark_generation(example_pretrained, batch['input_ids'], batch['attention_mask'])
tokens_generated = batch['input_ids'].shape[0] * 64
throughput_default = tokens_generated / max(default_time, 1e-12)
throughput_example = tokens_generated / max(example_time, 1e-12)

print(f'Eager model:  {default_time:.3f} s per batch | Throughput: {throughput_default:.1f} tokens/s')
print(f'Fused model: {example_time:.3f} s per batch | Throughput: {throughput_example:.1f} tokens/s')
print(f'Speedup:      {default_time / max(example_time, 1e-12):.2f}x')


## 요약 및 다음 단계
- 모든 커널을 `torch.compile`로 감싼 상태에서 GPU에서 일관되게 비교했습니다.
- Flash Attention 2 및 fused 커널이 eager 대비 어떤 속도 향상을 주는지 즉시 확인할 수 있습니다.
- Hugging Face 체크포인트를 이용해 실제 언어 모델 추론에서도 성능 차이를 측정할 수 있습니다.

### 권장 후속 작업
1. 실제 워크로드(프롬프트 길이, 배치 크기)를 반영해 파라미터를 수정
2. `torch.compile` 모드를 `max-autotune` 등으로 변경해 최적 지점을 탐색
3. 추가 fused 커널(예: attention mask fusing, layernorm)을 연결해 전체 파이프라인을 튜닝
