# Transformer from scratch — Blank Notebook (Aing)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aing-gachon/26-Spring-Transformer-Study/blob/main/Week2/Aing_Transformer_from_scratch_blank.ipynb)

이 노트북은 `transformer_from_scratch.py`를 기반으로 **Transformer 핵심 구성요소**(Self-Attention / Multi-Head / Add&Norm / FFN / Encoder-Decoder / Mask)를  
`____` **빈칸을 채우며** 논문 **Eq.(1), Eq.(2), Fig.2**의 흐름을 코드로 연결하는 실습 자료입니다.

- 빈칸은 **이해에 핵심인 지점만** 뚫었습니다. (의미 없는 빈칸 X)
- 빈칸 옆 주석은 **정답을 그대로 복붙할 수 없도록**, “역할/의도” 중심으로만 적었습니다.
- 아래 **실습 코드(학습 루프)** 는 제공된 그대로 사용하며, **그 부분은 빈칸이 없습니다.**

> 사용법  
> 1) 위에서부터 내려오며 `____`만 채우기  
> 2) Shape 주석과 동일한지 수시로 확인하기  
> 3) 마지막 `__main__` 테스트로 end-to-end 동작 확인하기


## 실행을 위한 설치

In [1]:
# 필수: PyTorch (텐서 연산/모델 정의)
!pip install torch
!pip -q install spacy datasets sacrebleu
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
Collecting de-core-news-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.8.0/de_core_news_sm-3.8.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: de-core-news-sm
Successfully installed de-core-news-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can

## 학습 목표

1. **Scaled Dot-Product Attention (Eq.1)** 을 코드로 구현하고, $QK^T / sqrt(d_k)$의 의미를 설명할 수 있다.  
2. **Multi-Head Attention**에서 `(N, seq, d_model)` → `(N, heads, seq, d_k)`로의 변환을 직접 구현할 수 있다.  
3. **Add & Norm + FFN (Eq.2)** 블록을 residual 관점에서 설명하고 구현할 수 있다.  
4. **Encoder / Decoder의 3가지 Attention**(self / cross / masked self)을 코드 흐름으로 구분할 수 있다.  
5. **Source mask / Target(causal) mask**가 왜 필요한지, 어디에 적용되는지 코드로 설명할 수 있다.


## (헷갈림 방지) CheatSheet ↔ 노트북 매핑

- **Eq.(1) Attention(Q,K,V)** ↔ `SelfAttention.forward`의  
  `attention_logits_QK` → `attention_weights` → `attention_out`
- **Multi-Head** ↔ `values/keys/queries`를 head로 쪼개는 `reshape/permute` + 마지막 `W_O` projection
- **Eq.(2) FFN** ↔ `TransformerBlock.forward`의 `feed_forward_out`
- **Fig.2 Encoder-Decoder** ↔ `Encoder`, `DecoderBlock`, `Decoder`, `Transformer.forward`
- **Mask** ↔ `make_src_mask`, `make_trg_mask` 그리고 attention logits에 적용되는 `masked_fill`


In [3]:
"""
A from scratch implementation of Transformer network,
following the paper Attention is all you need with a
few minor differences. I tried to make it as clear as
possible to understand and also went through the code
on my youtube channel!
"""

import torch
import torch.nn as nn


## 1) Scaled Dot-Product Attention (Eq.1) ↔ `SelfAttention.forward`
## 2) Multi-Head Attention (MHA) ↔ split → attention → concat → projection

- **핵심 수식(Eq.1)**:  $( \text{softmax}(QK^T/\sqrt{d_k})V )$
- `attention_logits_QK`는 head별 `(query, key)` 점수표입니다.
- `attention_weights`는 softmax 이후의 확률이며, 마지막에 `V`를 가중합해 `attention_out`을 만듭니다.

### 사고 질문
- (why) 내적 점수에 $1/\sqrt{d_k}$ 스케일링을 넣는 이유는?  
- (how) head로 쪼개면 표현력이 왜 늘어날까? (단일 head와 비교)


