In [3]:
import pickle 

%matplotlib inline

from io import open

import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%matplotlib inline

import nltk
nltk.download('punkt')

from nltk.tokenize import word_tokenize

[nltk_data] Downloading package punkt to /Users/niejiayi/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


#Encoder

In [4]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=1, drop_prob=0):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, n_layers, dropout=drop_prob, batch_first=True)

    def forward(self, inputs, hidden):
        #Embed input words
        embedded = self.embedding(inputs)
        #Pass the embedded word vectors into LSTM and return all outputs
        output, hidden = self.lstm(embedded, hidden)
        return output, hidden

    def init_hidden(self, batch_size=1):
        return (torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device),
                torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device))


#Decoder

In [5]:
class LuongDecoder(nn.Module):
    def __init__(self, hidden_size, output_size, attention, n_layers=1, drop_prob=0.1):
        super(LuongDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.drop_prob = drop_prob

        #Our Attention Mechanism is defined in a separate class
        self.attention = attention

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.dropout = nn.Dropout(self.drop_prob)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
        self.classifier = nn.Linear(self.hidden_size*2, self.output_size)
    
    def forward(self, inputs, hidden, encoder_outputs):
        #Embed input words
        embedded = self.embedding(inputs).view(1,1,-1)
        embedded = self.dropout(embedded)

        #Passing previous output word (embedded) and hidden state into LSTM cell
        lstm_out, hidden = self.lstm(embedded, hidden)

        #Calculating Alignment Scores - see Attention class for the forward pass function
        alignment_scores = self.attention(hidden[0], encoder_outputs)
        #Softmaxing alignment scores to obtain Attention weights
        attn_weights = F.softmax(alignment_scores.view(1,-1), dim=1)

        #Multiplying Attention weights with encoder outputs to get context vector
        context_vector = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs)

        #Concatenating output from LSTM with context vector
        output = torch.cat((lstm_out, context_vector),-1)
        #Pass concatenated vector through Linear layer acting as a Classifier
        output = F.log_softmax(self.classifier(output[0]), dim=1)
        return output, hidden, attn_weights

In [6]:
class Attention(nn.Module):
    def __init__(self, hidden_size, method="dot"):
        super(Attention, self).__init__()
        self.method = method
        self.hidden_size = hidden_size

        #Defining the layers/weights required depending on alignment scoring method
        if method == "general":
            self.fc = nn.Linear(hidden_size, hidden_size, bias=False)

        elif method == "concat":
            self.fc = nn.Linear(hidden_size, hidden_size, bias=False)
            self.weight = nn.Parameter(torch.FloatTensor(1, hidden_size))
  
    def forward(self, decoder_hidden, encoder_outputs):
        if self.method == "dot":
          #For the dot scoring method, no weights or linear layers are involved
          return encoder_outputs.bmm(decoder_hidden.view(1,-1,1)).squeeze(-1)
    
        elif self.method == "general":
            #For general scoring, decoder hidden state is passed through linear layers to introduce a weight matrix
            out = self.fc(decoder_hidden)
            return encoder_outputs.bmm(out.view(1,-1,1)).squeeze(-1)

        elif self.method == "concat":
            #For concat scoring, decoder hidden state and encoder outputs are concatenated first
            out = torch.tanh(self.fc(decoder_hidden+encoder_outputs))
            return out.bmm(self.weight.unsqueeze(-1)).squeeze(-1)

#DEMO

In [8]:
#load dictionaries
path_to_dic = "../util/vocabulary_lstm/"
with open(path_to_dic+'en_word2index.pkl', 'rb') as f:
    loaded_en_w2i = pickle.load(f)
with open(path_to_dic+'en_index2word.pkl', 'rb') as f:
    loaded_en_i2w = pickle.load(f)

with open(path_to_dic+'fr_word2index.pkl', 'rb') as f:
    loaded_fr_w2i = pickle.load(f)
with open(path_to_dic+'fr_index2word.pkl', 'rb') as f:
    loaded_fr_i2w = pickle.load(f)


