Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep clustering updates #92

Closed
wants to merge 1 commit into from
Closed

Deep clustering updates #92

wants to merge 1 commit into from

Conversation

sunits
Copy link
Collaborator

@sunits sunits commented May 4, 2020

Deep clustering updates
=> Create collate functions to ensure the seq length of batch elements are the same
=> Bucketing sampler to ensure each element of the batch are approximately of the same seq length, this avoid chopping of large parts of the dataset
=> Fix the bug where different random samples were taken for mixture and sources
=> Python code to create wav id sample count files
=> Introduce VAD mask for deep clustering
=> Normalize the deep clustering loss
=> Fix bug in the chimera++ model where projection view mixes up the seq and source indexes
=> Take log of spectra as input to train the model
=> Simple eval script which needs to be enhanced
=> Training script without pytorch lightning (To be made compatible with the core asteroid)

…s are the same

=> Bucketing sampler to ensure each element of the batch are approximately the same
=> Fix the bug where different random samples were taken for mixture and sources
=> Python code to create wav id sample count files
=> Introduce VAD mask for deep clustering
=> Normalize the deep clustering loss
=> Fix bug in the chimera++ model where projection view mixes up the seq and source indexes
=> Take log of spectra as input to train the model
=> Simple eval script which needs to be enhanced
=> Training script without pytorch lightning
@sunits sunits marked this pull request as draft May 4, 2020 07:55
@sunits sunits changed the title => Create collate functions to ensure the seq length of batch element… Deep clustering updates May 4, 2020
Copy link
Collaborator

@mpariente mpariente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a bunch for the PR !
I didn't spend much time reviewing the training script because it looked like WIP, let me know when I can review it.

Oh and the test for the DC loss breaks because the code change. The test will have to be updated before merging but that's not the priority

Comment on lines +451 to +453
if log:
#TODO: Use pytorch lightning logger here
print('Using log spectrum as input')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine not to print anything actually, it will be be printed in the conf dictionary.
I would call it differently though, maybe take_log or log_spec?

@@ -433,7 +433,7 @@ class ChimeraPP(nn.Module):
"""
def __init__(self, in_chan, n_src, rnn_type = 'lstm',
embedding_dim=20, n_layers=2, hidden_size=600,
dropout=0, bidirectional=True):
dropout=0, bidirectional=True, log=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update docstring as well

n_filters: 256
kernel_size: 256
stride: 64
log: True # Use log spectra as input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the log is taken in the mask network, then it belongs in a mask network config

tr_wav_len_list: exp/tr.wavid.samples
cv_wav_len_list: exp/cv.wavid.samples
tt_wav_len_list: exp/tt.wavid.samples
wav_base_path: /srv/storage/talc3@talc-data.nancy/multispeech/calcul/users/ssivasankaran/experiments/data/speech_separation/wsj0-mix/2speakers/wav8k/min/
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config file shouldn't need modification, absolute path have to be in the run.sh.
The text files containing infos about the dataset have to be under ./data/

Comment on lines 36 to +39
masker = ChimeraPP(int(enc.filterbank.n_feats_out/2), 2,
embedding_dim=20, n_layers=2, hidden_size=600, \
dropout=0, bidirectional=True)
dropout=0.5, bidirectional=True, \
log=conf['filterbank']['log'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constants should go in conf.yml, expose the ones that we'll experiment with in the run.sh as wel.

Comment on lines +59 to +60
# Removing additional saved info
checkpoint['state_dict'].pop('enc.filterbank._filters')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this necessary? Did it crash?

@@ -33,8 +35,31 @@ def make_model_and_optimizer(conf):
enc = fb.Encoder(fb.STFTFB(**conf['filterbank']))
masker = ChimeraPP(int(enc.filterbank.n_feats_out/2), 2,
embedding_dim=20, n_layers=2, hidden_size=600, \
dropout=0, bidirectional=True)
dropout=0.5, bidirectional=True, \
log=conf['filterbank']['log'])
model = Model(enc, masker)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the STFT is not inverted at train time, the goal with eventually be to go back to the time domain. You might as well attach the iSTFT to the model, it will make the rest easier IMO.

source: github.com/funcwj/deep-clustering.git
'''
# to dB
spectra_db = 20 * torch.log10(spectra)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stabilize log here

max_magnitude_db = torch.max(spectra_db)
threshold = 10**((max_magnitude_db - threshold_db) / 20)
mask = spectra > threshold
return mask.double()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is float not sufficient?

@@ -0,0 +1,31 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this follow the same format as in the WHAM recipes? It would make it much easier to switch from one to the other (specially for the wsj0-3mix that is not covered in WHAM)

@mpariente
Copy link
Collaborator

Closed by #95 and #96

@mpariente mpariente closed this May 11, 2020
@mpariente mpariente deleted the dc_recipe branch August 11, 2020 14:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants