In [1]:
# Yernar Shambayev, DL-2
# Решить задачу перевода с помощью механизма внимания
# 1. Возьмите англо-русскую пару фраз (https://www.manythings.org/anki/)
# 2. Обучите на них seq2seq with attention (на основе скалярного произведения, на основе MLP)
# Оцените качество

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import spacy
from spacy.lang.en import English
from spacy.lang.nl import Dutch
import random
from collections import Counter

In [2]:
# Возьмем для разнообразия англо-голландский корпус

with open("nld.txt","r+") as file:
    holland = [x[:-1] for x in file.readlines()]
en = []
nld = []
for line in holland:
    en.append(line.split("\t")[0])
    nld.append(line.split("\t")[1])

training_examples = 10000
spacy_en = English()
spacy_nld = Dutch()

en_words = Counter()
nld_words = Counter()
en_inputs = []
nld_inputs = []

# Токенизация
for i in range(training_examples):
    en_tokens = spacy_en(en[i])
    nld_tokens = spacy_nld(nld[i])
    if len(en_tokens)==0 or len(nld_tokens)==0:
        continue
        
    for token in en_tokens:
        en_words.update([token.text.lower()])
    en_inputs.append([token.text.lower() for token in en_tokens] + ['_EOS'])
    for token in nld_tokens:
        nld_words.update([token.text.lower()])
    nld_inputs.append([token.text.lower() for token in nld_tokens] + ['_EOS'])

en_words = ['_SOS','_EOS','_UNK'] + sorted(en_words,key=en_words.get,reverse=True)
en_w2i = {o:i for i,o in enumerate(en_words)}
en_i2w = {i:o for i,o in enumerate(en_words)}
nld_words = ['_SOS','_EOS','_UNK'] + sorted(nld_words,key=nld_words.get,reverse=True)
nld_w2i = {o:i for i,o in enumerate(nld_words)}
nld_i2w = {i:o for i,o in enumerate(nld_words)}

for i in range(len(en_inputs)):
    en_sentence = en_inputs[i]
    nld_sentence = nld_inputs[i]
    en_inputs[i] = [en_w2i[word] for word in en_sentence]
    nld_inputs[i] = [nld_w2i[word] for word in nld_sentence]
    
print(len(en_words), len(nld_words))

2623 3321


In [3]:
# Классы для скалярного произведения (Dot_Decoder), на основе MLP (MLP_Decoder)

class EncoderLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1, drop_prob=0):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers, dropout=drop_prob, batch_first=True)

    def forward(self, inputs, hidden):
        embedded = self.embedding(inputs)
        output, hidden = self.lstm(embedded, hidden)
        return output, hidden

    def init_hidden(self, batch_size=1):
        return (torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device),
                torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device))

class MLP_Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=1, drop_prob=0.1):
        super(BahdanauDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.drop_prob = drop_prob

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)

        self.fc_hidden = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.fc_encoder = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
        self.dropout = nn.Dropout(self.drop_prob)
        self.lstm = nn.LSTM(self.hidden_size*2, self.hidden_size, batch_first=True)
        self.classifier = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, inputs, hidden, encoder_outputs):
        encoder_outputs = encoder_outputs.squeeze()

        embedded = self.embedding(inputs).view(1, -1)
        embedded = self.dropout(embedded)

        x = torch.tanh(self.fc_hidden(hidden[0])+self.fc_encoder(encoder_outputs))
        alignment_scores = x.bmm(self.weight.unsqueeze(2))  

        attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)

        context_vector = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded, context_vector[0]), 1).unsqueeze(0)
        output, hidden = self.lstm(output, hidden)
        output = F.log_softmax(self.classifier(output[0]), dim=1)
        return output, hidden, attn_weights

class Dot_Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, attention, n_layers=1, drop_prob=0.1):
        super(Dot_Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.drop_prob = drop_prob

        self.attention = attention

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.dropout = nn.Dropout(self.drop_prob)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
        self.classifier = nn.Linear(self.hidden_size*2, self.output_size)
    
    def forward(self, inputs, hidden, encoder_outputs):
        embedded = self.embedding(inputs).view(1,1,-1)
        embedded = self.dropout(embedded)

        lstm_out, hidden = self.lstm(embedded, hidden)

        alignment_scores = self.attention(lstm_out,encoder_outputs)
        attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)

        context_vector = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs)

        output = torch.cat((lstm_out, context_vector),-1)
        output = F.log_softmax(self.classifier(output[0]), dim=1)
        return output, hidden, attn_weights

