In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
# Load the English-Russian parallel text data
with open('data/Tatoeba.en-ru.en', 'r', encoding='utf-8') as f:
    en_texts = f.read().strip().split('\n')

with open('data/Tatoeba.en-ru.ru', 'r', encoding='utf-8') as f:
    ru_texts = f.read().strip().split('\n')

print(f"Loaded {len(en_texts)} English sentences")
print(f"Loaded {len(ru_texts)} Russian sentences")
print(f"\nExample pair:")
print(f"EN: {en_texts[0]}")
print(f"RU: {ru_texts[0]}")

Loaded 540674 English sentences
Loaded 540674 Russian sentences

Example pair:
EN: For once in my life I'm doing a good deed... And it is useless. 
RU: Один раз в жизни я делаю хорошее дело... И оно бесполезно. 


In [None]:
en_chars = sorted(list(set(''.join(en_texts))))
ru_chars = sorted(list(set(''.join(ru_texts))))

SOS_TOKEN = '<SOS>'
EOS_TOKEN = '<EOS>'
PAD_TOKEN = '<PAD>'

en_chars = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + en_chars
ru_chars = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN] + ru_chars

en_vocab_size = len(en_chars)
ru_vocab_size = len(ru_chars)

print(f"English vocab size: {en_vocab_size}")
print(f"Russian vocab size: {ru_vocab_size}")

en_stoi = {ch: i for i, ch in enumerate(en_chars)}
en_itos = {i: ch for i, ch in enumerate(en_chars)}
ru_stoi = {ch: i for i, ch in enumerate(ru_chars)}
ru_itos = {i: ch for i, ch in enumerate(ru_chars)}

def encode_en(s):
    return [en_stoi[SOS_TOKEN]] + [en_stoi[c] for c in s] + [en_stoi[EOS_TOKEN]]

def encode_ru(s):
    return [ru_stoi[SOS_TOKEN]] + [ru_stoi[c] for c in s] + [ru_stoi[EOS_TOKEN]]

def decode_en(l):
    return ''.join([en_itos[i] for i in l if i not in [en_stoi[PAD_TOKEN], en_stoi[SOS_TOKEN], en_stoi[EOS_TOKEN]]])

def decode_ru(l):
    return ''.join([ru_itos[i] for i in l if i not in [ru_stoi[PAD_TOKEN], ru_stoi[SOS_TOKEN], ru_stoi[EOS_TOKEN]]])

English vocab size: 223
Russian vocab size: 231


In [4]:
# Hyperparameters
embd_dim = 256
max_seq_length = 128
batch_size = 64
n_heads = 8
n_layers = 8
device = "cuda" if torch.cuda.is_available() else "cpu"
dropout = 0.1

print(f"Using device: {device}")

Using device: cuda


In [None]:
def prepare_data(en_texts, ru_texts, max_len):
    pairs = []
    for en, ru in zip(en_texts, ru_texts):
        en_encoded = encode_en(en)
        ru_encoded = encode_ru(ru)
        
        if len(en_encoded) > max_len or len(ru_encoded) > max_len:
            continue
            
        pairs.append((en_encoded, ru_encoded))
    
    return pairs

pairs = prepare_data(en_texts, ru_texts, max_seq_length)
print(f"Prepared {len(pairs)} pairs (filtered by max length)")

split_idx = int(len(pairs) * 0.9)
train_pairs = pairs[:split_idx]
val_pairs = pairs[split_idx:]

print(f"Train pairs: {len(train_pairs)}")
print(f"Val pairs: {len(val_pairs)}")

Prepared 538850 pairs (filtered by max length)
Train pairs: 484965
Val pairs: 53885


In [None]:
def get_batch(pairs, batch_size, device):
    indices = torch.randint(len(pairs), (batch_size,))
    
    src_batch = []
    tgt_batch = []
    
    for idx in indices:
        src, tgt = pairs[idx]
        src_batch.append(src)
        tgt_batch.append(tgt)
    
    max_src_len = max(len(s) for s in src_batch)
    max_tgt_len = max(len(t) for t in tgt_batch)
    
    src_padded = torch.full((batch_size, max_src_len), en_stoi[PAD_TOKEN], dtype=torch.long)
    tgt_padded = torch.full((batch_size, max_tgt_len), ru_stoi[PAD_TOKEN], dtype=torch.long)
    
    for i, (src, tgt) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :len(src)] = torch.tensor(src)
        tgt_padded[i, :len(tgt)] = torch.tensor(tgt)
   
    tgt_input = tgt_padded[:, :-1]
    tgt_output = tgt_padded[:, 1:]
    
    return src_padded.to(device), tgt_input.to(device), tgt_output.to(device)

In [None]:
from model import InputEmbedding, DecoderBlock, CrossMultiHead, EncoderBlock