In [None]:
# =========================================================
# 1) SelfAttention (Scaled Dot-Product + Multi-Head)
# - CheatSheet §1~2 | CookBook step.2, 6~12
# - Eq.(1): softmax(QK^T / sqrt(d_k)) V
# =========================================================

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.d_model = embed_size
        self.h = heads
        # [CheatSheet §2 Multi-Head Attention | Step 1/2]
        # 힌트: transformer_cookbook.md [step.2]
        self.d_k = self.d_model // _____________
        # [CheatSheet §2 Multi-Head Attention | Step 2/2]
        assert (
            self.h * self.d_k == self.d_model
        ), "Embedding size needs to be divisible by heads"

        # [CheatSheet §1 Scaled Dot-Product Attention | Step 1/2] Q,K,V 선형변환 정의 (Eq.1)
        # 힌트: transformer_cookbook.md [step.6]
        # ================================================================================
        # Multi Head Self Attention:  원래는 Q, K, V를 (d_model -> d_k) 크기로 변환하는 작업입니다.
        # 하지만 논문의 구조를 그대로 따르려면 이 연산을 8번 해야하는데, 이러면 연산이 너무 비효율적입니다.
        # 그렇다면 원래 사이즈로 변환을 한 다음 그것을 head의 수만큼 나눠주면 되겠죠?
        # 이 말이 이해 되셨다면 빈칸을 채우실 수 있습니다!
        # ================================================================================
        self.W_V = nn.Linear(__________, ____________)  
        self.W_K = nn.Linear(__________, ____________)  
        self.W_Q = nn.Linear(__________, ____________)
        # [CheatSheet §2 Multi-Head Attention | Step 2/2] Concat(heads) 이후 W^O (Fig.2)
        # 힌트: transformer_cookbook.md [step.12]
        self.W_O = nn.Linear(__________, ____________)

    def forward(self, values, keys, query, mask):
        # [CheatSheet Shape 규칙 | Step 1/9] 배치 크기 N / 길이 추출
        # ================================================================================
        # NLP 모델에서 텐서의 shape은 보통 (배치 크기, 문장 길이, 임베딩 차원) 순서로 들어옵니다.
        # CheatSheet Shape 규칙을 보셨다면 N이 무엇을 의미하는지 아실 겁니다.
        # ================================================================================
        N = ____________

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # [CheatSheet §1 Scaled Dot-Product Attention | Step 2/9] Q, K, V 만들기 (Eq.1)
        # 힌트: transformer_cookbook.md [step.6]
        V = self.W_V(values)
        K = self.W_K(keys)
        Q = self.W_Q(query)

        # [CheatSheet §2 Multi-Head Attention | Step 3/9]
        # 힌트: transformer_cookbook.md [step.7]
        # ================================================================================
        # 여기서는 Multi Head Attention을 위해 Q, K, V를 head 수만큼 나누는 과정입니다.
        # 기존의 Q, K, V의 shape은 (배치 크기(N), 문장 길이(len), d_model) 입니다.
        # 여기서 head와 관련된 연산이 수행된 변수가 하나 있습니다.
        # 이제 빈칸을 풀어보세요!
        # ================================================================================
        V = V.reshape(N, value_len, ____________, ____________)  
        K = K.reshape(N, key_len, ____________, ____________)  
        Q = Q.reshape(N, query_len, ____________, ____________)

        # [CheatSheet §1 Scaled Dot-Product Attention | Step 4/9] attention_logits = QK^T (Eq.1)
        # 힌트: transformer_cookbook.md [step.8]  einsum("nqhd,nkhd->nhqk")
        attention_logits_QK = torch.einsum("____________________", [Q, K])  # 설명: head별 QK^T 점수표

        # [CheatSheet §4 Mask | Step 5/9] mask 적용 (padding/causal 차단)
        # 힌트: transformer_cookbook.md [step.9]  masked_fill(mask==0, -inf)
        if mask is not None:
            attention_logits_QK = attention_logits_QK.masked_fill(__________ == 0, float("-1e20"))  # 설명: softmax 후 0 되게

        # [CheatSheet §1 Scaled Dot-Product Attention | Step 6/9] scale(+softmax) (Eq.1)
        # - / sqrt(d_k) 로 softmax 포화 방지
        # - softmax dim은 key_len 축(마지막 축)
        # 힌트: transformer_cookbook.md [step.10]
        attention_weights = torch.softmax(attention_logits_QK / (self.d_k ** (1 / 2)), dim=-1)

        # [CheatSheet §1 Scaled Dot-Product Attention | Step 7/9] out_heads = attention_weights @ V (Eq.1)
        # 힌트: transformer_cookbook.md [step.11]
        out_heads = torch.einsum("____________________", [attention_weights, V])  # 설명: 가중합으로 새 표현 생성
        # [CheatSheet §2 Multi-Head Attention | Step 8/9]
        out = out_heads.reshape(N, query_len, ____________ * ____________)  # 설명: head 축 결합

        # [CheatSheet §2 Multi-Head Attention | Step 9/9] W^O output projection (Fig.2)
        out = self.W_O(out)

        return out

### ✅ Check 1: SelfAttention output shape 테스트
(빈칸을 채운 뒤 실행)

In [5]:
import torch

torch.manual_seed(0)

N = 2
q_len = 5
d_model = 32
h = 4

self_attn = SelfAttention(embed_size=d_model, heads=h)

query = torch.randn(N, q_len, d_model)  # (N, q_len, d_model)
keys = query  # (N, q_len, d_model)
values = query  # (N, q_len, d_model)

mask = torch.ones(N, 1, 1, q_len)  # (N, 1, 1, k_len)
mask[0, :, :, -1] = 0  # (N, 1, 1, k_len)

out = self_attn(values=values, keys=keys, query=query, mask=mask)  # (N, q_len, d_model)
assert out.shape == (N, q_len, d_model), f"shape mismatch: {out.shape}"
assert torch.isfinite(out).all(), "NaN/Inf detected in SelfAttention output"
print("[OK] SelfAttention output shape:", tuple(out.shape))


[OK] SelfAttention output shape: (2, 5, 32)


## 5) Add & Norm + FFN (Eq.2) ↔ `TransformerBlock.forward`

- **Residual(Add)**: 원본 입력을 보존한 채, 변환 결과를 더합니다.
- **LayerNorm(Norm)**: 분포를 안정화합니다.
- **FFN(Eq.2)**:  $\text{FFN}(x)=\max(0, xW_1+b_1)W_2+b_2$

### 사고 질문
- (why) Attention 뒤에 바로 FFN을 한 번 더 넣는 이유는?  
- (how) residual이 없으면 역전파에서 어떤 문제가 생길까?


In [None]:
# =========================================================
# 5) TransformerBlock (Add & Norm + FFN)
# - CheatSheet §5 | CookBook step.13~14
# =========================================================

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # [CheatSheet §5 Add & Norm + FFN | Step 1/2] 
        # 힌트: transformer_cookbook.md [step.14]
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        # [CheatSheet §5 Add & Norm + FFN | Step 1/4] (Sublayer) Multi-Head Attention
        # 힌트: transformer_cookbook.md [step.13]
        multihead_attention_output = self.attention(value, key, query, mask)

        # [CheatSheet §5 Add & Norm + FFN | Step 2/4] Add & Norm (Residual + LayerNorm + Dropout)
        # ================================================================================
        # ResNet 스터디에서 공부했던 Skip Connection 기억나시죠??
        # 물론 ResNet 논문에서 해결하려고 했던 문제는 degradation이지만, 연구 결과 Vanishing Gradient를 해결하는데도 도움이 되었다고 나옵니다.
        # 현재 단계는 query에 대해 가장 관련있는 value를 찾는 과정이면 attention 결과에 무엇을 더해야 할지 아시겠죠?
        # ================================================================================
        attention_residual_add = ____________ + ____________  # 설명: residual = attention_output + query
        post_attention_layernorm = self.norm1(attention_residual_add)
        x = self.dropout(post_attention_layernorm)

        # [CheatSheet §5 Add & Norm + FFN | Step 3/4] Position-wise FFN (shape를 보이게 쪼개서 실행)
        # 힌트: transformer_cookbook.md [step.14]
        ffn_linear1_output = self.feed_forward[0](x)
        ffn_relu_output = self.feed_forward[1](ffn_linear1_output)
        ffn_linear2_output = self.feed_forward[2](ffn_relu_output)

        # [CheatSheet §5 Add & Norm + FFN | Step 4/4] Add & Norm after FFN
        # ================================================================================
        # ResNet 스터디에서 공부했던 Skip Connection 기억나시죠??
        # 물론 ResNet 논문에서 해결하려고 했던 문제는 degradation이지만, 연구 결과 Vanishing Gradient를 해결하는데도 도움이 되었다고 나옵니다.
        # 현재 단계는 x에 대해ffn 결과를 출려하는 단계이니 skip connection을 구상할 아이디어가 떠오르시죠? 
        # ================================================================================
        ffn_residual_add = ____________ + ____________  # 설명: residual = ffn_output + x
        post_ffn_layernorm = self.norm2(ffn_residual_add)
        out = self.dropout(post_ffn_layernorm)

        return out

