From 0a8e5bcdc00efcec28f3b64ff4e289ba166c6ec8 Mon Sep 17 00:00:00 2001 From: Sandeep Subramanian Date: Mon, 27 Feb 2017 19:11:21 -0500 Subject: [PATCH] Added neural summarization components, minor refactor --- data_utils.py | 32 +++++++- evaluate.py | 63 ++++++++++----- model.py | 157 ++++++++++++++++++++++++++++++++++++- nmt.py | 22 ++++++ summarization.py | 196 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 445 insertions(+), 25 deletions(-) create mode 100644 summarization.py diff --git a/data_utils.py b/data_utils.py index 41d52e0..c060691 100644 --- a/data_utils.py +++ b/data_utils.py @@ -71,13 +71,30 @@ def construct_vocab(lines, vocab_size): def read_nmt_data(src, trg=None): """Read data from files.""" - src_lines = [line.strip().split() for line in open(src, 'r')] + print 'Reading source data ...' + src_lines = [] + with open(src, 'r') as f: + for ind, line in enumerate(f): + if ind % 100000 == 0: + print ind + src_lines.append(line.strip().split()) + + print 'Constructing source vocabulary ...' src_word2id, src_id2word = construct_vocab(src_lines, 30000) + src = {'data': src_lines, 'word2id': src_word2id, 'id2word': src_id2word} + del src_lines if trg is not None: - trg_lines = [line.strip().split() for line in open(trg, 'r')] + print 'Reading target data ...' + trg_lines = [] + with open(trg, 'r') as f: + for line in f: + trg_lines.append(line.strip().split()) + + print 'Constructing target vocabulary ...' trg_word2id, trg_id2word = construct_vocab(trg_lines, 30000) + trg = {'data': trg_lines, 'word2id': trg_word2id, 'id2word': trg_id2word} else: trg = None @@ -85,6 +102,17 @@ def read_nmt_data(src, trg=None): return src, trg +def read_summarization_data(src, trg): + """Read data from files.""" + src_lines = [line.strip().split() for line in open(src, 'r')] + trg_lines = [line.strip().split() for line in open(trg, 'r')] + word2id, id2word = construct_vocab(src_lines + trg_lines, 30000) + src = {'data': src_lines, 'word2id': word2id, 'id2word': id2word} + trg = {'data': trg_lines, 'word2id': word2id, 'id2word': id2word} + + return src, trg + + def get_minibatch( lines, word2ind, index, batch_size, max_len, add_start=True, add_end=True diff --git a/evaluate.py b/evaluate.py index 56be305..4e276fb 100644 --- a/evaluate.py +++ b/evaluate.py @@ -71,6 +71,31 @@ def get_bleu_moses(hypotheses, reference): return pipe.stdout.read() +def decode_minibatch( + config, + model, + input_lines_src, + input_lines_trg, + output_lines_trg_gold +): + """Decode a minibatch.""" + for i in xrange(config['data']['max_trg_length']): + + decoder_logit = model(input_lines_src, input_lines_trg) + word_probs = model.decode(decoder_logit) + decoder_argmax = word_probs.data.cpu().numpy().argmax(axis=-1) + next_preds = Variable( + torch.from_numpy(decoder_argmax[:, -1]) + ).cuda() + + input_lines_trg = torch.cat( + (input_lines_trg, next_preds.unsqueeze(1)), + 1 + ) + + return input_lines_trg + + def evaluate_model( model, src, src_test, trg, trg_test, config, src_valid=None, trg_valid=None, @@ -81,16 +106,22 @@ def evaluate_model( ground_truths = [] for j in xrange(0, len(src_test['data']), config['data']['batch_size']): + # Get source minibatch input_lines_src, output_lines_src, lens_src, mask_src = get_minibatch( src_test['data'], src['word2id'], j, config['data']['batch_size'], config['data']['max_src_length'], add_start=True, add_end=True ) - input_lines_trg_gold, output_lines_trg_gold, lens_src, mask_src = get_minibatch( - trg_test['data'], trg['word2id'], j, config['data']['batch_size'], - config['data']['max_src_length'], add_start=True, add_end=True + # Get target minibatch + input_lines_trg_gold, output_lines_trg_gold, lens_src, mask_src = ( + get_minibatch( + trg_test['data'], trg['word2id'], j, + config['data']['batch_size'], config['data']['max_trg_length'], + add_start=True, add_end=True + ) ) + # Initialize target with for every sentence input_lines_trg = Variable(torch.LongTensor( [ [trg['word2id']['']] @@ -98,32 +129,27 @@ def evaluate_model( ] )).cuda() - for i in xrange(config['data']['max_src_length']): - - decoder_logit = model(input_lines_src, input_lines_trg) - word_probs = model.decode(decoder_logit) - decoder_argmax = word_probs.data.cpu().numpy().argmax(axis=-1) - next_preds = Variable( - torch.from_numpy(decoder_argmax[:, -1]) - ).cuda() - - input_lines_trg = torch.cat( - (input_lines_trg, next_preds.unsqueeze(1)), - 1 - ) + # Decode a minibatch greedily __TODO__ add beam search decoding + input_lines_trg = decode_minibatch( + config, model, input_lines_src, + input_lines_trg, output_lines_trg_gold + ) + # Copy minibatch outputs to cpu and convert ids to words input_lines_trg = input_lines_trg.data.cpu().numpy() input_lines_trg = [ [trg['id2word'][x] for x in line] for line in input_lines_trg ] + # Do the same for gold sentences output_lines_trg_gold = output_lines_trg_gold.data.cpu().numpy() output_lines_trg_gold = [ [trg['id2word'][x] for x in line] for line in output_lines_trg_gold ] + # Process outputs for sentence_pred, sentence_real, sentence_real_src in zip( input_lines_trg, output_lines_trg_gold, @@ -148,11 +174,6 @@ def evaluate_model( print '--------------------------------------' ground_truths.append([''] + sentence_real[:index + 1]) - if '' in sentence_real_src: - index = sentence_real_src.index('') - else: - index = len(sentence_real_src) - return get_bleu(preds, ground_truths) diff --git a/model.py b/model.py index fe38f11..83eeb37 100644 --- a/model.py +++ b/model.py @@ -770,12 +770,18 @@ def __init__( dropout=self.dropout ) - self.decoder = LSTMAttentionDot( + self.decoder1 = LSTMAttentionDot( trg_emb_dim, trg_hidden_dim, batch_first=True ) + self.decoder2 = LSTMAttentionDot( + trg_hidden_dim, + trg_hidden_dim, + batch_first=True + ) + self.encoder2decoder = nn.Linear( self.src_hidden_dim * self.num_directions, trg_hidden_dim @@ -830,12 +836,20 @@ def forward(self, input_src, input_trg, trg_mask=None, ctx_mask=None): ctx = src_h.transpose(0, 1) - trg_h, (_, _) = self.decoder( + trg_h, (_, _) = self.decoder1( trg_emb, (decoder_init_state, c_t), ctx, ctx_mask ) + + trg_h, (_, _) = self.decoder2( + trg_h, + (decoder_init_state, c_t), + ctx, + ctx_mask + ) + trg_h_reshape = trg_h.contiguous().view( trg_h.size()[0] * trg_h.size()[1], trg_h.size()[2] @@ -858,6 +872,145 @@ def decode(self, logits): return word_probs +class Seq2SeqAttentionSharedEmbedding(nn.Module): + """Container module with an encoder, deocder, embeddings.""" + + def __init__( + self, + emb_dim, + vocab_size, + src_hidden_dim, + trg_hidden_dim, + ctx_hidden_dim, + attention_mode, + batch_size, + pad_token_src, + pad_token_trg, + bidirectional=True, + nlayers=2, + nlayers_trg=2, + dropout=0., + ): + """Initialize model.""" + super(Seq2SeqAttentionSharedEmbedding, self).__init__() + self.vocab_size = vocab_size + self.emb_dim = emb_dim + self.src_hidden_dim = src_hidden_dim + self.trg_hidden_dim = trg_hidden_dim + self.ctx_hidden_dim = ctx_hidden_dim + self.attention_mode = attention_mode + self.batch_size = batch_size + self.bidirectional = bidirectional + self.nlayers = nlayers + self.dropout = dropout + self.num_directions = 2 if bidirectional else 1 + self.pad_token_src = pad_token_src + self.pad_token_trg = pad_token_trg + + self.embedding = nn.Embedding( + vocab_size, + emb_dim, + self.pad_token_src + ) + + self.src_hidden_dim = src_hidden_dim // 2 \ + if self.bidirectional else src_hidden_dim + self.encoder = nn.LSTM( + emb_dim, + self.src_hidden_dim, + nlayers, + bidirectional=bidirectional, + batch_first=True, + dropout=self.dropout + ) + + self.decoder = LSTMAttentionDot( + emb_dim, + trg_hidden_dim, + batch_first=True + ) + + self.encoder2decoder = nn.Linear( + self.src_hidden_dim * self.num_directions, + trg_hidden_dim + ) + self.decoder2vocab = nn.Linear(trg_hidden_dim, vocab_size) + + self.init_weights() + + def init_weights(self): + """Initialize weights.""" + initrange = 0.1 + self.embedding.weight.data.uniform_(-initrange, initrange) + self.encoder2decoder.bias.data.fill_(0) + self.decoder2vocab.bias.data.fill_(0) + + def get_state(self, input): + """Get cell states and hidden states.""" + batch_size = input.size(0) \ + if self.encoder.batch_first else input.size(1) + h0_encoder = Variable(torch.zeros( + self.encoder.num_layers * self.num_directions, + batch_size, + self.src_hidden_dim + ), requires_grad=False) + c0_encoder = Variable(torch.zeros( + self.encoder.num_layers * self.num_directions, + batch_size, + self.src_hidden_dim + ), requires_grad=False) + + return h0_encoder.cuda(), c0_encoder.cuda() + + def forward(self, input_src, input_trg, trg_mask=None, ctx_mask=None): + """Propogate input through the network.""" + src_emb = self.embedding(input_src) + trg_emb = self.embedding(input_trg) + + self.h0_encoder, self.c0_encoder = self.get_state(input_src) + + src_h, (src_h_t, src_c_t) = self.encoder( + src_emb, (self.h0_encoder, self.c0_encoder) + ) + + if self.bidirectional: + h_t = torch.cat((src_h_t[-1], src_h_t[-2]), 1) + c_t = torch.cat((src_c_t[-1], src_c_t[-2]), 1) + else: + h_t = src_h_t[-1] + c_t = src_c_t[-1] + decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t)) + + ctx = src_h.transpose(0, 1) + + trg_h, (_, _) = self.decoder( + trg_emb, + (decoder_init_state, c_t), + ctx, + ctx_mask + ) + trg_h_reshape = trg_h.contiguous().view( + trg_h.size()[0] * trg_h.size()[1], + trg_h.size()[2] + ) + decoder_logit = self.decoder2vocab(trg_h_reshape) + decoder_logit = decoder_logit.view( + trg_h.size()[0], + trg_h.size()[1], + decoder_logit.size()[1] + ) + return decoder_logit + + def decode(self, logits): + """Return probability distribution over words.""" + logits_reshape = logits.view(-1, self.vocab_size) + word_probs = F.softmax(logits_reshape) + word_probs = word_probs.view( + logits.size()[0], logits.size()[1], logits.size()[2] + ) + return word_probs + + class Seq2SeqFastAttention(nn.Module): """Container module with an encoder, deocder, embeddings.""" diff --git a/nmt.py b/nmt.py index 6c701da..2bc18d7 100644 --- a/nmt.py +++ b/nmt.py @@ -217,11 +217,33 @@ logging.info('Real : %s ' % (' '.join(sentence_real))) logging.info('===============================================') + if j % config['management']['checkpoint_freq'] == 0: + + logging.info('Evaluating model ...') + bleu = evaluate_model( + model, src, src_test, trg, + trg_test, config, verbose=False, + metric='bleu', + ) + + logging.info('Epoch : %d Minibatch : %d : BLEU : %.5f ' % (i, j, bleu)) + + logging.info('Saving model ...') + + torch.save( + model.state_dict(), + open(os.path.join( + save_dir, + experiment_name + '__epoch_%d__minibatch_%d' % (i, j) + '.model'), 'wb' + ) + ) + bleu = evaluate_model( model, src, src_test, trg, trg_test, config, verbose=False, metric='bleu', ) + logging.info('Epoch : %d : BLEU : %.5f ' % (i, bleu)) torch.save( diff --git a/summarization.py b/summarization.py new file mode 100644 index 0000000..df55b2c --- /dev/null +++ b/summarization.py @@ -0,0 +1,196 @@ +#!/u/subramas/miniconda2/bin/python +"""Main script to run things""" +import sys + +sys.path.append('/u/subramas/Research/nmt-pytorch/') + +from data_utils import read_nmt_data, get_minibatch, read_config, hyperparam_string, read_summarization_data +from model import Seq2Seq, Seq2SeqAttention, Seq2SeqFastAttention, Seq2SeqAttentionSharedEmbedding +from evaluate import evaluate_model +import math +import numpy as np +import logging +import argparse +import os + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.autograd import Variable + +parser = argparse.ArgumentParser() +parser.add_argument( + "--config", + help="path to json config", + required=True +) +args = parser.parse_args() +config_file_path = args.config +config = read_config(config_file_path) +experiment_name = hyperparam_string(config) +save_dir = config['data']['save_dir'] +load_dir = config['data']['load_dir'] +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + filename='log/%s' % (experiment_name), + filemode='w' +) + +# define a new Handler to log to console as well +console = logging.StreamHandler() +# optional, set the logging level +console.setLevel(logging.INFO) +# set a format which is the same for console use +formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +# tell the handler to use this format +console.setFormatter(formatter) +# add the handler to the root logger +logging.getLogger('').addHandler(console) + + +print 'Reading data ...' + +src, trg = read_nmt_data( + src=config['data']['src'], + trg=config['data']['trg'] +) + +src_test, trg_test = read_nmt_data( + src=config['data']['test_src'], + trg=config['data']['test_trg'] +) + +batch_size = config['data']['batch_size'] +max_length = config['data']['max_src_length'] +vocab_size = len(src['word2id']) + +logging.info('Model Parameters : ') +logging.info('Task : %s ' % (config['data']['task'])) +logging.info('Model : %s ' % (config['model']['seq2seq'])) +logging.info('Language : %s ' % (config['model']['src_lang'])) +logging.info('Embedding Dim : %s' % (config['model']['dim_word_src'])) +logging.info('Source RNN Hidden Dim : %s' % (config['model']['dim'])) +logging.info('Target RNN Hidden Dim : %s' % (config['model']['dim'])) +logging.info('Source RNN Depth : %d ' % (config['model']['n_layers_src'])) +logging.info('Target RNN Depth : %d ' % (1)) +logging.info('Source RNN Bidirectional : %s' % (config['model']['bidirectional'])) +logging.info('Batch Size : %d ' % (config['model']['n_layers_trg'])) +logging.info('Optimizer : %s ' % (config['training']['optimizer'])) +logging.info('Learning Rate : %f ' % (config['training']['lrate'])) + +logging.info('Found %d words ' % (vocab_size)) + +weight_mask = torch.ones(vocab_size).cuda() +weight_mask[trg['word2id']['']] = 0 +loss_criterion = nn.CrossEntropyLoss(weight=weight_mask).cuda() + +model = Seq2SeqAttentionSharedEmbedding( + emb_dim=config['model']['dim_word_src'], + vocab_size=vocab_size, + src_hidden_dim=config['model']['dim'], + trg_hidden_dim=config['model']['dim'], + ctx_hidden_dim=config['model']['dim'], + attention_mode='dot', + batch_size=batch_size, + bidirectional=config['model']['bidirectional'], + pad_token_src=src['word2id'][''], + pad_token_trg=trg['word2id'][''], + nlayers=config['model']['n_layers_src'], + nlayers_trg=config['model']['n_layers_trg'], + dropout=0., +).cuda() + +if load_dir: + model.load_state_dict(torch.load( + open(load_dir) + )) + +bleu = evaluate_model( + model, src, src_test, trg, + trg_test, config, verbose=False, + metric='bleu', +) + +# __TODO__ Make this more flexible for other learning methods. +if config['training']['optimizer'] == 'adam': + lr = config['training']['lrate'] + optimizer = optim.Adam(model.parameters(), lr=lr) +elif config['training']['optimizer'] == 'adadelta': + optimizer = optim.Adadelta(model.parameters()) +elif config['training']['optimizer'] == 'sgd': + lr = config['training']['lrate'] + optimizer = optim.SGD(model.parameters(), lr=lr) +else: + raise NotImplementedError("Learning method not recommend for task") + +for i in xrange(1000): + losses = [] + for j in xrange(0, len(src['data']), batch_size): + + input_lines_src, _, lens_src, mask_src = get_minibatch( + src['data'], src['word2id'], j, + batch_size, max_length, add_start=True, add_end=True + ) + input_lines_trg, output_lines_trg, lens_trg, mask_trg = get_minibatch( + trg['data'], trg['word2id'], j, + batch_size, max_length, add_start=True, add_end=True + ) + + decoder_logit = model(input_lines_src, input_lines_trg) + optimizer.zero_grad() + + loss = loss_criterion( + decoder_logit.contiguous().view(-1, vocab_size), + output_lines_trg.view(-1) + ) + losses.append(loss.data[0]) + loss.backward() + optimizer.step() + + if j % config['management']['monitor_loss'] == 0: + logging.info('Epoch : %d Minibatch : %d Loss : %.5f' % ( + i, j, np.mean(losses)) + ) + losses = [] + + if ( + config['management']['print_samples'] and + j % config['management']['print_samples'] == 0 + ): + word_probs = model.decode( + decoder_logit + ).data.cpu().numpy().argmax(axis=-1) + + output_lines_trg = output_lines_trg.data.cpu().numpy() + for sentence_pred, sentence_real in zip( + word_probs[:5], output_lines_trg[:5] + ): + sentence_pred = [trg['id2word'][x] for x in sentence_pred] + sentence_real = [trg['id2word'][x] for x in sentence_real] + + if '' in sentence_real: + index = sentence_real.index('') + sentence_real = sentence_real[:index] + sentence_pred = sentence_pred[:index] + + logging.info('Predicted : %s ' % (' '.join(sentence_pred))) + logging.info('-----------------------------------------------') + logging.info('Real : %s ' % (' '.join(sentence_real))) + logging.info('===============================================') + + torch.save( + model.state_dict(), + open(os.path.join( + save_dir, + experiment_name + '__epoch_%d' % (i) + '.model'), 'wb' + ) + ) + + bleu = evaluate_model( + model, src, src_test, trg, + trg_test, config, verbose=False, + metric='bleu', + ) + logging.info('Epoch : %d : BLEU : %.5f ' % (i, bleu))