In [1]:
import time
import torch 
import pickle
import matplotlib.pyplot as plt 
import matplotlib.ticker as ticker
from torch import nn 
from tqdm import tqdm
from torchtext.vocab import Vectors
from torchtext.datasets import TranslationDataset
from torchtext.data import Field, BucketIterator, Iterator

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

MAX_LENGTH = 220
BATCH_SIZE = 128
CLIP = 1
EMB_SIZE = 300
NUM_LAYERS = 1

cuda


In [17]:
def create_mask(src, idx):
    mask = (src != idx).unsqueeze(2)
    return mask

class EncoderRNN(nn.Module):    
    def __init__(self, input_size, hidden_size, TEXT):
        super(EncoderRNN,self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, EMB_SIZE).to(device)
        self.gru = nn.GRU(EMB_SIZE, hidden_size, bidirectional = True, batch_first = True).to(device)
        self.fc = nn.Linear(2*hidden_size, hidden_size).to(device)

    def forward(self, x, src_len):
        embedded = self.embedding(x)
        #in order to avoid <pad>
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths = src_len, batch_first = True)
        packed_output, hidden = self.gru(packed_embedded)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first = True)
        output = self.fc(output) 
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))

        return output, hidden

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, TEXT):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, EMB_SIZE).to(device)
        self.gru = nn.GRU(hidden_size + EMB_SIZE, hidden_size, batch_first = True).to(device)
        self.attn_scores1 = nn.Linear(hidden_size + MAX_LENGTH, hidden_size, bias = True).to(device)
        self.attn_scores2 = nn.Linear(hidden_size, 1, bias = False).to(device)
        self.fc = nn.Linear(hidden_size, output_size).to(device)
        self.softmax = nn.Softmax(dim = 1).to(device)
        self.logsoftmax = nn.LogSoftmax(dim = 1).to(device)

    def forward(self, x, hidden, encoder_hidden_states, coverage, mask):
        embedded = self.embedding(x)
        coverage_vec = torch.cat((hidden, coverage), dim = 1)

        scores = self.attn_scores2(torch.tanh(torch.add(self.attn_scores1(coverage_vec).unsqueeze(1), encoder_hidden_states)))
        attn_weights = scores.masked_fill(mask == 0, -1e10)
        attn_weights = self.softmax(attn_weights)

        coverage[:, :attn_weights.shape[1]] += attn_weights.squeeze(2)

        context_vec = torch.matmul(attn_weights.permute(0, 2, 1), encoder_hidden_states)
        attn_hidden = torch.cat((torch.relu(embedded.unsqueeze(dim = 1)), context_vec), dim = 2)

        output, hidden = self.gru(attn_hidden, hidden.unsqueeze(dim = 0))
        output = self.fc(output)
        output = self.logsoftmax(output.squeeze(dim = 1))
        return output, hidden.squeeze(0), attn_weights, coverage

In [18]:
class Model(nn.Module):
    def __init__(self, encoder, decoder):
        super(Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, src_len, trg):
        outputs, hidden = self.encoder(src, src_len)
        mask = create_mask(src, SRC_PAD_IDX)
        coverage = torch.zeros(src.shape[0], MAX_LENGTH).to(device)
        decoder_input = trg[:, 0]
        decoder_outputs = torch.zeros(trg.shape[0], output_size, trg.shape[1]).to(device)

        for k in range(1, trg.shape[1]):
            output, hidden, _, coverage = self.decoder(decoder_input, hidden, outputs, coverage, mask)
            decoder_input = trg[:, k]
            decoder_outputs[:, :, k] = output

        return decoder_outputs

In [19]:
SRC = Field(init_token = '<SOS>', eos_token = '<EOS>', lower = True, sequential = True, pad_token = "<PAD>", batch_first = True, include_lengths = True)
TRG = Field(init_token = '<SOS>', eos_token = '<EOS>', lower = True, sequential = True, pad_token = "<PAD>", batch_first = True)

In [20]:
train_data, valid_data = TranslationDataset.splits(path= "./project/data/", train = "train", validation = "dev", test = None, exts = (".hi", ".en"), fields = (SRC, TRG))

In [21]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

In [22]:
train_iter, valid_iter = BucketIterator.splits((train_data, valid_data), batch_size=BATCH_SIZE, sort_key=lambda x: len(x.src), shuffle=True, sort_within_batch = True, device = device)

input_size = len(SRC.vocab)
output_size = len(TRG.vocab)
hidden_size = 300

SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
criterion = nn.NLLLoss(ignore_index = TRG_PAD_IDX)

encoder = EncoderRNN(input_size, hidden_size, SRC).to(device)
decoder = DecoderRNN(hidden_size, output_size, TRG).to(device)
model = Model(encoder, decoder).to(device)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 3)
epochs = 100

In [23]:
def train(iterator, model, optimizer, clip, epoch):
    model.train()
    epoch_loss = 0
    loop = tqdm(iterator)
    for _, batch in enumerate(loop):
        src, src_len = batch.src
        trg = batch.trg
        optimizer.zero_grad()
        decoder_outputs = model(src, src_len, trg)
        loss = criterion(decoder_outputs[:, :, 1:], trg[:, 1:])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

        loop.set_description('Epoch {}/{}'.format(epoch + 1, epochs))
        loop.set_postfix(loss=loss.item())
        
    return epoch_loss/len(iterator)


In [24]:
def validate(iterator, model):
    model.eval()
    with torch.no_grad():
        loss = 0
        for _, batch in enumerate(iterator):
            src, src_len = batch.src
            trg = batch.trg
            decoder_outputs = model(src, src_len, trg)
            loss = criterion(decoder_outputs[:, :, 1:], trg[:, 1:])
            loss += loss.item()
        return loss/len(iterator)

In [25]:
def translate(sentence, model, replace_with_src, phrase_table):
    model.eval()
    tokens = sentence.lower().strip().split()

    tokens = [SRC.init_token] + tokens + [SRC.eos_token]
    src_indices = [SRC.vocab.stoi[token.strip()] for token in tokens]

    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    src_len = torch.LongTensor([len(src_indices)]).to(device)

    coverage = torch.zeros(1, MAX_LENGTH).to(device)

    with torch.no_grad():
        outputs, hidden = model.encoder(src_tensor, src_len)

        mask = create_mask(src_tensor, SRC_PAD_IDX)
        trg_indices = [TRG.vocab.stoi[TRG.init_token]]

        attns = torch.zeros(MAX_LENGTH, len(src_indices)).to(device)
        replace_words = []
        for i in range(MAX_LENGTH):
            decoder_input = torch.LongTensor([trg_indices[-1]]).to(device)
            output,  hidden, attn, coverage = model.decoder(decoder_input, hidden, outputs, coverage,mask)
            attns[i, :] = attn.squeeze(0).squeeze(1)
            pred_token = output.argmax(dim = 1).item()
            if pred_token == TRG.vocab.stoi[TRG.unk_token]:
                if replace_with_src:
                    idx = attn.argmax(axis = 1).item()
                    src_token = SRC.vocab.itos[src_indices[idx]]
                    if src_token in phrase_table:
                        src_token = phrase_table[src_token][0]
                    replace_words.append((i, src_token))
                    
            trg_indices.append(pred_token)
            if pred_token == TRG.vocab.stoi[TRG.eos_token]:
                break

        trg_tokens = [TRG.vocab.itos[j] for j in trg_indices]

    return trg_tokens[1:], attns, replace_words

In [26]:
def translate_sentences(epoch, model, replace_with_src, phrase_table):
    f1 = open("/home/akshay.goindani/NLA/project/models/unk_result"+str(epoch)+".txt", "w")
    f2 = open("/home/akshay.goindani/NLA/project/models/phrase_table_result"+str(epoch)+".txt", "w")
    with open("/home/akshay.goindani/NLA/project/data/test.hi") as f:
        for line in f:
            translation, _, replace_words = translate(line.strip(), model, replace_with_src, phrase_table)
            f1.write(" ".join(translation[:-1]).strip() + "\n")
            for i in replace_words:
                translation[i[0]] = i[1].strip()
            f2.write(" ".join(translation[:-1]).strip() + "\n")
    f1.close()
    f2.close()
    
