In [1]:
import random
import os
import re
import unicodedata
import zipfile
 
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tokenizers
import tqdm

In [2]:
import requests

url = "https://www.manythings.org/anki/pol-eng.zip"
headers = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
                  "AppleWebKit/537.36 (KHTML, like Gecko) "
                  "Chrome/126.0 Safari/537.36"
}

r = requests.get(url, headers=headers)
r.raise_for_status()

with open("pol-eng.zip", "wb") as f:
    f.write(r.content)


In [3]:

import os
import unicodedata
import zipfile
import requests

url = "https://www.manythings.org/anki/pol-eng.zip"
headers = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
                  "AppleWebKit/537.36 (KHTML, like Gecko) "
                  "Chrome/126.0 Safari/537.36"
}

r = requests.get(url, headers=headers)
r.raise_for_status()

with open("pol-eng.zip", "wb") as f:
    f.write(r.content)

In [4]:
# Normalize text
# each line of the file is in the format "<english>\t<french>"
# We convert text to lowercase, normalize unicode (UFKC)
def normalize(line):
    """Normalize a line of text and split into two at the tab character"""
    line = unicodedata.normalize("NFKC", line.strip().lower())
    parts = line.split("\t")
    if len(parts) < 2:
        return None  # pomiń niepoprawne linie

    # niektóre linie mają więcej niż jedno tłumaczenie – weź tylko pierwsze dwa
    eng, pl = parts[0], parts[1]

    return eng.strip(), pl.strip()

text_pairs = []
with zipfile.ZipFile("pol-eng.zip", "r") as zip_ref:
    for line in zip_ref.read("pol.txt").decode("utf-8").splitlines():
        eng, pol = normalize(line)
        text_pairs.append((eng, pol))

In [5]:
import tokenizers
 
if os.path.exists("en_tokenizer.json") and os.path.exists("pl_tokenizer.json"):
    en_tokenizer = tokenizers.Tokenizer.from_file("en_tokenizer.json")
    pl_tokenizer = tokenizers.Tokenizer.from_file("pl_tokenizer.json")
else:
    en_tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE())
    pl_tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE())
 
    # Configure pre-tokenizer to split on whitespace and punctuation, add space at beginning of the sentence
    en_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True)
    pl_tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.ByteLevel(add_prefix_space=True)
 
    # Configure decoder: So that word boundary symbol "Ġ" will be removed
    en_tokenizer.decoder = tokenizers.decoders.ByteLevel()
    pl_tokenizer.decoder = tokenizers.decoders.ByteLevel()
 
    # Train BPE for English and French using the same trainer
    VOCAB_SIZE = 8000
    trainer = tokenizers.trainers.BpeTrainer(
        vocab_size=VOCAB_SIZE,
        special_tokens=["[start]", "[end]", "[pad]"],
        show_progress=True
    )
    en_tokenizer.train_from_iterator([x[0] for x in text_pairs], trainer=trainer)
    pl_tokenizer.train_from_iterator([x[1] for x in text_pairs], trainer=trainer)
 
    en_tokenizer.enable_padding(pad_id=en_tokenizer.token_to_id("[pad]"), pad_token="[pad]")
    pl_tokenizer.enable_padding(pad_id=pl_tokenizer.token_to_id("[pad]"), pad_token="[pad]")
 
    # Save the trained tokenizers
    en_tokenizer.save("en_tokenizer.json", pretty=True)
    pl_tokenizer.save("pl_tokenizer.json", pretty=True)

In [7]:
def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    return (x * cos) + (rotate_half(x) * sin)

class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim, max_seq_len=1024):
        super().__init__()
        N = 10000
        inv_freq = 1. / (N ** (torch.arange(0, dim, 2).float() / dim))
        position = torch.arange(max_seq_len).float()
        inv_freq = torch.cat((inv_freq, inv_freq), dim=-1)
        sinusoid_inp = torch.outer(position, inv_freq)
        self.register_buffer("cos", sinusoid_inp.cos())
        self.register_buffer("sin", sinusoid_inp.sin())

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.size(1)
        cos = self.cos[:seq_len].view(1, seq_len, 1, -1)
        sin = self.sin[:seq_len].view(1, seq_len, 1, -1)
        return apply_rotary_pos_emb(x, cos, sin)

