From c960b579a8a873f2ac58bd74d7377aa9d23493b0 Mon Sep 17 00:00:00 2001 From: LiyuanLucasLiu Date: Thu, 14 Sep 2017 15:42:47 -0500 Subject: [PATCH] deb --- model/crf.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/model/crf.py b/model/crf.py index 75dd2b1..197dce0 100644 --- a/model/crf.py +++ b/model/crf.py @@ -35,10 +35,10 @@ def rand_init(self): def forward(self, feats): """ - input: + inputs: - **feats** (batch_size, seq_len, hidden_dim) : input score from previous layers - output: + outputs: - **crf_score** (batch_size, seq_len, tag_size, tag_size) : output from crf layer """ return self.hidden2tag(feats).view(-1, self.tagset_size, self.tagset_size) @@ -68,10 +68,10 @@ def rand_init(self): def forward(self, feats): """ - input: + inputs: - **feats** (batch_size, seq_len, hidden_dim) : input score from previous layers - output: + outputs: - **crf_score** (batch_size, seq_len, tag_size, tag_size) : output from crf layer """ @@ -98,11 +98,11 @@ def __init__(self, tagset_size, if_cuda): def repack_vb(self, feature, target, mask): """packer for viterbi loss - input: + inputs: - **feature** (Seq_len, Batch_size): input feature - **target** (Seq_len, Batch_size): output target - **mask** (Seq_len, Batch_size): padding mask - output: + outputs: - **feature** (Seq_len, Batch_size) : input feature - **target** (Seq_len, Batch_size) : output target - **mask** (Seq_len, Batch_size) : padding mask @@ -121,11 +121,11 @@ def repack_vb(self, feature, target, mask): def repack_gd(self, feature, target, current): """packer for greedy loss - args: + argss: feature: input feature, of size Seq_len * Batch_size target: output target, of size Seq_len * Batch_size current: current state, of size Seq_len * Batch_size - output: + outputs: feature: input feature, of size Batch_size * Seq_len target: output target, of size (Seq_len * Batch_size) current: current state, of size (Seq_len * Batch_size) * 1 * 1 @@ -141,6 +141,8 @@ def repack_gd(self, feature, target, current): return fea_v, ts_v, cs_v def convert_for_eval(self, target): + """convert target to original decoding + """ return target % self.tagset_size