Skip to content

Commit

Permalink
Small reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
mpariente committed Apr 7, 2020
1 parent 8d1ca2e commit aa3485b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 65 deletions.
10 changes: 6 additions & 4 deletions egs/wham/WaveSplit/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### WaveSplit

things currently not clear:
---
#### Currently not clear:

- not clear if different encoders are used for separation and speaker stack. (from image in the paper it seems so)
- what is embedding dimension ? It seems 512 but it is not explicit in the paper
- mask used (sigmoid ?)
Expand All @@ -10,8 +10,10 @@ things currently not clear:
- loss right now is prone to go NaN especially if we don't take the mean after l2-distances computation.

---
structure:
- train.py contains training loop (nets instantiation lines 48-60, training loop lines 100- 116)
#### Structure:
- train.py contains training loop (nets instantiation
[lines 48-60](https://github.com/mpariente/asteroid/pull/70/files#diff-f69bcb61820a4a7cfc8fda9a554c251cR49), training loop lines
[100- 116](https://github.com/mpariente/asteroid/pull/70/files#diff-f69bcb61820a4a7cfc8fda9a554c251cR100))
- losses.py wavesplit losses
- wavesplit.py sep and speaker stacks nets
- wavesplitwham.py dataset parsing
26 changes: 3 additions & 23 deletions egs/wham/WaveSplit/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,21 @@


class ClippedSDR(nn.Module):

def __init__(self, clip_value=-30):
super(ClippedSDR, self).__init__()

self.snr = MultiSrcNegSDR("snr")
self.clip_value = float(clip_value)

def forward(self, est_targets, targets):

return torch.clamp(self.snr(est_targets, targets), min=self.clip_value)


class SpeakerVectorLoss(nn.Module):

def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="global",
weight=10, distance_reg=0.3, gaussian_reg=0.2, return_oracle=True):
super(SpeakerVectorLoss, self).__init__()


# not clear how embeddings are initialized.

self.learnable_emb = learnable_emb
Expand All @@ -35,7 +31,6 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob
self.distance_reg = float(distance_reg)
self.gaussian_reg = float(gaussian_reg)
self.return_oracle = return_oracle

assert loss_type in ["distance", "global", "local"]

# I initialize embeddings to be on unit sphere as speaker stack uses euclidean normalization
Expand All @@ -53,7 +48,6 @@ def __init__(self, n_speakers, embed_dim=32, learnable_emb=True, loss_type="glob
self.alpha = nn.Parameter(torch.Tensor([1.])) # not clear how these are initialized...
self.beta = nn.Parameter(torch.Tensor([0.]))


### losses go to NaN if I follow strictly the formulas maybe I am missing something...

@staticmethod
Expand Down Expand Up @@ -96,7 +90,9 @@ def _l_global_speaker(self, c_spk_vec_perm, spk_embeddings, spk_labels, spk_mask

def forward(self, speaker_vectors, spk_mask, spk_labels):

# spk_mask ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now.
# spk_mask ideally would be the speaker activty at frame level.
# Because WHAM speakers can be considered always two and active we
# fix this for now.
# mask with ones and zeros B, SRC, FRAMES

if self.gaussian_reg:
Expand Down Expand Up @@ -180,19 +176,3 @@ def forward(self, speaker_vectors, spk_mask, spk_labels):
c = ClippedSDR(-30)
a = torch.rand((2, 3, 200))
print(c(a, a))
















38 changes: 9 additions & 29 deletions egs/wham/WaveSplit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,14 @@
parser.add_argument('--exp_dir', default='exp/tmp',
help='Full path to save best validation model')

warnings.simplefilter("ignore", UserWarning)


class Wavesplit(pl.LightningModule): # redefinition

def __init__(self, train_loader,
val_loader=None, scheduler=None, config=None):
super().__init__()

# instantiation of stacks optimizers etc
# NOTE: I use separated encoders for speaker and sep stack as it is not specified in the paper...
# NOTE: I use separated encoders for speaker and sep stack
# as it is not specified in the paper...

self.enc_spk, self.dec = make_enc_dec("free", 512, 16, 8)
self.enc_sep = deepcopy(self.enc_spk)
Expand Down Expand Up @@ -75,7 +72,6 @@ def forward(self, *args, **kwargs):
Returns:
:class:`torch.Tensor`
"""

return self.model(*args, **kwargs)

def common_step(self, batch, batch_nb):
Expand All @@ -102,8 +98,13 @@ def common_step(self, batch, batch_nb):
spk_vectors = self.spk_stack(tf_rep)
B, src, embed, frames = spk_vectors.size()

# torch.ones ideally would be the speaker activty at frame level. Because WHAM speakers can be considered always two and active we fix this for now.
spk_loss, spk_vectors, oracle = self.spk_loss(spk_vectors, torch.ones((B, src, frames)).to(spk_vectors.device), spk_ids)
# torch.ones ideally would be the speaker activty at frame level.
# Because WHAM speakers can be considered always two and active we
# fix this for now.
spk_loss, spk_vectors, oracle = self.spk_loss(
spk_vectors, torch.ones((B, src, frames)).to(spk_vectors.device),
spk_ids
)
tf_rep = self.enc_sep(inputs)
B, n_filters, frames = tf_rep.size()
tf_rep = tf_rep[:, None, ...].expand(-1, src, -1, -1).reshape(B*src, n_filters, frames)
Expand Down Expand Up @@ -218,32 +219,11 @@ def train_dataloader(self):
def val_dataloader(self):
return self.val_loader

@pl.data_loader
def tng_dataloader(self): # pragma: no cover
""" Deprecated."""
pass

def on_save_checkpoint(self, checkpoint):
""" Overwrite if you want to save more things in the checkpoint."""
checkpoint['training_config'] = self.config
return checkpoint

def on_batch_start(self, batch):
""" Overwrite if needed. Called by pytorch-lightning"""
pass

def on_batch_end(self):
""" Overwrite if needed. Called by pytorch-lightning"""
pass

def on_epoch_start(self):
""" Overwrite if needed. Called by pytorch-lightning"""
pass

def on_epoch_end(self):
""" Overwrite if needed. Called by pytorch-lightning"""
pass

@staticmethod
def none_to_string(dic):
""" Converts `None` to ``'None'`` to be handled by torch summary writer.
Expand Down
10 changes: 1 addition & 9 deletions egs/wham/WaveSplit/wavesplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Conv1DBlock(nn.Module):

def __init__(self, hid_chan, kernel_size, padding,
dilation, norm_type="gLN"):
super(Conv1DBlock, self).__init__()
Expand All @@ -24,12 +23,11 @@ def forward(self, x):

return self.out(x)

class SepConv1DBlock(nn.Module):

class SepConv1DBlock(nn.Module):
def __init__(self, in_chan_spk_vec, hid_chan, kernel_size, padding,
dilation, norm_type="gLN", use_FiLM=True):
super(SepConv1DBlock, self).__init__()

self.use_FiLM = use_FiLM
conv_norm = norms.get(norm_type)
self.depth_conv1d = nn.Conv1d(hid_chan, hid_chan, kernel_size,
Expand Down Expand Up @@ -61,7 +59,6 @@ def forward(self, x, spk_vec):

class SpeakerStack(nn.Module):
# basically this is plain conv-tasnet remove this in future releases

def __init__(self, in_chan, n_src, n_blocks=14, n_repeats=1,
kernel_size=3,
norm_type="gLN"):
Expand Down Expand Up @@ -183,8 +180,3 @@ def get_config(self):
'norm_type': self.norm_type,
}
return config





0 comments on commit aa3485b

Please sign in to comment.