In [8]:
class GQA(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_kv_heads=None, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads or num_heads
        self.head_dim = hidden_dim // num_heads
        self.num_groups = num_heads // num_kv_heads
        self.dropout = dropout
        self.q_proj = nn.Linear(hidden_dim, hidden_dim)
        self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        self.v_proj = nn.Linear(hidden_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, q, k, v, mask=None, rope=None):
        q_batch_size, q_seq_len, hidden_dim = q.shape
        k_batch_size, k_seq_len, hidden_dim = k.shape
        v_batch_size, v_seq_len, hidden_dim = v.shape

        # projection
        q = self.q_proj(q).view(q_batch_size, q_seq_len, -1, self.head_dim).transpose(1, 2)
        k = self.k_proj(k).view(k_batch_size, k_seq_len, -1, self.head_dim).transpose(1, 2)
        v = self.v_proj(v).view(v_batch_size, v_seq_len, -1, self.head_dim).transpose(1, 2)

        # apply rotary positional encoding
        if rope:
            q = rope(q)
            k = rope(k)

        # compute grouped query attention
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        output = F.scaled_dot_product_attention(q, k, v,
                                                attn_mask=mask,
                                                dropout_p=self.dropout,
                                                enable_gqa=True)
        output = output.transpose(1, 2).reshape(q_batch_size, q_seq_len, hidden_dim).contiguous()
        output = self.out_proj(output)
        return output

In [9]:
class SwiGLU(nn.Module):
    def __init__(self, hidden_dim, intermediate_dim):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, intermediate_dim)
        self.up = nn.Linear(hidden_dim, intermediate_dim)
        self.down = nn.Linear(intermediate_dim, hidden_dim)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.act(self.gate(x)) * self.up(x)
        x = self.down(x)
        return x

In [10]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_kv_heads=None, dropout=0.1):
        super().__init__()
        self.self_attn = GQA(hidden_dim, num_heads, num_kv_heads, dropout)
        self.mlp = SwiGLU(hidden_dim, 4 * hidden_dim)
        self.norm1 = nn.RMSNorm(hidden_dim)
        self.norm2 = nn.RMSNorm(hidden_dim)

    def forward(self, x, mask=None, rope=None):
        # self-attention sublayer
        out = x
        out = self.norm1(x)
        out = self.self_attn(out, out, out, mask, rope)
        x = out + x
        # MLP sublayer
        out = self.norm2(x)
        out = self.mlp(out)
        return out + x

In [11]:
class DecoderLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_kv_heads=None, dropout=0.1):
        super().__init__()
        self.self_attn = GQA(hidden_dim, num_heads, num_kv_heads, dropout)
        self.cross_attn = GQA(hidden_dim, num_heads, num_kv_heads, dropout)
        self.mlp = SwiGLU(hidden_dim, 4 * hidden_dim)
        self.norm1 = nn.RMSNorm(hidden_dim)
        self.norm2 = nn.RMSNorm(hidden_dim)
        self.norm3 = nn.RMSNorm(hidden_dim)

    def forward(self, x, enc_out, mask=None, rope=None):
        # self-attention sublayer
        out = x
        out = self.norm1(out)
        out = self.self_attn(out, out, out, mask, rope)
        x = out + x
        # cross-attention sublayer
        out = self.norm2(x)
        out = self.cross_attn(out, enc_out, enc_out, None, rope)
        x = out + x
        # MLP sublayer
        x = out + x
        out = self.norm3(x)
        out = self.mlp(out)
        return out + x

