Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concatenative feature embeddings #147

Merged
merged 8 commits into from Aug 5, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
124 changes: 82 additions & 42 deletions onmt/Models.py
@@ -1,3 +1,4 @@
from __future__ import division
import torch
import torch.nn as nn
from torch.autograd import Variable
Expand All @@ -7,56 +8,82 @@
from onmt.modules.Gate import ContextGateFactory
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
import math


class Embeddings(nn.Module):
def __init__(self, opt, dicts, feature_dicts=None):
super(Embeddings, self).__init__()
self.positional_encoding = opt.position_encoding
if self.positional_encoding:
self.pe = self.make_positional_encodings(opt.word_vec_size, 5000)
if len(opt.gpus) > 0:
self.pe.cuda()

self.word_vec_size = opt.word_vec_size

super(Embeddings, self).__init__()
self.word_lut = nn.Embedding(dicts.size(),
opt.word_vec_size,
padding_idx=onmt.Constants.PAD)
# Word embeddings.
self.dropout = nn.Dropout(p=opt.dropout)
self.feature_dicts = feature_dicts
# Feature embeddings.
if self.feature_dicts:
self.feature_luts = nn.ModuleList([
nn.Embedding(feature_dict.size(),
opt.feature_vec_size,
padding_idx=onmt.Constants.PAD)
for feature_dict in feature_dicts])

# MLP on features and words.
self.activation = nn.ReLU()
self.linear = onmt.modules.BottleLinear(
opt.word_vec_size +
len(feature_dicts) * opt.feature_vec_size,
opt.word_vec_size)
self.dropout = nn.Dropout(p=opt.dropout)

self.feat_merge = opt.feat_merge

feat_exp = opt.feat_vec_exponent

# vocab_sizes: sequence of vocab sizes for words and each feature
vocab_sizes = [dicts.size()]
# emb_sizes
emb_sizes = [opt.word_vec_size]
if feature_dicts:
vocab_sizes.extend(feat_dict.size() for feat_dict in feature_dicts)
if opt.feat_merge == 'concat':
# When concatenating, derive the size of each feature's
# embedding from the number of values the embedding
# takes
feat_sizes = (int(feat_dict.size() ** feat_exp)
for feat_dict in feature_dicts)
else:
# sum feature merge (for now, the same as the old option
# for merging features in OpenNMT-py)
feat_sizes = (opt.feat_vec_size for feat_dict in feature_dicts)
# apply a layer of mlp to get it down to the right dim
self.mlp = nn.Sequential(onmt.modules.BottleLinear(
opt.word_vec_size +
len(feature_dicts) * opt.feat_vec_size,
opt.word_vec_size),
nn.ReLU())
emb_sizes.extend(feat_sizes)

self.emb_luts = nn.ModuleList([
nn.Embedding(vocab, dim,
padding_idx=onmt.Constants.PAD)
for vocab, dim in zip(vocab_sizes, emb_sizes)])

@property
def word_lut(self):
return self.emb_luts[0]

@property
def embedding_size(self):
if self.feat_merge == 'concat' and len(self.emb_luts) > 1:
return sum(emb_lut.embedding_dim
for emb_lut in self.emb_luts.children())
else:
self.feature_luts = nn.ModuleList([])
return self.word_lut.embedding_dim

def make_positional_encodings(self, dim, max_len):
pe = torch.FloatTensor(max_len, 1, dim).fill_(0)
for i in range(dim):
for j in range(max_len):
k = float(j) / (10000.0 ** (2.0*i / float(dim)))
pe[j, 0, i] = math.cos(k) if i % 2 == 1 else math.sin(k)
return pe
pe = torch.arange(0, max_len).unsqueeze(1).expand(max_len, dim)
div_term = 1 / torch.pow(10000, torch.arange(0, dim * 2, 2) / dim)
pe = pe * div_term.expand_as(pe)
pe[:, 0::2] = torch.sin(pe[:, 0::2])
pe[:, 1::2] = torch.cos(pe[:, 1::2])
return pe.unsqueeze(1)

