In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import time
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class TranslationDataset(Dataset):
    def __init__(self, src_file, tgt_file=None, src_vocab=None, tgt_vocab=None, max_len=100):
        self.src_data = []
        with open(src_file, 'r', encoding='utf-8') as f:
            for line in f:
                tokens = line.strip().split()
                self.src_data.append(tokens)
        
        self.tgt_data = []
        if tgt_file:
            with open(tgt_file, 'r', encoding='utf-8') as f:
                for line in f:
                    tokens = line.strip().split()
                    self.tgt_data.append(tokens)
        
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len
    
    def __len__(self):
        return len(self.src_data)
    
    def __getitem__(self, idx):
        src = self.src_data[idx]
        
        if len(self.tgt_data) > 0:
            tgt = self.tgt_data[idx]
            return {"src": src, "tgt": tgt}
        else:
            return {"src": src}

class Vocab:
    def __init__(self, pad_token="<pad>", unk_token="<unk>", sos_token="<s>", eos_token="</s>", min_freq=2):
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.min_freq = min_freq
        
        self.stoi = {pad_token: 0, unk_token: 1, sos_token: 2, eos_token: 3}
        self.itos = {0: pad_token, 1: unk_token, 2: sos_token, 3: eos_token}
        self.freq = {}
    
    def build_vocab(self, sentences):
        for sent in sentences:
            for token in sent:
                if token not in self.freq:
                    self.freq[token] = 0
                self.freq[token] += 1
        
        idx = len(self.stoi)
        for token, freq in self.freq.items():
            if freq >= self.min_freq and token not in self.stoi:
                self.stoi[token] = idx
                self.itos[idx] = token
                idx += 1
    
    def __len__(self):
        return len(self.stoi)