In [12]:
class Transformer(nn.Module):
    def __init__(self, num_layers, num_heads, num_kv_heads, hidden_dim,
                 max_seq_len, vocab_size_src, vocab_size_tgt, dropout=0.1):
        super().__init__()
        self.rope = RotaryPositionalEncoding(hidden_dim // num_heads, max_seq_len)
        self.src_embedding = nn.Embedding(vocab_size_src, hidden_dim)
        self.tgt_embedding = nn.Embedding(vocab_size_tgt, hidden_dim)
        self.encoders = nn.ModuleList([
            EncoderLayer(hidden_dim, num_heads, num_kv_heads, dropout) for _ in range(num_layers)
        ])
        self.decoders = nn.ModuleList([
            DecoderLayer(hidden_dim, num_heads, num_kv_heads, dropout) for _ in range(num_layers)
        ])
        self.out = nn.Linear(hidden_dim, vocab_size_tgt)

    def forward(self, src_ids, tgt_ids, src_mask=None, tgt_mask=None):
        # Encoder
        x = self.src_embedding(src_ids)
        for encoder in self.encoders:
            x = encoder(x, src_mask, self.rope)
        enc_out = x
        # Decoder
        x = self.tgt_embedding(tgt_ids)
        for decoder in self.decoders:
            x = decoder(x, enc_out, tgt_mask, self.rope)
        return self.out(x)

In [13]:
model_config = {
    "num_layers": 4,
    "num_heads": 8,
    "num_kv_heads": 4,
    "hidden_dim": 128,
    "max_seq_len": 768,
    "vocab_size_src": len(en_tokenizer.get_vocab()),
    "vocab_size_tgt": len(pl_tokenizer.get_vocab()),
    "dropout": 0.1,
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Transformer(**model_config).to(device)

In [14]:
import torch
from torch.utils.data import Dataset, DataLoader

class TranslationDataset(torch.utils.data.Dataset):
    def __init__(self, text_pairs):
        self.text_pairs = text_pairs

    def __len__(self):
        return len(self.text_pairs)

    def __getitem__(self, idx):
        eng, pol = self.text_pairs[idx]
        return eng, "[start] " + pol + " [end]"


def collate_fn(batch):
    en_str, pl_str = zip(*batch)
    en_enc = en_tokenizer.encode_batch(en_str, add_special_tokens=True)
    pl_enc = pl_tokenizer.encode_batch(pl_str, add_special_tokens=True)
    en_ids = [enc.ids for enc in en_enc]
    pl_ids = [enc.ids for enc in pl_enc]
    return torch.tensor(en_ids), torch.tensor(pl_ids)

BATCH_SIZE = 32
dataset = TranslationDataset(text_pairs)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [15]:
for en_ids, pl_ids in dataloader:
    print(f"English: {en_ids}")
    print(f"French: {pl_ids}")
    break

English: tensor([[ 261,  286,   74, 2352,  127,   24,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2],
        [  61,  168,  111,  147,  105,  113,   11,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2],
        [  93,  273,  796,  239, 2080,  153, 3421,   11,    2,    2,    2,    2,
            2,    2,    2,    2,    2],
        [ 243,  121,   75, 4655,  305,  232,   24,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2],
        [ 114,  387,   92,  406,   11,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2],
        [ 113,  121,   64, 1164, 2778,   11,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2],
        [  80, 1343,   92,  320,  372,  114,   93,  121, 5094,   65,  105,  114,
          681,   11,    2,    2,    2],
        [ 812,  637,  127,   11,    2,    2,    2,    2,    2,    2,    2,    2,
            2,    2,    2,    2,    2],
        [  61,  142,   

In [16]:
def create_causal_mask(seq_len, device):
    mask = torch.triu(torch.full((seq_len, seq_len), float('-inf'), device=device), diagonal=1)
    return mask

In [17]:
def create_padding_mask(batch, padding_token_id):
    batch_size, seq_len = batch.shape
    device = batch.device
    padded = torch.zeros_like(batch, device=device).float().masked_fill(batch == padding_token_id, float('-inf'))
    mask = torch.zeros(batch_size, seq_len, seq_len, device=device) + padded[:,:,None] + padded[:,None,:]
    return mask[:, None, :, :]

In [18]:
N_EPOCHS = 60
LR = 0.005
WARMUP_STEPS = 1000
CLIP_NORM = 5.0
best_loss = float('inf')

loss_fn = nn.CrossEntropyLoss(ignore_index=pl_tokenizer.token_to_id("[pad]"))

optimizer = optim.Adam(model.parameters(), lr=LR)
warmup_scheduler = optim.lr_scheduler.LinearLR(
    optimizer, start_factor=0.01, end_factor=1.0, total_iters=WARMUP_STEPS)
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=N_EPOCHS * len(dataloader) - WARMUP_STEPS, eta_min=0)
scheduler = optim.lr_scheduler.SequentialLR(
    optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[WARMUP_STEPS])

for epoch in range(N_EPOCHS):
    model.train()
    epoch_loss = 0
    for en_ids, pl_ids in dataloader:
        # Move the "sentences" to device
        en_ids = en_ids.to(device)
        pl_ids = pl_ids.to(device)
        # create source mask as padding mask, target mask as causal mask
        src_mask = create_padding_mask(en_ids, en_tokenizer.token_to_id("[pad]"))
        tgt_mask = create_causal_mask(pl_ids.shape[1], device).unsqueeze(0)
        tgt_mask = tgt_mask + create_padding_mask(pl_ids, pl_tokenizer.token_to_id("[pad]"))
        # zero the grad, then forward pass
        optimizer.zero_grad()
        outputs = model(en_ids, pl_ids, src_mask, tgt_mask)
        # compute the loss: compare 3D logits to 2D targets
        loss = loss_fn(outputs[:, :-1, :].reshape(-1, outputs.shape[-1]), pl_ids[:, 1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM, error_if_nonfinite=False)
        optimizer.step()
        scheduler.step()
        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{N_EPOCHS}; Avg loss {epoch_loss/len(dataloader)}; Latest loss {loss.item()}")
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for en_ids, pl_ids in tqdm.tqdm(dataloader, desc="Evaluating"):
            en_ids = en_ids.to(device)
            pl_ids = pl_ids.to(device)
            src_mask = create_padding_mask(en_ids, en_tokenizer.token_to_id("[pad]"))
            tgt_mask = create_causal_mask(pl_ids.shape[1], device).unsqueeze(0) + create_padding_mask(pl_ids, pl_tokenizer.token_to_id("[pad]"))
            outputs = model(en_ids, pl_ids, src_mask, tgt_mask)
            loss = loss_fn(outputs[:, :-1, :].reshape(-1, outputs.shape[-1]), pl_ids[:, 1:].reshape(-1))
            epoch_loss += loss.item()
    print(f"Eval loss: {epoch_loss/len(dataloader)}")
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), f"transformer-epoch-{epoch+1}.pth")



Epoch 1/60; Avg loss 4.096698644128886; Latest loss 3.0250887870788574


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.79it/s]


Eval loss: 2.8337806796175045
Epoch 2/60; Avg loss 2.788626352309005; Latest loss 2.6558470726013184


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.27it/s]


Eval loss: 2.1281219177710584
Epoch 3/60; Avg loss 2.289740744560185; Latest loss 2.501023530960083


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.08it/s]


