Skip to content

Commit

Permalink
Update pytorch frequency decomposition to only perform a single trans…
Browse files Browse the repository at this point in the history
…form, slicing coefficients for the inverse
  • Loading branch information
JohnVinyard committed Jun 24, 2018
1 parent a90fab2 commit 639ed77
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions zounds/learn/dct_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def idct(self, x, axis=-1):
basis = self._variable(self.dct_basis(x.shape[axis]))
return self._base_dct_transform(x, basis.t(), axis)

def dct_resample(self, x, factor, axis=-1):

def dct_resample(self, x, factor, axis=-1, zero_slice=None):
# figure out how many samples our resampled signal will have
n_samples = int(factor * x.shape[axis])

Expand All @@ -83,19 +82,38 @@ def dct_resample(self, x, factor, axis=-1):

new_coeffs[new_coeffs_slices] = coeffs[old_coeffs_slices]

if zero_slice is not None:
slce = [slice(None)] * x.dim()
slce[axis] = zero_slice
slce = tuple(slce)
new_coeffs[slce] = 0

return self.idct(new_coeffs)

def frequency_decomposition(self, x, factors, axis=-1):
bands = []

full_size = x.shape[axis]
factors = sorted(factors)
for f in factors:
if f == 1:
bands.append(x)
else:
rs = self.dct_resample(x, f, axis)
bands.append(rs)
us = self.dct_resample(rs, (1. / f), axis)
x = x - us
factors = [0] + list(factors)
coeffs = self.dct(x, axis=axis)

bands = []
for i in xrange(0, len(factors) - 1):
start_index = int(factors[i] * full_size)
stop_index = int(factors[i + 1] * full_size)

new_shape = list(coeffs.shape)
new_shape[axis] = stop_index
new_coeffs = self._variable(torch.zeros(new_shape))

slce = [slice(None)] * x.dim()
slce[axis] = slice(start_index, stop_index)

new_coeffs[slce] = coeffs[slce]

resampled = self.idct(new_coeffs)
bands.append(resampled)

return bands

def short_time_dct(self, x, size, step, window):
Expand Down

0 comments on commit 639ed77

Please sign in to comment.