Skip to content

Commit

Permalink
Merge branch 'main' into megatron_nmt_sample_training
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximumEntropy committed Jul 28, 2022
2 parents 9542e74 + 72d78d8 commit 18207e7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
11 changes: 8 additions & 3 deletions examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py
Expand Up @@ -65,9 +65,14 @@ def main(cfg) -> None:
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)

# Get the T5 Base configuration.
t5_cfg = MegatronT5GLUEModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
if hasattr(t5_cfg.data.validation_ds, 'task_name'):
t5_cfg = MegatronT5GLUEModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)
else:
t5_cfg = MegatronT5FinetuneModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True
)

# Override the T5 configuration with the one from the config file.
# NOTE: Only data can be overriden here since this the file being restored here should already correspond to a GLUE/XNLI finetuned model.
Expand Down
Expand Up @@ -99,7 +99,7 @@ def setup_metric(self, data_cfg):
metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes)
if metric_name != 'exact_string_match'
else metric()
for _ in range(len(self.cfg.data.test_ds.src_file_name))
for _ in range(len(data_cfg.src_file_name))
]
else:
metric = [metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes)]
Expand Down Expand Up @@ -387,7 +387,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
metric = metric['acc']
else:
self.log(metric_log_key, metric)
logging.info(f"{mode} {metric_name}: {metric}")
logging.info(f"{metric_log_key}: {metric}")
metric_object.reset()

averaged_loss.append(loss)
Expand Down Expand Up @@ -663,11 +663,11 @@ def build_train_valid_test_datasets(self, stage):
if stage != 'test':
self._validation_ds = self._build_eval_dataset(self.cfg.data.validation_ds)

if stage != 'validation':
if stage != 'validate':
if hasattr(self.cfg.data, 'test_ds'):
self._test_ds = self._build_eval_dataset(self.cfg.data.test_ds)

if stage == 'validation' or stage == 'test':
if stage == 'validate' or stage == 'test':
return
self._train_ds = self._build_train_dataset(self.cfg.data.train_ds)
logging.info(f'Finished building datasets ...')
Expand Down

0 comments on commit 18207e7

Please sign in to comment.