Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/flairNLPGH-2120-imporved-tensorb…
Browse files Browse the repository at this point in the history
…oard-logging' into flairNLPGH-2120-imporved-tensorboard-logging
  • Loading branch information
braunefe committed Mar 19, 2021
2 parents 065cf02 + cc32c6f commit 38eb8e6
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def train(
eval_on_train_fraction=0.0,
eval_on_train_shuffle=False,
save_model_each_k_epochs: int = 0,
main_score_type=("micro avg", 'f1-score'),
classification_main_metric=("micro avg", 'f1-score'),
tensorboard_comment='',
**kwargs,
) -> dict:
Expand Down Expand Up @@ -138,16 +138,16 @@ def train(
:param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
be saved each 5 epochs. Default is 0 which means no model saving.
:param save_model_epoch_step: Each save_model_epoch_step'th epoch the thus far trained model will be saved
:param main_score_type: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model
:param classification_main_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model
:param tensorboard_comment: Comment to use for tensorboard logging
:param kwargs: Other arguments for the Optimizer
:return:
"""
if isinstance(self.model, TextClassifier):
self.main_score_type=main_score_type
self.main_score_type=classification_main_metric
else:
if main_score_type is not None:
warnings.warn("Choosing a main score type during training is currently only possible for text_classification_model. Will use default main score type instead of specified one.")
if classification_main_metric is not None:
warnings.warn("Specification of main score type only implemented for text classifier. Defaulting to main score type of selected model.")
self.main_score_type = None
if self.use_tensorboard:
try:
Expand Down

0 comments on commit 38eb8e6

Please sign in to comment.