# Sliding Window Attention (SWA)

- 일반적인 self-attention은 각 sequence 요소가 다른 모든 sequence 요소에 접근할 수 있다는 점에서 global attention mechanism으로 생각할 수 있음.
- SWA는 context 크기를 현재 Query 위치의 주변으로 제한하기 때문에 local attention mechanism으로 생각할 수 있음.

![SWA](https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/swa-memory/1.webp)

- 이전의 모든 token에 attention을 하는 대신, **각 token은 현재 위치를 중심으로 fixed-size local window에만 attention 연산을 수행**함.
  - 이러한 localized attention은 **KV-cache의 크기를 크게 줄여줌**.
- sliding window attention은 원래 처음에 [2020년 LongFormer](https://arxiv.org/abs/2004.05150)라는 논문에서 처음 제시되었음.
  - 여기서(이 study 자료에서) Gemma 모델에 집중하는 이유는, 이 모델들이 SWA이 최근 다양한 모델에서 실재로 실행 가능한 접근 방식임을 보여주는 매우 훌륭한 open-weight model이기 때문.
- [Gemma 2](https://arxiv.org/abs/2408.00118)는 local(sliding window) attention layer와 global attention layer를 1:1 비율로 결합한 hybrid 방식을 사용했고, 각 token은 4000개의 token으로 구성된 context window에 attention을 함.
- [Gemma 3](https://arxiv.org/abs/2503.19786)는 효율성을 높이기 위해 설계를 더욱 발전시킴.
  - sliding window와 전체 attention layer의 비율을 5:1로 사용(즉, local attention layer 5개당 global attention layer가 1개)
  - sliding window 크기는 Gemma 2에선 4096 token 이었지만, 여기선 1024 token으로 축소됨.
- 이럼에도 불구하고, ablation study에 따르면 모델의 전반적인 quality에 아주 미세한 영향만을 미치는 것으로 분석되었음.
  - 즉, **sliding window 방식을 통해 상당한 메모리 및 computing resource를 절감**시켰을 뿐만 아니라, **modeling의 성능 손실은 최소화** 할 수 있었던 것.

### Code Examples

- SWA는 마찬가지로 GQA와 같은 다른 attention mechanism과 결합될 수 있음.
- 이전과 마찬가지로 학습된 상태가 아니므로 의미없는 텍스트를 생성함.
- 앞서 작성했던 `kv-cache` 기능도 함께 사용함.

In [10]:
import time
import tiktoken
import torch
import torch.nn as nn

In [11]:
class MultiHeadAttentionWithSWA(nn.Module):
    def __init__(self, d_in, d_out, dropout, num_heads, qkv_bias=False, sliding_window_size=None):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)    # Query weight
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)      # Key weight
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)    # Value weight
        self.out_projection = nn.Linear(d_out, d_out)           # Last output projection(head의 ouput들을 concat한 결과물)
        self.dropout = nn.Dropout(dropout)

        # 새로 추가된 부분
        self.sliding_window_size = sliding_window_size

        # KV-cache
        self.register_buffer('cache_k', None, persistent=False)
        self.register_buffer('cache_v', None, persistent=False)
        self.ptr_current_pos = 0
    
    def forward(self, x, use_cache=False):
        b, num_tokens, d_in = x.shape

        # Q, K, V projection
        Q, K_new, V_new = self.W_query(x), self.W_key(x), self.W_value(x)  # [b, num_tokens, d_out]
    
        # Reshape -> 기존의 split() 기능
            # [b, num_tokens, d_out] -> [b, num_tokens, num_heads, head_dim]
        Q = Q.view(b, num_tokens, self.num_heads, self.head_dim)
        K_new = K_new.view(b, num_tokens, self.num_heads, self.head_dim)
        V_new = V_new.view(b, num_tokens, self.num_heads, self.head_dim)

        # KV-cache update
        if use_cache:
            old_len = 0 if self.cache_k is None else self.cache_k.size(1)
            if self.cache_k is None:
                self.cache_k, self.cache_v = K_new, V_new
            else:
                self.cache_k = torch.cat([self.cache_k, K_new], dim=1)
                self.cache_v = torch.cat([self.cache_v, V_new], dim=1)
            
            # sliding window 적용
            if self.sliding_window_size is not None:
                if self.cache_k.size(1) > self.sliding_window_size:
                    self.cache_k = self.cache_k[:, -self.sliding_window_size:, :, :]
                    self.cache_v = self.cache_v[:, -self.sliding_window_size:, :, :]

            # masking을 위해 absolute start position 계산
            total_len = old_len + num_tokens
            k_len_now = self.cache_k.size(1)
            dropped = max(0, total_len - k_len_now)

            k_start_pos_abs = (self.ptr_current_pos - old_len) + dropped
            q_start_pos_abs = self.ptr_current_pos
            K, V = self.cache_k, self.cache_v
        
        else:
            K, V = K_new, V_new
        

        # attention score 연산을 위한 transpose
            # [b, num_tokens, num_heads, head_dim] -> [b, num_heads, num_tokens, head_dim]
        Q, K, V = Q.transpose(1, 2), K.transpose(1, 2), V.transpose(1, 2)

        # attention score 계산
        attn_scores = Q @ K.transpose(2, 3)

        # 일반 masking + sliding window masking
        num_tokens_Q = Q.shape[-2]
        num_tokens_K = K.shape[-2]
        device = Q.device

        # Q, K에 대한 absolute position 계산
        if use_cache:
            q_start, k_start = q_start_pos_abs, k_start_pos_abs
        else:
            q_start, k_start = 0, 0
        
        q_positions = torch.arange(q_start, q_start + num_tokens_Q, device=device, dtype=torch.long)
        k_positions = torch.arange(k_start, k_start + num_tokens_K, device=device, dtype=torch.long)

        # sliding window의 너비(width) 계산
        W = num_tokens_K + 1 if self.sliding_window_size is None else int(self.sliding_window_size)
        diff = q_positions.unsqueeze(-1) - k_positions.unsqueeze(0)  # [num_tokens_Q, num_tokens_K]
        
        # 마찬가지로 masking을 어디까지 적용할지 boolean으로 저장
        mask_bool = (diff < 0) | (diff >= W)

        if use_cache:
            self.ptr_current_pos += num_tokens_Q
        else:
            self.ptr_current_pos = 0

        # 실제 Masking 수행
        attn_scores = attn_scores.masked_fill(mask_bool, -torch.inf)

        # softmax scaling 및 dropout
        attn_weights = torch.softmax(attn_scores / K.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # (Q * K^T) * V 연산
            # [B, n_tokens, num_heads, d_head]
        context_vector = (attn_weights @ V).transpose(1, 2)

        # head들을 다시 CONCAT
            # [B, num_heads, n_tokens, d_head] -> [B, n_tokens, d_out]
        context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
        context_vector = self.out_projection(context_vector)  # output projection

        return context_vector

    def reset_cache(self):
        self.cache_k, self.cache_v, = None, None
        self.ptr_current_pos = 0

In [12]:
from _04_gpt import LayerNorm, FeedForward

class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.attention = MultiHeadAttentionWithSWA(
            d_in = cfg['embed_dim'],
            d_out = cfg['embed_dim'],
            num_heads = cfg['num_heads'],
            dropout = cfg['drop_rate'],
            qkv_bias = cfg['qkv_bias'],
            sliding_window_size = cfg['sliding_window_size']
        )

        self.ffn = FeedForward(cfg)

        self.norm1 = LayerNorm(cfg['embed_dim'])
        self.norm2 = LayerNorm(cfg['embed_dim'])

        self.drop_shortcut = nn.Dropout(cfg['drop_rate'])

    def forward(self, x, use_cache=False):
        # attention with skip connection
        residual = x
        x = self.norm1(x)           # LayerNorm
        x = self.attention(x, use_cache=use_cache)       # MHA, [batch, context_length, embed_dim]

        x = self.drop_shortcut(x)   # Dropout
        x = x + residual            # skip(residual) connection

        # FFN with skip connection
        residual = x
        x = self.norm2(x)           # LayerNorm
        x = self.ffn(x)             # FeedForward
        x = self.drop_shortcut(x)   # Dropout
        x = x + residual            # skip(residual) connection

        return x

In [13]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.token_embedding = nn.Embedding(cfg['vocab_size'], cfg['embed_dim'])
        self.position_embedding = nn.Embedding(cfg['context_length'], cfg['embed_dim'])
        self.drop_embedding = nn.Dropout(cfg['drop_rate'])

        # sliding window 관련 setting
            # 한번에 몇 칸을 이동할 것인지(stride), window의 크기는 얼마로 할 것인지(size)
        blocks = []
        window_stride = cfg["sliding_window_stride"]
        window_size = cfg["sliding_window_size"] if "sliding_window_size" in cfg else None

        for i in range(cfg['num_layers']):
            transformer_block = TransformerBlock(cfg)

            # 1개의 regular layer마다 K개의 SWA layer를 사용
            K = int(window_stride)
            if K <= 0:
                # 0이면 SWA layer를 사용하지 않고, negative면 모두 SWA layer를 사용
                use_swa = False if K == 0 else True
            else:
                group = K + 1
                use_swa = (i % group) < K

            transformer_block.attention.sliding_window_size = window_size if use_swa else None
            blocks.append(transformer_block)
        
        self.transformer_blocks = nn.ModuleList(blocks)

        self.current_pos = 0

        self.final_norm = LayerNorm(cfg['embed_dim'])
        self.out_head = nn.Linear(cfg['embed_dim'], cfg['vocab_size'], bias=False)
    
    def forward(self, in_idx, use_cache=False):
        batch_size, seq_length = in_idx.shape

        # token embedding, positional embedding을 더해서 최종 input embedding 구성
        token_embeddings = self.token_embedding(in_idx)
        
        if use_cache:
            pos_ids = torch.arange(
                self.current_pos, self.current_pos + seq_length, device=in_idx.device, dtype=torch.long
            )
            self.current_pos += seq_length
        else:
            pos_ids = torch.arange(
                0, seq_length, device=in_idx.device, dtype=torch.long
            )
        pos_embeddings = self.position_embedding(pos_ids).unsqueeze(0)

        x = token_embeddings + pos_embeddings   # [batch_size, num_tokens, embed_dim]

        x = self.drop_embedding(x)

        # Transformer block forward pass
        for block in self.transformer_blocks:
            x = block(x, use_cache=use_cache)

        # last layer norm
        x = self.final_norm(x)

        logits = self.out_head(x)

        return logits
    
    def reset_kv_cache(self):
        for block in self.transformer_blocks:
            block.attention.reset_cache()
        self.current_pos = 0

In [14]:
from _04_gpt import generate_text_simple_cached

start_context = "O say can you see,"

tokenizer = tiktoken.get_encoding("gpt2")
encoded = tokenizer.encode(start_context)

GPT_CONFIG_124M = {
        'vocab_size': 50257,                    # Vocabulary size
        'context_length': 200 + len(encoded),   # Context(max sequence) length
        'embed_dim': 768,                       # Embedding dimension
        'num_heads': 12,                        # Number of attention heads
        'num_layers': 12,                       # Number of layers(transformer blocks)
        'drop_rate': 0.1,                       # Dropout rate
        'qkv_bias': False,                      # Q,K,V bias
        'sliding_window_size': 1024,            # Sliding window size
        'sliding_window_stride': 2              # Sliding window stride
    }

torch.manual_seed(62)
model = GPTModel(GPT_CONFIG_124M)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

model.to(device)
model.eval()  # disable dropout

encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)