def load_pretrained_vectors(self, emb_file):
if emb_file is not None:
pretrained = torch.load(emb_file)
self.word_lut.weight.data.copy_(pretrained)

def merge(self, features):
if self.feat_merge == 'concat':
return torch.cat(features, 2)
else:
return self.mlp(torch.cat(features, 2))

def forward(self, src_input):
"""
Embed the words or utilize features and MLP.
Expand All @@ -65,22 +92,34 @@ def forward(self, src_input):
src_input (LongTensor): len x batch x nfeat

Return:
emb (FloatTensor): len x batch x input_size
emb (FloatTensor): len x batch x emb_size
emb_size is word_vec_size if there are no
features or the merge action is sum.
It is the sum of all feature dimensions
if the merge action is concatenate.
"""
word = self.word_lut(src_input[:, :, 0])
emb = word
if self.feature_dicts:
features = [feature_lut(src_input[:, :, j+1])
for j, feature_lut in enumerate(self.feature_luts)]
in_length, in_batch, nfeat = src_input.size()
aeq(nfeat, len(self.emb_luts))

# Apply one MLP layer.
emb = self.activation(
self.linear(torch.cat([word] + features, -1)))
if len(self.emb_luts) == 1:
emb = self.word_lut(src_input.squeeze(2))
else:
feat_inputs = (feat.squeeze(2)
for feat in src_input.split(1, dim=2))
features = [lut(feat)
for lut, feat in zip(self.emb_luts, feat_inputs)]
emb = self.merge(features)

if self.positional_encoding:
emb = emb + Variable(self.pe[:emb.size(0), :1, :emb.size(2)]
.expand_as(emb))
emb = self.dropout(emb)

out_length, out_batch, emb_size = emb.size()
aeq(in_length, out_length)
aeq(in_length, out_length)
aeq(emb_size, self.embedding_size)

return emb


Expand All @@ -104,11 +143,12 @@ def __init__(self, opt, dicts, feature_dicts=None):

# Size of the encoder RNN.
self.hidden_size = opt.rnn_size // self.num_directions
input_size = opt.word_vec_size

super(Encoder, self).__init__()
self.embeddings = Embeddings(opt, dicts, feature_dicts)

input_size = self.embeddings.embedding_size

# The Encoder RNN.
self.encoder_layer = opt.encoder_layer

Expand Down
11 changes: 9 additions & 2 deletions test/test_models.py
Expand Up @@ -16,9 +16,16 @@
help='Size of LSTM hidden states')
parser.add_argument('-word_vec_size', type=int, default=500,
help='Word embedding sizes')
parser.add_argument('-feature_vec_size', type=int, default=100,
parser.add_argument('-feat_vec_size', type=int, default=100,
help='Feature vec sizes')

parser.add_argument('-feat_merge', type=str, default='concat',
choices=['concat', 'sum'],
help='Merge action for the features embeddings')
parser.add_argument('-feat_vec_exponent', type=float, default=0.7,
help="""When features embedding sizes are not set and
using -feat_merge concat, their dimension will be set
to N^feat_vec_exponent where N is the number of values
the feature takes""")
parser.add_argument('-input_feed', type=int, default=1,
help="""Feed the context vector at each time step as
additional input (via concatenation with the word
Expand Down
11 changes: 9 additions & 2 deletions train.py
Expand Up @@ -37,9 +37,16 @@
help='Size of LSTM hidden states')
parser.add_argument('-word_vec_size', type=int, default=500,
help='Word embedding sizes')
parser.add_argument('-feature_vec_size', type=int, default=100,
parser.add_argument('-feat_vec_size', type=int, default=20,
help='Feature vec sizes')

parser.add_argument('-feat_merge', type=str, default='concat',
choices=['concat', 'sum'],
help='Merge action for the features embeddings')
parser.add_argument('-feat_vec_exponent', type=float, default=0.7,
help="""When features embedding sizes are not set and
using -feat_merge concat, their dimension will be set
to N^feat_vec_exponent where N is the number of values
the feature takes""")
parser.add_argument('-input_feed', type=int, default=1,
help="""Feed the context vector at each time step as
additional input (via concatenation with the word
Expand Down