Skip to content

Commit

Permalink
add meta embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
king-menin committed Feb 17, 2019
1 parent 4efaa7c commit 8450294
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 316 deletions.
172 changes: 4 additions & 168 deletions modules/layers/embedders.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from modules.layers import bert_modeling
import torch
from gensim.models import KeyedVectors
import os
import codecs
import logging
import json
from torch import nn
from elmoformanylangs.modules.embedding_layer import EmbeddingLayer
from elmoformanylangs.frontend import Model


# TODO: add from_config to other embedders
class BertEmbedder(nn.Module):

# @property
Expand Down Expand Up @@ -66,25 +58,10 @@ def forward(self, *batch):
return self.bert_gamma * torch.sum(all_encoder_layers, dim=0)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()

def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True

def freeze_to(self, to=-1):
idx = 0
if to < 0:
to = len(self.model.encoder.layer) + to + 1
for idx in range(to):
for param in self.model.encoder.layer[idx].parameters():
param.requires_grad = False
print("Embeddings freezed to {}".format(to))
to = len(self.model.encoder.layer)
for idx in range(idx, to):
for param in self.model.encoder.layer[idx].parameters():
param.requires_grad = True
self.model.train()

def get_n_trainable_params(self):
pp = 0
Expand All @@ -98,8 +75,8 @@ def get_n_trainable_params(self):

