In [2]:
import torch
from tqdm.notebook import tqdm
from IPython.display import clear_output
from torch.utils.data import Dataset, DataLoader
from torchtext.vocab import build_vocab_from_iterator
from torch import nn
from IPython.display import clear_output
import numpy as np
from torch.nn.utils.rnn import pad_sequence

from torchtext.vocab import vocab

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [4]:
class LangDataset(Dataset):
    def __init__(self, source_file_de, vocab_de, max_len, source_file_en=None, vocab_en=None, cut=False):
        self.text_de = []
        with open(source_file_de) as file:
            for line in file.readlines():
                self.text_de.append(np.array(line.split(' ')))
        self.text_de = np.array(self.text_de, dtype=type(self.text_de[0]))
        if cut:
            new_ind = np.random.choice(self.text_de.shape[0], self.text_de.shape[0] // 2)
            self.text_de = self.text_de[new_ind]
        self.text_en = None
        
        self.specials = ['<pad>', '<bos>', '<eos>', '<unk>']
        
        self.vocab_de = vocab_de
        self.itos_de = self.vocab_de.get_itos()
        self.vocab_en = None
        self.itos_en = None
        
        self.pad_index = self.vocab_de['<pad>']
        self.bos_index = self.vocab_de['<bos>']
        self.eos_index = self.vocab_de['<eos>']
        self.unk_index = self.vocab_de['<unk>']
        self.max_len = max_len
        
        if source_file_en is not None:
            self.text_en = []
            with open(source_file_en) as file:
                for line in file.readlines():
                    self.text_en.append(np.array(line.split(' ')))
            self.text_en = np.array(self.text_en, dtype=type(self.text_en[0]))
            if cut:
                self.text_en = self.text_en[new_ind]
            self.vocab_en = vocab_en
            self.itos_en = self.vocab_en.get_itos()

    def __len__(self):
        return len(self.text_de)

    def str_to_idx(self, text, lng='de'):
        if lng == 'de':
            return [self.vocab_de[word] for word in text]
        return [self.vocab_en[word] for word in text]

    def idx_to_str(self, idx, lng='de'):
        if lng == 'de':
            return [self.itos_de[index] for index in idx]
        return [self.itos_en[index] for index in idx]

    def encode(self, chars, lng='de'):
        chars = ['<bos>'] + list(chars) + ['<eos>']
        return self.str_to_idx(chars, lng)

    def decode(self, idx, lng='de'):
        chars = self.idx_to_str(idx, lng)
        return ' '.join(char for char in chars if char not in self.specials)

    def __getitem__(self, item):
        encoded_de = self.encode(self.text_de[item])
        
        if self.text_en is not None:
            encoded_en = self.encode(self.text_en[item], lng='en')
            return encoded_de, encoded_en
        
        return encoded_de

In [14]:
def training_epoch(model, optimizer, criterion, train_loader, tqdm_desc):
    train_loss = 0.0
    model.train()
    
    for de_text, en_text in tqdm(train_loader, desc=tqdm_desc):
        de_text = de_text.to(device)
        en_text = en_text.to(device)
        
        tgt_input = en_text[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(de_text, tgt_input)
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)
        src_padding_mask = src_padding_mask.to(device)
        tgt_padding_mask = tgt_padding_mask.to(device)

        optimizer.zero_grad()
        logits = model(de_text, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = en_text[:, 1:]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

    train_loss /= len(list(train_loader))
    return train_loss


@torch.no_grad()
def validation_epoch(model, criterion, test_loader, tqdm_desc):
    test_loss = 0.0
    model.eval()
    
    for de_text, en_text in tqdm(test_loader, desc=tqdm_desc):
        de_text = de_text.to(device)
        en_text = en_text.to(device)
        
        tgt_input = en_text[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(de_text, tgt_input)
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)
        src_padding_mask = src_padding_mask.to(device)
        tgt_padding_mask = tgt_padding_mask.to(device)
        
        logits = model(de_text, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = en_text[:, 1:]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))

        test_loss += loss.item()

    test_loss /= len(list(test_loader))
    return test_loss


def train(model, optimizer, scheduler, criterion, train_loader, test_loader, num_epochs):
    train_losses = []
    test_losses = []

    for epoch in range(1, num_epochs + 1):
        train_loss = training_epoch(
            model, optimizer, criterion, train_loader,
            tqdm_desc=f'Training {epoch}/{num_epochs}'
        )
        test_loss = validation_epoch(
            model, criterion, test_loader,
            tqdm_desc=f'Validating {epoch}/{num_epochs}'
        )

        if scheduler is not None:
            scheduler.step()

        train_losses += [train_loss]
        test_losses += [test_loss]
        # torch.save(model.state_dict(), "weights.pt")

    return train_losses, test_losses


def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask).to(device)
    tokens = torch.tensor([start_symbol]).unsqueeze(0).to(device)
    for i in range(max_len - 1):
        tgt_mask = generate_square_subsequent_mask(tokens.shape[1]).type(torch.bool).to(device)
        
        out = model.decode(tokens, memory, tgt_mask)
        prob = model.linear(out[:, -1])
        new_token = prob.argmax(dim=1)

        tokens = torch.cat([tokens, new_token.unsqueeze(0)], dim=1)
        if new_token.item() == en_vocab['<eos>']:
            break
    return tokens


