Skip to content

Commit

Permalink
deb
Browse files Browse the repository at this point in the history
  • Loading branch information
LiyuanLucasLiu committed Sep 14, 2017
1 parent dcc19b9 commit c960b57
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions model/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def rand_init(self):

def forward(self, feats):
"""
input:
inputs:
- **feats** (batch_size, seq_len, hidden_dim) : input score from previous layers
output:
outputs:
- **crf_score** (batch_size, seq_len, tag_size, tag_size) : output from crf layer
"""
return self.hidden2tag(feats).view(-1, self.tagset_size, self.tagset_size)
Expand Down Expand Up @@ -68,10 +68,10 @@ def rand_init(self):

def forward(self, feats):
"""
input:
inputs:
- **feats** (batch_size, seq_len, hidden_dim) : input score from previous layers
output:
outputs:
- **crf_score** (batch_size, seq_len, tag_size, tag_size) : output from crf layer
"""
Expand All @@ -98,11 +98,11 @@ def __init__(self, tagset_size, if_cuda):
def repack_vb(self, feature, target, mask):
"""packer for viterbi loss
input:
inputs:
- **feature** (Seq_len, Batch_size): input feature
- **target** (Seq_len, Batch_size): output target
- **mask** (Seq_len, Batch_size): padding mask
output:
outputs:
- **feature** (Seq_len, Batch_size) : input feature
- **target** (Seq_len, Batch_size) : output target
- **mask** (Seq_len, Batch_size) : padding mask
Expand All @@ -121,11 +121,11 @@ def repack_vb(self, feature, target, mask):
def repack_gd(self, feature, target, current):
"""packer for greedy loss
args:
argss:
feature: input feature, of size Seq_len * Batch_size
target: output target, of size Seq_len * Batch_size
current: current state, of size Seq_len * Batch_size
output:
outputs:
feature: input feature, of size Batch_size * Seq_len
target: output target, of size (Seq_len * Batch_size)
current: current state, of size (Seq_len * Batch_size) * 1 * 1
Expand All @@ -141,6 +141,8 @@ def repack_gd(self, feature, target, current):
return fea_v, ts_v, cs_v

def convert_for_eval(self, target):
"""convert target to original decoding
"""
return target % self.tagset_size


Expand Down

0 comments on commit c960b57

Please sign in to comment.