Skip to content

Commit

Permalink
Added neural summarization components, minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximumEntropy committed Feb 28, 2017
1 parent d2922bf commit 0a8e5bc
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 25 deletions.
32 changes: 30 additions & 2 deletions data_utils.py
Expand Up @@ -71,20 +71,48 @@ def construct_vocab(lines, vocab_size):

def read_nmt_data(src, trg=None):
"""Read data from files."""
src_lines = [line.strip().split() for line in open(src, 'r')]
print 'Reading source data ...'
src_lines = []
with open(src, 'r') as f:
for ind, line in enumerate(f):
if ind % 100000 == 0:
print ind
src_lines.append(line.strip().split())

print 'Constructing source vocabulary ...'
src_word2id, src_id2word = construct_vocab(src_lines, 30000)

src = {'data': src_lines, 'word2id': src_word2id, 'id2word': src_id2word}
del src_lines

if trg is not None:
trg_lines = [line.strip().split() for line in open(trg, 'r')]
print 'Reading target data ...'
trg_lines = []
with open(trg, 'r') as f:
for line in f:
trg_lines.append(line.strip().split())

print 'Constructing target vocabulary ...'
trg_word2id, trg_id2word = construct_vocab(trg_lines, 30000)

trg = {'data': trg_lines, 'word2id': trg_word2id, 'id2word': trg_id2word}
else:
trg = None

return src, trg


def read_summarization_data(src, trg):
"""Read data from files."""
src_lines = [line.strip().split() for line in open(src, 'r')]
trg_lines = [line.strip().split() for line in open(trg, 'r')]
word2id, id2word = construct_vocab(src_lines + trg_lines, 30000)
src = {'data': src_lines, 'word2id': word2id, 'id2word': id2word}
trg = {'data': trg_lines, 'word2id': word2id, 'id2word': id2word}

return src, trg


