In [1]:
import time
import torch 
import pickle

import matplotlib.pyplot as plt 
import matplotlib.ticker as ticker
from torch import nn 
from tqdm import tqdm
from torchtext.vocab import Vectors
from torchtext.datasets import TranslationDataset
from torchtext.data import Field, BucketIterator, Iterator

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

MAX_LENGTH = 220
BATCH_SIZE = 128
CLIP = 5
EMB_SIZE = 500
NUM_LAYERS = 1

cuda


In [3]:
def create_mask(src, idx):
    mask = (src != idx).unsqueeze(2)
    return mask

class EncoderRNN(nn.Module):    
    def __init__(self, input_size, hidden_size, TEXT):
        super(EncoderRNN,self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, EMB_SIZE).to(device)
        self.gru = nn.GRU(EMB_SIZE, hidden_size, bidirectional = True, batch_first = True).to(device)
        self.fc = nn.Linear(2*hidden_size, hidden_size).to(device)
        self.layer_norm = nn.LayerNorm(hidden_size).to(device)

    def forward(self, x, src_len):
        embedded = self.embedding(x)
        #in order to avoid <pad>
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths = src_len, batch_first = True)
        packed_output, hidden = self.gru(packed_embedded)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first = True)
        output = self.fc(output) 
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        output = self.layer_norm(output)
        hidden = self.layer_norm(hidden)
        return output, hidden

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, TEXT):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, EMB_SIZE).to(device)
        self.gru = nn.GRU(hidden_size + EMB_SIZE, hidden_size, batch_first = True).to(device)
        self.attn_scores1 = nn.Linear(hidden_size + MAX_LENGTH, hidden_size, bias = True).to(device)
        self.attn_scores2 = nn.Linear(hidden_size, 1, bias = False).to(device)
        self.fc = nn.Linear(hidden_size, output_size).to(device)
        self.softmax = nn.Softmax(dim = 1).to(device)
        self.logsoftmax = nn.LogSoftmax(dim = 1).to(device)
        self.layer_norm = nn.LayerNorm(hidden_size).to(device)

    def forward(self, x, hidden, encoder_hidden_states, coverage, mask):
        embedded = self.embedding(x)
        coverage_vec = torch.cat((hidden, coverage), dim = 1)

        scores = self.attn_scores2(torch.tanh(torch.add(self.attn_scores1(coverage_vec).unsqueeze(1), encoder_hidden_states)))
        attn_weights = scores.masked_fill(mask == 0, -1e10)
        attn_weights = self.softmax(attn_weights)

        coverage[:, :attn_weights.shape[1]] += attn_weights.squeeze(2)

        context_vec = torch.matmul(attn_weights.permute(0, 2, 1), encoder_hidden_states)
        attn_hidden = torch.cat((torch.relu(embedded.unsqueeze(dim = 1)), context_vec), dim = 2)

        output, hidden = self.gru(attn_hidden, hidden.unsqueeze(dim = 0))
        output = self.layer_norm(output)
        output = self.fc(output)
        output = self.logsoftmax(output.squeeze(dim = 1))
        return output, self.layer_norm(hidden.squeeze(0)), attn_weights, coverage

In [4]:
class Model(nn.Module):
    def __init__(self, encoder, decoder):
        super(Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, src_len, trg):
        outputs, hidden = self.encoder(src, src_len)
        mask = create_mask(src, SRC_PAD_IDX)
        coverage = torch.zeros(src.shape[0], MAX_LENGTH).to(device)
        decoder_input = trg[:, 0]
        decoder_outputs = torch.zeros(trg.shape[0], output_size, trg.shape[1]).to(device)

        for k in range(1, trg.shape[1]):
            output, hidden, _, coverage = self.decoder(decoder_input, hidden, outputs, coverage, mask)
            decoder_input = trg[:, k]
            decoder_outputs[:, :, k] = output

        return decoder_outputs

In [5]:
SRC = Field(init_token = '<SOS>', eos_token = '<EOS>', lower = True, sequential = True, pad_token = "<PAD>", batch_first = True, include_lengths = True)
TRG = Field(init_token = '<SOS>', eos_token = '<EOS>', lower = True, sequential = True, pad_token = "<PAD>", batch_first = True)

