Skip to content

Commit

Permalink
Add documentation for PerceptualLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnVinyard committed Jul 7, 2018
1 parent b3724ff commit de94b5a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
10 changes: 9 additions & 1 deletion docs/source/learn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ The Basics
.. autoclass:: PreprocessResult
.. autoclass:: PipelineResult

Custom Losses
-------------
.. autoclass:: PerceptualLoss
:members:

Data Preparation
----------------
.. autoclass:: UnitNorm
Expand Down Expand Up @@ -41,4 +46,7 @@ Hashing

Learned Models in Audio Processing Graphs
-----------------------------------------
.. autoclass:: Learned
.. autoclass:: Learned
:members:


58 changes: 32 additions & 26 deletions zounds/learn/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch
from scipy.signal import gaussian
from torch import nn
Expand All @@ -7,10 +6,36 @@

from dct_transform import DctTransform
from zounds.spectral import fir_filter_bank
from zounds.timeseries import Picoseconds, SampleRate
from zounds.timeseries import SampleRate


class PerceptualLoss(object):
"""
`PerceptualLoss` computes loss/distance in a feature space that roughly
approximates early stages of the human audio processing pipeline, instead
of computing raw sample loss. It decomposes a 1D (audio) signal into
frequency bands using an FIR filter bank whose frequencies are centered
according to a user-defined scale, performs half-wave rectification, puts
amplitudes on a log scale, and finally optionally applies a re-weighting
of frequency bands.
Args:
scale (FrequencyScale): a scale defining frequencies at which the FIR
filters will be centered
samplerate (SampleRate): samplerate needed to construct the FIR filter
bank
frequency_window (ndarray): window determining how narrow or wide filter
responses should be
basis_size (int): The kernel size, or number of "taps" for each filter
lap (int): The filter stride
log_factor (int): How much compression should be applied in the log
amplitude stage
frequency_weighting (FrequencyWeighting): an optional frequency
weighting to be applied after log amplitude scaling
cosine_similarity (bool): If `True`, compute the cosine similarity
between spectrograms, otherwise, compute the mean squared error
"""


class PerceptualLoss(nn.MSELoss):
def __init__(
self,
scale,
Expand All @@ -20,13 +45,10 @@ def __init__(
lap=2,
log_factor=100,
frequency_weighting=None,
phase_locking_cutoff_hz=None,
phase_locking_taps=64,
cosine_similarity=True):

super(PerceptualLoss, self).__init__()

self.phase_locking_taps = phase_locking_taps
self.cosine_similarity = cosine_similarity
self.log_factor = log_factor
self.scale = scale
Expand All @@ -37,25 +59,15 @@ def __init__(
basis = fir_filter_bank(
scale, basis_size, samplerate, frequency_window)

weights = Variable(torch.from_numpy(basis).float())
weights = torch.from_numpy(basis).float()
# out channels x in channels x kernel width
self.weights = weights.view(len(scale), 1, basis_size).contiguous()

self.frequency_weights = None
if frequency_weighting:
fw = frequency_weighting._wdata(self.scale)
self.frequency_weights = Variable(torch.from_numpy(fw).float())
self.frequency_weights = self.frequency_weights.view(
1, len(self.scale), 1)

self.pool_amount = None

if phase_locking_cutoff_hz is not None:
sr = SampleRate(
frequency=samplerate.frequency / lap,
duration=samplerate.duration / lap)
one_cycle = Picoseconds(int(1e12)) / phase_locking_cutoff_hz
self.pool_amount = int(np.ceil(one_cycle / sr.frequency))
self.frequency_weights = torch.from_numpy(fw)\
.float().view(1, len(self.scale), 1)

def cuda(self, device=None):
self.weights = self.weights.cuda()
Expand All @@ -80,12 +92,6 @@ def _transform(self, x):
if self.frequency_weights is not None:
features = features * self.frequency_weights

# loss of phase locking
if self.pool_amount is not None:
features = features.view(x.shape[0], 1, len(self.scale), -1)
features = F.max_pool2d(
features, (1, self.pool_amount), stride=(1, 1))

return features

def forward(self, input, target):
Expand Down

0 comments on commit de94b5a

Please sign in to comment.