def get_minibatch(
lines, word2ind, index, batch_size,
max_len, add_start=True, add_end=True
Expand Down
63 changes: 42 additions & 21 deletions evaluate.py
Expand Up @@ -71,6 +71,31 @@ def get_bleu_moses(hypotheses, reference):
return pipe.stdout.read()


def decode_minibatch(
config,
model,
input_lines_src,
input_lines_trg,
output_lines_trg_gold
):
"""Decode a minibatch."""
for i in xrange(config['data']['max_trg_length']):

decoder_logit = model(input_lines_src, input_lines_trg)
word_probs = model.decode(decoder_logit)
decoder_argmax = word_probs.data.cpu().numpy().argmax(axis=-1)
next_preds = Variable(
torch.from_numpy(decoder_argmax[:, -1])
).cuda()

input_lines_trg = torch.cat(
(input_lines_trg, next_preds.unsqueeze(1)),
1
)

return input_lines_trg


def evaluate_model(
model, src, src_test, trg,
trg_test, config, src_valid=None, trg_valid=None,
Expand All @@ -81,49 +106,50 @@ def evaluate_model(
ground_truths = []
for j in xrange(0, len(src_test['data']), config['data']['batch_size']):

# Get source minibatch
input_lines_src, output_lines_src, lens_src, mask_src = get_minibatch(
src_test['data'], src['word2id'], j, config['data']['batch_size'],
config['data']['max_src_length'], add_start=True, add_end=True
)

input_lines_trg_gold, output_lines_trg_gold, lens_src, mask_src = get_minibatch(
trg_test['data'], trg['word2id'], j, config['data']['batch_size'],
config['data']['max_src_length'], add_start=True, add_end=True
# Get target minibatch
input_lines_trg_gold, output_lines_trg_gold, lens_src, mask_src = (
get_minibatch(
trg_test['data'], trg['word2id'], j,
config['data']['batch_size'], config['data']['max_trg_length'],
add_start=True, add_end=True
)
)

# Initialize target with <s> for every sentence
input_lines_trg = Variable(torch.LongTensor(
[
[trg['word2id']['<s>']]
for i in xrange(input_lines_src.size(0))
]
)).cuda()

for i in xrange(config['data']['max_src_length']):

decoder_logit = model(input_lines_src, input_lines_trg)
word_probs = model.decode(decoder_logit)
decoder_argmax = word_probs.data.cpu().numpy().argmax(axis=-1)
next_preds = Variable(
torch.from_numpy(decoder_argmax[:, -1])
).cuda()

input_lines_trg = torch.cat(
(input_lines_trg, next_preds.unsqueeze(1)),
1
)
# Decode a minibatch greedily __TODO__ add beam search decoding
input_lines_trg = decode_minibatch(
config, model, input_lines_src,
input_lines_trg, output_lines_trg_gold
)

# Copy minibatch outputs to cpu and convert ids to words
input_lines_trg = input_lines_trg.data.cpu().numpy()
input_lines_trg = [
[trg['id2word'][x] for x in line]
for line in input_lines_trg
]

# Do the same for gold sentences
output_lines_trg_gold = output_lines_trg_gold.data.cpu().numpy()
output_lines_trg_gold = [
[trg['id2word'][x] for x in line]
for line in output_lines_trg_gold
]

# Process outputs
for sentence_pred, sentence_real, sentence_real_src in zip(
input_lines_trg,
output_lines_trg_gold,
Expand All @@ -148,11 +174,6 @@ def evaluate_model(
print '--------------------------------------'
ground_truths.append(['<s>'] + sentence_real[:index + 1])

if '</s>' in sentence_real_src:
index = sentence_real_src.index('</s>')
else:
index = len(sentence_real_src)

return get_bleu(preds, ground_truths)


Expand Down
157 changes: 155 additions & 2 deletions model.py
Expand Up @@ -770,12 +770,18 @@ def __init__(
dropout=self.dropout
)

self.decoder = LSTMAttentionDot(
self.decoder1 = LSTMAttentionDot(
trg_emb_dim,
trg_hidden_dim,
batch_first=True
)

self.decoder2 = LSTMAttentionDot(
trg_hidden_dim,
trg_hidden_dim,
batch_first=True
)

self.encoder2decoder = nn.Linear(
self.src_hidden_dim * self.num_directions,
trg_hidden_dim
Expand Down Expand Up @@ -830,12 +836,20 @@ def forward(self, input_src, input_trg, trg_mask=None, ctx_mask=None):

ctx = src_h.transpose(0, 1)

trg_h, (_, _) = self.decoder(
trg_h, (_, _) = self.decoder1(
trg_emb,
(decoder_init_state, c_t),
ctx,
ctx_mask
)

trg_h, (_, _) = self.decoder2(
trg_h,
(decoder_init_state, c_t),
ctx,
ctx_mask
)

trg_h_reshape = trg_h.contiguous().view(
trg_h.size()[0] * trg_h.size()[1],
trg_h.size()[2]
Expand All @@ -858,6 +872,145 @@ def decode(self, logits):
return word_probs


class Seq2SeqAttentionSharedEmbedding(nn.Module):
"""Container module with an encoder, deocder, embeddings."""

def __init__(
self,
emb_dim,
vocab_size,
src_hidden_dim,
trg_hidden_dim,
ctx_hidden_dim,
attention_mode,
batch_size,
pad_token_src,
pad_token_trg,
bidirectional=True,
nlayers=2,
nlayers_trg=2,
dropout=0.,
):
"""Initialize model."""
super(Seq2SeqAttentionSharedEmbedding, self).__init__()
self.vocab_size = vocab_size
self.emb_dim = emb_dim
self.src_hidden_dim = src_hidden_dim
self.trg_hidden_dim = trg_hidden_dim
self.ctx_hidden_dim = ctx_hidden_dim
self.attention_mode = attention_mode
self.batch_size = batch_size
self.bidirectional = bidirectional
self.nlayers = nlayers
self.dropout = dropout
self.num_directions = 2 if bidirectional else 1
self.pad_token_src = pad_token_src
self.pad_token_trg = pad_token_trg

self.embedding = nn.Embedding(
vocab_size,
emb_dim,
self.pad_token_src
)

self.src_hidden_dim = src_hidden_dim // 2 \
if self.bidirectional else src_hidden_dim
self.encoder = nn.LSTM(
emb_dim,
self.src_hidden_dim,
nlayers,
bidirectional=bidirectional,
batch_first=True,
dropout=self.dropout
)

self.decoder = LSTMAttentionDot(
emb_dim,
trg_hidden_dim,
batch_first=True
)

self.encoder2decoder = nn.Linear(
self.src_hidden_dim * self.num_directions,
trg_hidden_dim
)
self.decoder2vocab = nn.Linear(trg_hidden_dim, vocab_size)

self.init_weights()

def init_weights(self):
"""Initialize weights."""
initrange = 0.1
self.embedding.weight.data.uniform_(-initrange, initrange)
self.encoder2decoder.bias.data.fill_(0)
self.decoder2vocab.bias.data.fill_(0)

def get_state(self, input):
"""Get cell states and hidden states."""
batch_size = input.size(0) \
if self.encoder.batch_first else input.size(1)
h0_encoder = Variable(torch.zeros(
self.encoder.num_layers * self.num_directions,
batch_size,
self.src_hidden_dim
), requires_grad=False)
c0_encoder = Variable(torch.zeros(
self.encoder.num_layers * self.num_directions,
batch_size,
self.src_hidden_dim
), requires_grad=False)

return h0_encoder.cuda(), c0_encoder.cuda()

def forward(self, input_src, input_trg, trg_mask=None, ctx_mask=None):
"""Propogate input through the network."""
src_emb = self.embedding(input_src)
trg_emb = self.embedding(input_trg)

self.h0_encoder, self.c0_encoder = self.get_state(input_src)

src_h, (src_h_t, src_c_t) = self.encoder(
src_emb, (self.h0_encoder, self.c0_encoder)
)

if self.bidirectional:
h_t = torch.cat((src_h_t[-1], src_h_t[-2]), 1)
c_t = torch.cat((src_c_t[-1], src_c_t[-2]), 1)
else:
h_t = src_h_t[-1]
c_t = src_c_t[-1]
decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t))

ctx = src_h.transpose(0, 1)

trg_h, (_, _) = self.decoder(
trg_emb,
(decoder_init_state, c_t),
ctx,
ctx_mask
)
trg_h_reshape = trg_h.contiguous().view(
trg_h.size()[0] * trg_h.size()[1],
trg_h.size()[2]
)
decoder_logit = self.decoder2vocab(trg_h_reshape)
decoder_logit = decoder_logit.view(
trg_h.size()[0],
trg_h.size()[1],
decoder_logit.size()[1]
)
return decoder_logit

def decode(self, logits):
"""Return probability distribution over words."""
logits_reshape = logits.view(-1, self.vocab_size)
word_probs = F.softmax(logits_reshape)
word_probs = word_probs.view(
logits.size()[0], logits.size()[1], logits.size()[2]
)
return word_probs


class Seq2SeqFastAttention(nn.Module):
"""Container module with an encoder, deocder, embeddings."""

Expand Down
22 changes: 22 additions & 0 deletions nmt.py
Expand Up @@ -217,11 +217,33 @@
logging.info('Real : %s ' % (' '.join(sentence_real)))
logging.info('===============================================')

if j % config['management']['checkpoint_freq'] == 0:

logging.info('Evaluating model ...')
bleu = evaluate_model(
model, src, src_test, trg,
trg_test, config, verbose=False,
metric='bleu',
)

logging.info('Epoch : %d Minibatch : %d : BLEU : %.5f ' % (i, j, bleu))

logging.info('Saving model ...')

torch.save(
model.state_dict(),
open(os.path.join(
save_dir,
experiment_name + '__epoch_%d__minibatch_%d' % (i, j) + '.model'), 'wb'
)
)

bleu = evaluate_model(
model, src, src_test, trg,
trg_test, config, verbose=False,
metric='bleu',
)

logging.info('Epoch : %d : BLEU : %.5f ' % (i, bleu))

torch.save(
Expand Down

0 comments on commit 0a8e5bc

Please sign in to comment.