In [6]:
train_data, valid_data = TranslationDataset.splits(path= "../data/", train = "train", validation = "dev", test = None, exts = (".hi", ".en"), fields = (SRC, TRG))

In [7]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

In [8]:
train_iter, valid_iter = BucketIterator.splits((train_data, valid_data), batch_size=BATCH_SIZE, sort_key=lambda x: len(x.src), shuffle=True, sort_within_batch = True, device = device)

input_size = len(SRC.vocab)
output_size = len(TRG.vocab)
hidden_size = 500

SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.NLLLoss(ignore_index = TRG_PAD_IDX)

encoder = EncoderRNN(input_size, hidden_size, SRC).to(device)
decoder = DecoderRNN(hidden_size, output_size, TRG).to(device)
model = Model(encoder, decoder).to(device)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 3)
epochs = 100

In [9]:
def train(iterator, model, optimizer, clip, epoch):
    model.train()
    epoch_loss = 0
    loop = tqdm(iterator)
    for _, batch in enumerate(loop):
        src, src_len = batch.src
        trg = batch.trg
        optimizer.zero_grad()
        decoder_outputs = model(src, src_len, trg)
        loss = criterion(decoder_outputs[:, :, 1:], trg[:, 1:])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

        loop.set_description('Epoch {}/{}'.format(epoch + 1, epochs))
        loop.set_postfix(loss=loss.item())
        
    return epoch_loss/len(iterator)


In [10]:
def validate(iterator, model):
    model.eval()
    with torch.no_grad():
        loss = 0
        for _, batch in enumerate(iterator):
            src, src_len = batch.src
            trg = batch.trg
            decoder_outputs = model(src, src_len, trg)
            loss = criterion(decoder_outputs[:, :, 1:], trg[:, 1:])
            loss += loss.item()
        return loss/len(iterator)

In [23]:
def translate(sentence, model, replace_with_src, phrase_table):
    model.eval()
    tokens = sentence.lower().strip().split()

    tokens = [SRC.init_token] + tokens + [SRC.eos_token]
    src_indices = [SRC.vocab.stoi[token.strip()] for token in tokens]

    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    src_len = torch.LongTensor([len(src_indices)]).to(device)

    coverage = torch.zeros(1, MAX_LENGTH).to(device)

    with torch.no_grad():
        outputs, hidden = model.encoder(src_tensor, src_len)

        mask = create_mask(src_tensor, SRC_PAD_IDX)
        trg_indices = [TRG.vocab.stoi[TRG.init_token]]

        attns = torch.zeros(MAX_LENGTH, len(src_indices)).to(device)
        replace_words = []
        for i in range(MAX_LENGTH):
            decoder_input = torch.LongTensor([trg_indices[-1]]).to(device)
            output,  hidden, attn, coverage = model.decoder(decoder_input, hidden, outputs, coverage,mask)
            attns[i, :] = attn.squeeze(0).squeeze(1)
            pred_token = output.argmax(dim = 1).item()
            if pred_token == TRG.vocab.stoi[TRG.unk_token]:
                if replace_with_src:
                    idx = attn.argmax(axis = 1).item()
                    src_token = SRC.vocab.itos[src_indices[idx]]
                    if src_token in phrase_table:
                        src_token = phrase_table[src_token][0][-1]
                    replace_words.append((i, src_token))
                    
            trg_indices.append(pred_token)
            if pred_token == TRG.vocab.stoi[TRG.eos_token]:
                break

        trg_tokens = [TRG.vocab.itos[j] for j in trg_indices]

    return trg_tokens[1:], attns, replace_words


In [33]:
def translate_sentences(epoch, model, replace_with_src, phrase_table):
    f1 = open("/home/akshay.goindani/NLA/project/models/unk_result"+str(epoch)+".txt", "w")
    f2 = open("/home/akshay.goindani/NLA/project/models/phrase_table_result"+str(epoch)+".txt", "w")
    with open("/home/akshay.goindani/NLA/project/data/test.hi") as f:
        for line in f:
            translation, _, replace_words = translate(line.strip(), model, replace_with_src, phrase_table)
            f1.write(" ".join(translation[:-1]).strip() + "\n")
            for i in replace_words:
                translation[i[0]] = i[1].strip()
            f2.write(" ".join(translation[:-1]).strip() + "\n")
    f1.close()
    f2.close()
    
    
