<a href="https://colab.research.google.com/github/VladimirApter/ml4se_course_2025/blob/main/homeworks/ml4se_course_hw4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Модель: ChatGPT 5**

**Промпт:**
Реализуй transformer Encoder с помощью PyTorch

In [14]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------
# Positional Encoding
# ---------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


# ---------------------------
# Scaled Dot-Product Attention
# ---------------------------
def attention(q, k, v):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    attn = F.softmax(scores, dim=-1)
    return torch.matmul(attn, v), attn


# ---------------------------
# Multi-Head Attention
# ---------------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch, seq_len, d_model = x.size()
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        def split_heads(t):
            return t.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        q, k, v = split_heads(q), split_heads(k), split_heads(v)
        out, attn = attention(q, k, v)
        out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
        return self.out(out), attn


# ---------------------------
# Feed-Forward Network
# ---------------------------
class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, d_model)
        )

    def forward(self, x):
        return self.net(x)


# ---------------------------
# Transformer Encoder Layer
# ---------------------------
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, hidden_dim=128):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, hidden_dim)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.norm1(x + self.mha(x)[0])
        x = self.norm2(x + self.ffn(x))
        return x


# ---------------------------
# Transformer Encoder
# ---------------------------
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=32, num_heads=4, num_layers=2, max_len=100):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads) for _ in range(num_layers)])

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x)
        return x


if __name__ == "__main__":
    batch_size = 2
    seq_len = 10
    vocab_size = 50

    x = torch.randint(0, vocab_size, (batch_size, seq_len))  # случайные токены для примера
    encoder = TransformerEncoder(vocab_size)
    out = encoder(x)
    print(out)


tensor([[[ 1.2748,  1.5467,  1.0017,  1.0257,  0.4898,  0.7739, -1.0106,
          -0.7630, -1.3428,  0.9685, -0.4474,  0.6119, -0.6583,  0.2419,
          -2.1787,  1.1806,  1.2875,  1.0062, -0.2527, -1.3480, -0.6815,
          -0.0757, -0.9921, -0.1013, -0.2332,  1.8357, -0.3432, -0.2364,
          -1.2520,  0.3052, -1.4978, -0.1353],
         [ 0.9318,  0.0581, -1.1232,  0.8338,  0.7476, -0.3701, -0.8286,
          -1.0097, -1.1500,  1.1058, -1.4011,  0.4856, -0.5064,  1.8823,
          -1.1411, -0.4995, -0.9857,  1.1176,  0.3492, -0.6444, -0.2599,
          -0.3085,  0.6084, -0.7561,  0.7005,  1.0891,  0.1033,  1.2254,
          -2.3769,  1.9839, -0.0472,  0.1860],
         [ 1.9675, -0.5258, -0.8209, -0.8997, -0.2780,  1.1369, -1.1226,
           0.5430,  1.1129,  1.1498, -3.1179, -0.2546,  0.3334,  1.0349,
          -1.5578, -0.1944,  0.6038, -0.7521, -0.2527, -0.9540,  0.5158,
          -0.0922,  0.2762,  0.4092,  0.0435,  1.0747,  0.2018, -0.6792,
          -0.5334, -0.6019,  0