Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastSpeech2 training with MFA and Phoneme-based #107

Closed
ZDisket opened this issue Jul 10, 2020 · 169 comments
Closed

FastSpeech2 training with MFA and Phoneme-based #107

ZDisket opened this issue Jul 10, 2020 · 169 comments
Assignees
Labels
Discussion 😁 Discuss new feature enhancement 🚀 New feature or request FastSpeech FastSpeech related problems. Feature Request 🤗 Feature support question ❓ Further information is requested
Projects

Comments

@ZDisket
Copy link
Collaborator

ZDisket commented Jul 10, 2020

When training FastSpeech2 (fastspeech2_v2) with phonetic alignments extracted from MFA I get the error described:

/content/TensorflowTTS/tensorflow_tts/trainers/base_trainer.py in run(self)
     65         )
     66         while True:
---> 67             self._train_epoch()
     68 
     69             if self.finish_train:

/content/TensorflowTTS/tensorflow_tts/trainers/base_trainer.py in _train_epoch(self)
     87         for train_steps_per_epoch, batch in enumerate(self.train_data_loader, 1):
     88             # one step training
---> 89             self._train_step(batch)
     90 
     91             # check interval

<ipython-input-39-dd452e77975e> in _train_step(self, batch)
     75         """Train model one step."""
     76         charactor, duration, f0, energy, mel = batch
---> 77         self._one_step_fastspeech2(charactor, duration, f0, energy, mel)
     78 
     79         # update counts

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    642         # Lifting succeeded, so variables are initialized and we can run the
    643         # stateless function.
--> 644         return self._stateless_fn(*args, **kwds)
    645     else:
    646       canon_args, canon_kwds = \

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   2418     with self._lock:
   2419       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2420     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2421 
   2422   @property

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _filtered_call(self, args, kwargs)
   1663          if isinstance(t, (ops.Tensor,
   1664                            resource_variable_ops.BaseResourceVariable))),
