In [9]:
import sentencepiece as spm
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim


sp_ur = spm.SentencePieceProcessor()
sp_ur.load("/content/urdu_tokenizer.model")

sp_en = spm.SentencePieceProcessor()
sp_en.load("/content/roman_tokenizer.model")

PAD_ID = sp_en.pad_id()

class TranslationDataset(Dataset):
    def __init__(self, src_file, trg_file, src_sp, trg_sp, max_len=25):
        with open(src_file, encoding="utf-8") as f:
            self.src_sentences = [line.strip() for line in f]
        with open(trg_file, encoding="utf-8") as f:
            self.trg_sentences = [line.strip() for line in f]

        assert len(self.src_sentences) == len(self.trg_sentences)

        self.src_sp = src_sp
        self.trg_sp = trg_sp
        self.max_len = max_len

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src = self.src_sp.encode(self.src_sentences[idx], out_type=int)
        trg = self.trg_sp.encode(self.trg_sentences[idx], out_type=int)

        src = [self.src_sp.bos_id()] + src + [self.src_sp.eos_id()]
        trg = [self.trg_sp.bos_id()] + trg + [self.trg_sp.eos_id()]

        src = src[:self.max_len] + [self.src_sp.pad_id()] * (self.max_len - len(src))
        trg = trg[:self.max_len] + [self.trg_sp.pad_id()] * (self.max_len - len(trg))

        return torch.tensor(src), torch.tensor(trg)


full_dataset = TranslationDataset(
    "/content/all_urdu_no_empty_lines.txt",
    "/content/all_english_clean_no_empty_lines.txt",
    sp_ur, sp_en
)


total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size   = int(0.1 * total_size)
test_size  = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)



class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1, pad_idx=0, dropout=0.1):
        super().__init__()

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim,
            padding_idx=pad_idx
        )

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )

    def forward(self, src):
        """
        src: [batch_size, src_len]  (token IDs)
        """

        embedded = self.embedding(src)
        outputs, (hidden, cell) = self.lstm(embedded)
        return outputs, (hidden, cell)