In [None]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embd_dim, max_seq_length, n_heads, n_layers, dropout):
        super().__init__()
        
        # Embeddings
        self.src_embedding = InputEmbedding(src_vocab_size, embd_dim, max_seq_length)
        self.tgt_embedding = InputEmbedding(tgt_vocab_size, embd_dim, max_seq_length)
        
        self.encoder_layers = nn.ModuleList([
            EncoderBlock(embd_dim, n_heads, dropout) for _ in range(n_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(embd_dim, n_heads, dropout) for _ in range(n_layers)
        ])
        
        self.cross_attention_layers = nn.ModuleList([
            CrossMultiHead(embd_dim, n_heads, dropout) for _ in range(n_layers)
        ])
        self.cross_ln = nn.ModuleList([
            nn.LayerNorm(embd_dim) for _ in range(n_layers)
        ])
        
        self.ln_out = nn.LayerNorm(embd_dim)
        self.output_proj = nn.Linear(embd_dim, tgt_vocab_size)
        
    def encode(self, src):
        x = self.src_embedding(src)
        for layer in self.encoder_layers:
            x = layer(x)
        return x
    
    def decode(self, tgt, enc_output):
        x = self.tgt_embedding(tgt)
        for decoder_layer, cross_attn, cross_ln in zip(
            self.decoder_layers, self.cross_attention_layers, self.cross_ln
        ):
            x = decoder_layer(x)
            x = x + cross_attn(enc_output, cross_ln(x))
        return x
    
    def forward(self, src, tgt):
        enc_output = self.encode(src)
        dec_output = self.decode(tgt, enc_output)
        logits = self.output_proj(self.ln_out(dec_output))
        return logits

In [None]:
def compute_loss(targets, logits):
    B, T, C = logits.shape
    logits = logits.reshape(B * T, C)
    targets = targets.reshape(B * T)
    
    loss = F.cross_entropy(logits, targets, ignore_index=ru_stoi[PAD_TOKEN])
    return loss

In [None]:
model = Seq2SeqTransformer(
    src_vocab_size=en_vocab_size,
    tgt_vocab_size=ru_vocab_size,
    embd_dim=embd_dim,
    max_seq_length=max_seq_length,
    n_heads=n_heads,
    n_layers=n_layers,
    dropout=dropout
)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

Total parameters: 14,968,807


In [None]:
import os

def save_checkpoint(model, optimizer, step, train_loss, val_loss, checkpoint_dir='checkpoints'):
    """Save model checkpoint with training state"""
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'config': {
            'src_vocab_size': en_vocab_size,
            'tgt_vocab_size': ru_vocab_size,
            'embd_dim': embd_dim,
            'max_seq_length': max_seq_length,
            'n_heads': n_heads,
            'n_layers': n_layers,
            'dropout': dropout
        }
    }
    
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{step}.pt')
    torch.save(checkpoint, checkpoint_path)
    
    latest_path = os.path.join(checkpoint_dir, 'checkpoint_latest.pt')
    torch.save(checkpoint, latest_path)
    
    print(f"Checkpoint saved: {checkpoint_path}")
    return checkpoint_path


def load_checkpoint(checkpoint_path, model, optimizer=None):
    """Load model checkpoint and restore training state"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    step = checkpoint.get('step', 0)
    train_loss = checkpoint.get('train_loss', 0.0)
    val_loss = checkpoint.get('val_loss', 0.0)
    
    print(f"Checkpoint loaded from: {checkpoint_path}")
    print(f"Resuming from step: {step}")
    print(f"Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    return step, train_loss, val_loss

In [None]:
num_steps = 15000
eval_interval = 200
checkpoint_interval = 1000
model.train()

for step in range(num_steps):
    src, tgt_input, tgt_output = get_batch(train_pairs, batch_size, device)
    
    logits = model(src, tgt_input)
    loss = compute_loss(tgt_output, logits)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    if step % eval_interval == 0:
        model.eval()
        with torch.no_grad():
            src_val, tgt_input_val, tgt_output_val = get_batch(val_pairs, batch_size, device)
            logits_val = model(src_val, tgt_input_val)
            val_loss = compute_loss(tgt_output_val, logits_val)
        
        print(f"step: {step:5d}  train loss: {loss.item():.4f}  val loss: {val_loss.item():.4f}")
        model.train()
    
    if step % checkpoint_interval == 0 and step > 0:
        model.eval()
        with torch.no_grad():
            src_val, tgt_input_val, tgt_output_val = get_batch(val_pairs, batch_size, device)
            logits_val = model(src_val, tgt_input_val)
            val_loss = compute_loss(tgt_output_val, logits_val)
        save_checkpoint(model, optimizer, step, loss.item(), val_loss.item())
        model.train()

model.eval()
with torch.no_grad():
    src_val, tgt_input_val, tgt_output_val = get_batch(val_pairs, batch_size, device)
    logits_val = model(src_val, tgt_input_val)
    val_loss = compute_loss(tgt_output_val, logits_val)
save_checkpoint(model, optimizer, num_steps, loss.item(), val_loss.item())

print("\nTraining complete!")

step:     0  train loss: 5.6714  val loss: 4.4644
step:   200  train loss: 2.0392  val loss: 2.0664
step:   400  train loss: 1.7420  val loss: 1.8793
step:   600  train loss: 1.5836  val loss: 1.5922
step:   800  train loss: 1.3906  val loss: 1.4513
step:  1000  train loss: 1.3544  val loss: 1.3768
Checkpoint saved: checkpoints/checkpoint_step_1000.pt
step:  1200  train loss: 1.2802  val loss: 1.3216
step:  1400  train loss: 1.3317  val loss: 1.2123
step:  1600  train loss: 1.1548  val loss: 1.1505
step:  1800  train loss: 1.0824  val loss: 1.2346
step:  2000  train loss: 1.0821  val loss: 1.2251
Checkpoint saved: checkpoints/checkpoint_step_2000.pt
step:  2200  train loss: 1.1241  val loss: 1.1878
step:  2400  train loss: 1.1403  val loss: 1.0965
step:  2600  train loss: 1.1245  val loss: 1.0006
step:  2800  train loss: 1.0171  val loss: 1.1429
step:  3000  train loss: 0.8734  val loss: 1.1200
Checkpoint saved: checkpoints/checkpoint_step_3000.pt
step:  3200  train loss: 0.9823  val l

In [None]:
checkpoint_path = 'checkpoints/checkpoint_latest.pt'  # or specify a specific checkpoint
step, train_loss, val_loss = load_checkpoint(checkpoint_path, model, optimizer)

  checkpoint = torch.load(checkpoint_path, map_location=device)


Checkpoint loaded from: checkpoints/checkpoint_latest.pt
Resuming from step: 15000
Train loss: 0.3937, Val loss: 0.4854


In [None]:
def translate(model, en_text, max_len=max_seq_length, temperature=1.0):
    model.eval()
    
    with torch.no_grad():
        src_tokens = encode_en(en_text)
        if len(src_tokens) > max_len:
            src_tokens = src_tokens[:max_len]
        
        src = torch.tensor(src_tokens, device=device).unsqueeze(0)  # [1, T]
        
        enc_output = model.encode(src)
        tgt = torch.tensor([[ru_stoi[SOS_TOKEN]]], device=device)
        
        for _ in range(max_len):
            dec_output = model.decode(tgt, enc_output)
            logits = model.output_proj(model.ln_out(dec_output))
            
            next_logits = logits[0, -1, :] / temperature
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            if next_token.item() == ru_stoi[EOS_TOKEN]:
                break
            
            tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)
        
        translated = decode_ru(tgt[0].tolist())
        return translated


print("Translation Examples:\n")
test_sentences = [
    "Hello, how are you?",
    "I love learning languages.",
    "Today is a beautiful day.",
    "Let's try something.",
    "I will be back soon."
]

for en_text in test_sentences:
    translation = translate(model, en_text, temperature=0.8)
    print(f"EN: {en_text}")
    print(f"RU: {translation}")
    print()

Translation Examples:

EN: Hello, how are you?
RU: Здрово, вы возьми, как ты? 

EN: I love learning languages.
RU: Я люблю изучать языки. 

EN: Today is a beautiful day.
RU: Сегодня только красивый день. 

EN: Let's try something.
RU: Давай попробуем что-то остарость. 

EN: I will be back soon.
RU: Я скоро вернусь. 



In [None]:
print("Validation Examples with Ground Truth:\n")

for i in range(5):
    src, tgt = val_pairs[i]
    en_text = decode_en(src)
    ru_true = decode_ru(tgt)
    ru_pred = translate(model, en_text, temperature=0.5)
    
    print(f"EN:        {en_text}")
    print(f"RU (true): {ru_true}")
    print(f"RU (pred): {ru_pred}")
    print()

Validation Examples with Ground Truth:

EN:        I'm just about to change that. 
RU (true): Я как раз собираюсь это изменить. 
RU (pred): Я просто просто измениться в этом меняти. 

EN:        The project was never completed. 
RU (true): Проект так и не был завершён. 
RU (pred): Проект никогда не было никогда не полностью. 

EN:        Tom didn't want to stand out. 
RU (true): Том не хотел выделяться. 
RU (pred): Том не хотел выходить из улицы. 

EN:        Tom did not want to stand out. 
RU (true): Том не хотел выделяться. 
RU (pred): Том не хотел выходить и не выходить. 

EN:        The first step is realizing that you have a problem. 
RU (true): Первый шаг заключается в том, чтобы понять, что у вас есть проблема. 
RU (pred): Первый полезна проколизиться тебе, что у тебя проблема. 