### ✅ Check 2: TransformerBlock output shape 테스트
(빈칸을 채운 뒤 실행)

In [None]:
import torch

torch.manual_seed(0)

N = 2
seq_len = 6
d_model = 32
h = 4

block = TransformerBlock(embed_size=d_model, heads=h, dropout=0.0, forward_expansion=4)

x = torch.randn(N, seq_len, d_model)  # (N, seq_len, d_model)
mask = torch.ones(N, 1, 1, seq_len)  # (N, 1, 1, seq_len)

out = block(value=x, key=x, query=x, mask=mask)  # (N, seq_len, d_model)
assert out.shape == (N, seq_len, d_model), f"shape mismatch: {out.shape}"
assert torch.isfinite(out).all(), "NaN/Inf detected in TransformerBlock output"
print("[OK] TransformerBlock output shape:", tuple(out.shape))


## 6) Embedding + Positional Encoding ↔ `Encoder/Decoder.forward`

- 입력 토큰을 `word_embedding`으로 **(N, seq)** → **(N, seq, d_model)** 로 바꿉니다.
- `position_embedding`을 더해 **순서 정보**를 주입합니다.
- 이후 dropout을 거쳐 Block stack으로 들어갑니다.

### 사고 질문
- (why) 순서 정보가 없으면 Self-Attention은 어떤 문제가 생길까?  
- (how) 학습형 position embedding과 sinusoidal의 장단점은?


In [None]:
# =========================================================
# 6) Encoder (Embedding + Positional Encoding + Encoder Stack)
# - CheatSheet §6, §3(Encoder Self-Attention) | CookBook step.1, 4~5, 15
# =========================================================

class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, src_token_ids, src_padding_mask):
        # [CheatSheet §0 Shape 규칙 | Step 1/3] src_token_ids: (N, src_len)
        # 힌트: transformer_cookbook.md [step.1]
        N, src_len = src_token_ids.shape

        # [CheatSheet §6 Embedding + Positional Encoding | Step 2/3] positions 만들기
        # 힌트: transformer_cookbook.md [step.5]  arange(0, L).expand(N, L)
        # ================================================================================
        # Positional Encoding을 알고 있다면 "힌트!"부터 읽으셔도 됩니다.
        # RNN은 본질적으로 단어가 순서대로 들어오기 때문에 장기의존성 문제가 생깁니다. 이 문제를 해결하기 위해 Transformer는 모든 단어를 한 번에 병렬연산합니다.
        # 하지만 이 경우에는 "나 너 좋아해"나 "너 나 좋아해"나 같은 문장이 됩니다. 따라서 문장의 "토큰"마다 고유한 위치정보를 더해야 합니다.

        # 힌트!
        # 1. arange(0, x)를 사용해 [0, 1, 2, ..., "문장길이"-1] 형태의 1차원 번호표를 만듭니다.
        # 2. expand(x, y)를 사용해 이 번호표를 y만큼 복사합니다. 
        # 즉, 텐서 모양을 (0, x) -> (x, y)로 확장하여 모든 문장에게 번호표를 나눠줍니다.
        # 그렇다면 빈칸을 어떻게 채워야 할까요?
        # ================================================================================
        
        position_ids = torch.arange(0, ____________).to(self.device)  
        positions = position_ids.expand(__________, ____________)  

        # [CheatSheet §6 Embedding + Positional Encoding | Step 3/3] token_emb + pos_emb (+ dropout)
        # 힌트: transformer_cookbook.md [step.4~5]
        token_embedding_d_model = self.word_embedding(src_token_ids)
        positional_embedding_d_model = self.position_embedding(positions)
        out = self.dropout(____________ + ____________) 
   

        # [CheatSheet §3 Encoder Self-Attention | Step 1/1] Encoder stack 반복
        # 힌트: transformer_cookbook.md [step.15]  layer(out,out,out,src_mask)
        # ================================================================================
        # 여기서 layer는 self.layers의 TransformerBlock 객체입니다.
        # 아마 많은 분들이 놓치실텐데, torch의 객체를 불러오는 것만으로도 객체 내의 forward 함수는 실행됩니다.
        # 위의 두 문장을 연결해서 생각해보세요. 그러면 아래 4개의 입력은 Transformer의 어떤 함수를 실행시키기 위해 필요할까요?
        # 질문에 답을 할 수 있다면 빈칸을 채우실 수 있습니다!
        # ================================================================================
        for layer in self.layers:
            out = layer(____________, ____________, ____________, ____________)  

        return out

### ✅ Check 3: Encoder output shape 테스트
(빈칸을 채운 뒤 실행)

In [None]:
import torch

torch.manual_seed(0)

N = 2
src_len = 7
src_vocab_size = 50
src_pad_idx = 0
d_model = 32
h = 4

device = torch.device("cpu")

encoder = Encoder(
    src_vocab_size=src_vocab_size,
    embed_size=d_model,
    num_layers=2,
    heads=h,
    device=device,
    forward_expansion=4,
    dropout=0.0,
    max_length=100,
).to(device)

src_token_ids = torch.randint(1, src_vocab_size, (N, src_len)).to(device)  # (N, src_len)
src_token_ids[0, -2:] = src_pad_idx  # (N, src_len)

src_padding_mask = (src_token_ids != src_pad_idx).unsqueeze(1).unsqueeze(2)  # (N, 1, 1, src_len)

enc_out = encoder(src_token_ids=src_token_ids, src_padding_mask=src_padding_mask)  # (N, src_len, d_model)
assert enc_out.shape == (N, src_len, d_model), f"shape mismatch: {enc_out.shape}"
assert torch.isfinite(enc_out).all(), "NaN/Inf detected in Encoder output"
print("[OK] Encoder output shape:", tuple(enc_out.shape))


## 3) Encoder / Decoder의 3가지 Attention 매핑 ↔ `DecoderBlock.forward`

DecoderBlock에는 보통 2개의 attention이 있습니다.

