# Pig Latin, again

Inspired by CSC413 assignment

## Data prep

In [None]:
"""
import pandas as pd
import numpy as np
import os


with open('data\pig_latin_small.txt','r') as f, \
     open('data\small_src.txt','w') as src,\
     open('data\small_tgt.txt','w') as tgt:
     for line in f:
            parts = line.strip().split()
            if len(parts) >=2:
                src.write(parts[0] + '\n')
                tgt.write(parts[1] + '\n')


with open('data\pig_latin_large.txt','r') as f, \
     open('data\large_src.txt','w') as src,\
     open('data\large_tgt.txt','w') as tgt:
     for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                src.write(parts[0] + '\n')
                tgt.write(parts[1] + '\n')
"""

  with open('data\pig_latin_small.txt','r') as f, \
  open('data\small_src.txt','w') as src,\
  open('data\small_tgt.txt','w') as tgt:
  with open('data\pig_latin_large.txt','r') as f, \
  open('data\large_src.txt','w') as src,\
  open('data\large_tgt.txt','w') as tgt:


## Construct Basic Components

### Scaled Dot Attention

In [1]:
import torch
import torch.nn as nn

class Scaled_Dot_Attention(nn.Module):
    def __init__(self, d):
        super(Scaled_Dot_Attention, self).__init__()
        self.d = d
        self.Q = nn.Linear(d, d)
        self.K = nn.Linear(d, d)
        self.V = nn.Linear(d, d)
        self.softmax = nn.Softmax(dim=2)
        self.scaling_factor = torch.rsqrt(
            torch.tensor(self.d, dtype=torch.float)
        )
    def forward(self, queries, keys, values):
        Q = self.Q(queries)
        K = self.K(keys).transpose(1, 2)
        V = self.V(values)
        attention_scores = torch.bmm(Q, K) * self.scaling_factor
        attention_weights = self.softmax(attention_scores)
        attention = torch.bmm(attention_weights, V)
        return attention, attention_weights


### Multi-Head Attention

In [2]:
import torch
import torch.nn as nn

class Multihead_Attention(nn.Module):
    def __init__(self, d_model, h):
        super(Multihead_Attention, self).__init__()
        self.d = d_model//h
        self.Q = nn.ModuleList()
        self.K = nn.ModuleList()
        self.V = nn.ModuleList()
        self.WO = nn.Linear(h*self.d, d_model)
        self.h = h
        for i in range(h):
            self.Q.append(nn.Linear(d_model, self.d))
            self.K.append(nn.Linear(d_model, self.d))
            self.V.append(nn.Linear(d_model, self.d))

        self.softmax = nn.Softmax(dim=2)
        self.scaling_factor = torch.rsqrt(
            torch.tensor(self.d, dtype=torch.float)
        )
    def forward(self, queries, keys, values):
        attention_heads = []
        attention_weights = []
        for i in range(self.h):
            Q = self.Q[i](queries)
            K = self.K[i](keys).transpose(1, 2)
            V = self.V[i](values)
            attention_scores = torch.bmm(Q, K) * self.scaling_factor
            attention_w = self.softmax(attention_scores)
            attention_weights.append(attention_w)
            attention = torch.bmm(attention_w, V)
            attention_heads.append(attention)
        
        attention = torch.cat(attention_heads, dim=2)
        attention = self.WO(attention)
        return attention, attention_weights


### Test

In [3]:
test = Multihead_Attention(512, 8)
x = torch.randn(10, 6, 512)
test.forward(x, x, x)

