In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import math

##############################
# 1. Helper Functions & Dataset
##############################

def build_vocab(sentences):
    words = set()
    for sentence in sentences:
        words.update(sentence.lower().split())
    vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
    for word in sorted(words):
        vocab[word] = len(vocab)
    return vocab

def tokenize_sentence(sentence, vocab, max_len):
    tokens = sentence.lower().split()
    tokens = ["<sos>"] + tokens + ["<eos>"]
    token_ids = [vocab.get(token, vocab["<unk>"]) for token in tokens]
    if len(token_ids) < max_len:
        token_ids += [vocab["<pad>"]] * (max_len - len(token_ids))
    else:
        token_ids = token_ids[:max_len]
    return token_ids

dummy_data = [
    {"src": "I am a student", "tgt": "Je suis un étudiant"},
    {"src": "Hello world", "tgt": "Bonjour le monde"},
    {"src": "Good morning", "tgt": "Bonjour"},
    {"src": "How are you", "tgt": "Comment ça va"},
]

src_sentences = [item["src"] for item in dummy_data]
tgt_sentences = [item["tgt"] for item in dummy_data]
src_vocab = build_vocab(src_sentences)
tgt_vocab = build_vocab(tgt_sentences)

# Reverse mapping for decoding target tokens back to words
inv_tgt_vocab = {v: k for k, v in tgt_vocab.items()}

class TranslationDataset(Dataset):
    def __init__(self, data, src_vocab, tgt_vocab, max_len):
        self.data = data
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src_sentence = self.data[idx]["src"]
        tgt_sentence = self.data[idx]["tgt"]
        src_ids = tokenize_sentence(src_sentence, self.src_vocab, self.max_len)
        tgt_ids = tokenize_sentence(tgt_sentence, self.tgt_vocab, self.max_len)
        return torch.tensor(src_ids, dtype=torch.long), torch.tensor(tgt_ids, dtype=torch.long)

max_len = 10
dataset = TranslationDataset(dummy_data, src_vocab, tgt_vocab, max_len)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

##############################
# 2. Transformer Model Components
##############################

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, heads, dropout, forward_expansion):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, forward_expansion * embed_dim),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src, src_mask=None):
        attn_output, _ = self.self_attn(src, src, src, attn_mask=src_mask)
        src = self.norm1(src + self.dropout(attn_output))
        ff_output = self.feed_forward(src)
        src = self.norm2(src + self.dropout(ff_output))
        return src

class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim, heads, dropout, forward_expansion):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout)
        self.enc_dec_attn = nn.MultiheadAttention(embed_dim, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, forward_expansion * embed_dim),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_dim, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        self_attn_output, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
        tgt = self.norm1(tgt + self.dropout(self_attn_output))
        enc_dec_output, _ = self.enc_dec_attn(tgt, memory, memory, attn_mask=memory_mask)
        tgt = self.norm2(tgt + self.dropout(enc_dec_output))
        ff_output = self.feed_forward(tgt)
        tgt = self.norm3(tgt + self.dropout(ff_output))
        return tgt

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim, num_encoder_layers,
                 num_decoder_layers, heads, dropout, forward_expansion, max_len=100):
        super(Transformer, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim, max_len)
        self.pos_decoder = PositionalEncoding(embed_dim, max_len)
        
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, heads, dropout, forward_expansion)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, heads, dropout, forward_expansion)
            for _ in range(num_decoder_layers)
        ])
        self.fc_out = nn.Linear(embed_dim, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.src_embedding(src)       # (batch_size, src_seq_len, embed_dim)
        tgt = self.tgt_embedding(tgt)       # (batch_size, tgt_seq_len, embed_dim)
        src = self.pos_encoder(src)
        tgt = self.pos_decoder(tgt)
        src = self.dropout(src)
        tgt = self.dropout(tgt)
        
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)
        
        for layer in self.encoder_layers:
            src = layer(src, src_mask)
        memory = src
        
        for layer in self.decoder_layers:
            tgt = layer(tgt, memory, tgt_mask, src_mask)
        
        out = tgt.transpose(0, 1)  # (batch_size, tgt_seq_len, embed_dim)
        out = self.fc_out(out)
        return out

##############################
# 3. Training Function
##############################