1. **Masked Self-Attention**: Decoder 내부에서 미래 토큰을 못 보게 causal mask 적용  
2. **Cross-Attention(Encoder-Decoder Attention)**: Query는 decoder, Key/Value는 encoder output  
3. (Encoder는) **Self-Attention**만 사용

### 사고 질문
- (why) decoder self-attention에는 반드시 causal mask가 필요할까?  
- (how) cross-attention에서 Q/K/V의 출처를 코드로 정확히 짚어보자.


In [None]:
# =========================================================
# 3) DecoderBlock (Masked Self-Attention + Cross-Attention)
# - CheatSheet §3 | CookBook step.16~17
# =========================================================

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention = SelfAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        # [CheatSheet §3 Decoder Masked Self-Attention | Step 1/3] masked self-attn (미래 토큰 차단)
        # 힌트: transformer_cookbook.md [step.16]
        # ================================================================================
        # 코드 내에서 self.attention은 SelfAttention 객체이니 이 객체를 호출했다면 forward 함수를 실행해야 합니다.
        # 그리고 여기는 Decoder이므로 Masked Multi Head Self Attention을 해야 합니다.
        # 그러면 어디를 찾아야 할지 아시겠죠??????????
        # ================================================================================
        
        masked_self_attention_output = self.attention(____________, ____________, ____________, ____________)  

        # [CheatSheet §5 Add & Norm | Step 2/3] Residual + LayerNorm + Dropout (query 만들기)
        residual_add = ____________ + ____________  
        query = self.dropout(self.norm(residual_add))

        # [CheatSheet §3 Encoder-Decoder Attention | Step 3/3] Cross-Attention(+FFN) via TransformerBlock
        # 힌트: transformer_cookbook.md [step.17] 
        out = self.transformer_block(____________, ____________, ____________, ____________)  

        return out

### ✅ Check 4: DecoderBlock output shape 테스트
(빈칸을 채운 뒤 실행)

In [None]:
import torch

torch.manual_seed(0)

N = 2
src_len = 7
trg_len = 5
d_model = 32
h = 4

device = torch.device("cpu")

decoder_block = DecoderBlock(
    embed_size=d_model,
    heads=h,
    forward_expansion=4,
    dropout=0.0,
    device=device,
).to(device)

x = torch.randn(N, trg_len, d_model).to(device)  # (N, trg_len, d_model)
value = torch.randn(N, src_len, d_model).to(device)  # (N, src_len, d_model)
key = value  # (N, src_len, d_model)

src_mask = torch.ones(N, 1, 1, src_len).to(device)  # (N, 1, 1, src_len)

trg_mask_base = torch.tril(torch.ones(trg_len, trg_len)).to(device)  # (trg_len, trg_len)
trg_mask = trg_mask_base.expand(N, 1, trg_len, trg_len)  # (N, 1, trg_len, trg_len)

out = decoder_block(x=x, value=value, key=key, src_mask=src_mask, trg_mask=trg_mask)  # (N, trg_len, d_model)
assert out.shape == (N, trg_len, d_model), f"shape mismatch: {out.shape}"
assert torch.isfinite(out).all(), "NaN/Inf detected in DecoderBlock output"
print("[OK] DecoderBlock output shape:", tuple(out.shape))


## 6) Embedding + Positional Encoding ↔ `Encoder/Decoder.forward`

- 입력 토큰을 `word_embedding`으로 **(N, seq)** → **(N, seq, d_model)** 로 바꿉니다.
- `position_embedding`을 더해 **순서 정보**를 주입합니다.
- 이후 dropout을 거쳐 Block stack으로 들어갑니다.

### 사고 질문
- (why) 순서 정보가 없으면 Self-Attention은 어떤 문제가 생길까?  
- (how) 학습형 position embedding과 sinusoidal의 장단점은?


In [None]:
# =========================================================
# 6) Decoder (Embedding + Positional Encoding + Decoder Stack + Vocab Projection)
# - CheatSheet §6, §7 | CookBook step.1, 4~5, 18
# =========================================================

class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, trg_token_ids, enc_out, src_padding_mask, trg_causal_mask):
        # [CheatSheet §0 Shape 규칙 | Step 1/4] trg_token_ids: (N, trg_len)
        # 힌트: transformer_cookbook.md [step.1]
        N, trg_len = trg_token_ids.shape

        # [CheatSheet §6 Embedding + Positional Encoding | Step 2/4] positions 만들기
        # 힌트: transformer_cookbook.md [step.5]
        position_ids = torch.arange(0, ____________).to(self.device) 
        positions = position_ids.expand(__________, ____________) 

        # [CheatSheet §6 Embedding + Positional Encoding | Step 3/4] token_emb + pos_emb (+ dropout)
        token_embedding_d_model = self.word_embedding(trg_token_ids)
        positional_embedding_d_model = self.position_embedding(positions)
        x = self.dropout(____________ + ____________) 

        # [CheatSheet §3 Decoder Stack | Step 1/1] DecoderBlock 반복
        # 힌트: transformer_cookbook.md [step.18]
        for layer in self.layers:
            x = layer(____________, ____________, ____________, ____________, ____________) 

        # [CheatSheet §7 Output projection | Step 4/4] vocab logits 생성
        # 힌트: transformer_cookbook.md [step.18] 
        out = self.fc_out(____________)  

        return out

## 4) Mask(마스크) ↔ `make_src_mask`, `make_trg_mask`

- **Source mask**: PAD 토큰을 attention에서 무시하기 위함 (`src_pad_idx`)
- **Target mask(causal)**: 미래 토큰을 보지 못하게 하는 **상삼각 마스크**

적용 위치: attention logits에 `masked_fill(mask==0, -1e20)` 처럼 큰 음수로 가려 softmax 후 0이 되게 합니다.

### 사고 질문
- (why) logits 단계에서 마스킹해야 softmax 후 정확히 0이 될까?  
- (how) target mask의 shape를 head/배치 차원까지 맞추는 흐름을 추적해보자.