(tensor([[[-1.0173e-01, -7.3404e-02, -1.1476e-01,  ...,  2.5383e-01,
           -1.2380e-01, -3.0524e-02],
          [-1.4823e-01, -5.5232e-02, -8.5166e-02,  ...,  2.7197e-01,
           -2.3911e-02,  7.8952e-02],
          [-1.5832e-01, -6.5937e-02, -5.7866e-02,  ...,  2.2423e-01,
           -4.0030e-02,  1.8797e-03],
          [-1.7705e-01, -7.8383e-02, -6.4205e-02,  ...,  2.8287e-01,
           -9.7089e-02, -8.6408e-03],
          [-1.5542e-01, -7.7695e-02, -1.2241e-01,  ...,  2.5611e-01,
           -1.5427e-01,  7.1391e-02],
          [-1.3411e-01,  1.4788e-03, -6.6327e-02,  ...,  2.7231e-01,
           -7.7905e-02,  1.6624e-02]],
 
         [[-7.2817e-02,  2.0150e-01,  6.0139e-02,  ...,  7.8401e-02,
            6.5397e-02, -1.1457e-01],
          [-8.6510e-02,  2.1970e-01,  5.5662e-02,  ...,  1.3764e-01,
            6.9616e-02, -1.8051e-01],
          [-1.6608e-01,  1.9520e-01,  7.2893e-02,  ...,  7.2950e-02,
            6.7888e-03, -9.5728e-02],
          [-1.3589e-01,  2.5303e-0

In [4]:
import torch
import torch.nn.functional as F

# Example dummy attention scores (batch_size=1, query_len=3, key_len=3)
attention_scores = torch.tensor([[[1.0, 2.0, 3.0],
                                  [4.0, 5.0, 6.0],
                                  [7.0, 8.0, 9.0]]])

print("Original attention scores:\n", attention_scores)

# Create a mask: 1 allows attention, 0 blocks attention
mask = torch.tensor([[[1, 0, 1],
                      [1, 1, 0],
                      [0, 1, 1]]])

# Apply mask using masked_fill: wherever mask == 0, set to -1e9 (large negative)
masked_attention_scores = attention_scores.masked_fill(mask == 0, float('-1e9'))

print("\nMasked attention scores:\n", masked_attention_scores)

# Apply softmax over last dimension (key_len)
attention_probs = F.softmax(masked_attention_scores, dim=-1)

print("\nAttention probabilities after softmax:\n", attention_probs)


Original attention scores:
 tensor([[[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])

Masked attention scores:
 tensor([[[ 1.0000e+00, -1.0000e+09,  3.0000e+00],
         [ 4.0000e+00,  5.0000e+00, -1.0000e+09],
         [-1.0000e+09,  8.0000e+00,  9.0000e+00]]])

Attention probabilities after softmax:
 tensor([[[0.1192, 0.0000, 0.8808],
         [0.2689, 0.7311, 0.0000],
         [0.0000, 0.2689, 0.7311]]])


In [5]:
import torch
import torch.nn as nn

class Multihead_Attention_Masked(nn.Module):
    def __init__(self, d_model, h):
        super(Multihead_Attention_Masked, self).__init__()
        self.d = d_model//h
        self.Q = nn.ModuleList()
        self.K = nn.ModuleList()
        self.V = nn.ModuleList()
        self.WO = nn.Linear(h*self.d, d_model)
        self.h = h
        for i in range(h):
            self.Q.append(nn.Linear(d_model, self.d))
            self.K.append(nn.Linear(d_model, self.d))
            self.V.append(nn.Linear(d_model, self.d))

        self.softmax = nn.Softmax(dim=2)
        self.scaling_factor = torch.rsqrt(
            torch.tensor(self.d, dtype=torch.float)
        )
    def forward(self, queries, keys, values, mask=None):
        attention_heads = []
        attention_weights = []
        for i in range(self.h):
            Q = self.Q[i](queries)
            K = self.K[i](keys).transpose(1, 2)
            V = self.V[i](values)
            attention_scores = torch.bmm(Q, K) * self.scaling_factor

            if mask is not None:
                attention_scores = attention_scores.masked_fill(mask == 0, float('-1e9'))
                
            attention_w = self.softmax(attention_scores)
            attention_weights.append(attention_w)
            attention = torch.bmm(attention_w, V)
            attention_heads.append(attention)
        
        attention = torch.cat(attention_heads, dim=2)
        attention = self.WO(attention)
        return attention, attention_weights


In [6]:
import torch

# Define your attention module
mha = Multihead_Attention_Masked(d_model=32, h=4)  # Example small model
batch_size = 2
seq_len = 5
d_model = 32

# Create dummy inputs
x = torch.randn(batch_size, seq_len, d_model)  # (batch, seq_len, d_model)

# Create causal mask (batch_size, seq_len, seq_len)
def generate_causal_mask(batch_size, seq_len, device):
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).unsqueeze(0).repeat(batch_size, 1, 1)
    return mask

mask = generate_causal_mask(batch_size, seq_len, x.device)

# Forward pass with mask
output, attn_weights = mha(x, x, x, mask=mask)

print("Output shape:", output.shape)  # Expected: (batch_size, seq_len, d_model)
print("Attention weights shape per head:", attn_weights[0].shape)  # (batch_size, seq_len, seq_len)
print("Causal mask:\n", mask[0])
print("First head attention weights:\n", attn_weights[0][0])  # Visualize head 0, batch 0


Output shape: torch.Size([2, 5, 32])
Attention weights shape per head: torch.Size([2, 5, 5])
Causal mask:
 tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])
First head attention weights:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3651, 0.6349, 0.0000, 0.0000, 0.0000],
        [0.4139, 0.2928, 0.2933, 0.0000, 0.0000],
        [0.1590, 0.2242, 0.3177, 0.2991, 0.0000],
        [0.1350, 0.1600, 0.3227, 0.1777, 0.2047]], grad_fn=<SelectBackward0>)


### Transformer Encoder

In [None]:
N = 6
d_model = 512
n_head = 8
class Transformer_Encoder_naive(nn.Module):
    def __init__(self, N, d_model, n_head, vocab_size):
        super(Transformer_Encoder_naive, self).__init__()
        self.N = N
        self.d_model = d_model
        self.n_head = n_head
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pe = self.positional_encoding()

        self.attentions = nn.ModuleList(
            [Multihead_Attention(d_model, n_head) for i in range(N)]
        )

        self.mlps = nn.ModuleList(
            [   nn.Sequential(
                    nn.Linear(d_model, 4*d_model),
                    nn.ReLU(),
                    nn.Linear(4*d_model, d_model),
                )
                for i in range(self.N)
            ]
        )

    def forward(self, x):

        seq_len = x.shape[1]
        embedded = self.embedding(x)
        embedded = embedded + self.pe[:seq_len]
        for i in range(self.N):
            attention, _ = self.attentions[i](embedded, embedded, embedded)
            attention = attention + embedded
            attention = nn.LayerNorm(attention.shape)(attention)
            mlp = self.mlps[i](attention)
            embedded = mlp + attention
            embedded = nn.LayerNorm(embedded.shape)(embedded)
        return embedded

    def positional_encoding(self, seq_len = 1000):
        d_model = self.d_model
        position = torch.arange(seq_len).unsqueeze(1)
        exp_term = torch.arange(d_model//2).unsqueeze(0) / d_model
        sin_term = torch.sin(position / (10000 ** exp_term))
        cos_term = torch.cos(position / (10000 ** exp_term))
        positional_encoding = torch.zeros(seq_len, d_model)
        positional_encoding[:, 0::2] = sin_term
        positional_encoding[:, 1::2] = cos_term
        positional_encoding.to('cuda')
        return positional_encoding
        

In [25]:
"""N = 6
d_model = 512
n_head = 8
class Transformer_Encoder(nn.Module):
    def __init__(self, N, d_model, n_head, vocab_size):
        super(Transformer_Encoder, self).__init__()
        self.N = N
        self.d_model = d_model
        self.n_head = n_head
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pe = self.positional_encoding()
        self.pe.to('cuda')

        self.attentions = nn.ModuleList(
            [Multihead_Attention_Masked(d_model, n_head) for i in range(N)]
        )

        self.mlps = nn.ModuleList(
            [   nn.Sequential(
                    nn.Linear(d_model, 4*d_model),
                    nn.ReLU(),
                    nn.Linear(4*d_model, d_model),
                )
                for i in range(self.N)
            ]
        )

    def forward(self, x):
        x.to('cuda')
        seq_len = x.shape[1]
        embedded = self.embedding(x)
        embedded = embedded + self.pe[:seq_len].to(embedded.device)
        for i in range(self.N):
            attention, _ = self.attentions[i](embedded, embedded, embedded)
            attention = attention + embedded
            attention = nn.LayerNorm(attention.shape)(attention)
            mlp = self.mlps[i](attention)
            embedded = mlp + attention
            embedded = nn.LayerNorm(embedded.shape)(embedded)
        return embedded

    def positional_encoding(self, seq_len = 1000):
        d_model = self.d_model
        position = torch.arange(seq_len).unsqueeze(1)
        exp_term = torch.arange(d_model//2).unsqueeze(0) / d_model
        sin_term = torch.sin(position / (10000 ** exp_term))
        cos_term = torch.cos(position / (10000 ** exp_term))
        positional_encoding = torch.zeros(seq_len, d_model)
        positional_encoding[:, 0::2] = sin_term
        positional_encoding[:, 1::2] = cos_term
        return positional_encoding
        """
class Transformer_Encoder(nn.Module):
    def __init__(self, N, d_model, n_head, vocab_size):
        super(Transformer_Encoder, self).__init__()
        self.N = N
        self.d_model = d_model
        self.n_head = n_head
        self.vocab_size = vocab_size

        # Embedding
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Positional Encoding registered as buffer (device safe)
        self.register_buffer('pe', self.positional_encoding())

        # Attention modules
        self.attentions = nn.ModuleList([
            Multihead_Attention_Masked(d_model, n_head) for _ in range(N)
        ])

        # Feedforward (MLP) modules
        self.mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, 4 * d_model),
                nn.ReLU(),
                nn.Linear(4 * d_model, d_model),
            ) for _ in range(N)
        ])

        # LayerNorms: 2 per block (after attn, after mlp)
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(2 * N)
        ])

    def forward(self, x):
        device = x.device  # Get device dynamically

        seq_len = x.shape[1]

        # Embedding + positional encoding
        embedded = self.embedding(x)  # (batch_size, seq_len, d_model)
        pe = self.pe[:seq_len, :].to(device)  # Make sure PE is on same device
        embedded = embedded + pe

        # Encoder layers
        for i in range(self.N):
            # Self-attention (masked or not, depends on design)
            attention, _ = self.attentions[i](embedded, embedded, embedded)
            embedded = self.layer_norms[2 * i](embedded + attention)

            # Feedforward (MLP)
            mlp_output = self.mlps[i](embedded)
            embedded = self.layer_norms[2 * i + 1](embedded + mlp_output)

        return embedded  # (batch_size, seq_len, d_model)

    def positional_encoding(self, seq_len=1000):
        d_model = self.d_model
        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe  # (seq_len, d_model)


