In [2]:
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.nn.functional as F

##############################################
# Utility: Tokenizer and vocabulary creation #
##############################################
# Source and target example sentences
sentence_en = "I love AI ."
sentence_fr = "J' adore l'IA ."

# Note: Vocabulary mapping must match the tokens that appear (you can extend these as needed).
word_map_en = {"<pad>": 0, "I": 1, "love": 2, "AI": 3, ".": 4}
word_map_fr = {"<pad>": 0, "J'": 1, "adore": 2, "l'IA": 3, ".": 4}

def tokenize(sentence, word_map):
    tokens = [word_map[word] for word in sentence.split()]
    print(f"Tokens for '{sentence}': {tokens}")
    return torch.tensor(tokens)

##############################################
# Positional Encoding (used by both encoder and decoder)
##############################################
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # [max_len, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * 
                             -(torch.log(torch.tensor(10000.0)) / d_model))
        encoding[:, 0::2] = torch.sin(position * div_term)  # Even indices
        encoding[:, 1::2] = torch.cos(position * div_term)  # Odd indices

        # Register as buffer so it’s saved in state_dict but not updated by optimizer.
        self.register_buffer("pe", encoding.unsqueeze(0))  # shape (1, max_len, d_model)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        pos_enc = self.pe[:, :x.size(1)]
        print("\nPositional Encoding Values:")
        print(pos_enc)
        x = x + pos_enc
        print("\nEmbeddings After Adding Positional Encoding:")
        print(x)
        return x

##############################################
# Multi-Head Attention (supports separate Q/K/V for both encoder and decoder)
##############################################
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_k = d_model // num_heads
        self.d_v = d_model // num_heads

        # Linear transformations for Q, K, V.
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)

        # Final output linear layer.
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, query, key=None, value=None, mask=None):
        # If key/value are not provided, use query (for self-attention)
        if key is None:
            key = query
        if value is None:
            value = query

        batch_size = query.size(0)
        seq_len_q = query.size(1)
        seq_len_k = key.size(1)

        # Linear projections and reshape for multiple heads
        q = self.linear_q(query).view(batch_size, seq_len_q, self.num_heads, self.d_k).transpose(1, 2)
        k = self.linear_k(key).view(batch_size, seq_len_k, self.num_heads, self.d_k).transpose(1, 2)
        v = self.linear_v(value).view(batch_size, seq_len_k, self.num_heads, self.d_v).transpose(1, 2)

        print("\nQuery Matrix (Q):")
        print(q)
        print("\nKey Matrix (K):")
        print(k)
        print("\nValue Matrix (V):")
        print(v)

        # Compute scaled dot-product attention scores.
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        print("\nRaw Attention Scores (before softmax):")
        print(attn_scores)

        if mask is not None:
            # mask should be broadcastable to [batch_size, num_heads, seq_len_q, seq_len_k]
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(attn_scores, dim=-1)
        print("\nAttention Weights (after softmax):")
        print(attn_weights)

        # Multiply the weights by values.
        attention_output = torch.matmul(attn_weights, v)
        # Concatenate multiple heads and apply a final linear transform.
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        print("\nAttention Output (before final FC layer):")
        print(attention_output)

        output = self.fc(attention_output)
        print("\nFinal Output from Multi-Head Attention (after FC):")
        print(output)

        return output

##############################################
# Feed-Forward Network
##############################################
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=512):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        print("\nInput to FeedForward Network:")
        print(x)
        x_fc1 = self.fc1(x)
        print("\nAfter first linear layer (fc1):")
        print(x_fc1)
        x_relu = F.relu(x_fc1)
        print("\nAfter ReLU Activation:")
        print(x_relu)
        x_fc2 = self.fc2(x_relu)
        print("\nAfter second linear layer (fc2):")
        print(x_fc2)
        return x_fc2

##############################################
# Encoder Layer (one layer for simplicity)
##############################################
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=512):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feedforward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        print("\n=== Encoder Layer Input ===")
        print(x)
        attn_output = self.self_attn(x, mask=mask)
        x = self.norm1(x + attn_output)
        print("\n=== After Self-Attention and Norm (Encoder) ===")
        print(x)
        ff_output = self.feedforward(x)
        x = self.norm2(x + ff_output)
        print("\n=== After FeedForward and Norm (Encoder) ===")
        print(x)
        return x

##############################################
# Decoder Layer (with masked self-attention and cross-attention)
##############################################
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=512):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feedforward = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
        print("\n=== Decoder Layer Input ===")
        print(x)
        # 1. Masked self-attention (causal mask prevents future lookahead)
        self_attn_out = self.self_attn(x, x, x, mask=self_mask)
        x = self.norm1(x + self_attn_out)
        print("\n=== After Masked Self-Attention and Norm (Decoder) ===")
        print(x)
        # 2. Cross-attention (queries from decoder x, keys/values from encoder_output)
        cross_attn_out = self.cross_attn(x, encoder_output, encoder_output, mask=cross_mask)
        x = self.norm2(x + cross_attn_out)
        print("\n=== After Cross-Attention and Norm (Decoder) ===")
        print(x)
        # 3. Feed-Forward Network.
        ff_output = self.feedforward(x)
        x = self.norm3(x + ff_output)
        print("\n=== After FeedForward and Norm (Decoder) ===")
        print(x)
        return x

