<a href="https://colab.research.google.com/github/Calcifer777/learn-nlp/blob/main/custom_nmt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [235]:
%matplotlib inline

In [236]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data

## Dataset

In [237]:
%%bash

test -d data || mkdir data 
wget -q https://raw.githubusercontent.com/L1aoXingyu/seq2seq-translation/master/data/eng-fra.txt -O data/eng-fra.txt

In [238]:
import unicodedata
import re

from torch.utils.data import Dataset, DataLoader

In [239]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [240]:
class EnFrDataset(Dataset):

  def __init__(self, path: str):
    with open(path, "r", encoding="utf-8") as fp:
      self.lines = [
          [normalizeString(s.strip()) for s in l.split("\t")] 
          for l in fp.readlines()
          if len(l.split("\t")) == 2
      ]
    self.lines = [l for l in self.lines if len(l[0]) < 50 and len(l[1]) < 50]
    self.fields = ["src", "tgt"]

  def __len__(self):
    return len(self.lines)

  def __getitem__(self, idx):
    return self.lines[idx]


In [241]:
ds = EnFrDataset(path="data/eng-fra.txt")

In [242]:
ds[50003]

['she had a new dress made .', 'elle s est fait faire une nouvelle robe .']

## Data Loader

In [243]:
!pip install -q transformers

In [244]:
from transformers import DistilBertTokenizer 
from functools import partial

In [245]:
BATCH_SIZE = 8

In [246]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [247]:
def collate_fn(samples):
  samples = list(map(list, zip(*samples)))
  tokenize = partial(tokenizer.batch_encode_plus, return_tensors="pt", padding=True)
  return dict(
      src_tokens=tokenize(samples[0])["input_ids"],
      tgt_tokens=tokenize(samples[1])["input_ids"]
  )

In [248]:
dl = DataLoader(
    dataset=ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)

In [249]:
sample_batch = next(iter(dl))

## Models

In [250]:
import random

from torch import nn
from torch.nn import functional as F

In [251]:
HIDDEN_SIZE = 128

In [252]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

    def forward(self, input, hidden=None):
        hidden = self.init_hidden(input.shape[0]) if hidden is None else hidden
        o, h = self.gru(self.embedding(input), hidden)
        return {"output": o, "hidden": h}
  
    def init_hidden(self, batch_size):
        return torch.zeros(1, batch_size, self.hidden_size)

In [253]:
encoder = Encoder(input_size=tokenizer.vocab_size, hidden_size=HIDDEN_SIZE)

In [254]:
enc_out = encoder(
    input=sample_batch["src_tokens"], 
    # hidden=encoder.initHidden(batch_size=BATCH_SIZE)
)
print(f"{enc_out['output'].shape = }")
print(f"{enc_out['hidden'].shape = }")

enc_out['output'].shape = torch.Size([8, 11, 128])
enc_out['hidden'].shape = torch.Size([1, 8, 128])


In [255]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size

        self.mha = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=1, batch_first=True)
        self.attn_combine = nn.Linear(hidden_size*2, hidden_size)
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden, enc_output):
        
        # embeddings: batch, 1, H
        embeddings = self.embedding(input)

        # attn_scores: batch, 1, H
        # attn_weights: batch, 1, T_src
        attn_scores, attn_weights = self.mha(hidden.permute(1, 0, 2), enc_output, enc_output)

        # Combine attn outputs with decoder input via a FF layer
        # rnn_input: batch, 1, H
        rnn_input = self.attn_combine(torch.cat([embeddings, attn_scores], dim=-1))
        rnn_input = F.relu(rnn_input)

        # Feed the rnn input and the previous hidden state to the RNN layer
        # rnn_o: batch, 1, H
        # rnn_h: 1, batch, H
        rnn_o, rnn_h = self.gru(rnn_input, hidden)

        # out: batch, 1, vocab_tgt_size
        out = self.out(rnn_o)
        out = F.log_softmax(out, dim=-1)
        
        return {
            "output": out,
            "hidden": rnn_h,
            "attention_weights": attn_weights,
        }

In [256]:
decoder = Decoder(
    input_size=tokenizer.vocab_size,
    hidden_size=HIDDEN_SIZE,
    output_size=tokenizer.vocab_size,
)
self = decoder

In [257]:
decoder_out = decoder(
    input=sample_batch["tgt_tokens"][:, :1],
    hidden=torch.rand(1, BATCH_SIZE, HIDDEN_SIZE),
    enc_output=enc_out['output'],
)
print(f"{decoder_out['output'].shape = }")
print(f"{decoder_out['hidden'].shape = }")
print(f"{decoder_out['attention_weights'].shape = }")

decoder_out['output'].shape = torch.Size([8, 1, 30522])
decoder_out['hidden'].shape = torch.Size([1, 8, 128])
decoder_out['attention_weights'].shape = torch.Size([8, 1, 11])


In [378]:
class Seq2Seq(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, teacher_forcing_ratio=0.5, max_length=50, cls_token=100):
    super(Seq2Seq, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size

    self.encoder = Encoder(input_size=input_size, hidden_size=hidden_size)
    self.decoder = Decoder(input_size=input_size, hidden_size=hidden_size, output_size=output_size, )

    self.teacher_forcing_ratio = teacher_forcing_ratio
    self.max_length = max_length
    self.cls_token = cls_token

  def forward(self, input_enc, input_dec):
    enc_out = self.encoder(input_enc)

    use_teacher_forcing = False # random.rand() < self.teacher_forcing_ratio

    outputs = []
    attention_weights = []
    
    batch_size = input_enc.shape[0]
    dec_in = torch.ones((batch_size, 1)).long() * self.cls_token
    dec_h = enc_out["hidden"]
    
    target_length = input_dec.shape[1]
    for idx in range(target_length):

      dec_out = self.decoder(
        input=dec_in,
        hidden=dec_h,
        enc_output=enc_out["output"],
      )
      dec_h = dec_out["hidden"]
      outputs.append(dec_out["output"])
      attention_weights.append(dec_out["attention_weights"])

      if use_teacher_forcing is False:
        # Use decoder prediction as next output
        _, top_idx = dec_out["output"].topk(1)
        dec_in = top_idx.squeeze(1).detach()
      elif use_teacher_forcing is True and (idx < target_length + 1):
        # Use target values as next output
        # Skip this step if last iteration
        dec_in = input_dec[:, idx].unsqueeze(1)

    return {
        # batch, T_tgt, vocab_size_tgt
        "output": torch.cat(outputs, dim=1),
        # batch, T_src, T_tgt
        "attention_weights": torch.cat(attention_weights, dim=1),
    }
    


In [379]:
seq2seq = Seq2Seq(
     input_size=tokenizer.vocab_size,
     hidden_size=HIDDEN_SIZE,
     output_size=tokenizer.vocab_size,
)
self = seq2seq

In [380]:
input_enc = sample_batch["src_tokens"]
input_dec = sample_batch["tgt_tokens"]

enc_out, enc_hidden = self.encoder(input_enc)

seq2seq_out = seq2seq(
    input_enc=input_enc,
    input_dec=input_dec,
)

In [381]:
seq2seq_out["output"].shape

torch.Size([8, 17, 30522])