Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep clustering/Chimera recipe #96

Merged
merged 9 commits into from
May 11, 2020
Merged

Deep clustering/Chimera recipe #96

merged 9 commits into from
May 11, 2020

Conversation

mpariente
Copy link
Collaborator

Adding the Deep clustering / Chimera++ recipe on wsj2mix and wsj3mix.
Mainly based on @sunits initial work.

Data prep, dataloader, training and evaluation script.

Evaluatio script will use the mask-inference head whenever possible and DC head only if loss_alpha is equal to 1.

First successful iteration on the DC head only gets 10.1 dB SDR improvement (not SI-SDR).

Things left to do :

  • Upload some more results on DC alone, Chimera++ (maybe even 3mix)
  • Try to use kmeans in native PyTorch instead of sklearn because evaluation is really slow.

Comment on lines +65 to +67
proj = proj.view(batch, n_frames, -1, self.embedding_dim).transpose(1, 2)
# (batch, freq * frames, emb)
proj = proj.reshape(batch, -1, self.embedding_dim)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug was here. Without the transpose, the time bins where not aligned with each other and training was impossible.
I added a note about it in the DC loss.

Comment on lines +146 to +155
try:
# Last best model summary
with open(os.path.join(exp_dir, 'best_k_models.json'), "r") as f:
best_k = json.load(f)
best_model_path = min(best_k, key=best_k.get)
except FileNotFoundError:
# Get last checkpoint
all_ckpt = os.listdir(os.path.join(exp_dir, 'checkpoints/'))
all_ckpt.sort()
best_model_path = os.path.join(exp_dir, 'checkpoints', all_ckpt[-1])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this would be a way to bypass the best_k_models.json.

with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
json.dump(checkpoint.best_k_models, f, indent=0)
#torch.save(system.model.state_dict(), os.path.join(exp_dir, 'final.pth'))
# Save last model for convenience
torch.save(system.model.state_dict(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this would be another one. But if training didn't finish, this model is not saved..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant