In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, V)
        return output, attn_weights
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # Linear projections and split into heads
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Apply attention
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        # Final linear projection
        output = self.W_o(attn_output)
        return output, attn_weights

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
    
    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection and layer norm
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Self-attention (with target mask)
        attn_output, _ = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Cross-attention (with encoder output)
        attn_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_seq_length, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) 
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src_tokens, src_mask=None):
        # Embedding + positional encoding
        x = self.token_embedding(src_tokens)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Pass through encoder layers
        for layer in self.layers:
            x = layer(x, src_mask)
        
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, max_seq_length, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, tgt_tokens, enc_output, src_mask=None, tgt_mask=None):
        # Embedding + positional encoding
        x = self.token_embedding(tgt_tokens)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # Pass through decoder layers
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_layers=6, 
                 num_heads=8, d_ff=2048, max_seq_length=5000, dropout=0.1):
        super().__init__()
        
        self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, 
                              d_ff, max_seq_length, dropout)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads,
                              d_ff, max_seq_length, dropout)
        self.output_layer = nn.Linear(d_model, tgt_vocab_size)
        
        # Weight initialization
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def create_src_mask(self, src_tokens, pad_token=0):
        # Mask for padding tokens
        src_mask = (src_tokens != pad_token).unsqueeze(1).unsqueeze(2)
        return src_mask
    
    def create_tgt_mask(self, tgt_tokens, pad_token=0):
        # Mask for padding and future tokens
        tgt_pad_mask = (tgt_tokens != pad_token).unsqueeze(1).unsqueeze(2)
        seq_len = tgt_tokens.size(1)
        tgt_sub_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(tgt_tokens.device)
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        return tgt_mask
    
    def forward(self, src_tokens, tgt_tokens):
        src_mask = self.create_src_mask(src_tokens)
        tgt_mask = self.create_tgt_mask(tgt_tokens)
        
        enc_output = self.encoder(src_tokens, src_mask)
        dec_output = self.decoder(tgt_tokens, enc_output, src_mask, tgt_mask)
        
        output = self.output_layer(dec_output)
        return output
    
    def generate(self, src_tokens, max_length=50, start_token=1, end_token=2):
        """Метод для генерации последовательности"""
        self.eval()
        
        # Кодируем исходную последовательность
        src_mask = self.create_src_mask(src_tokens)
        enc_output = self.encoder(src_tokens, src_mask)
        
        # Начинаем с start token
        generated = torch.tensor([[start_token]], device=src_tokens.device)
        
        for _ in range(max_length):
            tgt_mask = self.create_tgt_mask(generated)
            dec_output = self.decoder(generated, enc_output, src_mask, tgt_mask)
            output = self.output_layer(dec_output[:, -1, :])
            
            next_token = output.argmax(-1).unsqueeze(0)
            generated = torch.cat([generated, next_token], dim=1)
            
            if next_token.item() == end_token:
                break
        
        return generated

In [2]:
def test_transformer():
    # Параметры
    src_vocab_size = 1000  # Размер словаря исходного языка
    tgt_vocab_size = 1000  # Размер словаря целевого языка
    d_model = 512
    num_layers = 6
    num_heads = 8
    d_ff = 2048
    max_seq_length = 100
    
    # Создаем модель
    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=d_model,
        num_layers=num_layers,
        num_heads=num_heads,
        d_ff=d_ff,
        max_seq_length=max_seq_length
    )
    
    # Тестовые данные (batch_size=2, seq_len=10)
    src_tokens = torch.tensor([
        [1, 2, 3, 4, 5, 0, 0, 0, 0, 0],  # с padding
        [6, 7, 8, 9, 10, 11, 12, 0, 0, 0]
    ])
    
    tgt_tokens = torch.tensor([
        [1, 2, 3, 4, 0, 0, 0, 0, 0, 0],
        [5, 6, 7, 8, 9, 10, 0, 0, 0, 0]
    ])
    
    # Forward pass
    output = model(src_tokens, tgt_tokens[:, :-1])  # Сдвиг для teacher forcing
    
    print(f"Input shape: {src_tokens.shape}")
    print(f"Target shape: {tgt_tokens.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    return model

# Запускаем тест
model = test_transformer()

Input shape: torch.Size([2, 10])
Target shape: torch.Size([2, 10])
Output shape: torch.Size([2, 9, 1000])
Model parameters: 45,675,496


In [5]:
def train_transformer_example():
    model = test_transformer()
    
    # Пример обучения
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Игнорируем padding
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
    
    # Простой цикл обучения (в реальности нужны данные)
    model.train()
    for epoch in range(3):
        # Примерные данные (в реальности здесь ваш DataLoader)
        src = torch.randint(0, 1000, (32, 20))  # batch_size=32, seq_len=20
        tgt = torch.randint(0, 1000, (32, 15))
        
        optimizer.zero_grad()
        output = model(src, tgt[:, :-1])
        
        loss = criterion(output.reshape(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

train_transformer_example()

Input shape: torch.Size([2, 10])
Target shape: torch.Size([2, 10])
Output shape: torch.Size([2, 9, 1000])
Model parameters: 45,675,496
Epoch 1, Loss: 7.2894
Epoch 2, Loss: 7.2009
Epoch 3, Loss: 7.3108
