In [9]:
source = "I am a boy."
target = "I ch bin ein Junge."

# -----------
#Encoder ["I am a boy."] ->
# Next Token Prediction
# 1. step Decoder [h, <bos>] -> "ich"
# 2. step Decoder [h, (<bos>, "ich")] -> "bin"
# 3. step Decoder [h, (<bos>, "ich", "bin")] -> "ein"
# 4. step Decoder [h, (<bos>, "ich", "bin", "ein")] -> "Junge"
# 5. step Decoder [h, (<bos>, "ich", "bin", "ein", "Junge")] -> "<eos>"

# X = (h, (<bos>, "ich", "bin", "ein", "Junge" )) -> Input for Decoder
# y = ("ich", "bin", "ein", "Junge", <eos>) -> target/labels for loss

In [10]:
import re, json, torch, torch.nn as nn
from torch.utils.data import DataLoader
 
path = "./deu.txt"
 
lines = open(path, encoding="utf-8").read().strip().split("\n")
lines = lines[:20000]
 
pairs = [ln.split("\t")[:2] for ln in lines] 
src_texts, tgt_texts = zip(*pairs)

In [11]:
PAD, UNK, BOS, EOS = 0, 1, 2, 3 # special tokens
# PAD = Padding, UNK = Unknown
# BOS, EOS

VOCAB_SIZE = 20004

def tokenize(s): return re.findall(r"\b\w+\b", s.lower())
def build_vocab(texts, max_tokens=VOCAB_SIZE):
    from collections import Counter
    freq = Counter(tok for t in texts for tok in tokenize(t))
    itos = ["<pad>", "<unk>", "<bos>", "<eos>"] + [w for w,_ in freq.most_common(max_tokens-4)]
    return {w:i for i,w in enumerate(itos)}, itos
src_texts_vocab, src_itos = build_vocab(src_texts)
tgt_texts_vocab, tgt_itos = build_vocab(tgt_texts)


def vectorize(text, stoi, max_len, add_bos_eos=False):
    ids = [stoi.get(tok, UNK) for tok in tokenize(text)]
    if add_bos_eos: ids = [BOS] + ids + [EOS]
    ids = ids[:max_len]
    if len(ids) < max_len: ids += [PAD]*(max_len-len(ids))
    return ids

vectorize(src_texts[60], src_texts_vocab, 30)


max_src, max_tgt = 30, 30  
 
def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src = torch.tensor([vectorize(t, src_texts_vocab, max_src) for t in src_batch])
    tgt = torch.tensor([vectorize(t, tgt_texts_vocab, max_tgt, add_bos_eos=True) for t in tgt_batch])
    tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
    return src, tgt_in, tgt_out
 
dataset = list(zip(src_texts, tgt_texts))
loader = DataLoader(dataset, batch_size= 64, shuffle=True, collate_fn=collate_fn)

In [12]:
len(src_texts_vocab) #number of 
len(tgt_texts_vocab) #number of tokens

5594

In [13]:
sentence = "this is sample sentence of embedding"
setence2 = "this is sentence embedding" 


dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
dc

{'embedding': 0, 'is': 1, 'of': 2, 'sample': 3, 'sentence': 4, 'this': 5}

In [14]:
vocab_size_tmp = len(dc)
emb = torch.nn.Embedding(vocab_size_tmp, 3)
emb.weight.data

tensor([[ 0.8064, -2.6200, -1.1094],
        [ 0.6000,  0.7072,  1.1946],
        [-0.6399, -1.8828,  1.0682],
        [-0.3114, -0.7501,  0.4061],
        [ 0.0199,  0.5681, -0.1465],
        [ 0.6575,  0.0779,  1.1704]])

In [17]:
emb_dim = 128
hid_dim = 256

class Encoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD)
        self.add_module("rnn", nn.GRU(emb_dim,hid_dim,batch_first=True))
    def forward(self, src):
        x = self.embedding(src)
        _, hidden = self.rnn(x)
        return hidden
class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD)
        self.add_module("rnn", nn.GRU(emb_dim,hid_dim,batch_first=True))
        self.fc = nn.Linear(hid_dim, vocab_size)
    def forward(self, x, h):
        x = self.embedding(x)
        out, _= self.rnn(x,h)
        return self.fc(out)
        
class Seq25Seq(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec
    
    def forward(self, src, tgt_in_dec):
        # src ... source (english sentences)
        # tgt_in_dec ... actual german sentences that are also input to the decoder
        hidden_enc = self.enc(src)
        logits = self.dec(tgt_in_dec, hidden_enc)
        return logits
    
device = "mps"

model = Seq25Seq(
    Encoder(len(src_texts_vocab)),
    Decoder(len(tgt_texts_vocab))
).to(device)
model


Seq25Seq(
  (enc): Encoder(
    (embedding): Embedding(3455, 128, padding_idx=0)
    (rnn): GRU(128, 256, batch_first=True)
  )
  (dec): Decoder(
    (embedding): Embedding(5594, 128, padding_idx=0)
    (rnn): GRU(128, 256, batch_first=True)
    (fc): Linear(in_features=256, out_features=5594, bias=True)
  )
)

In [18]:
crit = nn.CrossEntropyLoss(ignore_index=PAD) #padding is an artificial token and it keeps the sequence long
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)
epochs = 20

@torch.no_grad()
def translate(prompt, max_len=max_tgt):
    model.eval()
    src = torch.tensor([vectorize(prompt, src_texts_vocab, max_src)], device=device)
    h = model.enc(src)
    ys = torch.tensor([[BOS]], device=device)
    out_tokens = []
    for _ in range(max_len):
        logits = model.dec(ys, h)
        next_id = logits[0, -1].argmax().item()
        if next_id in (EOS, PAD): break
        out_tokens.append(next_id)
        ys = torch.cat([ys, torch.tensor([[next_id]], device=device)], dim=1)
    return " ".join(tgt_itos[t] for t in out_tokens)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for src, tgt_in, tgt_out in loader:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device) 
        logits = model(src,tgt_in)
        loss = crit(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        running_loss += loss.item()
    print(f"epoch {epoch+1}: loss {running_loss/len(loader):.4f}")
    print(translate("I will do my best."))


epoch 1: loss 4.4135
ich bin ein buch
epoch 2: loss 3.2131
ich habe das gesagt
epoch 3: loss 2.6003
ich habe mich geirrt
epoch 4: loss 2.1796
ich habe mein buch
epoch 5: loss 1.8581
ich werde mein auto holen
epoch 6: loss 1.5962
ich werde mein auto holen
epoch 7: loss 1.3697
ich werde mein bestes
epoch 8: loss 1.1729
ich werde mein bestes
epoch 9: loss 1.0044
ich werde mein bestes
epoch 10: loss 0.8599
ich werde mein bestes
epoch 11: loss 0.7351
ich werde mein bestes
epoch 12: loss 0.6337
ich bin mein eigener chef
epoch 13: loss 0.5480
ich werde mein bestes
epoch 14: loss 0.4832
ich werde mein bestes
epoch 15: loss 0.4295
ich werde mein bestes geben
epoch 16: loss 0.3865
ich werde mein bestes
epoch 17: loss 0.3520
ich werde mein bestes geben
epoch 18: loss 0.3272
ich bin mein eigener chef
epoch 19: loss 0.3077
ich bin wie in ordnung
epoch 20: loss 0.2923
ich werde mein bestes geben