In [9]:
test_encoder = Transformer_Encoder_naive(N, d_model, n_head, 1000)
test_tensor = torch.randint(0, 1000, (10, 6))
print(test_tensor.shape)
test_encoder.forward(test_tensor)

torch.Size([10, 6])


tensor([[[ 0.6070, -0.4748,  0.0108,  ..., -0.1395, -0.5308,  0.1014],
         [ 1.6450,  0.1393, -0.8619,  ...,  0.0583, -0.6279,  0.7517],
         [ 0.0480, -0.3419,  0.6967,  ...,  0.3572, -1.1321,  0.1739],
         [-0.4232, -3.2014, -0.6346,  ...,  0.1541, -0.6211,  0.2021],
         [ 0.6266, -0.2097, -1.1710,  ..., -0.2623, -0.0604,  0.4459],
         [-0.5671, -0.4261, -0.5070,  ..., -0.1120, -0.8331,  0.8379]],

        [[ 0.4379,  0.3000,  0.6080,  ...,  0.8164, -0.3258,  1.1259],
         [-0.0790, -1.1930,  0.2627,  ...,  1.0599, -1.9356, -0.0343],
         [-0.2700, -1.2963,  0.3298,  ...,  0.6309,  1.8400, -1.2053],
         [-0.8559, -0.1037, -0.9344,  ..., -0.3237,  0.7930,  0.0665],
         [-0.8746, -1.6965, -1.1340,  ...,  1.5759,  0.8942,  1.0444],
         [-1.6168, -0.9863, -0.3203,  ...,  1.0035,  1.2694,  1.2060]],

        [[ 0.0726, -1.1976,  0.6546,  ..., -1.0682, -1.2797,  1.5777],
         [-0.5232, -0.4907,  1.3249,  ...,  0.2323, -1.4941,  2.0197],
  

### Decoder

In [27]:
import torch
import torch.nn as nn

class Transformer_Decoder(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, vocab_size, max_seq_len=1000):
        super(Transformer_Decoder, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(vocab_size, d_model)

        self.pe = self.positional_encoding(max_seq_len).to(device='cuda')

        self.self_attentions = nn.ModuleList([
            Multihead_Attention_Masked(d_model, num_heads) for _ in range(num_layers)
        ])

        self.cross_attentions = nn.ModuleList([
            Multihead_Attention_Masked(d_model, num_heads) for _ in range(num_layers)
        ])

        self.mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, 4 * d_model),
                nn.ReLU(),
                nn.Linear(4 * d_model, d_model)
            ) for _ in range(num_layers)
        ])

        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(3 * num_layers) 
        ])

    def forward(self, target_input, encoder_output, mask=None, encoder_mask=None):

        seq_len = target_input.size(1)
        embedded = self.embedding(target_input) + self.pe[:seq_len, :]

        decoder_output = embedded

        for i in range(self.num_layers):
            causal_mask = self.generate_causal_mask(decoder_output.size(0), seq_len, decoder_output.device)
            self_attn, _ = self.self_attentions[i](decoder_output, decoder_output, decoder_output, mask=causal_mask)
            decoder_output = self.layer_norms[i * 3](decoder_output + self_attn)

            cross_attn, _ = self.cross_attentions[i](decoder_output, encoder_output, encoder_output, mask=encoder_mask)
            decoder_output = self.layer_norms[i * 3 + 1](decoder_output + cross_attn)

            mlp_output = self.mlps[i](decoder_output)
            decoder_output = self.layer_norms[i * 3 + 2](decoder_output + mlp_output)

        return decoder_output  # (batch_size, target_seq_len, d_model)

    def generate_causal_mask(self, batch_size, seq_len, device):
        # Generates a lower triangular mask (causal)
        return torch.tril(torch.ones(seq_len, seq_len, device=device)).unsqueeze(0).repeat(batch_size, 1, 1)

    def positional_encoding(self, max_seq_len):
        position = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) * -(torch.log(torch.tensor(10000.0)) / self.d_model))
        pe = torch.zeros(max_seq_len, self.d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe


### Combine

In [11]:
class TransformerEncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, d_model, vocab_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.output_projection = nn.Linear(d_model, vocab_size)

    def forward(self, src_input, tgt_input, src_mask=None, tgt_mask=None):

        encoder_output = self.encoder(src_input)

        decoder_output = self.decoder(tgt_input, encoder_output, mask=tgt_mask, encoder_mask=src_mask) 

        logits = self.output_projection(decoder_output) 
        return logits


In [12]:
# Create encoder and decoder (reuse your naive ones)
encoder = Transformer_Encoder(N=6, d_model=512, n_head=8, vocab_size=1000)
decoder = Transformer_Decoder(d_model=512, num_heads=8, num_layers=6, vocab_size=1000)

# Create wrapper
transformer = TransformerEncoderDecoder(encoder, decoder, d_model=512, vocab_size=1000)

# Dummy input tokens
src_input = torch.randint(0, 1000, (4, 10))  # (batch_size, src_seq_len)
tgt_input = torch.randint(0, 1000, (4, 6))   # (batch_size, tgt_seq_len)

# Causal mask for decoder (optional but recommended)
def generate_causal_mask(batch_size, seq_len, device):
    return torch.tril(torch.ones(seq_len, seq_len, device=device)).unsqueeze(0).repeat(batch_size, 1, 1)

tgt_mask = generate_causal_mask(batch_size=4, seq_len=6, device=src_input.device)

# Forward pass
logits = transformer(src_input, tgt_input, src_mask=None, tgt_mask=tgt_mask)
print("Logits shape:", logits.shape)  # Expected: (batch_size, tgt_seq_len, vocab_size)


Logits shape: torch.Size([4, 6, 1000])


## Training

In [28]:
# Load source (English) and target (Pig Latin) words
with open("data/small_src.txt", "r") as f:
    src_sentences = f.read().strip().split("\n")

with open("data/small_tgt.txt", "r") as f:
    tgt_sentences = f.read().strip().split("\n")

# Basic check
print(f"Loaded {len(src_sentences)} samples")
print(f"Example:\nSRC: {src_sentences[0]}\nTGT: {tgt_sentences[0]}")

from collections import Counter

# Get all characters from source and target
all_chars = set("".join(src_sentences + tgt_sentences))
char2idx = {ch: idx + 4 for idx, ch in enumerate(sorted(all_chars))}
char2idx["<pad>"] = 0
char2idx["<sos>"] = 1
char2idx["<eos>"] = 2
char2idx["<unk>"] = 3

idx2char = {idx: ch for ch, idx in char2idx.items()}

vocab_size = len(char2idx)
print(f"Vocabulary size: {vocab_size}")

def tokenize(sentence, char2idx, add_sos_eos=True):
    tokens = [char2idx.get(c, char2idx["<unk>"]) for c in sentence]
    if add_sos_eos:
        return [char2idx["<sos>"]] + tokens + [char2idx["<eos>"]]
    else:
        return tokens

# Example:
src_indices = tokenize(src_sentences[0], char2idx)
tgt_indices = tokenize(tgt_sentences[0], char2idx)

print("SRC indices:", src_indices)
print("TGT indices:", tgt_indices)



from torch.utils.data import Dataset, DataLoader

class PigLatinDataset(Dataset):
    def __init__(self, src_sentences, tgt_sentences, char2idx):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.char2idx = char2idx

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src = tokenize(self.src_sentences[idx], self.char2idx)
        tgt = tokenize(self.tgt_sentences[idx], self.char2idx)
        return torch.tensor(src), torch.tensor(tgt)

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_lens = [len(seq) for seq in src_batch]
    tgt_lens = [len(seq) for seq in tgt_batch]
    src_padded = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=char2idx["<pad>"])
    tgt_padded = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=char2idx["<pad>"])
    return src_padded, tgt_padded