Eval loss: 1.8111035327729108
Epoch 4/60; Avg loss 1.9999193435238558; Latest loss 2.5427656173706055


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.77it/s]


Eval loss: 1.5360182169847807
Epoch 5/60; Avg loss 1.7943998007827269; Latest loss 1.8478368520736694


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.88it/s]


Eval loss: 1.3832164406041476
Epoch 6/60; Avg loss 1.6288506450400253; Latest loss 1.8708763122558594


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.02it/s]


Eval loss: 1.2374786511155444
Epoch 7/60; Avg loss 1.4774724625776787; Latest loss 1.6931941509246826


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.45it/s]


Eval loss: 1.1248808801909116
Epoch 8/60; Avg loss 1.3545997662035663; Latest loss 1.6242321729660034


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.99it/s]


Eval loss: 1.0742572821612422
Epoch 9/60; Avg loss 1.2462908121425331; Latest loss 1.088580846786499


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.28it/s]


Eval loss: 0.9992654616823914
Epoch 10/60; Avg loss 1.1521053031619761; Latest loss 1.2594029903411865


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.83it/s]


Eval loss: 0.9367452514083406
Epoch 11/60; Avg loss 1.0866035470318707; Latest loss 1.6371737718582153


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.16it/s]


Eval loss: 0.8501602898911101
Epoch 12/60; Avg loss 1.0295887278202283; Latest loss 1.3390588760375977


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.70it/s]


Eval loss: 0.8296748725685914
Epoch 13/60; Avg loss 1.0087333970166898; Latest loss 1.283947229385376


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.20it/s]


Eval loss: 0.8488006032865527
Epoch 14/60; Avg loss 0.9580731604308; Latest loss 1.2222539186477661


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.73it/s]


Eval loss: 0.7968495941448741
Epoch 15/60; Avg loss 0.9069944125876914; Latest loss 1.189395546913147


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.15it/s]


Eval loss: 0.6963220210941152
Epoch 16/60; Avg loss 0.8531025157720035; Latest loss 0.7798734903335571


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.56it/s]


Eval loss: 0.6940147786156611
Epoch 17/60; Avg loss 0.8021865700786123; Latest loss 0.9306039810180664


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.99it/s]


Eval loss: 0.5992186551984817
Epoch 18/60; Avg loss 0.7541010420213645; Latest loss 0.776816725730896


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.16it/s]


Eval loss: 0.5584236017064307
Epoch 19/60; Avg loss 0.7034157690593259; Latest loss 0.843147337436676


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.95it/s]


Eval loss: 0.5554019195451395
Epoch 20/60; Avg loss 0.6643437010990559; Latest loss 0.5714995861053467


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.16it/s]


Eval loss: 0.5245650897006954
Epoch 21/60; Avg loss 0.6192800577743721; Latest loss 0.6121620535850525


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.95it/s]


