Skip to content

Commit

Permalink
Use new batch-wise mean/std normalization. Add additional parameters …
Browse files Browse the repository at this point in the history
…which are passed on to morlet_filter_bank. Add documentation
  • Loading branch information
JohnVinyard committed Jan 25, 2019
1 parent 41627cd commit a04c7b1
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions zounds/learn/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,54 @@
from zounds.spectral import morlet_filter_bank, AWeighting, FrequencyDimension
from zounds.core import ArrayWithUnits
from zounds.timeseries import TimeDimension
from util import batchwise_mean_std_normalization


class FilterBank(nn.Module):
def __init__(self, samplerate, kernel_size, scale, scaling_factors):
"""
A torch module that convolves a 1D input signal with a bank of morlet
filters.
Args:
samplerate (SampleRate): the samplerate of the input signal
kernel_size (int): the length in samples of each filter
scale (FrequencyScale): a scale whose center frequencies determine the
fundamental frequency of each filer
scaling_factors (int or list of int): Scaling factors for each band,
which determine the time-frequency resolution tradeoff.
The number(s) should fall between 0 and 1, with smaller numbers
achieving better frequency resolution, and larget numbers better
time resolution
normalize_filters (bool): When true, ensure that each filter in the bank
has unit norm
a_weighting (bool): When true, apply a perceptually-motivated weighting
of the filters
See Also:
:class:`~zounds.spectral.AWeighting`
:function:`~zounds.spectral.morlet_filter_bank`
"""

def __init__(
self,
samplerate,
kernel_size,
scale,
scaling_factors,
normalize_filters=True,
a_weighting=True):

super(FilterBank, self).__init__()

filter_bank = morlet_filter_bank(
samplerate,
kernel_size,
scale,
scaling_factors)
filter_bank *= AWeighting()
scaling_factors,
normalize=normalize_filters)

if a_weighting:
filter_bank *= AWeighting()

self.scale = scale
self.filter_bank = torch.from_numpy(filter_bank).float() \
Expand Down Expand Up @@ -46,17 +82,6 @@ def temporal_pooling(self, x, kernel_size, stride):
x = F.avg_pool1d(x, kernel_size, stride, padding=kernel_size // 2)
return x

def normalize(self, x):
"""
give each instance zero mean and unit variance
"""
orig_shape = x.shape
x = x.view(x.shape[0], -1)
x = x - x.mean(dim=1, keepdim=True)
x = x / (x.std(dim=1, keepdim=True) + 1e-8)
x = x.view(orig_shape)
return x

def transform(self, samples, pooling_kernel_size, pooling_stride):
# convert the raw audio samples to a PyTorch tensor
tensor_samples = torch.from_numpy(samples).float() \
Expand Down Expand Up @@ -85,6 +110,6 @@ def forward(self, x, normalize=True):
x = self.log_magnitude(x)

if normalize:
x = self.normalize(x)
x = batchwise_mean_std_normalization(x)

return x[..., :nsamples].contiguous()
return x[..., :nsamples].contiguous()

0 comments on commit a04c7b1

Please sign in to comment.