dataset = PigLatinDataset(src_sentences, tgt_sentences, char2idx)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)


import torch.optim as optim

# Initialize your transformer model
encoder = Transformer_Encoder(N=6, d_model=512, n_head=8, vocab_size=vocab_size)
decoder = Transformer_Decoder(d_model=512, num_heads=8, num_layers=6, vocab_size=vocab_size)
model = TransformerEncoderDecoder(encoder, decoder, d_model=512, vocab_size=vocab_size).to('cuda' if torch.cuda.is_available() else 'cpu')

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=char2idx["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
device = next(model.parameters()).device
for epoch in range(10):
    model.train()
    total_loss = 0
    for src_batch, tgt_batch in dataloader:
        src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
        tgt_input = tgt_batch[:, :-1]
        tgt_target = tgt_batch[:, 1:]

        # Causal mask for decoder
        batch_size, tgt_seq_len = tgt_input.shape
        tgt_mask = torch.tril(torch.ones(tgt_seq_len, tgt_seq_len, device=device)).unsqueeze(0).repeat(batch_size, 1, 1)

        # Forward pass
        logits = model(src_batch, tgt_input, src_mask=None, tgt_mask=tgt_mask)
        logits = logits.view(-1, vocab_size)
        tgt_target = tgt_target.reshape(-1)

        loss = criterion(logits, tgt_target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}: Loss = {total_loss / len(dataloader):.4f}")



Loaded 3255 samples
Example:
SRC: apprehensive
TGT: apprehensiveway
Vocabulary size: 42
SRC indices: [1, 16, 31, 31, 33, 20, 23, 20, 29, 34, 24, 37, 20, 2]
TGT indices: [1, 16, 31, 31, 33, 20, 23, 20, 29, 34, 24, 37, 20, 38, 16, 40, 2]
Epoch 1: Loss = 1.3842
Epoch 2: Loss = 0.3071
Epoch 3: Loss = 0.1006
Epoch 4: Loss = 0.0623
Epoch 5: Loss = 0.0429
Epoch 6: Loss = 0.0308
Epoch 7: Loss = 0.0427
Epoch 8: Loss = 0.0363
Epoch 9: Loss = 0.0230
Epoch 10: Loss = 0.0292


In [31]:
def greedy_decode(model, src_sentence, char2idx, idx2char, max_len=20, device='cpu'):
    model.eval()

    # Tokenize source sentence
    src_tokens = [char2idx.get(c, char2idx["<unk>"]) for c in src_sentence]
    src_tokens = [char2idx["<sos>"]] + src_tokens + [char2idx["<eos>"]]
    src_tensor = torch.tensor(src_tokens).unsqueeze(0).to(device)  # (1, src_seq_len)

    # Encoder output
    encoder_output = model.encoder(src_tensor)

    # Start decoding with <sos>
    tgt_tokens = [char2idx["<sos>"]]
    for _ in range(max_len):
        tgt_input = torch.tensor(tgt_tokens).unsqueeze(0).to(device)  # (1, current_tgt_len)

        # Causal mask for decoder
        tgt_seq_len = tgt_input.shape[1]
        tgt_mask = torch.tril(torch.ones(tgt_seq_len, tgt_seq_len, device=device)).unsqueeze(0)

        # Decoder output
        decoder_output = model.decoder(tgt_input, encoder_output, mask=tgt_mask)
        logits = model.output_projection(decoder_output)  # (1, tgt_seq_len, vocab_size)

        # Get the last time step prediction
        next_token_logits = logits[0, -1, :]  # (vocab_size,)
        next_token = torch.argmax(next_token_logits).item()

        # If <eos>, stop decoding
        if next_token == char2idx["<eos>"]:
            break

        # Append next token
        tgt_tokens.append(next_token)

    # Convert token IDs back to string
    output_chars = [idx2char[token] for token in tgt_tokens[1:]]  # Exclude <sos>
    output_word = "".join(output_chars)
    return output_word


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Example English word
src_sentence = "chimney-board"

# Translate to Pig Latin
piglatin_translation = greedy_decode(model, src_sentence, char2idx, idx2char, device=device)

print(f"English: {src_sentence}")
print(f"Pig Latin: {piglatin_translation}")


English: chimney-board
Pig Latin: imneychay-oardray
