diff --git a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py index d7bc750292b8..62fe80a663ed 100644 --- a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py @@ -65,9 +65,14 @@ def main(cfg) -> None: trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) # Get the T5 Base configuration. - t5_cfg = MegatronT5GLUEModel.restore_from( - restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True - ) + if hasattr(t5_cfg.data.validation_ds, 'task_name'): + t5_cfg = MegatronT5GLUEModel.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True + ) + else: + t5_cfg = MegatronT5FinetuneModel.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True + ) # Override the T5 configuration with the one from the config file. # NOTE: Only data can be overriden here since this the file being restored here should already correspond to a GLUE/XNLI finetuned model. diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index ea5f627a294c..8d8e70a10e5f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -99,7 +99,7 @@ def setup_metric(self, data_cfg): metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes) if metric_name != 'exact_string_match' else metric() - for _ in range(len(self.cfg.data.test_ds.src_file_name)) + for _ in range(len(data_cfg.src_file_name)) ] else: metric = [metric(average=data_cfg.metric.average, num_classes=data_cfg.metric.num_classes)] @@ -387,7 +387,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): metric = metric['acc'] else: self.log(metric_log_key, metric) - logging.info(f"{mode} {metric_name}: {metric}") + logging.info(f"{metric_log_key}: {metric}") metric_object.reset() averaged_loss.append(loss) @@ -663,11 +663,11 @@ def build_train_valid_test_datasets(self, stage): if stage != 'test': self._validation_ds = self._build_eval_dataset(self.cfg.data.validation_ds) - if stage != 'validation': + if stage != 'validate': if hasattr(self.cfg.data, 'test_ds'): self._test_ds = self._build_eval_dataset(self.cfg.data.test_ds) - if stage == 'validation' or stage == 'test': + if stage == 'validate' or stage == 'test': return self._train_ds = self._build_train_dataset(self.cfg.data.train_ds) logging.info(f'Finished building datasets ...')