class DecoderLSTM(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        hidden_dim,
        num_layers=1,
        pad_idx=0,
        dropout=0.1,
        encoder_bidirectional=False
    ):
        """
        vocab_size: target vocab size (roman vocab)
        embed_dim: embedding dim for target tokens
        hidden_dim: decoder hidden size (per layer)
        num_layers: stacked LSTM layers in decoder
        pad_idx: id used for padding (for embedding)
        encoder_bidirectional: True if encoder was bidirectional (so we must reduce states)
        """
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.encoder_bidirectional = encoder_bidirectional

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )

        self.fc_out = nn.Linear(hidden_dim, vocab_size)

        if self.encoder_bidirectional:
            self.reduce_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
            self.reduce_cell = nn.Linear(hidden_dim * 2, hidden_dim)
        else:
            self.reduce_hidden = None
            self.reduce_cell = None

    def init_hidden_from_encoder(self, enc_hidden, enc_cell):
        """
    Convert encoder hidden/cell -> decoder hidden/cell.
    Works even if encoder and decoder have different layer counts.
    """
        num_enc_layers_times_dirs, batch, hidden_enc = enc_hidden.size()
        num_dirs = 2 if self.encoder_bidirectional else 1
        num_enc_layers = num_enc_layers_times_dirs // num_dirs

        enc_hidden = enc_hidden.view(num_enc_layers, num_dirs, batch, hidden_enc)
        enc_cell   = enc_cell.view(num_enc_layers, num_dirs, batch, hidden_enc)

        if self.encoder_bidirectional:
            hidden_cat = torch.cat([enc_hidden[:, 0], enc_hidden[:, 1]], dim=2)
            cell_cat   = torch.cat([enc_cell[:, 0], enc_cell[:, 1]], dim=2)
        else:
            hidden_cat = enc_hidden.squeeze(1)
            cell_cat   = enc_cell.squeeze(1)

        hidden_proj = torch.tanh(self.reduce_hidden(hidden_cat))
        cell_proj   = torch.tanh(self.reduce_cell(cell_cat))

        if self.num_layers > num_enc_layers:
            pad_layers = self.num_layers - num_enc_layers
            pad_hidden = torch.zeros(pad_layers, batch, self.hidden_dim, device=hidden_proj.device)
            pad_cell   = torch.zeros(pad_layers, batch, self.hidden_dim, device=cell_proj.device)
            hidden_proj = torch.cat([hidden_proj, pad_hidden], dim=0)
            cell_proj   = torch.cat([cell_proj, pad_cell], dim=0)
        elif self.num_layers < num_enc_layers:
            hidden_proj = hidden_proj[-self.num_layers:]
            cell_proj   = cell_proj[-self.num_layers:]

        return hidden_proj, cell_proj

    def forward_step(self, input_tokens, hidden, cell):
        """
        Run the decoder for one time-step.

        input_tokens: [batch] or [batch, 1] tensor of token ids (usually the previous token)
        hidden: [num_layers, batch, hidden_dim]
        cell:   [num_layers, batch, hidden_dim]

        Returns:
          prediction: [batch, vocab_size]  (logits for next token)
          hidden, cell: updated states
        """
        if input_tokens.dim() == 1:
            input_tokens = input_tokens.unsqueeze(1)

        embedded = self.embedding(input_tokens)
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))

        logits = self.fc_out(output.squeeze(1))
        return logits, hidden, cell

    def forward(self, trg, hidden, cell, teacher_forcing_ratio=1.0):
        """
        Decode a full target sequence (used in training with teacher forcing).
        trg: [batch, trg_len]  (token ids; trg[:,0] should be <sos>)
        hidden, cell: initial decoder states
        teacher_forcing_ratio: float in [0,1], prob of using ground-truth token at each step

        Returns:
          outputs: [batch, trg_len, vocab_size] (logits for each step, first column usually unused)
          hidden, cell: final states
        """
        batch_size, trg_len = trg.size()
        device = trg.device
        vocab_size = self.vocab_size

        outputs = torch.zeros(batch_size, trg_len, vocab_size, device=device)

        input_tok = trg[:, 0]

        for t in range(1, trg_len):
            logits, hidden, cell = self.forward_step(input_tok, hidden, cell)
            outputs[:, t, :] = logits

            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            next_input = trg[:, t] if teacher_force else logits.argmax(1)
            input_tok = next_input

        return outputs, hidden, cell

    def generate(self, hidden, cell, sos_id, eos_id, max_len=50):
        """
        Greedy generation (inference).

        Returns:
          generated: [batch, gen_len] token ids (without initial <sos>)
        """
        batch_size = hidden.size(1)
        device = hidden.device

        input_tok = torch.full((batch_size,), sos_id, dtype=torch.long, device=device)
        generated = []

        for _ in range(max_len):
            logits, hidden, cell = self.forward_step(input_tok, hidden, cell)
            next_tok = logits.argmax(1)
            generated.append(next_tok.unsqueeze(1))
            input_tok = next_tok
            if (next_tok == eos_id).all():
                break

        if len(generated) == 0:
            return torch.empty((batch_size, 0), dtype=torch.long, device=device)
        return torch.cat(generated, dim=1)



class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        """
        src: [batch, src_len]   (input Urdu token IDs)
        trg: [batch, trg_len]   (target Roman token IDs)
        teacher_forcing_ratio: probability of using ground-truth token instead of model prediction

        Returns:
          outputs: [batch, trg_len, vocab_size] (decoder logits for each target position)
        """
        enc_outputs, (enc_hidden, enc_cell) = self.encoder(src)

        dec_hidden, dec_cell = self.decoder.init_hidden_from_encoder(enc_hidden, enc_cell)

        outputs, _, _ = self.decoder(trg, dec_hidden, dec_cell, teacher_forcing_ratio=teacher_forcing_ratio)

        return outputs

    def translate(self, src, sos_id, eos_id, max_len=50):
        """
        Greedy decoding for inference (no teacher forcing).

        src: [batch, src_len]
        sos_id: ID of <sos> in target vocab
        eos_id: ID of <eos> in target vocab
        max_len: max output length

        Returns:
          generated: [batch, gen_len] token IDs
        """
        enc_outputs, (enc_hidden, enc_cell) = self.encoder(src)

        dec_hidden, dec_cell = self.decoder.init_hidden_from_encoder(enc_hidden, enc_cell)

        generated = self.decoder.generate(dec_hidden, dec_cell, sos_id, eos_id, max_len=max_len)

        return generated