@classmethod
def create(cls,
bert_config_file, init_checkpoint_pt, embedding_dim=768, use_cuda=True, bert_mode="weighted",
freeze=True):
bert_config_file, init_checkpoint_pt, embedding_dim=768, use_cuda=True,
bert_mode="weighted", freeze=True):
bert_config = bert_modeling.BertConfig.from_json_file(bert_config_file)
model = bert_modeling.BertModel(bert_config)
if use_cuda:
Expand All @@ -115,144 +92,3 @@ def create(cls,
if freeze:
model.freeze()
return model


class Word2VecEmbedder(nn.Module):
def __init__(self,
vocab_size,
embedding_dim=300):
super(Word2VecEmbedder, self).__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.model = nn.Embedding(vocab_size, embedding_dim, padding_idx=self.pad_id)
if normalize:
weight = self.embedding.weight
norms = weight.data.norm(2, 1)
if norms.dim() == 1:
norms = norms.unsqueeze(1)
weight.data.div_(norms.expand_as(weight.data))

if trainable:
self.embedding.weight.requires_grad = True
else:
self.embedding.weight.requires_grad = False
self.loaded = False
self.path = None
self.init_weights()

def init_weights(self):
for p in self.embedding.parameters():
torch.nn.init.xavier_normal(p)

def forward(self, *batch):
input_ids = batch[0]
return self.model(input_ids)

def load_gensim_word2vec(self, path, words, binary=False):
self.loaded = True
word_vectors = KeyedVectors.load_word2vec_format(path, binary=binary)
for word, idx in words.items():
if word in word_vectors:
if idx < self.vocab_size:
self.embedding.weight.data[idx].set_(torch.FloatTensor(word_vectors[word]))
return self

@classmethod
def create(cls, path, words, binary=False, embedding_dim=300, padding_idx=0, trainable=True, normalize=True):
model = cls(
vocab_size=len(words), embedding_dim=embedding_dim, padding_idx=padding_idx,
trainable=trainable, normalize=normalize)
model = model.load_gensim_word2vec(path, words, binary)
return model


class ElmoEmbedder(nn.Module):
def __init__(self, model, config, embedding_dim=1024, use_cuda=True, elmo_mode="avg"):
super(ElmoEmbedder, self).__init__()
self.model = model
self.embedding_dim = embedding_dim
self.model = model
self.use_cuda = use_cuda
self.config = config
self.elmo_mode = elmo_mode

if self.elmo_mode == "weighted":
self.elmo_weights = nn.Parameter(torch.FloatTensor(3, 1))
self.elmo_gamma = nn.Parameter(torch.FloatTensor(1, 1))

if use_cuda:
self.cuda()

self.init_weights()

def init_weights(self):
if self.elmo_mode == "weighted":
nn.init.xavier_normal(self.elmo_weights)
nn.init.xavier_normal(self.elmo_gamma)

def forward(self, *batch):
w, c, masks = batch[:3]
all_encoder_layers = self.model.forward(w, c, masks)
if self.elmo_mode == "avg":
return all_encoder_layers.mean(0)
elif self.bert_mode == "weighted":
all_encoder_layers = torch.stack([a * b for a, b in zip(all_encoder_layers, self.elmo_weights)])
return self.elmo_gamma * torch.sum(all_encoder_layers, dim=0)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

@classmethod
def create(
cls, model_dir, config_name, embedding_dim=1024, use_cuda=True, elmo_mode="avg", freeze=True):
with open(os.path.join(model_dir, config_name), 'r') as fin:
config = json.load(fin)
# For the model trained with character-based word encoder.
if config['token_embedder']['char_dim'] > 0:
char_lexicon = {}
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
for line in fpi:
tokens = line.strip().split('\t')
if len(tokens) == 1:
tokens.insert(0, '\u3000')
token, i = tokens
char_lexicon[token] = int(i)
char_emb_layer = EmbeddingLayer(
config['token_embedder']['char_dim'], char_lexicon, fix_emb=False, embs=None)
logging.info('char embedding size: ' +
str(len(char_emb_layer.word2id)))
else:
char_emb_layer = None

# For the model trained with word form word encoder.
if config['token_embedder']['word_dim'] > 0:
word_lexicon = {}
with codecs.open(os.path.join(model_dir, 'word.dic'), 'r', encoding='utf-8') as fpi:
for line in fpi:
tokens = line.strip().split('\t')
if len(tokens) == 1:
tokens.insert(0, '\u3000')
token, i = tokens
word_lexicon[token] = int(i)
word_emb_layer = EmbeddingLayer(
config['token_embedder']['word_dim'], word_lexicon, fix_emb=False, embs=None)
logging.info('word embedding size: ' +
str(len(word_emb_layer.word2id)))
else:
word_emb_layer = None

# instantiate the model
model = Model(config, word_emb_layer, char_emb_layer, use_cuda)

model.load_model(model_dir)

model.eval()
model = cls(model, config, embedding_dim, use_cuda, elmo_mode=elmo_mode)
if freeze:
model.freeze()
return model
164 changes: 23 additions & 141 deletions modules/layers/encoders.py
Original file line number Diff line number Diff line change
@@ -1,162 +1,44 @@
from torch import nn
import torch
from .embedders import BertEmbedder


class BertBiLSTMEncoder(nn.Module):

# @property
def get_config(self):
config = {
"name": "BertBiLSTMEncoder",
"params": {
"hidden_dim": self.hidden_dim,
"rnn_layers": self.rnn_layers,
"use_cuda": self.use_cuda,
"embeddings": self.embeddings.get_config()
}
}
return config

def __init__(self, embeddings,
hidden_dim=128, rnn_layers=1, use_cuda=True):
super(BertBiLSTMEncoder, self).__init__()
self.embeddings = embeddings
self.hidden_dim = hidden_dim
self.rnn_layers = rnn_layers
self.use_cuda = use_cuda
self.lstm = nn.LSTM(
self.embeddings.embedding_dim, hidden_dim // 2,
rnn_layers, batch_first=True, bidirectional=True)
self.hidden = None
if use_cuda:
self.cuda()
self.init_weights()
self.output_dim = hidden_dim

@classmethod
def from_config(cls, config):
if config["embeddings"]["name"] == "BertEmbedder":
embeddings = BertEmbedder.create(**config["embeddings"]["params"])
else:
raise NotImplemented("form_config is implemented only for BertEmbedder now :(")
return cls.create(embeddings, config["hidden_dim"], config["rnn_layers"], config["use_cuda"])

def init_weights(self):
# for p in self.lstm.parameters():
# nn.init.xavier_normal(p)
pass

def forward(self, batch):
input, input_mask = batch[0], batch[1]
output = self.embeddings(*batch)
# output = self.dropout(output)
lens = input_mask.sum(-1)
output = nn.utils.rnn.pack_padded_sequence(
output, lens.tolist(), batch_first=True)
output, self.hidden = self.lstm(output)
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
return output, self.hidden

def get_n_trainable_params(self):
pp = 0
for p in list(self.parameters()):
if p.requires_grad:
num = 1
for s in list(p.size()):
num = num * s
pp += num
return pp

@classmethod
def create(cls, embeddings, hidden_dim=128, rnn_layers=1, use_cuda=True):
model = cls(
embeddings=embeddings, hidden_dim=hidden_dim, rnn_layers=rnn_layers, use_cuda=use_cuda)
return model


class ElmoBiLSTMEncoder(nn.Module):

def __init__(self, embeddings,
hidden_dim=128, rnn_layers=1, use_cuda=True):
super(ElmoBiLSTMEncoder, self).__init__()
self.embeddings = embeddings
self.hidden_dim = hidden_dim
self.rnn_layers = rnn_layers
self.use_cuda = use_cuda
self.lstm = nn.LSTM(
self.embeddings.embedding_dim, hidden_dim // 2,
rnn_layers, batch_first=True, bidirectional=True)
self.hidden = None
if use_cuda:
self.cuda()
self.init_weights()
self.output_dim = hidden_dim

def init_weights(self):
# for p in self.lstm.parameters():
# nn.init.xavier_normal(p)
pass

def forward(self, batch):
input, input_mask = batch[0], batch[-2]
output = self.embeddings(*batch)
# output = self.dropout(output)
lens = input_mask.sum(-1)
output = nn.utils.rnn.pack_padded_sequence(
output, lens.tolist(), batch_first=True)
output, self.hidden = self.lstm(output)
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
return output, self.hidden

def get_n_trainable_params(self):
pp = 0
for p in list(self.parameters()):
if p.requires_grad:
num = 1
for s in list(p.size()):
num = num * s
pp += num
return pp

@classmethod
def create(cls, embeddings, hidden_dim=128, rnn_layers=1, use_cuda=True):
model = cls(
embeddings=embeddings, hidden_dim=hidden_dim, rnn_layers=rnn_layers, use_cuda=use_cuda)
return model


class BertMetaBiLSTMEncoder(nn.Module):

def __init__(self, embeddings, meta_dim,
hidden_dim=128, rnn_layers=1, use_cuda=True):
def __init__(self, embeddings, meta_embeddings=None,
hidden_dim=128, rnn_layers=1, dropout=0.5, use_cuda=True):
super(BertMetaBiLSTMEncoder, self).__init__()
self.embeddings = embeddings
self.meta_embeddings = meta_embeddings
self.hidden_dim = hidden_dim
self.rnn_layers = rnn_layers
self.use_cuda = use_cuda
self.dropout = nn.Dropout(dropout)
meta_dim = 0
if self.meta_embeddings:
meta_dim = meta_embeddings.embedding_dim
self.lstm = nn.LSTM(
self.embeddings.embedding_dim, hidden_dim // 2,
self.embeddings.embedding_dim + meta_dim,
hidden_dim // 2,
rnn_layers, batch_first=True, bidirectional=True)
self.hidden = None
if use_cuda:
self.cuda()
self.init_weights()
self.meta_dim = meta_dim
# + meta_dim
self.output_dim = hidden_dim + meta_dim
self.output_dim = hidden_dim
self.hidden = None

def init_weights(self):
# for p in self.lstm.parameters():
# nn.init.xavier_normal(p)
pass
for param in self.parameters():
if len(param.shape) >= 2:
nn.init.orthogonal_(param.data)
else:
nn.init.normal_(param.data)

def forward(self, batch):
input, input_mask = batch[0], batch[1]
input_mask = batch[1]
output = self.embeddings(*batch)
# output = torch.cat((self.embeddings(*batch), batch[3]), dim=-1)
# print(output.shape)
# output = self.dropout(output)
if self.meta_embeddings:
output = torch.cat((output, self.meta_embeddings(*batch)), dim=-1)
output = self.dropout(output)
lens = input_mask.sum(-1)
output = nn.utils.rnn.pack_padded_sequence(
output, lens.tolist(), batch_first=True)
Expand All @@ -176,8 +58,8 @@ def get_n_trainable_params(self):
return pp

@classmethod
def create(cls, embeddings, meta_dim, hidden_dim=128, rnn_layers=1, use_cuda=True):
def create(cls, embeddings, meta_embeddings=None,
hidden_dim=128, rnn_layers=1, dropout=0.5, use_cuda=True):
model = cls(
embeddings=embeddings, meta_dim=meta_dim,
hidden_dim=hidden_dim, rnn_layers=rnn_layers, use_cuda=use_cuda)
embeddings, meta_embeddings, hidden_dim, rnn_layers, dropout, use_cuda=use_cuda)
return model

0 comments on commit 8450294

Please sign in to comment.