Eval loss: 0.47248147761520415
Epoch 22/60; Avg loss 0.5773843680812751; Latest loss 0.7021769881248474


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.81it/s]


Eval loss: 0.46768921340842134
Epoch 23/60; Avg loss 0.538996307572072; Latest loss 0.6130213141441345


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.46it/s]


Eval loss: 0.45883150038604525
Epoch 24/60; Avg loss 0.504644463594392; Latest loss 0.6355393528938293


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.55it/s]


Eval loss: 0.40750095187588775
Epoch 25/60; Avg loss 0.4751349442010302; Latest loss 0.658591628074646


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.19it/s]


Eval loss: 0.37189820975653193
Epoch 26/60; Avg loss 0.4405467912609862; Latest loss 0.5668201446533203


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.66it/s]


Eval loss: 0.33505229517799123
Epoch 27/60; Avg loss 0.4164567585441981; Latest loss 0.6039367914199829


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.11it/s]


Eval loss: 0.3402984551855727
Epoch 28/60; Avg loss 0.39434592704186455; Latest loss 0.6600249409675598


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.54it/s]


Eval loss: 0.3062501291384973
Epoch 29/60; Avg loss 0.364822095092363; Latest loss 0.3183078169822693


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.01it/s]


Eval loss: 0.28139396769749103
Epoch 30/60; Avg loss 0.33754022480086243; Latest loss 0.7558216452598572


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.71it/s]


Eval loss: 0.26100009169461696
Epoch 31/60; Avg loss 0.3144687582060015; Latest loss 0.3184591233730316


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.27it/s]


Eval loss: 0.24360140074049058
Epoch 32/60; Avg loss 0.2869799555349365; Latest loss 0.30393359065055847


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.48it/s]


Eval loss: 0.2249272644712731
Epoch 33/60; Avg loss 0.26516515095563176; Latest loss 0.24410958588123322


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.22it/s]


Eval loss: 0.21952540115550914
Epoch 34/60; Avg loss 0.24535300271670563; Latest loss 0.18771925568580627


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.60it/s]


Eval loss: 0.2150392374471657
Epoch 35/60; Avg loss 0.23404450048996842; Latest loss 0.34812915325164795


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.15it/s]


Eval loss: 0.17556117159462592
Epoch 36/60; Avg loss 0.20910632311746608; Latest loss 0.1159706711769104


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.13it/s]


Eval loss: 0.17560243716121235
Epoch 37/60; Avg loss 0.20041046417909691; Latest loss 0.42961224913597107


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.36it/s]


Eval loss: 0.15469145014395183
Epoch 38/60; Avg loss 0.18037113707177296; Latest loss 0.2432195544242859


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.06it/s]


Eval loss: 0.13748216278547865
Epoch 39/60; Avg loss 0.16222046116773503; Latest loss 0.14209596812725067


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.75it/s]


Eval loss: 0.13211829097466096
Epoch 40/60; Avg loss 0.15025901100592534; Latest loss 0.11331888288259506


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.23it/s]


Eval loss: 0.11422492249762012
Epoch 41/60; Avg loss 0.13644030034707447; Latest loss 0.18769453465938568


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.95it/s]


Eval loss: 0.10588729173477864
Epoch 42/60; Avg loss 0.12541529278554545; Latest loss 0.2878486216068268


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.14it/s]


Eval loss: 0.09353603049997888
Epoch 43/60; Avg loss 0.114422532876795; Latest loss 0.057775065302848816


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.07it/s]


Eval loss: 0.0901099546041923
Epoch 44/60; Avg loss 0.10552363240753752; Latest loss 0.19279521703720093


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.92it/s]


Eval loss: 0.08391413144978584
Epoch 45/60; Avg loss 0.0954700652659215; Latest loss 0.12453530728816986


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.93it/s]


Eval loss: 0.0759484096040947
Epoch 46/60; Avg loss 0.08757146469026841; Latest loss 0.0896192267537117


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.38it/s]


Eval loss: 0.06742265241785589
Epoch 47/60; Avg loss 0.08056822297095555; Latest loss 0.025572726503014565


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.27it/s]


Eval loss: 0.06169689109183747
Epoch 48/60; Avg loss 0.07317174340480004; Latest loss 0.11764517426490784


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.93it/s]


Eval loss: 0.05825147658222471
Epoch 49/60; Avg loss 0.06772384938523265; Latest loss 0.1727825552225113


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.72it/s]


