Skip to content

Commit

Permalink
debugged
Browse files Browse the repository at this point in the history
  • Loading branch information
LiyuanLucasLiu committed Sep 14, 2017
1 parent 8e34eb5 commit 8fad924
Showing 1 changed file with 76 additions and 40 deletions.
116 changes: 76 additions & 40 deletions model/lm_lstm_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@
import model.highway as highway

class LM_LSTM_CRF(nn.Module):
"""LM_LSTM_CRF model
args:
tagset_size: size of label set
char_size: size of char dictionary
char_dim: size of char embedding
char_hidden_dim: size of char-level lstm hidden dim
char_rnn_layers: number of char-level lstm layers
embedding_dim: size of word embedding
word_hidden_dim: size of word-level blstm hidden dim
word_rnn_layers: number of word-level lstm layers
vocab_size: size of word dictionary
dropout_ratio: dropout ratio
large_CRF: use CRF_L or not, refer model.crf.CRF_L and model.crf.CRF_S for more details
if_highway: use highway layers or not
in_doc_words: number of words that occurred in the corpus (used for language model prediction)
highway_layers: number of highway layers
"""

def __init__(self, tagset_size, char_size, char_dim, char_hidden_dim, char_rnn_layers, embedding_dim, word_hidden_dim, word_rnn_layers, vocab_size, dropout_ratio, large_CRF=True, if_highway = False, in_doc_words = 2, highway_layers = 1):

super(LM_LSTM_CRF, self).__init__()
Expand Down Expand Up @@ -59,21 +78,44 @@ def __init__(self, tagset_size, char_size, char_dim, char_hidden_dim, char_rnn_l
self.word_seq_length = 1

def set_batch_size(self, bsize):
"""
set batch size
"""
self.batch_size = bsize

def set_batch_seq_size(self, sentence):
"""
set batch size and sequence length
"""
tmp = sentence.size()
self.word_seq_length = tmp[0]
self.batch_size = tmp[1]

def rand_init_embedding(self):
"""
random initialize char-level embedding
"""
utils.init_embedding(self.char_embeds.weight)

def load_pretrained_word_embedding(self, pre_word_embeddings):
"""
load pre-trained word embedding
args:
pre_word_embeddings (self.word_size, self.word_dim) : pre-trained embedding
"""
assert (pre_word_embeddings.size()[1] == self.word_dim)
self.word_embeds.weight = nn.Parameter(pre_word_embeddings)

def rand_init(self, init_char_embedding=True, init_word_embedding=False):
"""
random initialization
args:
init_char_embedding: random initialize char embedding or not
init_word_embedding: random initialize word embedding or not
"""

if init_char_embedding:
utils.init_embedding(self.char_embeds.weight)
if init_word_embedding:
Expand All @@ -91,41 +133,19 @@ def rand_init(self, init_char_embedding=True, init_word_embedding=False):
utils.init_linear(self.word_pre_train_out)
self.crf.rand_init()

def char_pre_train_forward(self, sentence, hidden=None):
#sentence: seq_len_char * batch
#original order
embeds = self.char_embeds(sentence)
d_embeds = self.dropout(embeds)
lstm_out, hidden = self.forw_char_lstm(d_embeds)
lstm_out = lstm_out.view(-1, self.char_hidden_dim)
d_lstm_out = self.dropout(lstm_out)
if self.if_highway:
char_out = self.forw2char(d_lstm_out)
d_char_out = self.dropout(char_out)
else:
d_char_out = d_lstm_out
pre_score = self.char_pre_train_out(d_char_out)
return pre_score, hidden
def word_pre_train_forward(self, sentence, position, hidden=None):
"""
output of forward language model
def char_pre_train_backward(self, sentence, hidden=None):
#sentence: seq_len_char * batch
#reverse order
embeds = self.char_embeds(sentence)
d_embeds = self.dropout(embeds)
lstm_out, hidden = self.back_char_lstm(d_embeds)
lstm_out = lstm_out.view(-1, self.char_hidden_dim)
d_lstm_out = self.dropout(lstm_out)
if self.if_highway:
char_out = self.forw2char(d_lstm_out)
d_char_out = self.dropout(char_out)
else:
d_char_out = d_lstm_out
pre_score = self.char_pre_train_out(d_char_out)
return pre_score, hidden
args:
sentence (char_seq_len, batch_size): char-level representation of sentence
position (word_seq_len, batch_size): position of blank space in char-level representation of sentence
hidden: initial hidden state
def word_pre_train_forward(self, sentence, position, hidden=None):
#sentence: seq_len_char * batch
#original order
return:
language model output (word_seq_len, in_doc_word), hidden
"""

embeds = self.char_embeds(sentence)
d_embeds = self.dropout(embeds)
lstm_out, hidden = self.forw_char_lstm(d_embeds)
Expand All @@ -145,8 +165,17 @@ def word_pre_train_forward(self, sentence, position, hidden=None):
return pre_score, hidden

def word_pre_train_backward(self, sentence, position, hidden=None):
#sentence: seq_len_char * batch
#reverse order
"""
output of backward language model
args:
sentence (char_seq_len, batch_size): char-level representation of sentence (inverse order)
position (word_seq_len, batch_size): position of blank space in inversed char-level representation of sentence
hidden: initial hidden state
return:
language model output (word_seq_len, in_doc_word), hidden
"""
embeds = self.char_embeds(sentence)
d_embeds = self.dropout(embeds)
lstm_out, hidden = self.back_char_lstm(d_embeds)
Expand All @@ -166,11 +195,18 @@ def word_pre_train_backward(self, sentence, position, hidden=None):
return pre_score, hidden

def forward(self, forw_sentence, forw_position, back_sentence, back_position, word_seq, hidden=None):
#forw_sentence: seq_len_char * batch
#forw_position: seq_len_word * batch
#back_sentence: seq_len_char * batch
#back_position: seq_len_word * batch
#word_seq: seq_len_word * batch
'''
args:
forw_sentence (char_seq_len, batch_size) : char-level representation of sentence
forw_position (word_seq_len, batch_size) : position of blank space in char-level representation of sentence
back_sentence (char_seq_len, batch_size) : char-level representation of sentence (inverse order)
back_position (word_seq_len, batch_size) : position of blank space in inversed char-level representation of sentence
word_seq (word_seq_len, batch_size) : word-level representation of sentence
hidden: initial hidden state
return:
crf output (word_seq_len, batch_size, tag_size, tag_size), hidden
'''

self.set_batch_seq_size(forw_position)

Expand Down

0 comments on commit 8fad924

Please sign in to comment.