-> 1665         self.captured_inputs)
   1666 
   1667   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1744       # No tape is watching; skip to running the function.
   1745       return self._build_call_outputs(self._inference_function.call(
-> 1746           ctx, args, cancellation_manager=cancellation_manager))
   1747     forward_backward = self._select_forward_and_backward_functions(
   1748         args,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    596               inputs=args,
    597               attrs=attrs,
--> 598               ctx=ctx)
    599         else:
    600           outputs = execute.execute_with_cancellation(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Incompatible shapes: [16,823,80] vs. [16,867,80]
	 [[node mean_absolute_error/sub (defined at <ipython-input-39-dd452e77975e>:115) ]] [Op:__inference__one_step_fastspeech2_341496]

Errors may have originated from an input operation.
Input Source operations connected to node mean_absolute_error/sub:
 mel (defined at <ipython-input-39-dd452e77975e>:77)	
 tf_fast_speech2_2/mel_before/BiasAdd (defined at /content/TensorflowTTS/tensorflow_tts/models/fastspeech2.py:196)

Function call stack:
_one_step_fastspeech2

I did everything I could think of to rule out my durations as the problem including verification that length is the same, so I don't know what happened.
Interestingly enough, when training with mixed_precision off the same error happens but with different values:

InvalidArgumentError:  Incompatible shapes: [16,763,80] vs. [16,806,80]
	 [[node mean_absolute_error/sub (defined at <ipython-input-39-dd452e77975e>:115) ]] [Op:__inference__one_step_fastspeech2_449871]

Errors may have originated from an input operation.
Input Source operations connected to node mean_absolute_error/sub:
 tf_fast_speech2_3/mel_before/BiasAdd (defined at /content/TensorflowTTS/tensorflow_tts/models/fastspeech2.py:196)	
 mel (defined at <ipython-input-39-dd452e77975e>:77)

Function call stack:
_one_step_fastspeech2

Am I missing something?

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jul 10, 2020

@ZDisket see here https://github.com/TensorSpeech/TensorflowTTS/blob/master/examples/fastspeech2/train_fastspeech2.py#L116-L117. Loss f0/duration/energy is fine but loss mel-spectrogram missmatch length. I pretty sure that ur sum(duration) != len(mel), maybe just some samples or maybe all samples :D. Pls ref this comment (#46 (comment))

@trfnhle
Copy link
Collaborator

trfnhle commented Jul 10, 2020

which duration value you use when training ground truth or predict? If you use gt duration then the only reason comes in my mind is that some sample has sum of duration different with length of Melspectrogram. You could use the following code to ignore outlined sample while loading data.

difference = []
for i in tqdm(range(len(self.characters))):
        len_mel = len(np.load(self.mel_files[i]))
        duration = np.load(self.duration_files[i])
        total_duration = np.sum(duration)
	if len_mel != total_duration:
		difference.append(i)

Then filtering difference ids

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

@l4zyf9x @dathudeptrai Thanks, I'll see how deep the problem goes.

@dathudeptrai dathudeptrai added bug 🐛 Something isn't working FastSpeech FastSpeech related problems. question ❓ Further information is requested stat:awaiting response ☏ Waiting Response labels Jul 10, 2020
@dathudeptrai dathudeptrai added this to In progress in FastSpeech Jul 10, 2020
@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

It seems that the length of the normalized mels is slightly higher than that of the sum of durations.
match(1).txt
Plotting one mel, it looks like the difference is in silence. So I have to go through all the mels and cut them off where the duration stops.
In this image for example, the mel has len 586 while the duration sum is 554
plot11
Or I could pad the duration by adding a SIL token (see here) at the end.
Which approach is more optimal?

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jul 10, 2020

@abylouw do you have any thought ? .@ZDisket it seems the length missmatch too much, in my preprocess for F0/energy, i also padding or drop some last elements to make the length is equal but the mismatch is just 1 frame. cc: @azraelkuan

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

@dathudeptrai I don't think it's too much when considering that these are phonetic durations. I'm considering just adding the difference to the last duration for every utterance, but I want to hear what you guys think about it first.

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jul 10, 2020

hmm, i don't have any experience in this case @ZDisket. maybe try to just adding to the last duration and see what happend when training :D. I think @abylouw succeed to train FastSpeech2 with the duration extract from MFA (or something like this) so i will wait his comment in this case.

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

@dathudeptrai I've gone ahead and it's training fine, the very first loss values look fine and predictions at 3000 steps.zip (very early) look normal. ming024 stated in his README.md that it only takes an hour of training on a GTX 1080 to start producing decent samples.

@dathudeptrai
Copy link
Collaborator

@ZDisket let see, i don't know what is the proper way to solve the length missmatch is this case, hope the model don't have any problem at the end of the mel :D

@trfnhle
Copy link
Collaborator

trfnhle commented Jul 10, 2020

@ZDisket One question. Do you train Tacotron and extract duration with the same Mel that you use to train FastSpeech2?

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

@l4zyf9x Those are extracted durations on ground truth LJSpeech audio from Montreal Forced Aligner. I didn't use a teacher model.

@trfnhle
Copy link
Collaborator

trfnhle commented Jul 10, 2020

As @dathudeptrai mention before. I think you should reference this comment #46 (comment). I think the solution is that you config MFA have the same time resolution. For example: frame_rate=16000, hop_length=256. So you need set rate of MFA = 1000 * 256 / 16000 = 16(ms). Finally, it will be possible to mismatch 1 to 2 frame, you just need pad or drop the last frame

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

@l4zyf9x If that was the problem then it'd be much worse. I already took that into account for my second to frame conversion, taking it from here.

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

Although with some problems, the model at 30k is capable of generating speech. Here's a notebook.
@dathudeptrai you like?

@dathudeptrai
Copy link
Collaborator

@ZDisket the mel looks good at 30k steps, let training it around 150k steps and compared the performance. note that fastspeech2_v2 is small version

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

@dathudeptrai Does the small version lower quality, or is it supposed to be the same or higher?

@dathudeptrai
Copy link
Collaborator

@ZDisket need to tune :)). I think we can tune and find the smaller version without worse quality. Let see the performance after 150k steps.

@trfnhle
Copy link
Collaborator

trfnhle commented Jul 10, 2020

@ZDisket I just have look at your code and MFA documentation. Following your code, when config MFA, frame_shift should be 1000*256/22050~11.61. What is your MFA frame_shift

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 10, 2020