def create_batch(data, src_vocab, tgt_vocab=None):
    src_list = []
    
    for item in data:
        tokens = [src_vocab.stoi.get(token, src_vocab.stoi[src_vocab.unk_token]) for token in item["src"]]
        tokens = [src_vocab.stoi[src_vocab.sos_token]] + tokens + [src_vocab.stoi[src_vocab.eos_token]]
        src_list.append(tokens)
    
    max_src_len = max(len(s) for s in src_list)
    padded_src = []
    src_mask = []
    
    for tokens in src_list:
        padding = [src_vocab.stoi[src_vocab.pad_token]] * (max_src_len - len(tokens))
        padded_src.append(tokens + padding)
        mask = [True] * len(tokens) + [False] * (max_src_len - len(tokens))
        src_mask.append(mask)
    
    src_tensor = torch.LongTensor(padded_src).to(device)
    src_mask = torch.BoolTensor(src_mask).to(device)
    
    if tgt_vocab is None:
        return {"src": src_tensor, "src_mask": src_mask}
    
    tgt_list = []
    
    for item in data:
        tokens = [tgt_vocab.stoi.get(token, tgt_vocab.stoi[tgt_vocab.unk_token]) for token in item["tgt"]]
        tokens = [tgt_vocab.stoi[tgt_vocab.sos_token]] + tokens + [tgt_vocab.stoi[tgt_vocab.eos_token]]
        tgt_list.append(tokens)
    
    max_tgt_len = max(len(t) for t in tgt_list)
    padded_tgt = []
    tgt_mask = []
    
    for tokens in tgt_list:
        padding = [tgt_vocab.stoi[tgt_vocab.pad_token]] * (max_tgt_len - len(tokens))
        padded_tgt.append(tokens + padding)
        mask = [True] * len(tokens) + [False] * (max_tgt_len - len(tokens))
        tgt_mask.append(mask)
    
    tgt_tensor = torch.LongTensor(padded_tgt).to(device)
    tgt_mask = torch.BoolTensor(tgt_mask).to(device)
    
    return {
        "src": src_tensor, 
        "tgt": tgt_tensor, 
        "src_mask": src_mask, 
        "tgt_mask": tgt_mask
    }

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer("pe", pe)
    
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(EncoderLayer, self).__init__()
        
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.self_attn_norm = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )
        self.ff_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src, src_mask=None):
        _src = src
        src2, _ = self.self_attn(src, src, src, key_padding_mask=~src_mask)
        src = self.self_attn_norm(src + self.dropout(src2))
        
        src2 = self.ff(src)
        src = self.ff_norm(src + self.dropout(src2))
        
        return src

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.self_attn_norm = nn.LayerNorm(embed_dim)
        
        self.enc_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.enc_attn_norm = nn.LayerNorm(embed_dim)
        
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )
        self.ff_norm = nn.LayerNorm(embed_dim)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, tgt, enc_src, tgt_mask=None, src_mask=None, tgt_is_causal=True):
        _tgt = tgt
        tgt2, _ = self.self_attn(tgt, tgt, tgt, 
                                 attn_mask=nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(device) if tgt_is_causal else None,
                                 key_padding_mask=~tgt_mask)
        tgt = self.self_attn_norm(tgt + self.dropout(tgt2))
        
        tgt2, _ = self.enc_attn(tgt, enc_src, enc_src, key_padding_mask=~src_mask)
        tgt = self.enc_attn_norm(tgt + self.dropout(tgt2))
        
        tgt2 = self.ff(tgt)
        tgt = self.ff_norm(tgt + self.dropout(tgt2))
        
        return tgt

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_len, dropout=0.1):
        super(Encoder, self).__init__()
        
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim, max_len)
        
        self.layers = nn.ModuleList([
            EncoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(embed_dim)
    
    def forward(self, src, src_mask=None):
        src = src.transpose(0, 1)
        
        src = self.token_embedding(src) * self.scale
        src = self.pos_encoding(src)
        src = self.dropout(src)
        
        for layer in self.layers:
            src = layer(src, src_mask)
        
        return src

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_len, dropout=0.1):
        super(Decoder, self).__init__()
        
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoding = PositionalEncoding(embed_dim, max_len)
        
        self.layers = nn.ModuleList([
            DecoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(embed_dim)
    
    def forward(self, tgt, enc_src, tgt_mask=None, src_mask=None):
        tgt = tgt.transpose(0, 1)
        
        tgt = self.token_embedding(tgt) * self.scale
        tgt = self.pos_encoding(tgt)
        tgt = self.dropout(tgt)
        
        for layer in self.layers:
            tgt = layer(tgt, enc_src, tgt_mask, src_mask)
        
        return tgt

class Transformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 tgt_vocab_size, 
                 embed_dim=512, 
                 num_layers=6, 
                 num_heads=8, 
                 ff_dim=2048, 
                 max_len=5000, 
                 dropout=0.1):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(src_vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_len, dropout)
        self.decoder = Decoder(tgt_vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_len, dropout)
        self.fc_out = nn.Linear(embed_dim, tgt_vocab_size)
        
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        enc_src = self.encoder(src, src_mask)
        
        output = self.decoder(tgt, enc_src, tgt_mask, src_mask)
        
        output = self.fc_out(output)
        
        return output.permute(1, 0, 2)

def translate_sentence(model, sentence, src_vocab, tgt_vocab, device, max_len=50):
    model.eval()
    
    if isinstance(sentence, str):
        tokens = sentence.strip().split()
    else:
        tokens = sentence
    
    src_indexes = [src_vocab.stoi.get(token, src_vocab.stoi[src_vocab.unk_token]) for token in tokens]
    src_indexes = [src_vocab.stoi[src_vocab.sos_token]] + src_indexes + [src_vocab.stoi[src_vocab.eos_token]]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    
    src_mask = torch.ones(src_tensor.shape).bool().to(device)
    
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)
    
    trg_indexes = [tgt_vocab.stoi[tgt_vocab.sos_token]]
    
    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = torch.ones(trg_tensor.shape).bool().to(device)
        
        with torch.no_grad():
            output = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
            output = model.fc_out(output)
        
        pred_token = output.argmax(2)[-1, -1].item()
        trg_indexes.append(pred_token)
        
        if pred_token == tgt_vocab.stoi[tgt_vocab.eos_token]:
            break
    
    trg_tokens = [tgt_vocab.itos[i] for i in trg_indexes[1:]]
    
    if trg_tokens[-1] == tgt_vocab.eos_token:
        trg_tokens = trg_tokens[:-1]
    
    return trg_tokens

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    
    for batch in tqdm(iterator):
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        src_mask = batch["src_mask"].to(device)
        tgt_mask = batch["tgt_mask"].to(device)
        
        optimizer.zero_grad()
        
        output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :-1])
        
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        tgt = tgt[:, 1:].contiguous().view(-1)
        
        loss = criterion(output, tgt)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
        for batch in iterator:
            src = batch["src"].to(device)
            tgt = batch["tgt"].to(device)
            src_mask = batch["src_mask"].to(device)
            tgt_mask = batch["tgt_mask"].to(device)
            
            output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :-1])
            
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            tgt = tgt[:, 1:].contiguous().view(-1)
            
            loss = criterion(output, tgt)
            
            epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)

