Skip to content

Commit

Permalink
🐛 Fixing labeling dict bug, added PAD token to seq label dict.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed May 22, 2019
1 parent cc6d4ea commit d2af50c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 3 additions & 1 deletion kashgari/pre_processors/labeling_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def _build_label_dict(self,
Args:
label_list: corpus label list
"""
label2idx: Dict[str: int] = {}
label2idx: Dict[str: int] = {
self.token_pad: 0
}

token2count = {}
for label_set in label_list:
Expand Down
5 changes: 3 additions & 2 deletions kashgari/tasks/labeling/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def compile_model(self, **kwargs):
weight = np.full((len(idx2label),), 50)
for idx, label in idx2label.items():
if label == self.embedding.processor.token_pad:
weight[idx] = 0
weight[idx] = 0.1
if label in ['O']:
weight[idx] = 1
weight_dict = {}
for idx, label in idx2label.items():
weight_dict[label] = weight[idx]
logging.debug(f"label weights set to {weight_dict}")
logging.debug(f"label weights set to {weight_dict}")
kwargs['loss'] = weighted_categorical_crossentropy(weight)
super(BaseLabelingModel, self).compile_model(**kwargs)

Expand Down Expand Up @@ -125,6 +125,7 @@ def build_model_arc(self):


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
from kashgari.tasks.labeling import CNNLSTMModel
from kashgari.corpus import ChineseDailyNerCorpus

Expand Down

0 comments on commit d2af50c

Please sign in to comment.