#load test samples
path_to_data = "../data/"
with open(path_to_data+'fr_text_preprocessed_test1000.txt', encoding='utf-8') as file3:
    test_articles = file3.readlines()

In [9]:
fr_inputs = []

#Converting French testing articles to their token indexes
for i in range(len(test_articles)):

    #tokenize first:
    fr_tokens = word_tokenize(test_articles[i])
    fr_inputs.append([token.lower() for token in fr_tokens] + ['_EOS'])
    
    #word to index
    fr_article= fr_inputs[i]
    fr_inputs[i] = [loaded_fr_w2i[word] if word in loaded_fr_w2i else loaded_fr_w2i['_UNK'] for word in fr_article]




In [10]:
print("There are ",len(fr_inputs), "testing samples")
print("Visualize one article for testing: " + " ".join([loaded_fr_i2w[x] for x in fr_inputs[10]]))

There are  1000 testing samples
Visualize one article for testing: separation pouvoirs notion plus celebree nen guere daussi confuse 1789 declaration droits lhomme citoyen notait article quune societe laquelle separation pouvoirs netait garantie point constitution deux siecles plus tard quasi-totalite constitutions adoptees pays lex-bloc sovietique reprenaient presque mot mot meme enonce faits nul pourtant jamais envisage trois pouvoirs executif judiciaire legislatif puisque cest deux quil sagissait puissent fonctionner facon vraiment autonome derriere terme separation cest fait plutot lidee dun equilibre dune balance pouvoirs lon envisagee paradoxe regimes moins democratiques entendu sappuyer xixe siecle lexpression separation pouvoirs _UNK exemple lindependance pouvoir executif mettre labri toute velleite controle parlementaire cela ete cas second empire france concentrons-nous present posant deux questions avons-nous besoin dune separation dun equilibre pouvoirs oui selon quelles mo

In [None]:
hidden_size = 512

# Define the path to the saved model
model_path = "../weights/model_epoch_5_lstm.pth"

# Load the saved model
checkpoint = torch.load(model_path)

# Extract the encoder and decoder state dictionaries from the loaded checkpoint
encoder_state_dict = checkpoint["encoder"]
decoder_state_dict = checkpoint["decoder"]

# Create the encoder and decoder models with the same architecture as the saved models
encoder = EncoderLSTM(len(loaded_fr_i2w), hidden_size).to(device)
attn = Attention(hidden_size,"concat")
decoder = LuongDecoder(hidden_size,len(loaded_en_i2w),attn).to(device)

# Load the state dictionaries into the models
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

# Set the models to evaluation mode
encoder.eval()
decoder.eval()


In [None]:
path_to_evaluation = "../evaluation/baseline_lstm/"
with open(path_to_evaluation+"demo.txt", "w") as f:
  for idx in range(0,len(fr_inputs)):
      h = encoder.init_hidden()
      inp = torch.tensor(fr_inputs[idx]).unsqueeze(0).to(device)
      encoder_outputs, h = encoder(inp,h)

      decoder_input = torch.tensor([loaded_en_w2i['_SOS']],device=device)
      decoder_hidden = h
      output = []
      attentions = []
      cnt = 0
      while True:
          decoder_output, decoder_hidden, attn_weights = decoder(decoder_input, decoder_hidden, encoder_outputs)
          _, top_index = decoder_output.topk(2)
          decoder_input = torch.tensor([top_index[0][1].item()],device=device)
          #If the decoder output is the End Of Sentence token, stop decoding process
          if top_index[0][1].item() == loaded_fr_w2i["_EOS"] or cnt >=25:
              break
          output.append(top_index[0][1].item())
          attentions.append(attn_weights.squeeze().cpu().detach().numpy())
          cnt+=1

      summary = ' '.join([loaded_en_i2w[x] for x in output])
      print(summary)
      f.write(summary + "\n")


In [None]:
with open(path_to_evaluation+'demo.txt', 'r') as file:
    lines = file.readlines()
    lines = [line.strip() for line in lines]  # Remove trailing newline characters
print(lines[0])

in paris of a simple term are now department , is a simple term , where he was the victim . the government . .
