Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from logger import VisdomLogger, TensorBoardLogger
from model import DeepSpeech, supported_rnns
from test import evaluate
from utils import reduce_tensor, check_loss
from utils import reduce_tensor, check_loss, remove_parallel_wrapper

parser = argparse.ArgumentParser(description='DeepSpeech training')
parser.add_argument('--train-manifest', metavar='DIR',
Expand Down Expand Up @@ -53,8 +53,10 @@
parser.add_argument('--continue-from', default='', help='Continue from checkpoint model')
parser.add_argument('--finetune', dest='finetune', action='store_true',
help='Finetune the model from checkpoint "continue_from"')
parser.add_argument('--speed-volume-perturb', dest='speed_volume_perturb', action='store_true', help='Use random tempo and gain perturbations.')
parser.add_argument('--spec-augment', dest='spec_augment', action='store_true', help='Use simple spectral augmentation on mel spectograms.')
parser.add_argument('--speed-volume-perturb', dest='speed_volume_perturb', action='store_true',
help='Use random tempo and gain perturbations.')
parser.add_argument('--spec-augment', dest='spec_augment', action='store_true',
help='Use simple spectral augmentation on mel spectograms.')
parser.add_argument('--noise-dir', default=None,
help='Directory to inject noise into audio. If default, noise Inject not added')
parser.add_argument('--noise-prob', default=0.4, help='Probability of noise being added per sample')
Expand Down Expand Up @@ -188,7 +190,8 @@ def update(self, val, n=1):

decoder = GreedyDecoder(labels)
train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels,
normalize=True, speed_volume_perturb=args.speed_volume_perturb, spec_augment=args.spec_augment)
normalize=True, speed_volume_perturb=args.speed_volume_perturb,
spec_augment=args.spec_augment)
test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels,
normalize=True, speed_volume_perturb=False, spec_augment=False)
if not args.distributed:
Expand Down Expand Up @@ -287,9 +290,15 @@ def update(self, val, n=1):
if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0 and main_proc:
file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth' % (save_folder, epoch + 1, i + 1)
print("Saving checkpoint model to %s" % file_path)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, amp=amp, epoch=epoch, iteration=i,
torch.save(DeepSpeech.serialize(remove_parallel_wrapper(model),
optimizer=optimizer,
amp=amp,
epoch=epoch,
iteration=i,
loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results, avg_loss=avg_loss),
wer_results=wer_results,
cer_results=cer_results,
avg_loss=avg_loss),
file_path)
del loss, out, float_out

Expand Down Expand Up @@ -332,8 +341,13 @@ def update(self, val, n=1):

if main_proc and args.checkpoint:
file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, amp=amp, epoch=epoch, loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results),
torch.save(DeepSpeech.serialize(remove_parallel_wrapper(model),
optimizer=optimizer,
amp=amp,
epoch=epoch,
loss_results=loss_results,
wer_results=wer_results,
cer_results=cer_results),
file_path)
# anneal lr
for g in optimizer.param_groups:
Expand All @@ -342,8 +356,12 @@ def update(self, val, n=1):

if main_proc and (best_wer is None or best_wer > wer):
print("Found better validated model, saving to %s" % args.model_path)
torch.save(DeepSpeech.serialize(model, optimizer=optimizer, amp=amp, epoch=epoch, loss_results=loss_results,
wer_results=wer_results, cer_results=cer_results)
torch.save(DeepSpeech.serialize(remove_parallel_wrapper(model),
optimizer=optimizer,
amp=amp, epoch=epoch,
loss_results=loss_results,
wer_results=wer_results,
cer_results=cer_results)
, args.model_path)
best_wer = wer
avg_loss = 0
Expand Down
11 changes: 11 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,14 @@ def load_model(device, model_path, use_half):
if use_half:
model = model.half()
return model


def remove_parallel_wrapper(model):
"""
Return the model or extract the model out of the parallel wrapper
:param model: The training model
:return: The model without parallel wrapper
"""
# Take care of distributed/data-parallel wrapper
model_no_wrapper = model.module if hasattr(model, "module") else model
return model_no_wrapper