In [None]:
# =========================================================
# 0) Transformer (Encoder + Decoder + Mask 2종)
# - CheatSheet §0(전체), §4(Mask) | CookBook step.3, 19
# =========================================================
# 이 cell은 주석을 추가로 채울 방벙을 잘 모르겠어. 나부터 이 코드를 잘 모르겠어서 이해하는 것부터 도와줘.
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=512,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0,
        device="cpu",
        max_length=100,
    ):

        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src_token_ids):
        # [CheatSheet §4 Mask | Step 1/2] src padding mask: (src != pad) -> unsqueeze(1)->unsqueeze(2)
        # 힌트: transformer_cookbook.md [step.3]
        # ================================================================================
        # 이거 테스트로 풀어볼 때 많이 어려웠습니다ㅜㅜ. 일단 변수 이름부터 읽어볼까요?
        # "src_is_not_pad"는 src(원본)에서 is_not_pad(패딩이 아닌 부분)입니다. 그렇다면 src_token과 "무엇"이 같지 않은 부분을 확인해야 할까요?
        # ================================================================================
        src_is_not_pad = (src_token_ids != ____________)  # 설명: pad 토큰 위치(False)를 만들기
        src_padding_mask = src_is_not_pad.unsqueeze(____)  
        src_padding_mask = src_padding_mask.unsqueeze(____)  
        return src_padding_mask.to(self.device)

    def make_trg_mask(self, trg_token_ids):
        # [CheatSheet §4 Mask | Step 2/2] trg causal mask: tril(ones(L,L)).expand(N, 1, L, L)
        # 힌트: transformer_cookbook.md [step.3]
        N, trg_len = trg_token_ids.shape
        trg_ones = torch.ones((__________, ____________))  # 설명: (trg_len, trg_len) ones
        trg_lower_triangular = torch.tril(trg_ones)  # 설명: 하삼각(미래 토큰 차단)
        trg_causal_mask = trg_lower_triangular.expand(__________, 1, ____________, ____________)  

        return trg_causal_mask.to(self.device)

    def forward(self, src_token_ids, trg_token_ids):
        # [CheatSheet 전체 구조 | Step 1/3] mask 만들기
        # 힌트: transformer_cookbook.md [step.19]
        src_padding_mask = self.make_src_mask(____________)
        trg_causal_mask = self.make_trg_mask(____________)

        # [CheatSheet 전체 구조 | Step 2/3] Encoder
        # ================================================================================
        # 객체를 호출하면 객체 안의 어떤 함수가 자동으로 호출된다고 정말 많이 남겨뒀습니다.
        # 그러면 이제 찾아봅시다!
        # ================================================================================
        enc_src = self.encoder(____________, ____________)
        # [CheatSheet 전체 구조 | Step 3/3] Decoder
        out = self.decoder(____________, ____________, ____________, ____________)

        return out

### ✅ Check 5: Transformer forward + mask shape 테스트
(빈칸을 채운 뒤 실행)

In [None]:
import torch
import torch.nn as nn

torch.manual_seed(0)

N = 2
src_len = 8
trg_len = 6

src_vocab_size = 100
trg_vocab_size = 120

src_pad_idx = 0
trg_pad_idx = 0

device = torch.device("cpu")

model = Transformer(
    src_vocab_size=src_vocab_size,
    trg_vocab_size=trg_vocab_size,
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx,
    embed_size=32,
    num_layers=2,
    forward_expansion=4,
    heads=4,
    dropout=0.0,
    device=device,
    max_length=100,
).to(device)

src_token_ids = torch.randint(1, src_vocab_size, (N, src_len)).to(device)  # (N, src_len)
trg_token_ids = torch.randint(1, trg_vocab_size, (N, trg_len)).to(device)  # (N, trg_len)

src_token_ids[0, -2:] = src_pad_idx  # (N, src_len)
trg_token_ids[0, -1] = trg_pad_idx  # (N, trg_len)

trg_input_ids = trg_token_ids[:, :-1]  # (N, trg_len-1)
trg_target_ids = trg_token_ids[:, 1:]  # (N, trg_len-1)

src_padding_mask = model.make_src_mask(src_token_ids)  # (N, 1, 1, src_len)
trg_causal_mask = model.make_trg_mask(trg_input_ids)  # (N, 1, trg_len-1, trg_len-1)

assert src_padding_mask.shape == (N, 1, 1, src_len), f"src_mask shape mismatch: {src_padding_mask.shape}"
assert trg_causal_mask.shape == (N, 1, trg_len - 1, trg_len - 1), f"trg_mask shape mismatch: {trg_causal_mask.shape}"

logits = model(src_token_ids, trg_input_ids)  # (N, trg_len-1, trg_vocab_size)
assert logits.shape == (N, trg_len - 1, trg_vocab_size), f"logits shape mismatch: {logits.shape}"
assert torch.isfinite(logits).all(), "NaN/Inf detected in Transformer logits"

criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)

logits_2d = logits.reshape(-1, trg_vocab_size)  # (N*(trg_len-1), trg_vocab_size)
targets_1d = trg_target_ids.reshape(-1)  # (N*(trg_len-1),)

loss = criterion(logits_2d, targets_1d)
loss.backward()

print("[OK] Transformer forward + loss backward:", loss.detach().item())


In [None]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
        device
    )
    # (N, src_len)
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)
    # (N, trg_len)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = Transformer(
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        device=device,
    ).to(device)

    trg_input_ids = trg[:, :-1]
    # (N, trg_len-1)
    out = model(x, trg_input_ids)
    # (N, trg_len-1, trg_vocab_size)
    print(out.shape)


### ✅ Check 6: (데이터 다운로드 없이) 더미 배치로 **1 step 학습 루프** 스모크 테스트

아래 코드는 **데이터셋 없이도** Transformer 학습 루프가 돌아가는지 확인하는 최소 테스트입니다.

- **Teacher Forcing 시프트 패턴(표준):**  
  `trg_input = trg[:, :-1]` / `trg_y = trg[:, 1:]`
- **CrossEntropy reshape 패턴(표준):**  
  `logits: (N, T, V) -> (N*T, V)` / `trg_y: (N, T) -> (N*T)`
- 위 빈칸(`SelfAttention/Block/Encoder/Decoder/...`)을 다 채운 뒤 실행하세요.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- tiny toy vocab ---
src_vocab_size = 50
trg_vocab_size = 60
src_pad_idx = 0
trg_pad_idx = 0