def main():
    
    import random

    SEED = 42
    random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    
    train_src_path = "train.de-en.de"
    train_tgt_path = "train.de-en.en"
    valid_src_path = "val.de-en.de"
    valid_tgt_path = "val.de-en.en"
    test_src_path = "test1.de-en.de"
    output_path = "translation.txt"
    
    BATCH_SIZE = 64
    EMBED_DIM = 512
    NUM_HEADS = 4
    NUM_LAYERS = 4
    FF_DIM = 1024
    DROPOUT = 0.1
    MAX_LEN = 100
    LEARNING_RATE = 0.0001
    N_EPOCHS = 30
    CLIP = 1.0
    MIN_FREQ = 2
    
    print("Loading data...")
    train_dataset = TranslationDataset(train_src_path, train_tgt_path)
    valid_dataset = TranslationDataset(valid_src_path, valid_tgt_path)
    test_dataset = TranslationDataset(test_src_path)
    
    print("Building vocabulary...")
    src_vocab = Vocab(min_freq=MIN_FREQ)
    src_vocab.build_vocab([sent for sent in train_dataset.src_data])
    
    tgt_vocab = Vocab(min_freq=MIN_FREQ)
    tgt_vocab.build_vocab([sent for sent in train_dataset.tgt_data])
    
    print(f"Source vocabulary size: {len(src_vocab)}")
    print(f"Target vocabulary size: {len(tgt_vocab)}")
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=lambda x: create_batch(x, src_vocab, tgt_vocab)
    )
    
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=lambda x: create_batch(x, src_vocab, tgt_vocab)
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        collate_fn=lambda x: create_batch(x, src_vocab)
    )
    
    print("Creating model...")
    model = Transformer(
        len(src_vocab),
        len(tgt_vocab),
        EMBED_DIM,
        NUM_LAYERS,
        NUM_HEADS,
        FF_DIM,
        MAX_LEN,
        DROPOUT
    ).to(device)
    
    print(f"The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")
    
    criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab.stoi[tgt_vocab.pad_token])
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)
    
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    
    best_valid_loss = float('inf')
    
    for epoch in range(N_EPOCHS):
        start_time = time.time()
        
        train_loss = train(model, train_loader, optimizer, criterion, CLIP)
        valid_loss = evaluate(model, valid_loader, criterion)
        
        scheduler.step(valid_loss)
        
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'transformer-best-model.pt')
            print(f"New best model saved!")
        
        print(f"Epoch: {epoch+1}/{N_EPOCHS} ｜ Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")
    
    model.load_state_dict(torch.load('transformer-best-model.pt'))
    
    test_translations = []
    
    print("Translating test set...")
    model.eval()
    with torch.no_grad():
        for batch in tqdm(test_loader):
            src = batch["src"].to(device)
            src_mask = batch["src_mask"].to(device)
            
            src_tokens = []
            for i in range(src.shape[1]):
                if src[0, i].item() != src_vocab.stoi[src_vocab.pad_token]:
                    src_tokens.append(src_vocab.itos[src[0, i].item()])
            
            if src_tokens[0] == src_vocab.sos_token:
                src_tokens = src_tokens[1:]
            if src_tokens[-1] == src_vocab.eos_token:
                src_tokens = src_tokens[:-1]
            
            translation = translate_sentence(model, src_tokens, src_vocab, tgt_vocab, device)
            test_translations.append(" ".join(translation))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for translation in test_translations:
            f.write(translation + '\n')
    
    print(f"Translations saved to {output_path}")

main()