@l4zyf9x My relevant code is:

tg = textgrid.TextGrid.fromFile("./TextGrids/" + tgp)
  pha = tg[1]
  durations = []
  phs = "{"
  for interval in pha.intervals:
    mark = interval.mark
    if mark in sil_phones:
      mark = "SIL"
    dur = interval.duration()*(sarate/hopsz)
    durations.append(int(dur))

I didn't read that part of the documentation and since both ming024 and I are getting results, it's not necessary.

@abylouw
Copy link
Contributor

abylouw commented Jul 10, 2020

Is trim_silence true in your config? I did not use librosa to trim the silences, but rather trimmed them from the labels as aligned by our HTK implementation. We have a SIL start and end phone for all utterances, and in the pre-processing script instead of trimming with librosa as here:

https://github.com/TensorSpeech/TensorflowTTS/blob/84c004172d604ecb966e1a905b81d6e1db2fc2e0/tensorflow_tts/bin/preprocess.py#L206-L212

we basically trim the first and last labels of the phones:

def trim_silences(audio, durs, samplerate):
    audio_start = durs[0] * samplerate
    audio_end = (np.sum(durs) - durs[-1]) * samplerate
    trimmed_audio = audio[int(round(audio_start)):int(round(audio_end))]
    return trimmed_audio

where the durs list contains the durations of the phones in seconds.

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 11, 2020

@dathudeptrai After 110k steps, the performance is pretty bad, you can see samples in the notebook. I think I'll instead cut the mels at the time of data loading and train fastspeech2_v1. I also added training instructions here

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jul 11, 2020

@ZDisket thanks, i will take a look, phoneme based should better than charactor fastspeech2 v1 here. I am review everything now then release first version. After that i will train phoneme, support multigpu/tpu ...

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 11, 2020

@dathudeptrai Are you going to train the current version which adds the difference to the last duration? There are a few problems in the ends of the audio samples, although I suspect it could be just fastspeech2 v2 being bad.

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 22, 2020

@dathudeptrai
Yes, I saved the weights from that notebook you gave me into my Google Drive (you should publish the model by the way). The only flaw is that the audio is a little bit high on the low frequencies but this is nothing for some bass and treble manipulation. @Dicksonchin93 You can grab the link and see it in action with my notebook: https://colab.research.google.com/drive/1wXdeTQQdMdhkpvto7hDVgEfCanreNoE9?usp=sharing

@dathudeptrai
Copy link
Collaborator

@Dicksonchin93 you can combine :)), i can't since i don't have enough disk space to train VCTK :v

@Dicksonchin93
Copy link

@dathudeptrai @ZDisket Do you guys have any idea on how to reduce the metallic effect to make it more natural? I'm looking in some manual audio processing to remove some of these metallic sounds

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 22, 2020

@Dicksonchin93 If you can isolate the metallic noise to use as a noise profile then a conventional noise reduction filter (like the one found in Audacity) might be enough. Otherwise, try to train it more or with a better dataset, or train v1 which is heavier and higher quality if you're doing v2.

@dathudeptrai
Copy link
Collaborator

can anyone help me write a correct code for phoneme_to_sequence id in https://github.com/TensorSpeech/TensorflowTTS/blob/master/tensorflow_tts/processor/ljspeech.py :(. I just want to train tacotron-2 with phoneme now :(. Here is my code :(

    def phoneme_to_sequence(self, text):
        sequence = []
        clean_text = _clean_text(text, [self.cleaner_names])
        phoneme = G2P(clean_text)
        for p in phoneme:
            if p in valid_symbols:
                p = "@" + p
            if p in _symbol_to_id.keys():
                sequence.append(_symbol_to_id[p])
        return sequence

@ZDisket
Copy link
Collaborator Author

ZDisket commented Jul 23, 2020