def translate(model, src, start_symbol, vocab):
    model.eval()
    result = ""
    
    for line in tqdm(src):
        line = torch.Tensor(line).unsqueeze(0).to(device)
        max_len = line.shape[1]
        src_mask = torch.zeros((line.shape[1], line.shape[1])).type(torch.bool).to(device)
        
        trans_line = greedy_decode(model, line, src_mask, max_len + 5, start_symbol)
        
        result += vocab.decode(trans_line.reshape(-1), lng='en') + '\n'
    return result


def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool)
    PAD_IDX = 0
    src_padding_mask = (src == PAD_IDX)
    tgt_padding_mask = (tgt == PAD_IDX)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(torch.Tensor(src_sample))
        tgt_batch.append(torch.Tensor(tgt_sample))

    return pad_sequence(src_batch, batch_first=True, padding_value=0).type(torch.LongTensor), pad_sequence(tgt_batch, batch_first=True, padding_value=0).type(torch.LongTensor)

In [6]:
def create_vocab(files: list, t=1):
    text = []
    max_len = 0
    for file_source in files:
        with open(file_source) as file:
            for line in file.readlines():
                line = line.replace('\n', '')
                
                new_line = line.split(' ')
                text.append(new_line)
                if len(new_line) > max_len:
                    max_len = len(new_line)
            
    vocab = build_vocab_from_iterator(text, specials=['<pad>', '<bos>', '<eos>', '<unk>'], min_freq=t)
    vocab.set_default_index(vocab['<unk>'])
    return vocab, max_len + 2

In [7]:
class MyEmbedder(nn.Module):
    def __init__(self, vocab_size, emb_size, pad_id, max_len=5000):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_size, padding_idx=pad_id)
        
        den = torch.exp(- torch.arange(0, emb_size, 2) * np.log(10000) / emb_size)
        pos = torch.arange(0, max_len).unsqueeze(1)
        pos_embedding = torch.zeros((max_len, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(0)

        self.register_buffer('pos_embedding', pos_embedding)
        
        self.emb_size = emb_size
        
        self.dropout = nn.Dropout(p=0.1)

    def forward(self, tokens):
        embeddings = self.embedding(tokens.long()) * np.sqrt(self.emb_size)
        out = embeddings + self.pos_embedding[:, :embeddings.size(1)]
        return self.dropout(out)

    
class MyTransformer(nn.Module):
    def __init__(self, src_vocab, trg_vocab, nhead, pad_id, max_len, d_model=512, num_enc_lay=3, num_dec_lay=3, dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.embedding_src = MyEmbedder(src_vocab, d_model, pad_id)
        self.embedding_trg = MyEmbedder(trg_vocab, d_model, pad_id)
        
        self.trans = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_enc_lay,
                                       num_decoder_layers=num_dec_lay, dim_feedforward=dim_feedforward,
                                       dropout=dropout, batch_first=True)
        self.linear = nn.Linear(d_model, trg_vocab)
        
        self.nhead = nhead
        
    def forward(self, src, trg, src_mask, trg_mask, src_padding_mask, trg_padding_mask, memory_key_padding_mask):
        embeddings_src = self.embedding_src(src)
        embeddings_trg = self.embedding_trg(trg)
        
        trans_out = self.trans(embeddings_src, embeddings_trg, src_mask, trg_mask, None,
                                src_padding_mask, trg_padding_mask, memory_key_padding_mask)
        return self.linear(trans_out)
    
    def encode(self, src, src_mask):
        return self.trans.encoder(self.embedding_src(src), src_mask)

    def decode(self, tgt, memory, tgt_mask):
        return self.trans.decoder(self.embedding_trg(tgt), memory, tgt_mask)

In [20]:
en_vocab, max_en = create_vocab(['train.de-en.en'], t=1)
de_vocab, max_de = create_vocab(['train.de-en.de'], t=8)

max_str_size = max(max_en, max_de)
max_str_size

82

In [21]:
train_Dataset = LangDataset('train.de-en.de', de_vocab, max_str_size, 'train.de-en.en', en_vocab, cut=False)

val_Dataset = LangDataset('val.de-en.de', de_vocab, max_str_size, 'val.de-en.en', en_vocab)

test_Dataset = LangDataset('test1.de-en.de', de_vocab, max_str_size)

In [22]:
train_loader = DataLoader(train_Dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

val_loader = DataLoader(val_Dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)

test_de_loader = DataLoader(test_Dataset, batch_size=1, shuffle=False)

In [23]:
num_epochs = 10

model = MyTransformer(len(de_vocab), len(en_vocab), 8, train_Dataset.pad_index, max_str_size, 512, 3, 3, 512, 0.1)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)
criterion = torch.nn.CrossEntropyLoss(ignore_index=train_Dataset.pad_index, label_smoothing=0.1)

In [None]:
train_losses, test_losses, train_accuracies, test_accuracies = train(
    model, optimizer, None, criterion, train_loader, val_loader, num_epochs
)

In [19]:
# import gc

# torch.cuda.empty_cache()
# gc.collect()

24

In [25]:
result = translate(model, test_Dataset, en_vocab['<bos>'], train_Dataset)
with open('test.de-en.en', 'w') as file:
    file.write(result)

  0%|          | 0/2998 [00:00<?, ?it/s]