class Attention(nn.Module):
    def __init__(self, hidden_size, method="dot"):
        super(Attention, self).__init__()
        self.method = method
        self.hidden_size = hidden_size

        if method == "general":
            self.fc = nn.Linear(hidden_size, hidden_size, bias=False)

        elif method == "concat":
            self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
            self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
  
    def forward(self, decoder_hidden, encoder_outputs):
        if self.method == "dot":
          return encoder_outputs.bmm(decoder_hidden.view(1,-1,1)).squeeze(-1)
    
        elif self.method == "general":
            out = self.fc(decoder_hidden)
            return encoder_outputs.bmm(out.view(1,-1,1)).squeeze(-1)

        elif self.method == "concat":
            out = torch.tanh(self.fc(decoder_hidden+encoder_outputs))
            return out.bmm(self.weight.unsqueeze(-1)).squeeze(-1)

In [4]:
device = 'cpu'
hidden_size = 256
encoder = EncoderLSTM(len(en_words), hidden_size).to(device)
attn = Attention(hidden_size,"concat")
decoder = Dot_Decoder(hidden_size,len(nld_words),attn).to(device)

lr = 0.001
encoder_optimizer = optim.Adam(encoder.parameters(), lr=lr)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr)

In [5]:
# Обучение

EPOCHS = 10
teacher_forcing_prob = 0.5
encoder.train()
decoder.train()

for epoch in range(1, EPOCHS+1):
    avg_loss = 0.
    for i, sentence in enumerate(en_inputs):
        loss = 0.
        h = encoder.init_hidden()
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        inp = torch.tensor(sentence).unsqueeze(0).to(device)
        encoder_outputs, h = encoder(inp,h)
        
        decoder_input = torch.tensor([en_w2i['_SOS']],device=device)
        decoder_hidden = h
        output = []
        teacher_forcing = True if random.random() < teacher_forcing_prob else False
        
        for ii in range(len(nld_inputs[i])):
            decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, decoder_hidden, encoder_outputs)
           
            top_value, top_index = decoder_output.topk(1)
            if teacher_forcing:
                decoder_input = torch.tensor([nld_inputs[i][ii]],device=device)
            else:
                decoder_input = torch.tensor([top_index.item()],device=device)
            output.append(top_index.item())
            
            loss += F.nll_loss(decoder_output.view(1,-1), torch.tensor([nld_inputs[i][ii]],device=device))
        loss = loss/len(nld_inputs[i])
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
        avg_loss += loss.item()/len(en_inputs)
    print(f'Эпоха: {epoch}, потери: {avg_loss}')

Эпоха: 1, потери: 3.1886707433819774
Эпоха: 2, потери: 2.1686106621755212
Эпоха: 3, потери: 1.6647057241053291
Эпоха: 4, потери: 1.328551610455762
Эпоха: 5, потери: 1.0579434494096769
Эпоха: 6, потери: 0.8824573078539446
Эпоха: 7, потери: 0.7620667414881795
Эпоха: 8, потери: 0.6655963821187009
Эпоха: 9, потери: 0.6352166065754693
Эпоха: 10, потери: 0.5822985703860916


In [6]:
# Проверка качества

encoder.eval()
decoder.eval()

i = random.randint(0,len(en_inputs)-1)
h = encoder.init_hidden()
inp = torch.tensor(en_inputs[i]).unsqueeze(0).to(device)
encoder_outputs, h = encoder(inp,h)

decoder_input = torch.tensor([en_w2i['_SOS']],device=device)
decoder_hidden = h
output = []
attentions = []
while True:
    decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, decoder_hidden, encoder_outputs)
    _, top_index = decoder_output.topk(1)
    decoder_input = torch.tensor([top_index.item()],device=device)
    
    if top_index.item() == nld_w2i["_EOS"]:
        break
    output.append(top_index.item())
    attentions.append(attn_weights.squeeze().cpu().detach().numpy())
    
print("Входная фраза на английском языке: "+ " ".join([en_i2w[x] for x in en_inputs[i]]))
print("Предсказание: " + " ".join([nld_i2w[x] for x in output]))
print("Выходная фраза на голландском языке: " + " ".join([nld_i2w[x] for x in nld_inputs[i]]))

Входная фраза на английском языке: let 's stop here . _EOS
Предсказание: laten we hier stoppen .
Выходная фраза на голландском языке: laten we hier stoppen . _EOS