phrase_table = pickle.load(open("../data/threshold_0.0_phrase-table.pkl", "rb"))

In [13]:
best = 1e18
for epoch in range(epochs):
    train_loss = train(train_iter, model, optimizer, CLIP, epoch)
    valid_loss = validate(valid_iter, model)
    print("Train Loss:", train_loss)
    print("Validation Loss:", valid_loss.item())
    if valid_loss < best:
        best = valid_loss
        checkpoint = {
            "epoch": epoch + 1,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "train_loss": train_loss,
            "valid_loss": valid_loss
        }
        translate_sentences(epoch + 1, model, True, phrase_table)
        torch.save(checkpoint, "/home/akshay.goindani/NLA/project/models/model"+str(epoch+1)+".tar")
        early_stop = 0
    else:
        early_stop += 1
    scheduler.step(valid_loss)
    if early_stop > 7:
        print('Epoch %d: early stopping' % (epoch + 1))
        break


Epoch 1/100: 100%|██████████| 363/363 [01:52<00:00,  2.81it/s, loss=4.21]


Train Loss: 4.8819980109033505
Validation Loss: 0.8256821036338806


Epoch 2/100: 100%|██████████| 363/363 [01:59<00:00,  2.64it/s, loss=3.52]


Train Loss: 3.243202854779141
Validation Loss: 0.7586845755577087


Epoch 3/100: 100%|██████████| 363/363 [01:57<00:00,  2.88it/s, loss=3.25]


Train Loss: 2.356930606621356
Validation Loss: 0.7348169088363647


Epoch 4/100: 100%|██████████| 363/363 [01:56<00:00,  2.30it/s, loss=2.93] 


Train Loss: 1.763206873550888
Validation Loss: 0.7334088683128357


Epoch 5/100: 100%|██████████| 363/363 [01:56<00:00,  3.47it/s, loss=0.784]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.3721112557186568
Validation Loss: 0.7413994669914246


Epoch 6/100: 100%|██████████| 363/363 [02:00<00:00,  3.38it/s, loss=2.08] 
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.0941263524297185
Validation Loss: 0.7737724781036377


Epoch 7/100: 100%|██████████| 363/363 [01:57<00:00,  3.14it/s, loss=0.544]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 0.9011978786182141
Validation Loss: 0.7824774980545044


Epoch 8/100: 100%|██████████| 363/363 [01:58<00:00,  2.93it/s, loss=0.556]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 0.7238291568294701
Validation Loss: 0.8128000497817993


Epoch 9/100: 100%|██████████| 363/363 [01:58<00:00,  2.74it/s, loss=0.936] 
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 0.4483032576616638
Validation Loss: 0.8157902956008911


Epoch 10/100: 100%|██████████| 363/363 [01:55<00:00,  3.48it/s, loss=0.0543]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 0.32858830058213434
Validation Loss: 0.8254267573356628


Epoch 11/100: 100%|██████████| 363/363 [01:59<00:00,  3.48it/s, loss=0.0474]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 0.27195435809612933
Validation Loss: 0.8382109999656677


Epoch 12/100: 100%|██████████| 363/363 [01:55<00:00,  3.70it/s, loss=0.0882]


Train Loss: 0.23139938787036363
Validation Loss: 0.8515743017196655
Epoch 12: early stopping


In [30]:
checkpoint = torch.load("../models/model4.tar")
encoder = EncoderRNN(input_size, hidden_size, SRC).to(device)
decoder = DecoderRNN(hidden_size, output_size, TRG).to(device)
model = Model(encoder, decoder).to(device)

model.load_state_dict(checkpoint["model"])


<All keys matched successfully>

In [31]:
print(len(phrase_table))
translate_sentences(4, model, True, phrase_table)

1023352