print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
print("\nInput text:", start_context)
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)

if torch.cuda.is_available():
    torch.cuda.synchronize()
start = time.time()

token_ids = generate_text_simple_cached(
    model = model,
    idx = encoded_tensor,
    max_new_tokens = 200
)

if torch.cuda.is_available():
    torch.cuda.synchronize()
total_time = time.time() - start

decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
print("\nOutput:", token_ids)
print("Output length:", len(token_ids[0]))
print("Output text:", decoded_text)

print(f"\nTime: {total_time:.2f} sec")
print(f"{int(len(token_ids[0])/total_time)} tokens/sec")

if torch.cuda.is_available():
    max_mem_bytes = torch.cuda.max_memory_allocated()
    max_mem_gb = max_mem_bytes / (1024 ** 3)
    print(f"Max memory allocated: {max_mem_gb:.2f} GB")

Using device: cuda

                      IN

Input text: O say can you see,
Encoded input text: [46, 910, 460, 345, 766, 11]
encoded_tensor.shape: torch.Size([1, 6])


                      OUT

Output: tensor([[   46,   910,   460,   345,   766,    11, 22226,   925, 34232,  5168,
         27785, 14186, 32605,  7357, 47936, 24290, 48499, 30992, 39717, 20059,
          1097, 24607, 30354, 14055, 15616, 24603, 27361, 29567, 48629, 11703,
         24760, 13227,  5212, 37031, 12895, 24726,  6632, 32664, 17420, 20044,
         48934, 31498, 23915,  7478, 16821, 23507, 28713, 33450, 33852, 18144,
         38192, 27193, 15106, 41251, 30723, 25910, 18118, 15063, 14396, 11831,
         13480, 10588,  3519, 34547, 28987, 46372, 19753,   677,  8555, 19954,
          5332, 49938, 35240, 31508,  1448, 17127, 29945, 22865, 10322, 21679,
          6455, 27560, 45560,  2424, 18761, 31089, 41822, 40607,  5058, 15455,
         22056,  9750,  1147, 26224, 27054, 44730, 26320, 22049,  9370, 36206,
      

In [15]:
from _04_gpt_with_kv_cache import main

main()

Using device: cuda

                      IN

Input text: O say can you see,
Encoded input text: [46, 910, 460, 345, 766, 11]
encoded_tensor.shape: torch.Size([1, 6])


                      OUT

Output: tensor([[   46,   910,   460,   345,   766,    11,  2213,  8344, 31206,  5306,
         21397, 33975, 27504, 11423, 48095, 27370, 22947,  5887, 41265, 22451,
         29247, 17442, 19084, 11193, 36335, 47439, 13261,  4605,  6148, 33883,
          5357, 48165, 14897,  2718, 36130, 30210, 35545, 44590, 17908, 25627,
         44931, 18421, 10639, 29587, 35404,  9167, 36746, 45355, 43651,  5747,
         35434, 13811, 19969, 33151, 10422, 28924, 19252, 38586, 13270, 44607,
         21911, 18656, 41031, 33233, 23345, 16265, 46451,  9519, 43042, 15282,
         32318, 10110, 32940, 19949, 39966, 44046,  7773, 10712, 12995,  3246,
            25, 30030, 44784, 40684, 37284,    66, 39577, 25370,  6908, 18202,
         33103,  5048,  6171, 42269, 27385,  1218, 49331, 23697, 36703, 20438,
      

- KV-cache를 사용한 일반적인 MHA보다 초당 처리하는 토큰 수가 많아짐. (49 vs. 33)
- 또한 localized attention (SWA)으로 인해 처리 속도, GPU memory 사용량이 줄어들었음.
  - 4.15s (SWA) vs. 6.13s (MHA)
  - 1.24GB (SWA) vs. 1.91GB (MHA)
- 하지만, **명시적으로 head 수를 줄였던 GQA에 비해선 처리 속도, GPU memory 사용량이 뒤떨어짐**.
  - 3.77s (GQA) vs. 4.15s (SWA)
  - 0.58GB (GQA) vs. 1.24GB (SWA) 
  - 아무래도 core attention 연산(수식)을 직접 수정한 것이 아닌, sliding window 기법을 통한 trick?을 써서 차이점이 발생하는 것이 아닌가 하는 생각.

[LongFormer에서 처음 제안된 SWA 참고 유튜브 링크](https://youtu.be/it0iZ93aLs4)  <br>
[Google Gemma 3 presentation](https://youtu.be/FagNt06rSBk)