# Transformer Full

In [2]:
import torch
import torch.nn as nn
# import PositionalEncoding, TransformerEncoder, TransformerDecoder

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, ffn_hidden=2048, num_layers=8, max_len=5000):
        super().__init__()
        
        # Embedding layer for both encoder and decoder inputs
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model, max_len=max_len)

        # Encoder and decoder stacks
        self.encoder = TransformerEncoder(num_layers, d_model, num_heads, ffn_hidden)
        self.decoder = TransformerDecoder(num_layers, d_model, num_heads, ffn_hidden)

        # Final output layer: projects to vocab size
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, src_tokens, tgt_tokens, look_ahead_mask=None):
        """
        src_tokens: (src_seq_len,) or (batch, src_seq_len)
        tgt_tokens: (tgt_seq_len,) or (batch, tgt_seq_len)
        """
        # Embed + add positional encoding
        src_embed = self.token_embedding(src_tokens)          # (seq_len, d_model)
        src_embed = self.positional_encoding(src_embed)

        tgt_embed = self.token_embedding(tgt_tokens)          # (seq_len, d_model)
        tgt_embed = self.positional_encoding(tgt_embed)

        # Encoder
        encoder_output = self.encoder(src_embed)

        # Decoder
        decoder_output = self.decoder(tgt_embed, encoder_output, look_ahead_mask)

        # Project to vocab logits
        logits = self.output_layer(decoder_output)  # (seq_len, vocab_size)
        return logits


In [3]:
# ------------------------
# Full Transformer Model with Embeddings
# ------------------------
class TransformerWithEmbeddings(nn.Module):
    def __init__(self, vocab_size_src, vocab_size_tgt, d_model=512, num_heads=8, ffn_hidden=2048, num_layers=6, max_len=100):
        super().__init__()

        self.d_model = d_model

        # Embedding layers for source and target
        self.src_embedding = nn.Embedding(vocab_size_src, d_model)
        self.tgt_embedding = nn.Embedding(vocab_size_tgt, d_model)

        # Shared positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_len)

        # Transformer encoder-decoder
        self.transformer = Transformer(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            ffn_hidden=ffn_hidden
        )

        # Final linear layer to project decoder output to vocabulary
        self.output_linear = nn.Linear(d_model, vocab_size_tgt)

    def forward(self, src_ids, tgt_ids):
        # src_ids: (src_seq_len,)
        # tgt_ids: (tgt_seq_len,)

        # Embed and encode positional info
        src = self.pos_encoding(self.src_embedding(src_ids) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32)))
        tgt = self.pos_encoding(self.tgt_embedding(tgt_ids) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32)))

        # Create look-ahead mask
        look_ahead_mask = self.generate_look_ahead_mask(tgt.size(0))

        # Transformer forward
        output = self.transformer(src, tgt, look_ahead_mask)

        # Project to vocab size
        logits = self.output_linear(output)  # Shape: (tgt_seq_len, vocab_size_tgt)
        return logits

    def generate_look_ahead_mask(self, size):
        # Creates a lower-triangular matrix (1s in allowed positions, 0s elsewhere)
        return torch.tril(torch.ones(size, size)).bool()