##############################################
# Functions to build full encoder and decoder stacks
##############################################
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len=5000):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        print("\n=== Encoder Input Tokens ===")
        print(x)
        x = self.embedding(x)
        print("\n=== Token Embeddings ===")
        print(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        print("\n=== Final Encoder Output ===")
        print(x)
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len=5000):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
        # Final linear layer to project the decoder output to vocabulary size (for prediction)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x, encoder_output, self_mask=None, cross_mask=None):
        print("\n=== Decoder Input Tokens ===")
        print(x)
        x = self.embedding(x)
        print("\n=== Decoder Token Embeddings ===")
        print(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, encoder_output, self_mask, cross_mask)
        logits = self.fc_out(x)
        print("\n=== Final Decoder Output (logits) ===")
        print(logits)
        return logits

##############################################
# Function to generate a causal mask for decoder self-attention.
##############################################
def generate_square_subsequent_mask(sz):
    # Creates a lower-triangular matrix of ones.
    mask = torch.tril(torch.ones(sz, sz)).unsqueeze(0).unsqueeze(0)  # shape: (1, 1, sz, sz)
    return mask  # 1s for allowed positions; 0s elsewhere.

##############################################
# Full Transformer Model combining the Encoder and Decoder
##############################################
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=8, num_heads=2, d_ff=32, num_enc_layers=1, num_dec_layers=1, max_len=5000):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, num_heads, d_ff, num_enc_layers, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_heads, d_ff, num_dec_layers, max_len)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, cross_mask=None):
        encoder_output = self.encoder(src, src_mask)
        output = self.decoder(tgt, encoder_output, self_mask=tgt_mask, cross_mask=cross_mask)
        return output

##############################################
# Example Usage: Run full Transformer on tokenized sentences
##############################################
if __name__ == "__main__":
    torch.manual_seed(0)  # For reproducibility

    # Tokenize source and target sentences (shape will be [seq_len])
    src_tokens = tokenize(sentence_en, word_map_en)  # e.g., [1, 2, 3, 4]
    tgt_tokens = tokenize(sentence_fr, word_map_fr)  # e.g., [1, 2, 3, 4]
    
    # Add a batch dimension: shape becomes (batch_size, seq_len)
    src_tokens = src_tokens.unsqueeze(0)
    tgt_tokens = tgt_tokens.unsqueeze(0)

    # Model hyperparameters (d_model=8, vocab sizes per your word maps)
    d_model = 8
    num_heads = 2
    d_ff = 32  # Setting a small feedforward inner layer for clarity
    num_enc_layers = 1
    num_dec_layers = 1

    transformer_model = Transformer(src_vocab_size=len(word_map_en),
                                    tgt_vocab_size=len(word_map_fr),
                                    d_model=d_model,
                                    num_heads=num_heads,
                                    d_ff=d_ff,
                                    num_enc_layers=num_enc_layers,
                                    num_dec_layers=num_dec_layers)

    # Create mask for target (decoder) self-attention so that each token can only attend to previous tokens.
    tgt_seq_len = tgt_tokens.size(1)
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)

    # For simplicity, we are not applying any special mask in encoder or cross-attention.
    src_mask = None
    cross_mask = None

    # Run the full Transformer: it will print intermediate values along the way.
    output_logits = transformer_model(src_tokens, tgt_tokens, src_mask=src_mask, tgt_mask=tgt_mask, cross_mask=cross_mask)

    print("\n=== Final Output Logits from Transformer ===")
    print(output_logits)

Tokens for 'I love AI .': [1, 2, 3, 4]
Tokens for 'J' adore l'IA .': [1, 2, 3, 4]

=== Encoder Input Tokens ===
tensor([[1, 2, 3, 4]])

=== Token Embeddings ===
tensor([[[ 0.3223, -1.2633,  0.3500,  0.3081,  0.1198,  1.2377,  1.1168,
          -0.2473],
         [-1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,
           1.8530],
         [-0.2159, -0.7425,  0.5627,  0.2596, -0.1740, -0.6787,  0.9383,
           0.4889],
         [ 1.2032,  0.0845, -1.2001, -0.0048, -0.5181, -0.3067, -1.5810,
           1.7066]]], grad_fn=<EmbeddingBackward0>)

Positional Encoding Values:
tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e

In [4]:
!nvidia-smi

Sat Feb  8 07:25:48 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        Off | 00000000:01:00.0 Off |                  N/A |
|  0%   47C    P8               9W / 320W |   6040MiB / 16376MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    