@dathudeptrai I just take the phonetic equivalent of the transcription (I do it from the MFA output, you'll have to run G2P on it) and modify metadata.csv to have those wrapped between curly braces {AA1, ...} before doing any preprocessing steps. The LJSpeech preprocessor is capable of handling the rest.

@dathudeptrai
Copy link
Collaborator

@ZDisket ok, i'm training with phoneme based on G2P now :D

@manmay-nakhashi
Copy link

@dathudeptrai i have espeak based phoneme to sequence will it work ?

@dathudeptrai
Copy link
Collaborator

@manmay-nakhashi i think G2P is faster than espeak :v.

@manmay-nakhashi
Copy link

@dathudeptrai but less language support , i couldn't find anyway to load custom g2p model into python

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Jul 23, 2020

@manmay-nakhashi As you can see, every languages/datasets should have its own preprocessor.py. The important is not that G2p less language support, this is divide and conquer strategy. Almost each language have its own framework to convert charactor to phoneme, we don't need use general framework for this :v. In my point, interms of model implementation, more general is good. But interms of preprocessing code, more general -> less flexible :v, also less readable.

@manmay-nakhashi
Copy link

@dathudeptrai you are right in term of flexibility but if we can make a system as a custom model for G2P based on user defined phoneme output then we can make both flexible and generic.

@dathudeptrai
Copy link
Collaborator

@manmay-nakhashi yes, that is on my plan, make preprocessing stage as class based so user can inherit base_preprocessing class to implement their preprocessing class. As base_trainer here :D

@manmay-nakhashi
Copy link

@dathudeptrai exactly that would make this framework more robust, and scalable

@janson91
Copy link

@dathudeptrai
hi, I have trained Phoneme-based fastspeech2 and I get result like this:
It seemed to be over fitting. And the audio results generate by griffin lim sound bad. Can you give me some suggestion?
image
image
image

thx

@dathudeptrai
Copy link
Collaborator

@janson91 everything ok., the mel looks good. GL is just to check the output is correct or not, it always generates bad audios. You should train vocoder such as mb-melgan

@janson91
Copy link

@dathudeptrai
ok, mb-melgan is training, but it is very slow. It will take hundreds of hours to train on 8 gpus(1080ti)

@dathudeptrai
Copy link
Collaborator

dathudeptrai commented Aug 10, 2020

@janson91 the batch_size in the config is for 1 gpu and it's 64. If you training with 8gpus, i suggest you training with batch_size = 8 or 16 :)). If you use batch_size = 64 that mean global_batch_size is 64 * 8 = 512 :))). Mb -Melgan need to be train around 1M steps with batch_size 64 :D.

@janson91
Copy link

Ok thx @dathudeptrai

@janson91
Copy link

@dathudeptrai hi,
I just adjusted sampling_rate: 48000 in config yaml for sampling rate of my raw data is 48000
however, I get the synthesis audio half length of original audio.
Do I leave out other parameters?
image

thx

@dathudeptrai
Copy link
Collaborator

@janson91 the sampling_rate in processing config should be the same with the sampling_rate in training config :D.

@janson91
Copy link

@dathudeptrai sampling_rate of two config is same. So are there other possible parameters that lead to incorrect result?

@dathudeptrai
Copy link
Collaborator

@janson91 could you create a new issue ? and ofc, please give us ur preprocessing config and training config :D .

@dathudeptrai
Copy link
Collaborator

fastspeech + MFA is supported in (https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/fastspeech2_multispeaker). I will close this issue, thanks for all ur help :D

@dathudeptrai dathudeptrai moved this from In progress to Done in FastSpeech Aug 13, 2020
@ykzj
Copy link

ykzj commented Apr 14, 2021

@ZDisket thanks, i will take a look, phoneme based should better than charactor fastspeech2 v1 here. I am review everything now then release first version. After that i will train phoneme, support multigpu/tpu ...

@ dathudeptrai what's the progress of migrating fastspeech to TPU?

@binbinxue
Copy link

#46 (comment)

i have to ask, i checked the phoneme conversions of MFA vs the g2p_en phoneme conversions, they do not match. I can see that the base processor in your repo used g2p_en. The MFA produces different phoneme to the same word also g2p_en preserves punctuation but MFA doesn't. can you comment on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Discussion 😁 Discuss new feature enhancement 🚀 New feature or request FastSpeech FastSpeech related problems. Feature Request 🤗 Feature support question ❓ Further information is requested
Projects
FastSpeech
  
Done
Development

No branches or pull requests