diff --git a/train.py b/train.py index 18c07156..7266d785 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from logger import VisdomLogger, TensorBoardLogger from model import DeepSpeech, supported_rnns from test import evaluate -from utils import reduce_tensor, check_loss +from utils import reduce_tensor, check_loss, remove_parallel_wrapper parser = argparse.ArgumentParser(description='DeepSpeech training') parser.add_argument('--train-manifest', metavar='DIR', @@ -53,8 +53,10 @@ parser.add_argument('--continue-from', default='', help='Continue from checkpoint model') parser.add_argument('--finetune', dest='finetune', action='store_true', help='Finetune the model from checkpoint "continue_from"') -parser.add_argument('--speed-volume-perturb', dest='speed_volume_perturb', action='store_true', help='Use random tempo and gain perturbations.') -parser.add_argument('--spec-augment', dest='spec_augment', action='store_true', help='Use simple spectral augmentation on mel spectograms.') +parser.add_argument('--speed-volume-perturb', dest='speed_volume_perturb', action='store_true', + help='Use random tempo and gain perturbations.') +parser.add_argument('--spec-augment', dest='spec_augment', action='store_true', + help='Use simple spectral augmentation on mel spectograms.') parser.add_argument('--noise-dir', default=None, help='Directory to inject noise into audio. If default, noise Inject not added') parser.add_argument('--noise-prob', default=0.4, help='Probability of noise being added per sample') @@ -188,7 +190,8 @@ def update(self, val, n=1): decoder = GreedyDecoder(labels) train_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.train_manifest, labels=labels, - normalize=True, speed_volume_perturb=args.speed_volume_perturb, spec_augment=args.spec_augment) + normalize=True, speed_volume_perturb=args.speed_volume_perturb, + spec_augment=args.spec_augment) test_dataset = SpectrogramDataset(audio_conf=audio_conf, manifest_filepath=args.val_manifest, labels=labels, normalize=True, speed_volume_perturb=False, spec_augment=False) if not args.distributed: @@ -287,9 +290,15 @@ def update(self, val, n=1): if args.checkpoint_per_batch > 0 and i > 0 and (i + 1) % args.checkpoint_per_batch == 0 and main_proc: file_path = '%s/deepspeech_checkpoint_epoch_%d_iter_%d.pth' % (save_folder, epoch + 1, i + 1) print("Saving checkpoint model to %s" % file_path) - torch.save(DeepSpeech.serialize(model, optimizer=optimizer, amp=amp, epoch=epoch, iteration=i, + torch.save(DeepSpeech.serialize(remove_parallel_wrapper(model), + optimizer=optimizer, + amp=amp, + epoch=epoch, + iteration=i, loss_results=loss_results, - wer_results=wer_results, cer_results=cer_results, avg_loss=avg_loss), + wer_results=wer_results, + cer_results=cer_results, + avg_loss=avg_loss), file_path) del loss, out, float_out @@ -332,8 +341,13 @@ def update(self, val, n=1): if main_proc and args.checkpoint: file_path = '%s/deepspeech_%d.pth.tar' % (save_folder, epoch + 1) - torch.save(DeepSpeech.serialize(model, optimizer=optimizer, amp=amp, epoch=epoch, loss_results=loss_results, - wer_results=wer_results, cer_results=cer_results), + torch.save(DeepSpeech.serialize(remove_parallel_wrapper(model), + optimizer=optimizer, + amp=amp, + epoch=epoch, + loss_results=loss_results, + wer_results=wer_results, + cer_results=cer_results), file_path) # anneal lr for g in optimizer.param_groups: @@ -342,8 +356,12 @@ def update(self, val, n=1): if main_proc and (best_wer is None or best_wer > wer): print("Found better validated model, saving to %s" % args.model_path) - torch.save(DeepSpeech.serialize(model, optimizer=optimizer, amp=amp, epoch=epoch, loss_results=loss_results, - wer_results=wer_results, cer_results=cer_results) + torch.save(DeepSpeech.serialize(remove_parallel_wrapper(model), + optimizer=optimizer, + amp=amp, epoch=epoch, + loss_results=loss_results, + wer_results=wer_results, + cer_results=cer_results) , args.model_path) best_wer = wer avg_loss = 0 diff --git a/utils.py b/utils.py index de508b07..3cb9af3a 100644 --- a/utils.py +++ b/utils.py @@ -38,3 +38,14 @@ def load_model(device, model_path, use_half): if use_half: model = model.half() return model + + +def remove_parallel_wrapper(model): + """ + Return the model or extract the model out of the parallel wrapper + :param model: The training model + :return: The model without parallel wrapper + """ + # Take care of distributed/data-parallel wrapper + model_no_wrapper = model.module if hasattr(model, "module") else model + return model_no_wrapper