# model hyperparams (paper-base 느낌, but tiny for smoke test)
model = Transformer(
    src_vocab_size=src_vocab_size,
    trg_vocab_size=trg_vocab_size,
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx,
    embed_size=128,
    num_layers=2,
    forward_expansion=4,
    heads=4,
    dropout=0.1,
    device=device,
    max_length=64,
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# --- dummy batch (variable lengths + padding) ---
src_batch = [
    torch.tensor([1, 5, 6, 4, 3, 9, 2]),
    torch.tensor([1, 8, 7, 3, 4, 5]),
]
trg_batch = [
    torch.tensor([1, 7, 4, 3, 5, 9, 2, 2]),
    torch.tensor([1, 5, 6, 2, 4, 7]),
]

src = pad_sequence(src_batch, batch_first=True, padding_value=src_pad_idx).to(device)
# (N, src_len)
trg = pad_sequence(trg_batch, batch_first=True, padding_value=trg_pad_idx).to(device)
# (N, trg_len)

trg_input = trg[:, :-1]
# (N, trg_len-1)
trg_y = trg[:, 1:]
# (N, trg_len-1)

logits = model(src, trg_input)
# (N, trg_len-1, trg_vocab_size)

logits_flat = logits.reshape(-1, logits.size(-1))
# (N*(trg_len-1), trg_vocab_size)
trg_y_flat = trg_y.reshape(-1)
# (N*(trg_len-1),)

optimizer.zero_grad()
loss = criterion(logits_flat, trg_y_flat)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

print("✅ smoke loss:", float(loss.item()))


## 8) 실습: Multi30k 학습을 **실제로** 돌려보기 (논문 설정 정렬)

아래 코드는 `seq2seq_transformer.py`의 흐름을 가져오되, **이 노트북의 Transformer 구현(Encoder–Decoder, Fig.2)** 을 그대로 사용하도록 재구성한  
**Multi30k(De→En) 학습 스켈레톤**입니다.

- 이 섹션은 **빈칸이 없습니다.**  
  (단, **위의 빈칸 구현이 완료되어야** 실행됩니다.)
- Multi30k는 논문에서 사용한 WMT14보다 훨씬 작은 데이터라 **논문 BLEU를 그대로 기대하면 안 됩니다.**
- 그래도 아래 “훈련 레시피”는 가능한 한 논문 설명에 맞춰 정렬했습니다.

### 논문 정렬 포인트(훈련 레시피)
- Optimizer: **Adam(β1=0.9, β2=0.98, ε=1e-9)**
- Learning rate schedule: **Noam warmup(4000) + inverse sqrt decay**
- Regularization: **dropout=0.1**, **label smoothing=0.1**(PAD는 ignore)

### 이 노트북 구현 ↔ 논문 차이(중요)
- (Paper) **Sinusoidal Positional Encoding** vs (Here) **Learned positional embedding**
- (Paper) **BPE/word-piece** vs (Here) **spaCy word tokenizer**
- (Paper) **weight tying**(임베딩/출력 가중치 공유) vs (Here) 미적용

### 필요한 패키지(처음 1회)

In [None]:
!pip install datasets sacrebleu spacy
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm

> ✅ 처음에는 `num_epochs=1`로 “동작 확인”만 하고, 그 다음 epochs를 늘리세요.

In [None]:
# train_multi30k.py (paper-aligned-ish) — notebook friendly
# - 기반: seq2seq_transformer.py
# - 변경: (1) torchtext legacy 대신 HF datasets 사용 (Multi30k 다운로드/로딩 안정)
#        (2) 이 노트북의 Transformer(from scratch) 그대로 사용
#        (3) 논문 훈련 레시피(Adam betas/eps, Noam LR, label smoothing) 반영

import math
import random
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import spacy

# --- optional: BLEU (sacrebleu) ---
try:
    import sacrebleu
except Exception:
    sacrebleu = None

# --- HF datasets for Multi30k ---
try:
    from datasets import load_dataset
except Exception:
    load_dataset = None


SPECIAL_TOKENS = {
    "pad": "<pad>",
    "unk": "<unk>",
    "sos": "<sos>",
    "eos": "<eos>",
}


def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def build_vocab(sentences, tokenize_fn, max_size: int = 10000, min_freq: int = 2):
    """Word-level vocab builder (toy / study-friendly)."""
    counter = Counter()
    for s in sentences:
        counter.update(tokenize_fn(s))

    # reserve specials at the beginning so indices are stable
    itos = [
        SPECIAL_TOKENS["pad"],
        SPECIAL_TOKENS["unk"],
        SPECIAL_TOKENS["sos"],
        SPECIAL_TOKENS["eos"],
    ]

    for tok, freq in counter.most_common():
        if freq < min_freq:
            continue
        if tok in itos:
            continue
        itos.append(tok)
        if len(itos) >= max_size:
            break

    stoi = {tok: i for i, tok in enumerate(itos)}
    return stoi, itos


class Multi30kWordDataset(Dataset):
    """(de, en) sentence pairs -> (src_ids, trg_ids)"""

    def __init__(
        self,
        split,
        src_tokenize_fn,
        trg_tokenize_fn,
        src_stoi,
        trg_stoi,
        src_max_len: int = 100,
        trg_max_len: int = 100,
    ):
        self.split = split
        self.src_tokenize_fn = src_tokenize_fn
        self.trg_tokenize_fn = trg_tokenize_fn
        self.src_stoi = src_stoi
        self.trg_stoi = trg_stoi
        self.src_max_len = src_max_len
        self.trg_max_len = trg_max_len

        self.src_unk = src_stoi[SPECIAL_TOKENS["unk"]]
        self.trg_unk = trg_stoi[SPECIAL_TOKENS["unk"]]
        self.src_sos = src_stoi[SPECIAL_TOKENS["sos"]]
        self.src_eos = src_stoi[SPECIAL_TOKENS["eos"]]
        self.trg_sos = trg_stoi[SPECIAL_TOKENS["sos"]]
        self.trg_eos = trg_stoi[SPECIAL_TOKENS["eos"]]

    def __len__(self):
        return len(self.split)

    def __getitem__(self, idx):
        example = self.split[idx]
        src_text = example["de"]
        trg_text = example["en"]

        src_tokens = self.src_tokenize_fn(src_text)[: self.src_max_len - 2]
        trg_tokens = self.trg_tokenize_fn(trg_text)[: self.trg_max_len - 2]

        src_ids = [self.src_sos] + [self.src_stoi.get(t, self.src_unk) for t in src_tokens] + [
            self.src_eos
        ]
        trg_ids = [self.trg_sos] + [self.trg_stoi.get(t, self.trg_unk) for t in trg_tokens] + [
            self.trg_eos
        ]

        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(trg_ids, dtype=torch.long)


def make_collate_fn(src_pad_idx: int, trg_pad_idx: int):
    def collate_fn(batch):
        src_list = [b[0] for b in batch]
        trg_list = [b[1] for b in batch]
        src = pad_sequence(src_list, batch_first=True, padding_value=src_pad_idx)
        # (N, src_len)
        trg = pad_sequence(trg_list, batch_first=True, padding_value=trg_pad_idx)
        # (N, trg_len)
        return src, trg

    return collate_fn


def noam_lr_lambda(step: int, d_model: int, warmup_steps: int = 4000):
    """Paper: lr = d_model^{-0.5} * min(step^{-0.5}, step * warmup^{-1.5})"""
    step = max(step, 1)
    return (d_model ** -0.5) * min(step ** -0.5, step * (warmup_steps ** -1.5))


@torch.no_grad()
def greedy_decode(
    model,
    src_ids_1d: torch.Tensor,
    src_pad_idx: int,
    trg_sos_idx: int,
    trg_eos_idx: int,
    max_len: int,
    device: str,
):
    """Greedy decoding for quick sanity check (not beam search)."""
    model.eval()

    src = src_ids_1d.unsqueeze(0).to(device)
    # (N=1, src_len)

    generated = [trg_sos_idx]

    for _ in range(max_len):
        trg = torch.tensor(generated, dtype=torch.long, device=device).unsqueeze(0)
        # (N=1, trg_len)

        logits = model(src, trg)
        # (N=1, trg_len, trg_vocab_size)

        next_token = int(logits[0, -1].argmax(dim=-1).item())
        generated.append(next_token)

        if next_token == trg_eos_idx:
            break

    return generated


def train_one_epoch(model, loader, optimizer, scheduler, criterion, device: str):
    model.train()
    total_loss = 0.0

    for src, trg in loader:
        src = src.to(device)
        # (N, src_len)
        trg = trg.to(device)
        # (N, trg_len)

        # --- Teacher forcing shift (표준 패턴) ---
        trg_input = trg[:, :-1]
        # (N, trg_len-1)
        trg_y = trg[:, 1:]
        # (N, trg_len-1)

        logits = model(src, trg_input)
        # (N, trg_len-1, trg_vocab_size)

        logits_flat = logits.reshape(-1, logits.size(-1))
        # (N*(trg_len-1), trg_vocab_size)
        trg_y_flat = trg_y.reshape(-1)
        # (N*(trg_len-1),)

        optimizer.zero_grad()
        loss = criterion(logits_flat, trg_y_flat)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        total_loss += float(loss.item())

    return total_loss / max(1, len(loader))


@torch.no_grad()
def evaluate(model, loader, criterion, device: str):
    model.eval()
    total_loss = 0.0

    for src, trg in loader:
        src = src.to(device)
        trg = trg.to(device)

        trg_input = trg[:, :-1]
        trg_y = trg[:, 1:]

        logits = model(src, trg_input)

        logits_flat = logits.reshape(-1, logits.size(-1))
        trg_y_flat = trg_y.reshape(-1)

        loss = criterion(logits_flat, trg_y_flat)
        total_loss += float(loss.item())

    return total_loss / max(1, len(loader))


def main(
    num_epochs: int = 1,
    batch_size: int = 64,
    max_vocab_size: int = 10000,
    min_freq: int = 2,
    max_len: int = 100,
    d_model: int = 512,
    num_layers: int = 6,
    num_heads: int = 8,
    forward_expansion: int = 4,
    dropout: float = 0.1,
    warmup_steps: int = 4000,
):
    if load_dataset is None:
        raise ImportError("❌ datasets가 없습니다. 먼저 `pip install datasets`를 실행하세요.")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    set_seed(42)

    # --- Load Multi30k (bentrevett subset on HF hub) ---
    raw = load_dataset("bentrevett/multi30k")
    train_raw = raw["train"]
    valid_raw = raw["validation"]
    test_raw = raw["test"]

    # --- Tokenizers (spaCy) ---
    # spaCy v3+에서는 'de'/'en' shortcut이 아니라 full model name이 필요합니다.
    spacy_de = spacy.load("de_core_news_sm")
    spacy_en = spacy.load("en_core_web_sm")

    def tokenize_de(text: str):
        return [tok.text.lower() for tok in spacy_de.tokenizer(text)]

    def tokenize_en(text: str):
        return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

    # --- Build vocab (word-level) ---
    src_stoi, src_itos = build_vocab(train_raw["de"], tokenize_de, max_size=max_vocab_size, min_freq=min_freq)
    trg_stoi, trg_itos = build_vocab(train_raw["en"], tokenize_en, max_size=max_vocab_size, min_freq=min_freq)

    src_pad_idx = src_stoi[SPECIAL_TOKENS["pad"]]
    trg_pad_idx = trg_stoi[SPECIAL_TOKENS["pad"]]
    trg_sos_idx = trg_stoi[SPECIAL_TOKENS["sos"]]
    trg_eos_idx = trg_stoi[SPECIAL_TOKENS["eos"]]

    # --- Datasets / Loaders ---
    train_ds = Multi30kWordDataset(train_raw, tokenize_de, tokenize_en, src_stoi, trg_stoi, max_len, max_len)
    valid_ds = Multi30kWordDataset(valid_raw, tokenize_de, tokenize_en, src_stoi, trg_stoi, max_len, max_len)
    test_ds = Multi30kWordDataset(test_raw, tokenize_de, tokenize_en, src_stoi, trg_stoi, max_len, max_len)

    collate_fn = make_collate_fn(src_pad_idx, trg_pad_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # --- Model (this notebook's from-scratch Transformer) ---
    model = Transformer(
        src_vocab_size=len(src_itos),
        trg_vocab_size=len(trg_itos),
        src_pad_idx=src_pad_idx,
        trg_pad_idx=trg_pad_idx,
        embed_size=d_model,
        num_layers=num_layers,
        forward_expansion=forward_expansion,
        heads=num_heads,
        dropout=dropout,
        device=device,
        max_length=max_len,
    ).to(device)

    # --- Paper-aligned optimizer + schedule ---
    optimizer = optim.Adam(model.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9)

    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: noam_lr_lambda(step + 1, d_model=d_model, warmup_steps=warmup_steps),
    )

    criterion = nn.CrossEntropyLoss(
        ignore_index=trg_pad_idx,
        label_smoothing=0.1,  # paper: ε_ls = 0.1
    )

    # --- Train ---
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, criterion, device)
        valid_loss = evaluate(model, valid_loader, criterion, device)

        print(f"epoch={epoch:02d} train_loss={train_loss:.4f} valid_loss={valid_loss:.4f}")

        # quick qualitative check: translate a random validation sample
        sample = valid_raw[random.randrange(len(valid_raw))]
        src_text = sample["de"]
        trg_text = sample["en"]

        src_tokens = tokenize_de(src_text)[: max_len - 2]
        src_ids = [src_stoi[SPECIAL_TOKENS["sos"]]] + [src_stoi.get(t, src_stoi[SPECIAL_TOKENS["unk"]]) for t in src_tokens] + [src_stoi[SPECIAL_TOKENS["eos"]]]
        src_ids = torch.tensor(src_ids, dtype=torch.long)

        pred_ids = greedy_decode(
            model=model,
            src_ids_1d=src_ids,
            src_pad_idx=src_pad_idx,
            trg_sos_idx=trg_sos_idx,
            trg_eos_idx=trg_eos_idx,
            max_len=50,
            device=device,
        )

        pred_tokens = [trg_itos[i] for i in pred_ids]
        pred_tokens = [t for t in pred_tokens if t not in {SPECIAL_TOKENS["sos"], SPECIAL_TOKENS["eos"], SPECIAL_TOKENS["pad"]}]

        print("DE:", src_text)
        print("GT:", trg_text)
        print("PR:", " ".join(pred_tokens))
        print("-" * 80)

    # --- BLEU (optional) ---
    if sacrebleu is None:
        print("sacrebleu가 없어서 BLEU를 생략합니다. (pip install sacrebleu)")
        return

    # quick BLEU on a small subset (speed)
    model.eval()
    preds = []
    refs = []

    for i in range(200):
        ex = test_raw[i]
        src_text = ex["de"]
        ref_text = ex["en"]

        src_tokens = tokenize_de(src_text)[: max_len - 2]
        src_ids = [src_stoi[SPECIAL_TOKENS["sos"]]] + [src_stoi.get(t, src_stoi[SPECIAL_TOKENS["unk"]]) for t in src_tokens] + [src_stoi[SPECIAL_TOKENS["eos"]]]
        src_ids = torch.tensor(src_ids, dtype=torch.long)

        pred_ids = greedy_decode(
            model=model,
            src_ids_1d=src_ids,
            src_pad_idx=src_pad_idx,
            trg_sos_idx=trg_sos_idx,
            trg_eos_idx=trg_eos_idx,
            max_len=50,
            device=device,
        )

        pred_tokens = [trg_itos[i] for i in pred_ids]
        pred_tokens = [t for t in pred_tokens if t not in {SPECIAL_TOKENS["sos"], SPECIAL_TOKENS["eos"], SPECIAL_TOKENS["pad"]}]
        preds.append(" ".join(pred_tokens))
        refs.append(ref_text)

    bleu = sacrebleu.corpus_bleu(preds, [refs]).score
    print(f"BLEU (greedy, first 200 test samples) = {bleu:.2f}")