# Hyperparameters
print(sp_ur.get_piece_size())
print(sp_en.get_piece_size())
INPUT_DIM  = sp_ur.get_piece_size()
OUTPUT_DIM = sp_en.get_piece_size()
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HIDDEN_DIM = 512
ENC_LAYERS = 2
DEC_LAYERS = 4
ENC_DROPOUT = 0.3
DEC_DROPOUT = 0.3
LEARNING_RATE = 5e-4
N_EPOCHS = 100
CLIP = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(
    vocab_size=INPUT_DIM,
    embed_dim=ENC_EMB_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=ENC_LAYERS,
    dropout=ENC_DROPOUT
)


decoder = DecoderLSTM(
    vocab_size=OUTPUT_DIM,
    embed_dim=DEC_EMB_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=DEC_LAYERS,
    dropout=DEC_DROPOUT,
    encoder_bidirectional=True
)


model = Seq2Seq(encoder, decoder, DEVICE).to(DEVICE)

print(model)
print(PAD_ID)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)


def train_one_epoch(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0

    for src, trg in dataloader:
        src, trg = src.to(DEVICE), trg.to(DEVICE)

        optimizer.zero_grad()

        output = model(src, trg)


        output_dim = output.shape[-1]

        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)


def evaluate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for src, trg in dataloader:
            src, trg = src.to(DEVICE), trg.to(DEVICE)

            output = model(src, trg, teacher_forcing_ratio=0)

            output_dim = output.shape[-1]
            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(dataloader)


def generate_example_translation(model, dataloader, src_sp, trg_sp, device):
    """Generate an example translation from the validation set"""
    model.eval()
    with torch.no_grad():
        src, trg = next(iter(dataloader))
        src, trg = src[0:1].to(device), trg[0:1].to(device)

        sos_id = trg_sp.bos_id()
        eos_id = trg_sp.eos_id()
        generated = model.translate(src, sos_id, eos_id)

        src_tokens = src[0].cpu().tolist()
        trg_tokens = trg[0].cpu().tolist()
        gen_tokens = generated[0].cpu().tolist()

        src_text = src_sp.decode([t for t in src_tokens if t not in [src_sp.pad_id(), src_sp.bos_id(), src_sp.eos_id()]])
        trg_text = trg_sp.decode([t for t in trg_tokens if t not in [trg_sp.pad_id(), trg_sp.bos_id(), trg_sp.eos_id()]])
        gen_text = trg_sp.decode([t for t in gen_tokens if t not in [trg_sp.pad_id(), trg_sp.bos_id(), trg_sp.eos_id()]])

        return src_text, trg_text, gen_text

best_valid_loss = float("inf")

for epoch in range(N_EPOCHS):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_loader, criterion)

    src_text, trg_text, gen_text = generate_example_translation(model, val_loader, sp_ur, sp_en, DEVICE)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "seq2seq_best.pt")

    print(f"Epoch: {epoch+1:02}")
    print(f"\tTrain Loss: {train_loss:.3f}")
    print(f"\t Val. Loss: {valid_loss:.3f}")
    print(f"\t Example:")
    print(f"\t   Urdu:    {src_text}")
    print(f"\t   Target:  {trg_text}")
    print(f"\t   Predicted: {gen_text}")
    print("-" * 60)



model.load_state_dict(torch.load("seq2seq_best.pt"))

test_loss = evaluate(model, test_loader, criterion)
print(f"Test Loss: {test_loss:.3f}")



2000
2000
Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(2000, 256, padding_idx=0)
    (lstm): LSTM(256, 512, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
  )
  (decoder): DecoderLSTM(
    (embedding): Embedding(2000, 256, padding_idx=0)
    (lstm): LSTM(256, 512, num_layers=4, batch_first=True, dropout=0.3)
    (fc_out): Linear(in_features=512, out_features=2000, bias=True)
    (reduce_hidden): Linear(in_features=1024, out_features=512, bias=True)
    (reduce_cell): Linear(in_features=1024, out_features=512, bias=True)
  )
)
0
Epoch: 01
	Train Loss: 5.839
	 Val. Loss: 5.760
	 Example:
	   Urdu:    غرور زہد نے سکھلا دیا ہے واعظ کو
	   Target:  ghurur-e-zohd ne sikhla diya hai vaaiz ko
	   Predicted: ye------------------
------------------------------------------------------------
Epoch: 02
	Train Loss: 5.526
	 Val. Loss: 5.658
	 Example:
	   Urdu:    غرور زہد نے سکھلا دیا ہے واعظ کو
	   Target:  ghurur-e-zohd ne sikhla diya hai vaaiz ko
	   Predicted: va 