Skip to content

Commit

Permalink
Update to Lightning 1.0.3 (#1323)
Browse files Browse the repository at this point in the history
* update add_state and loading checkpoints

Signed-off-by: Jason <jasoli@nvidia.com>

* syntax error

Signed-off-by: Jason <jasoli@nvidia.com>
  • Loading branch information
blisc committed Oct 21, 2020
1 parent 5cf0428 commit 4e4597a
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 22 deletions.
6 changes: 2 additions & 4 deletions examples/asr/speech_to_text_infer.py
Expand Up @@ -56,12 +56,10 @@ def main():

if args.asr_model.endswith('.nemo'):
logging.info(f"Using local ASR model from {args.asr_model}")
# TODO: Remove strict, when lightning has persistent parameter support for add_state()
asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model, strict=False)
asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
else:
logging.info(f"Using NGC cloud ASR model {args.asr_model}")
# TODO: Remove strict, when lightning has persistent parameter support for add_state()
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model, strict=False)
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
asr_model.setup_test_data(
test_data_config={
'sample_rate': 16000,
Expand Down
Expand Up @@ -49,8 +49,7 @@ def main(cfg: DictConfig) -> None:
model = PunctuationCapitalizationModel(cfg.model, trainer=trainer)
else:
logging.info(f'Loading pretrained model {cfg.pretrained_model}')
# TODO: Remove strict, when lightning has persistent parameter support for add_state()
model = PunctuationCapitalizationModel.from_pretrained(cfg.pretrained_model, strict=False)
model = PunctuationCapitalizationModel.from_pretrained(cfg.pretrained_model)
data_dir = cfg.model.dataset.get('data_dir', None)
if data_dir:
# we can also do finetunining of the pretrained model but it will require
Expand Down
3 changes: 1 addition & 2 deletions examples/nlp/token_classification/token_classification.py
Expand Up @@ -85,8 +85,7 @@ def main(cfg: DictConfig) -> None:
model = TokenClassificationModel(cfg.model, trainer=trainer)
else:
logging.info(f'Loading pretrained model {cfg.pretrained_model}')
# TODO: Remove strict, when lightning has persistent parameter support for add_state()
model = TokenClassificationModel.from_pretrained(cfg.pretrained_model, strict=False)
model = TokenClassificationModel.from_pretrained(cfg.pretrained_model)

data_dir = cfg.model.dataset.get('data_dir', None)
if data_dir:
Expand Down
3 changes: 1 addition & 2 deletions examples/tts/test_tts_infer.py
Expand Up @@ -75,8 +75,7 @@ def main():
logging.set_verbosity(logging.DEBUG)

logging.info(f"Using NGC cloud ASR model {args.asr_model}")
# TODO: Remove strict, when lightning has persistent parameter support for add_state()
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model, strict=False)
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
logging.info(f"Using NGC cloud TTS Spectrogram Generator model {args.tts_model_spec}")
tts_model_spec = SpectrogramGenerator.from_pretrained(model_name=args.tts_model_spec)
logging.info(f"Using NGC cloud TTS Vocoder model {args.tts_model_vocoder}")
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/metrics/wer.py
Expand Up @@ -111,8 +111,8 @@ def __init__(
self.ctc_decode = ctc_decode
self.log_prediction = log_prediction

self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)

def ctc_decoder_predictions_tensor(self, predictions: torch.Tensor) -> List[str]:
"""
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/metrics/wer_bpe.py
Expand Up @@ -74,8 +74,8 @@ def __init__(
self.ctc_decode = ctc_decode
self.log_prediction = log_prediction

self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)

def ctc_decoder_predictions_tensor(self, predictions: torch.Tensor) -> List[str]:
"""
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/common/metrics/classification_accuracy.py
Expand Up @@ -63,8 +63,10 @@ def __init__(self, top_k=None, dist_sync_on_step=False):
top_k = [1]

self.top_k = top_k
self.add_state("correct_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum')
self.add_state("total_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum')
self.add_state(
"correct_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False
)
self.add_state("total_counts_k", default=torch.zeros(len(self.top_k)), dist_reduce_fx='sum', persistent=False)

def update(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
Expand Down
10 changes: 6 additions & 4 deletions nemo/collections/nlp/metrics/classification_report.py
Expand Up @@ -76,10 +76,12 @@ def __init__(
self.ids_to_labels = None
self.mode = mode

self.add_state("tp", default=torch.zeros(num_classes), dist_reduce_fx='sum')
self.add_state("fn", default=torch.zeros(num_classes), dist_reduce_fx='sum')
self.add_state("fp", default=torch.zeros(num_classes), dist_reduce_fx='sum')
self.add_state("num_examples_per_class", default=torch.zeros(num_classes), dist_reduce_fx='sum')
self.add_state("tp", default=torch.zeros(num_classes), dist_reduce_fx='sum', persistent=False)
self.add_state("fn", default=torch.zeros(num_classes), dist_reduce_fx='sum', persistent=False)
self.add_state("fp", default=torch.zeros(num_classes), dist_reduce_fx='sum', persistent=False)
self.add_state(
"num_examples_per_class", default=torch.zeros(num_classes), dist_reduce_fx='sum', persistent=False
)

def update(self, predictions: torch.Tensor, labels: torch.Tensor):
TP = []
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/metrics/perplexity.py
Expand Up @@ -29,7 +29,7 @@ class Perplexity(Metric):

def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state('perplexity', default=torch.tensor(0), dist_reduce_fx='mean')
self.add_state('perplexity', default=torch.tensor(0), dist_reduce_fx='mean', persistent=False)

def update(self, loss: torch.Tensor):
self.perplexity = torch.exp(loss)
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
@@ -1,6 +1,6 @@
numpy>=1.18.2
onnx>=1.7.0
pytorch-lightning==1.0.2
pytorch-lightning>=1.0.3
python-dateutil
torch
wget
Expand Down

0 comments on commit 4e4597a

Please sign in to comment.