Eval loss: 0.05506358880081137
Epoch 50/60; Avg loss 0.06165326463687894; Latest loss 0.019851485267281532


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.73it/s]


Eval loss: 0.04897872859763691
Epoch 51/60; Avg loss 0.05755186770582151; Latest loss 0.05499405413866043


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.05it/s]


Eval loss: 0.045393065728864986
Epoch 52/60; Avg loss 0.05193503073482107; Latest loss 0.0427817665040493


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.52it/s]


Eval loss: 0.041708463221861905
Epoch 53/60; Avg loss 0.0480362647735547; Latest loss 0.0884770080447197


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.18it/s]


Eval loss: 0.039633023814710476
Epoch 54/60; Avg loss 0.044817577740900984; Latest loss 0.0705394372344017


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.18it/s]


Eval loss: 0.03692043560635156
Epoch 55/60; Avg loss 0.04154791339684126; Latest loss 0.04127172380685806


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.90it/s]


Eval loss: 0.036458278581862937
Epoch 56/60; Avg loss 0.039173656615799544; Latest loss 0.012574691325426102


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 116.66it/s]


Eval loss: 0.034782544891039914
Epoch 57/60; Avg loss 0.037184714837776064; Latest loss 0.06925081461668015


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.95it/s]


Eval loss: 0.034217104987263144
Epoch 58/60; Avg loss 0.03531523812051206; Latest loss 0.025644326582551003


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.12it/s]


Eval loss: 0.03411191303263454
Epoch 59/60; Avg loss 0.034735216888623216; Latest loss 0.006583702750504017


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 117.30it/s]


Eval loss: 0.033753974537184844
Epoch 60/60; Avg loss 0.033859962240922574; Latest loss 0.03819034621119499


Evaluating: 100%|██████████| 1622/1622 [00:13<00:00, 118.12it/s]

Eval loss: 0.03395433518871936





In [20]:
# Test for a few samples
model.eval()
N_SAMPLES = 5
MAX_LEN = 60
with torch.no_grad():
    start_token = torch.tensor([pl_tokenizer.token_to_id("[start]")]).to(device)
    for en, true_fr in random.sample(dataset.text_pairs, N_SAMPLES):
        en_ids = torch.tensor(en_tokenizer.encode(en).ids).unsqueeze(0).to(device)

        # get context from encoder
        src_mask = create_padding_mask(en_ids, en_tokenizer.token_to_id("[pad]"))
        x = model.src_embedding(en_ids)
        for encoder in model.encoders:
            x = encoder(x, src_mask, model.rope)
        enc_out = x

        # generate output from decoder
        pl_ids = start_token.unsqueeze(0)
        for _ in range(MAX_LEN):
            tgt_mask = create_causal_mask(pl_ids.shape[1], device).unsqueeze(0)
            tgt_mask = tgt_mask + create_padding_mask(pl_ids, pl_tokenizer.token_to_id("[pad]"))
            x = model.tgt_embedding(pl_ids)
            for decoder in model.decoders:
                x = decoder(x, enc_out, tgt_mask, model.rope)
            outputs = model.out(x)

            outputs = outputs.argmax(dim=-1)
            pl_ids = torch.cat([pl_ids, outputs[:, -1:]], axis=-1)
            if pl_ids[0, -1] == pl_tokenizer.token_to_id("[end]"):
                break

        # Decode the predicted IDs
        pred_fr = pl_tokenizer.decode(pl_ids[0].tolist())
        print(f"English: {en}")
        print(f"Polish: {true_fr}")
        print(f"Predicted: {pred_fr}")
        print()

English: i'm saving money in order to study abroad.
Polish: oszczędzam pieniądze na studia za granicą.
Predicted:  oszczędzam pieniądze na studia za granicą. 

English: you're too young to get married, aren't you?
Polish: jesteś za młoda żeby wyjść za mąż, nieprawdaż?
Predicted:  jesteś za młoda żeby wyjść za mąż, nieprawdaż? 

English: what's your favorite swear word?
Polish: jakie jest twoje ulubione przekleństwo?
Predicted:  co jest ukratne coś do pożyte? 

English: we've got so much to talk about.
Polish: mamy sporo do pogadania.
Predicted:  mamy sporo do pogadania. 

English: i just don't want to have people thinking i'm weak.
Polish: po prostu nie chcę, aby ludzie myśleli, że jestem słaby.
Predicted:  po prostu nie chcę, aby ludzie myśleli, że jestem słaby. 