phrase_table = create_phrase_table()

In [27]:
best = 1e18
for epoch in range(epochs):
    train_loss = train(train_iter, model, optimizer, CLIP, epoch)
    valid_loss = validate(valid_iter, model)
    print("Train Loss:", train_loss)
    print("Validation Loss:", valid_loss.item())
    if valid_loss < best:
        best = valid_loss
        checkpoint = {
            "epoch": epoch + 1,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "train_loss":train_loss,
            "valid_loss": valid_loss
        }
        translate_sentences(epoch + 1, model, True, phrase_table)
        torch.save(checkpoint, "/home/akshay.goindani/NLA/project/models/model"+str(epoch+1)+".tar")
        early_stop = 0
    else:
        early_stop += 1
    scheduler.step(valid_loss)
    if early_stop > 7:
        print('Epoch %d: early stopping' % (epoch + 1))
        break


Epoch 1/100: 100%|██████████| 363/363 [01:30<00:00,  3.54it/s, loss=4.43]


Train Loss: 5.587334969155388
Validation Loss: 0.9049055576324463


Epoch 2/100: 100%|██████████| 363/363 [01:31<00:00,  3.29it/s, loss=4.57]


Train Loss: 4.224245642827562
Validation Loss: 0.8239840269088745


Epoch 3/100: 100%|██████████| 363/363 [01:30<00:00,  4.09it/s, loss=3.87]


Train Loss: 3.4956279733621054
Validation Loss: 0.7910586595535278


Epoch 4/100: 100%|██████████| 363/363 [01:32<00:00,  3.45it/s, loss=3.42]


Train Loss: 2.9541097403229433
Validation Loss: 0.7725210189819336


Epoch 5/100: 100%|██████████| 363/363 [01:35<00:00,  3.30it/s, loss=1.98]


Train Loss: 2.5181669995147664
Validation Loss: 0.7673913240432739


Epoch 6/100: 100%|██████████| 363/363 [01:33<00:00,  3.05it/s, loss=2.24] 
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 2.167250365757745
Validation Loss: 0.7727333903312683


Epoch 7/100: 100%|██████████| 363/363 [01:35<00:00,  4.23it/s, loss=1.42] 
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.876606619719303
Validation Loss: 0.7718324065208435


Epoch 8/100: 100%|██████████| 363/363 [01:35<00:00,  2.81it/s, loss=2.64] 
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.6461994051112288
Validation Loss: 0.77394700050354


Epoch 9/100: 100%|██████████| 363/363 [01:32<00:00,  2.51it/s, loss=2.43] 
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.4549829670056167
Validation Loss: 0.786953866481781


Epoch 10/100: 100%|██████████| 363/363 [01:32<00:00,  3.44it/s, loss=0.646]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.1560942969046348
Validation Loss: 0.7835620641708374


Epoch 11/100: 100%|██████████| 363/363 [01:35<00:00,  3.24it/s, loss=0.756]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.0665269067957381
Validation Loss: 0.7863731384277344


Epoch 12/100: 100%|██████████| 363/363 [01:35<00:00,  3.67it/s, loss=0.695]
  0%|          | 0/363 [00:00<?, ?it/s]

Train Loss: 1.0171985656187226
Validation Loss: 0.7905864715576172


Epoch 13/100: 100%|██████████| 363/363 [01:32<00:00,  4.93it/s, loss=0.956]


Train Loss: 0.9783977206909296
Validation Loss: 0.7941483855247498
Epoch 13: early stopping


In [13]:
def create_phrase_table():
    dic = {}
    with open("./project/data/phrase-table.hi-en.onmt") as f:
        for line in f:
            temp = line.strip().split("|||")
            src = temp[0].strip()
            trg = temp[1].strip()
            if src not in dic:
                dic[src] = [trg]
            else:
                dic[src].append(trg)
                
    return dic