def train_transformer():
    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)
    embed_dim = 32
    num_encoder_layers = 2
    num_decoder_layers = 2
    heads = 2
    dropout = 0.1
    forward_expansion = 2
    max_len_model = max_len
    epochs = 10
    lr = 1e-3

    model = Transformer(src_vocab_size, tgt_vocab_size, embed_dim,
                        num_encoder_layers, num_decoder_layers, heads,
                        dropout, forward_expansion, max_len_model)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=src_vocab["<pad>"])

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (src_batch, tgt_batch) in enumerate(dataloader):
            src_batch = src_batch.to(device)
            tgt_batch = tgt_batch.to(device)
            # Prepare target input and output (teacher forcing)
            tgt_input = tgt_batch[:, :-1]
            tgt_output = tgt_batch[:, 1:]

            optimizer.zero_grad()
            output = model(src_batch, tgt_input)
            output = output.reshape(-1, tgt_vocab_size)
            tgt_output = tgt_output.reshape(-1)
            loss = criterion(output, tgt_output)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if batch_idx % 1 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {loss.item():.4f}")
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), "transformer_translation.pth")
    print("Training complete. Model saved as transformer_translation.pth")

##############################
# 4. Inference / Testing Function
##############################

def greedy_decode(model, src_sentence, src_vocab, tgt_vocab, max_len):
    model.eval()
    # Tokenize source sentence
    src_ids = tokenize_sentence(src_sentence, src_vocab, max_len)
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0)  # (1, seq_len)
    src_tensor = src_tensor.to(next(model.parameters()).device)
    
    # Start target with <sos>
    tgt_ids = [tgt_vocab["<sos>"]]
    for _ in range(max_len - 1):
        tgt_tensor = torch.tensor(tgt_ids, dtype=torch.long).unsqueeze(0).to(src_tensor.device)
        # Generate output
        with torch.no_grad():
            output = model(src_tensor, tgt_tensor)
        # Get the logits of the last token in sequence
        next_token_logits = output[:, -1, :]  # (1, vocab_size)
        next_token = next_token_logits.argmax(dim=-1).item()
        tgt_ids.append(next_token)
        if next_token == tgt_vocab["<eos>"]:
            break
    return tgt_ids

def decode_sentence(token_ids, inv_vocab):
    # Convert token IDs back to words and remove special tokens
    words = []
    for token in token_ids:
        word = inv_vocab.get(token, "<unk>")
        if word in ["<sos>", "<eos>", "<pad>"]:
            continue
        words.append(word)
    return " ".join(words)

##############################
# 5. Main Execution
##############################

if __name__ == "__main__":
    train_transformer()
    
    # Load the trained model for testing
    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)
    embed_dim = 32
    num_encoder_layers = 2
    num_decoder_layers = 2
    heads = 2
    dropout = 0.1
    forward_expansion = 2
    max_len_model = max_len

    model = Transformer(src_vocab_size, tgt_vocab_size, embed_dim,
                        num_encoder_layers, num_decoder_layers, heads,
                        dropout, forward_expansion, max_len_model)
    model.load_state_dict(torch.load("transformer_translation.pth"))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Test with a sample source sentence
    test_sentence = "Hello world"
    print("\nTest Source Sentence:", test_sentence)
    predicted_ids = greedy_decode(model, test_sentence, src_vocab, tgt_vocab, max_len)
    predicted_sentence = decode_sentence(predicted_ids, inv_tgt_vocab)
    print("Predicted Translation:", predicted_sentence)


Epoch 1, Batch 1, Loss: 2.5273
Epoch 1, Batch 2, Loss: 2.6432
Epoch 1 Average Loss: 2.5852
Epoch 2, Batch 1, Loss: 2.2309
Epoch 2, Batch 2, Loss: 2.5676
Epoch 2 Average Loss: 2.3993
Epoch 3, Batch 1, Loss: 2.4073
Epoch 3, Batch 2, Loss: 2.0881
Epoch 3 Average Loss: 2.2477
Epoch 4, Batch 1, Loss: 2.1631
Epoch 4, Batch 2, Loss: 2.1112
Epoch 4 Average Loss: 2.1372
Epoch 5, Batch 1, Loss: 2.1438
Epoch 5, Batch 2, Loss: 1.9478
Epoch 5 Average Loss: 2.0458
Epoch 6, Batch 1, Loss: 2.0849
Epoch 6, Batch 2, Loss: 1.8571
Epoch 6 Average Loss: 1.9710
Epoch 7, Batch 1, Loss: 2.1208
Epoch 7, Batch 2, Loss: 1.7018
Epoch 7 Average Loss: 1.9113
Epoch 8, Batch 1, Loss: 1.9595
Epoch 8, Batch 2, Loss: 1.5335
Epoch 8 Average Loss: 1.7465
Epoch 9, Batch 1, Loss: 1.5998
Epoch 9, Batch 2, Loss: 1.9508
Epoch 9 Average Loss: 1.7753
Epoch 10, Batch 1, Loss: 1.6418
Epoch 10, Batch 2, Loss: 1.6593
Epoch 10 Average Loss: 1.6505
Training complete. Model saved as transformer_translation.pth

Test Source Sentence: He

  model.load_state_dict(torch.load("transformer_translation.pth"))
