Skip to content

Commit

Permalink
Added autocast to radtts UT
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Nov 19, 2022
1 parent 3403e0f commit d4f48e1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,8 @@ def input_example(self, max_batch=1, max_dim=256):
"""
par = next(self.parameters())
sz = (max_batch, max_dim)
inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int)
inp = torch.randint(16, 32, sz, device=par.device, dtype=torch.int64)
lens = torch.randint(max_dim // 4, max_dim // 2, (max_batch,), device=par.device, dtype=torch.int)
speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64)
inputs = {
'text': inp,
Expand Down
4 changes: 3 additions & 1 deletion tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile

import pytest
import torch
from omegaconf import OmegaConf

from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel
Expand Down Expand Up @@ -79,4 +80,5 @@ def test_RadTTSModel_export_to_torchscript(self, radtts_model):
model = radtts_model.cuda()
with tempfile.TemporaryDirectory() as tmpdir:
filename = os.path.join(tmpdir, 'rad.ts')
model.export(output=filename, verbose=True, check_trace=True)
with torch.cuda.amp.autocast(enabled=True):
model.export(output=filename, verbose=True, check_trace=True)

0 comments on commit d4f48e1

Please sign in to comment.