-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deep clustering updates #92
Conversation
…s are the same => Bucketing sampler to ensure each element of the batch are approximately the same => Fix the bug where different random samples were taken for mixture and sources => Python code to create wav id sample count files => Introduce VAD mask for deep clustering => Normalize the deep clustering loss => Fix bug in the chimera++ model where projection view mixes up the seq and source indexes => Take log of spectra as input to train the model => Simple eval script which needs to be enhanced => Training script without pytorch lightning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a bunch for the PR !
I didn't spend much time reviewing the training script because it looked like WIP, let me know when I can review it.
Oh and the test for the DC loss breaks because the code change. The test will have to be updated before merging but that's not the priority
if log: | ||
#TODO: Use pytorch lightning logger here | ||
print('Using log spectrum as input') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's fine not to print anything actually, it will be be printed in the conf dictionary.
I would call it differently though, maybe take_log
or log_spec
?
@@ -433,7 +433,7 @@ class ChimeraPP(nn.Module): | |||
""" | |||
def __init__(self, in_chan, n_src, rnn_type = 'lstm', | |||
embedding_dim=20, n_layers=2, hidden_size=600, | |||
dropout=0, bidirectional=True): | |||
dropout=0, bidirectional=True, log=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update docstring as well
n_filters: 256 | ||
kernel_size: 256 | ||
stride: 64 | ||
log: True # Use log spectra as input |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the log is taken in the mask network, then it belongs in a mask network config
tr_wav_len_list: exp/tr.wavid.samples | ||
cv_wav_len_list: exp/cv.wavid.samples | ||
tt_wav_len_list: exp/tt.wavid.samples | ||
wav_base_path: /srv/storage/talc3@talc-data.nancy/multispeech/calcul/users/ssivasankaran/experiments/data/speech_separation/wsj0-mix/2speakers/wav8k/min/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config file shouldn't need modification, absolute path have to be in the run.sh.
The text files containing infos about the dataset have to be under ./data/
masker = ChimeraPP(int(enc.filterbank.n_feats_out/2), 2, | ||
embedding_dim=20, n_layers=2, hidden_size=600, \ | ||
dropout=0, bidirectional=True) | ||
dropout=0.5, bidirectional=True, \ | ||
log=conf['filterbank']['log']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These constants should go in conf.yml
, expose the ones that we'll experiment with in the run.sh as wel.
# Removing additional saved info | ||
checkpoint['state_dict'].pop('enc.filterbank._filters') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this necessary? Did it crash?
@@ -33,8 +35,31 @@ def make_model_and_optimizer(conf): | |||
enc = fb.Encoder(fb.STFTFB(**conf['filterbank'])) | |||
masker = ChimeraPP(int(enc.filterbank.n_feats_out/2), 2, | |||
embedding_dim=20, n_layers=2, hidden_size=600, \ | |||
dropout=0, bidirectional=True) | |||
dropout=0.5, bidirectional=True, \ | |||
log=conf['filterbank']['log']) | |||
model = Model(enc, masker) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if the STFT is not inverted at train time, the goal with eventually be to go back to the time domain. You might as well attach the iSTFT to the model, it will make the rest easier IMO.
source: github.com/funcwj/deep-clustering.git | ||
''' | ||
# to dB | ||
spectra_db = 20 * torch.log10(spectra) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stabilize log here
max_magnitude_db = torch.max(spectra_db) | ||
threshold = 10**((max_magnitude_db - threshold_db) / 20) | ||
mask = spectra > threshold | ||
return mask.double() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is float
not sufficient?
@@ -0,0 +1,31 @@ | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this follow the same format as in the WHAM recipes? It would make it much easier to switch from one to the other (specially for the wsj0-3mix that is not covered in WHAM)
Deep clustering updates
=> Create collate functions to ensure the seq length of batch elements are the same
=> Bucketing sampler to ensure each element of the batch are approximately of the same seq length, this avoid chopping of large parts of the dataset
=> Fix the bug where different random samples were taken for mixture and sources
=> Python code to create wav id sample count files
=> Introduce VAD mask for deep clustering
=> Normalize the deep clustering loss
=> Fix bug in the chimera++ model where projection view mixes up the seq and source indexes
=> Take log of spectra as input to train the model
=> Simple eval script which needs to be enhanced
=> Training script without pytorch lightning (To be made compatible with the core asteroid)