# ✅ 실행 예시 (처음엔 epochs를 줄여서!)
main(num_epochs=1)


## (정답 공개) — 정말 마지막에만 확인하세요

In [None]:
# ===========================
# ANSWERS — 빈칸에 들어갈 내용만 (마지막에 확인하세요)
# ===========================

# --- SelfAttention ---
# [__init__]
# self.d_k = self.d_model // self.h
# self.d_k * self.h == self.d_model
# self.W_V = nn.Linear(self.d_model, self.d_model)
# self.W_K = nn.Linear(self.d_model, self.d_model)
# self.W_Q = nn.Linear(self.d_model, self.d_model)
# self.W_O = nn.Linear(self.d_model, self.d_model)

# [forward]
# N = query.shape[0]
# V = V.reshape(N, value_len, self.h, self.d_k)  # (N, value_len, h, d_k)
# K = K.reshape(N, key_len, self.h, self.d_k)  # (N, key_len, h, d_k)
# Q = Q.reshape(N, query_len, self.h, self.d_k)  # (N, query_len, h, d_k)
# attention_logits_QK = torch.einsum("nqhd,nkhd->nhqk", [Q, K])
# attention_logits_QK = attention_logits_QK.masked_fill(mask == 0, float("-1e20"))
# attention_weights = torch.softmax(attention_logits_QK / (self.d_k ** (1 / 2)), dim=3)
# out_heads = torch.einsum("nhql,nlhd->nqhd", [attention_weights, V])
# out = out_heads.reshape(N, query_len, self.h * self.d_k)


