Skip to content

Commit

Permalink
More sampling rate checks and tests.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
  • Loading branch information
mzient committed Dec 19, 2019
1 parent 526c6a9 commit 979f5fd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
3 changes: 3 additions & 0 deletions dali/operators/decoder/audio/audio_decoder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ void AudioDecoderCpu::DecodeBatch(workspace_t<Backend> &ws) {
tp.DoWorkWithID([&, i](int thread_id) {
try {
DecodeSample<OutputType>(decoded_output[i], thread_id, i);
sample_rate_output[i].data[0] = use_resampling_
? target_sample_rates_[i]
: sample_meta_[i].sample_rate;
} catch (const DALIException &e) {
DALI_FAIL(make_string("Error decoding file.\nError: ", e.what(), "\nFile: ",
files_names_[i], "\n"));
Expand Down
16 changes: 11 additions & 5 deletions dali/test/python/test_operator_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def __init__(self):

def define_graph(self):
self.raw_file = self.file_source()
dec_plain,_ = self.plain_decoder(self.raw_file)
dec_res,_ = self.resampling_decoder(self.raw_file)
dec_mix,_ = self.downmixing_decoder(self.raw_file)
dec_res_mix,_ = self.resampling_downmixing_decoder(self.raw_file)
out = [dec_plain, dec_res, dec_mix, dec_res_mix]
dec_plain, rates_plain = self.plain_decoder(self.raw_file)
dec_res, rates_res = self.resampling_decoder(self.raw_file)
dec_mix, rates_mix = self.downmixing_decoder(self.raw_file)
dec_res_mix, rates_res_mix = self.resampling_downmixing_decoder(self.raw_file)
out = [dec_plain, dec_res, dec_mix, dec_res_mix,
rates_plain, rates_res, rates_mix, rates_res_mix]
return out

def iter_setup(self):
Expand Down Expand Up @@ -113,6 +114,11 @@ def test_decoded_vs_generated():
ref3 = generate_waveforms(ref_len[3], freqs[idx] * (rates[idx] / rate2))
ref3 = ref3.mean(axis = 1, keepdims = 1)

assert(out[4].at(i)[0] == rates[idx])
assert(out[5].at(i)[0] == rate1)
assert(out[6].at(i)[0] == rates[idx])
assert(out[7].at(i)[0] == rate2)

# just reading - allow only for rounding
assert np.allclose(plain, ref0, rtol = 0, atol=0.5)
# resampling - allow for 1e-3 dynamic range error
Expand Down

0 comments on commit 979f5fd

Please sign in to comment.