In [1]:
# transformer_modules.py

import torch
import torch.nn as nn
import math

# 1. Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# 2. Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv_proj(x).view(B, T, self.num_heads, 3 * self.d_k).transpose(1, 2)
        q, k, v = qkv.chunk(3, dim=-1)
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = scores.softmax(dim=-1)
        context = attn @ v
        context = context.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(context)

# 3. Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

# 4. Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

# 5. Full Transformer
class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, d_ff=2048, num_layers=6, max_len=512):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.output(x)
# ▼ Transformerのテスト実行コード

# ダミー入力作成（バッチサイズ2、長さ10）
vocab_size = 1000
seq_len = 10
batch_size = 2
dummy_input = torch.randint(0, vocab_size, (batch_size, seq_len))

# モデル初期化
model = MiniTransformer(vocab_size=vocab_size)

# 順伝播してみる
output = model(dummy_input)

# 出力を確認
print("✅ 出力のshape:", output.shape)  # → torch.Size([2, 10, 1000])
print("✅ 出力の一部:", output[0][0][:5])  # 1単語目の出力の最初の5要素


✅ 出力のshape: torch.Size([2, 10, 1000])
✅ 出力の一部: tensor([-0.2014, -0.6102, -0.1346,  0.7078,  0.0482], grad_fn=<SliceBackward0>)