# --- TransformerBlock ---
# [forward]
# attention_residual_add = multihead_attention_output + query
# ffn_residual_add = ffn_linear2_output + x


# --- Encoder ---
# [forward]
# position_ids = torch.arange(0, src_len).to(self.device)
# positions = position_ids.expand(N, src_len)
# out = self.dropout(token_embedding_d_model + positional_embedding_d_model)
# out = layer(out, out, out, src_padding_mask)


# --- DecoderBlock ---
# [forward]
# masked_self_attention_output = self.attention(x, x, x, trg_mask)
# residual_add = masked_self_attention_output + x
# out = self.transformer_block(value, key, query, src_mask)


# --- Decoder ---
# [forward]
# position_ids = torch.arange(0, trg_len).to(self.device)
# positions = position_ids.expand(N, trg_len)
# x = self.dropout(token_embedding_d_model + positional_embedding_d_model)
# x = layer(x, enc_out, enc_out, src_padding_mask, trg_causal_mask)
# out = self.fc_out(x)


# --- Transformer ---
# [make_src_mask]
# src_is_not_pad = (src_token_ids != self.src_pad_idx)
# src_padding_mask = src_is_not_pad.unsqueeze(1)
# src_padding_mask = src_padding_mask.unsqueeze(2)

# [make_trg_mask]
# trg_ones = torch.ones((trg_len, trg_len))
# trg_causal_mask = trg_lower_triangular.expand(N, 1, trg_len, trg_len)

# [forward]
# src_padding_mask = self.make_src_mask(src_token_ids)
# trg_causal_mask = self.make_trg_mask(trg_token_ids)
# enc_src = self.encoder(src_token_ids, src_padding_mask)
# out = self.decoder(trg_token_ids, enc_src, src_padding_mask, trg_causal_mask)
