diff --git a/include/rppdefs.h b/include/rppdefs.h index d35f82bfd..f7cd7a711 100644 --- a/include/rppdefs.h +++ b/include/rppdefs.h @@ -369,10 +369,13 @@ typedef enum */ typedef enum { - NCHW, - NHWC, - NCDHW, - NDHWC + NCHW, // BatchSize-Channels-Height-Width + NHWC, // BatchSize-Height-Width-Channels + NCDHW, // BatchSize-Channels-Depth-Height-Width + NDHWC, // BatchSize-Depth-Height-Width-Channels + NHW, // BatchSize-Height-Width + NFT, // BatchSize-Frequency-Time -> Frequency Major used for Spectrogram / MelfilterBank + NTF // BatchSize-Time-Frequency -> Time Major used for Spectrogram / MelfilterBank } RpptLayout; /*! \brief RPPT Tensor 2D ROI type enum @@ -434,6 +437,15 @@ typedef enum TF, //Time Major } RpptSpectrogramLayout; +/*! \brief RPPT Mel Scale Formula + * \ingroup group_rppdefs + */ +typedef enum +{ + SLANEY = 0, // Follows Slaney’s MATLAB Auditory Modelling Work behavior + HTK, // Follows O’Shaughnessy’s book formula, consistent with Hidden Markov Toolkit(HTK), m = 2595 * log10(1 + (f/700)) +} RpptMelScaleFormula; + /*! \brief RPPT Tensor 2D ROI LTRB struct * \ingroup group_rppdefs */ diff --git a/include/rppt_tensor_audio_augmentations.h b/include/rppt_tensor_audio_augmentations.h index 4e5f412db..13259cd22 100644 --- a/include/rppt_tensor_audio_augmentations.h +++ b/include/rppt_tensor_audio_augmentations.h @@ -132,6 +132,26 @@ RppStatus rppt_down_mixing_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_ */ RppStatus rppt_spectrogram_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcLengthTensor, bool centerWindows, bool reflectPadding, Rpp32f *windowFunction, Rpp32s nfft, Rpp32s power, Rpp32s windowLength, Rpp32s windowStep, RpptSpectrogramLayout layout, rppHandle_t rppHandle); +/*! \brief Mel filter bank augmentation HOST backend + * \details Mel filter bank augmentation for audio data + * \param[in] srcPtr source tensor in HOST memory + * \param[in] srcDescPtr source tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32, layout - NFT / NTF) + * \param[out] dstPtr destination tensor in HOST memory + * \param[in] dstDescPtr destination tensor descriptor (Restrictions - numDims = 3, offsetInBytes >= 0, dataType = F32, layout - NFT / NTF) + * \param[in] srcDimsTensor source audio buffer length and number of channels (1D tensor in HOST memory, of size batchSize * 2) + * \param[in] maxFreq maximum frequency if not provided maxFreq = sampleRate / 2 + * \param[in] minFreq minimum frequency + * \param[in] melFormula formula used to convert frequencies from hertz to mel and from mel to hertz (SLANEY / HTK) + * \param[in] numFilter number of mel filters + * \param[in] sampleRate sampling rate of the audio + * \param[in] normalize boolean variable that determine whether to normalize weights / not + * \param[in] rppHandle RPP HOST handle created with \ref rppCreateWithBatchSize() + * \return A \ref RppStatus enumeration. + * \retval RPP_SUCCESS Successful completion. + * \retval RPP_ERROR* Unsuccessful completion. + */ +RppStatus rppt_mel_filter_bank_host(RppPtr_t srcPtr, RpptDescPtr srcDescPtr, RppPtr_t dstPtr, RpptDescPtr dstDescPtr, Rpp32s *srcDims, Rpp32f maxFreq, Rpp32f minFreq, RpptMelScaleFormula melFormula, Rpp32s numFilter, Rpp32f sampleRate, bool normalize, rppHandle_t rppHandle); + /*! \brief Resample augmentation on HOST backend * \details Resample augmentation for audio data * \param[in] srcPtr source tensor in HOST memory diff --git a/src/modules/cpu/host_tensor_audio_augmentations.hpp b/src/modules/cpu/host_tensor_audio_augmentations.hpp index e38d2d843..82d43d082 100644 --- a/src/modules/cpu/host_tensor_audio_augmentations.hpp +++ b/src/modules/cpu/host_tensor_audio_augmentations.hpp @@ -30,6 +30,7 @@ SOFTWARE. #include "kernel/pre_emphasis_filter.hpp" #include "kernel/down_mixing.hpp" #include "kernel/spectrogram.hpp" +#include "kernel/mel_filter_bank.hpp" #include "kernel/resample.hpp" #endif // HOST_TENSOR_AUDIO_AUGMENTATIONS_HPP \ No newline at end of file diff --git a/src/modules/cpu/kernel/mel_filter_bank.hpp b/src/modules/cpu/kernel/mel_filter_bank.hpp new file mode 100644 index 000000000..9cc6d26d2 --- /dev/null +++ b/src/modules/cpu/kernel/mel_filter_bank.hpp @@ -0,0 +1,252 @@ +/* +MIT License + +Copyright (c) 2019 - 2024 Advanced Micro Devices, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#include "rppdefs.h" +#include "rpp_cpu_simd.hpp" +#include "rpp_cpu_common.hpp" + +struct BaseMelScale +{ + public: + virtual Rpp32f hz_to_mel(Rpp32f hz) = 0; + virtual Rpp32f mel_to_hz(Rpp32f mel) = 0; + virtual ~BaseMelScale() = default; +}; + +struct HtkMelScale : public BaseMelScale +{ + Rpp32f hz_to_mel(Rpp32f hz) { return 1127.0f * std::log(1.0f + (hz / 700.0f)); } + Rpp32f mel_to_hz(Rpp32f mel) { return 700.0f * (std::exp(mel / 1127.0f) - 1.0f); } + public: + ~HtkMelScale() {}; +}; + +struct SlaneyMelScale : public BaseMelScale +{ + const Rpp32f freqLow = 0; + const Rpp32f fsp = 200.0 / 3.0; + const Rpp32f minLogHz = 1000.0; + const Rpp32f minLogMel = (minLogHz - freqLow) / fsp; + const Rpp32f stepLog = 0.068751777; // Equivalent to std::log(6.4) / 27.0; + + const Rpp32f invMinLogHz = 1.0f / 1000.0; + const Rpp32f invStepLog = 1.0f / stepLog; + const Rpp32f invFsp = 1.0f / fsp; + + Rpp32f hz_to_mel(Rpp32f hz) + { + Rpp32f mel = 0.0f; + if (hz >= minLogHz) + mel = minLogMel + std::log(hz * invMinLogHz) * invStepLog; + else + mel = (hz - freqLow) * invFsp; + + return mel; + } + + Rpp32f mel_to_hz(Rpp32f mel) + { + Rpp32f hz = 0.0f; + if (mel >= minLogMel) + hz = minLogHz * std::exp(stepLog * (mel - minLogMel)); + else + hz = freqLow + mel * fsp; + return hz; + } + public: + ~SlaneyMelScale() {}; +}; + +RppStatus mel_filter_bank_host_tensor(Rpp32f *srcPtr, + RpptDescPtr srcDescPtr, + Rpp32f *dstPtr, + RpptDescPtr dstDescPtr, + Rpp32s *srcDimsTensor, + Rpp32f maxFreqVal, // check unused + Rpp32f minFreqVal, + RpptMelScaleFormula melFormula, + Rpp32s numFilter, + Rpp32f sampleRate, + bool normalize, + rpp::Handle& handle) +{ + BaseMelScale *melScalePtr; + switch(melFormula) + { + case RpptMelScaleFormula::HTK: + melScalePtr = new HtkMelScale; + break; + case RpptMelScaleFormula::SLANEY: + default: + melScalePtr = new SlaneyMelScale(); + break; + } + Rpp32u numThreads = handle.GetNumThreads(); + Rpp32u batchSize = srcDescPtr->n; + Rpp32f *scratchMem = handle.GetInitHandle()->mem.mcpu.scratchBufferHost; + + Rpp32f maxFreq = sampleRate / 2; + Rpp32f minFreq = minFreqVal; + + // Convert lower, higher frequencies to mel scale and find melStep + Rpp64f melLow = melScalePtr->hz_to_mel(minFreq); + Rpp64f melHigh = melScalePtr->hz_to_mel(maxFreq); + Rpp64f melStep = (melHigh - melLow) / (numFilter + 1); + + omp_set_dynamic(0); +#pragma omp parallel for num_threads(numThreads) + for(int batchCount = 0; batchCount < batchSize; batchCount++) + { + Rpp32f *srcPtrTemp = srcPtr + batchCount * srcDescPtr->strides.nStride; + Rpp32f *dstPtrTemp = dstPtr + batchCount * dstDescPtr->strides.nStride; + + // Extract nfft, number of Frames, numBins + Rpp32s nfft = (srcDimsTensor[batchCount * 2] - 1) * 2; + Rpp32s numBins = nfft / 2 + 1; + Rpp32s numFrames = srcDimsTensor[batchCount * 2 + 1]; + + // Find hzStep + Rpp64f hzStep = static_cast(sampleRate) / nfft; + Rpp64f invHzStep = 1.0 / hzStep; + + // Find fftBinStart and fftBinEnd + Rpp32s fftBinStart = std::ceil(minFreq * invHzStep); + Rpp32s fftBinEnd = std::ceil(maxFreq * invHzStep); + fftBinEnd = std::min(fftBinEnd, numBins); + + // Set/Fill normFactors, weightsDown and intervals + Rpp32f *normFactors = scratchMem + (batchCount * numFilter); + std::fill(normFactors, normFactors + numFilter, 1.f); // normFactors contain numFilter values of type float + Rpp32f *weightsDown = scratchMem + (batchSize * numFilter) + (batchCount * numBins); + memset(weightsDown, 0, sizeof(numBins * sizeof(Rpp32f))); // weightsDown contain numBins values of type float + Rpp32s *intervals = reinterpret_cast(weightsDown + (batchSize * numBins)); + std::fill(intervals, intervals + numBins, -1); // intervals contain numBins values of type integer + + Rpp32s fftBin = fftBinStart; + Rpp64f mel0 = melLow, mel1 = melLow + melStep; + Rpp64f fIter = fftBin * hzStep; + for (int interval = 0; interval < numFilter + 1; interval++, mel0 = mel1, mel1 += melStep) + { + Rpp64f f0 = melScalePtr->mel_to_hz(mel0); + Rpp64f f1 = melScalePtr->mel_to_hz(interval == numFilter ? melHigh : mel1); + Rpp64f slope = 1. / (f1 - f0); + + if (normalize && interval < numFilter) + { + Rpp64f f2 = melScalePtr->mel_to_hz(mel1 + melStep); + normFactors[interval] = 2.0 / (f2 - f0); + } + + for (; fftBin < fftBinEnd && fIter < f1; fftBin++, fIter = fftBin * hzStep) + { + weightsDown[fftBin] = (f1 - fIter) * slope; + intervals[fftBin] = interval; + } + } + + Rpp32u maxFrames = std::min(static_cast(numFrames + 8), dstDescPtr->strides.hStride); + Rpp32u maxAlignedLength = maxFrames & ~7; + Rpp32u vectorIncrement = 8; + + // Set ROI values in dst buffer to 0.0 + for(int i = 0; i < numFilter; i++) + { + Rpp32f *dstPtrRow = dstPtrTemp + i * dstDescPtr->strides.hStride; + Rpp32u vectorLoopCount = 0; + for(; vectorLoopCount < maxAlignedLength; vectorLoopCount += 8) + { + _mm256_storeu_ps(dstPtrRow, avx_p0); + dstPtrRow += 8; + } + for(; vectorLoopCount < maxFrames; vectorLoopCount++) + *dstPtrRow++ = 0.0f; + } + + Rpp32u alignedLength = numFrames & ~7; + __m256 pSrc, pDst; + Rpp32f *srcRowPtr = srcPtrTemp + fftBinStart * srcDescPtr->strides.hStride; + for (int64_t fftBin = fftBinStart; fftBin < fftBinEnd; fftBin++) + { + auto filterUp = intervals[fftBin]; + auto weightUp = 1.0f - weightsDown[fftBin]; + auto filterDown = filterUp - 1; + auto weightDown = weightsDown[fftBin]; + + if (filterDown >= 0) + { + Rpp32f *dstRowPtrTemp = dstPtrTemp + filterDown * dstDescPtr->strides.hStride; + Rpp32f *srcRowPtrTemp = srcRowPtr; + + if (normalize) + weightDown *= normFactors[filterDown]; + __m256 pWeightDown = _mm256_set1_ps(weightDown); + + int vectorLoopCount = 0; + for(; vectorLoopCount < alignedLength; vectorLoopCount += vectorIncrement) + { + pSrc = _mm256_loadu_ps(srcRowPtrTemp); + pSrc = _mm256_mul_ps(pSrc, pWeightDown); + pDst = _mm256_loadu_ps(dstRowPtrTemp); + pDst = _mm256_add_ps(pDst, pSrc); + _mm256_storeu_ps(dstRowPtrTemp, pDst); + dstRowPtrTemp += vectorIncrement; + srcRowPtrTemp += vectorIncrement; + } + + for (; vectorLoopCount < numFrames; vectorLoopCount++) + (*dstRowPtrTemp++) += weightDown * (*srcRowPtrTemp++); + } + + if (filterUp >= 0 && filterUp < numFilter) + { + Rpp32f *dstRowPtrTemp = dstPtrTemp + filterUp * dstDescPtr->strides.hStride; + Rpp32f *srcRowPtrTemp = srcRowPtr; + + if (normalize) + weightUp *= normFactors[filterUp]; + __m256 pWeightUp = _mm256_set1_ps(weightUp); + + int vectorLoopCount = 0; + for(; vectorLoopCount < alignedLength; vectorLoopCount += vectorIncrement) + { + pSrc = _mm256_loadu_ps(srcRowPtrTemp); + pSrc = _mm256_mul_ps(pSrc, pWeightUp); + pDst = _mm256_loadu_ps(dstRowPtrTemp); + pDst = _mm256_add_ps(pDst, pSrc); + _mm256_storeu_ps(dstRowPtrTemp, pDst); + dstRowPtrTemp += vectorIncrement; + srcRowPtrTemp += vectorIncrement; + } + + for (; vectorLoopCount < numFrames; vectorLoopCount++) + (*dstRowPtrTemp++) += weightUp * (*srcRowPtrTemp++); + } + + srcRowPtr += srcDescPtr->strides.hStride; + } + } + delete melScalePtr; + + return RPP_SUCCESS; +} diff --git a/src/modules/rppt_tensor_audio_augmentations.cpp b/src/modules/rppt_tensor_audio_augmentations.cpp index e20211ec1..bafaf93fb 100644 --- a/src/modules/rppt_tensor_audio_augmentations.cpp +++ b/src/modules/rppt_tensor_audio_augmentations.cpp @@ -197,6 +197,46 @@ RppStatus rppt_spectrogram_host(RppPtr_t srcPtr, } } +/******************** mel_filter_bank ********************/ + +RppStatus rppt_mel_filter_bank_host(RppPtr_t srcPtr, + RpptDescPtr srcDescPtr, + RppPtr_t dstPtr, + RpptDescPtr dstDescPtr, + Rpp32s* srcDimsTensor, + Rpp32f maxFreq, + Rpp32f minFreq, + RpptMelScaleFormula melFormula, + Rpp32s numFilter, + Rpp32f sampleRate, + bool normalize, + rppHandle_t rppHandle) +{ + if (srcDescPtr->layout != RpptLayout::NFT) return RPP_ERROR_INVALID_SRC_LAYOUT; + if (dstDescPtr->layout != RpptLayout::NFT) return RPP_ERROR_INVALID_DST_LAYOUT; + + if ((srcDescPtr->dataType == RpptDataType::F32) && (dstDescPtr->dataType == RpptDataType::F32)) + { + mel_filter_bank_host_tensor(static_cast(srcPtr), + srcDescPtr, + static_cast(dstPtr), + dstDescPtr, + srcDimsTensor, + maxFreq, + minFreq, + melFormula, + numFilter, + sampleRate, + normalize, + rpp::deref(rppHandle)); + return RPP_SUCCESS; + } + else + { + return RPP_ERROR_NOT_IMPLEMENTED; + } +} + /******************** resample ********************/ RppStatus rppt_resample_host(RppPtr_t srcPtr, diff --git a/utilities/test_suite/HOST/Tensor_host_audio.cpp b/utilities/test_suite/HOST/Tensor_host_audio.cpp index ac05bcc90..6f2b888c4 100644 --- a/utilities/test_suite/HOST/Tensor_host_audio.cpp +++ b/utilities/test_suite/HOST/Tensor_host_audio.cpp @@ -31,7 +31,7 @@ int main(int argc, char **argv) if (argc < MIN_ARG_COUNT) { printf("\nImproper Usage! Needs all arguments!\n"); - printf("\nUsage: ./Tensor_host_audio \n"); + printf("\nUsage: ./Tensor_host_audio \n"); return -1; } @@ -144,7 +144,7 @@ int main(int argc, char **argv) int noOfIterations = (int)audioNames.size() / batchSize; double maxWallTime = 0, minWallTime = 500, avgWallTime = 0; string testCaseName; - printf("\nRunning %s %d times (each time with a batch size of %d images) and computing mean statistics...", func.c_str(), numRuns, batchSize); + printf("\nRunning %s %d times (each time with a batch size of %d audio files) and computing mean statistics...", func.c_str(), numRuns, batchSize); for (int iterCount = 0; iterCount < noOfIterations; iterCount++) { // read and decode audio and fill the audio dim values @@ -348,6 +348,56 @@ int main(int argc, char **argv) break; } + case 7: + { + testCaseName = "mel_filter_bank"; + + Rpp32f sampleRate = 16000; + Rpp32f minFreq = 0.0; + Rpp32f maxFreq = sampleRate / 2; + RpptMelScaleFormula melFormula = RpptMelScaleFormula::SLANEY; + Rpp32s numFilter = 80; + bool normalize = true; + Rpp32s srcDimsTensor[] = {257, 225, 257, 211, 257, 214}; // (height, width) for each tensor in a batch for given QA inputs. + // Accepts outputs from FT layout of Spectrogram for QA + srcDescPtr->layout = dstDescPtr->layout = RpptLayout::NFT; + + maxDstHeight = 0; + maxDstWidth = 0; + maxSrcHeight = 0; + maxSrcWidth = 0; + for(int i = 0, j = 0; i < batchSize; i++, j += 2) + { + maxSrcHeight = std::max(maxSrcHeight, (int)srcDimsTensor[j]); + maxSrcWidth = std::max(maxSrcWidth, (int)srcDimsTensor[j + 1]); + dstDims[i].height = numFilter; + dstDims[i].width = srcDimsTensor[j + 1]; + maxDstHeight = std::max(maxDstHeight, (int)dstDims[i].height); + maxDstWidth = std::max(maxDstWidth, (int)dstDims[i].width); + } + + srcDescPtr->h = maxSrcHeight; + srcDescPtr->w = maxSrcWidth; + dstDescPtr->h = maxDstHeight; + dstDescPtr->w = maxDstWidth; + + set_audio_descriptor_dims_and_strides_nostriding(srcDescPtr, batchSize, maxSrcHeight, maxSrcWidth, maxSrcChannels, offsetInBytes); + set_audio_descriptor_dims_and_strides_nostriding(dstDescPtr, batchSize, maxDstHeight, maxDstWidth, maxDstChannels, offsetInBytes); + + // Set buffer sizes for src/dst + unsigned long long spectrogramBufferSize = (unsigned long long)srcDescPtr->h * (unsigned long long)srcDescPtr->w * (unsigned long long)srcDescPtr->c * (unsigned long long)srcDescPtr->n; + unsigned long long melFilterBufferSize = (unsigned long long)dstDescPtr->h * (unsigned long long)dstDescPtr->w * (unsigned long long)dstDescPtr->c * (unsigned long long)dstDescPtr->n; + inputf32 = (Rpp32f *)realloc(inputf32, spectrogramBufferSize * sizeof(Rpp32f)); + outputf32 = (Rpp32f *)realloc(outputf32, melFilterBufferSize * sizeof(Rpp32f)); + + // Read source data + read_from_bin_file(inputf32, srcDescPtr, srcDimsTensor, "spectrogram", scriptPath); + + startWallTime = omp_get_wtime(); + rppt_mel_filter_bank_host(inputf32, srcDescPtr, outputf32, dstDescPtr, srcDimsTensor, maxFreq, minFreq, melFormula, numFilter, sampleRate, normalize, handle); + + break; + } default: { missingFuncFlag = 1; diff --git a/utilities/test_suite/HOST/runAudioTests.py b/utilities/test_suite/HOST/runAudioTests.py index dec6ffa9c..94bd15251 100644 --- a/utilities/test_suite/HOST/runAudioTests.py +++ b/utilities/test_suite/HOST/runAudioTests.py @@ -36,7 +36,7 @@ outFolderPath = os.getcwd() buildFolderPath = os.getcwd() caseMin = 0 -caseMax = 6 +caseMax = 7 # Get a list of log files based on a flag for preserving output def get_log_file_list(): @@ -49,7 +49,7 @@ def run_unit_test(srcPath, case, numRuns, testType, batchSize, outFilePath): print("--------------------------------") print("Running a New Functionality...") print("--------------------------------") - print(f"./Tensor_host_audio {srcPath} {case} {numRuns} {testType} {numRuns} {batchSize}") + print(f"./Tensor_host_audio {srcPath} {case} {testType} {numRuns} {batchSize}") result = subprocess.run([buildFolderPath + "/build/Tensor_host_audio", srcPath, str(case), str(testType), str(numRuns), str(batchSize), outFilePath, scriptPath], stdout=subprocess.PIPE) # nosec print(result.stdout.decode()) @@ -61,7 +61,7 @@ def run_performance_test(loggingFolder, srcPath, case, numRuns, testType, batchS print("Running a New Functionality...") print("--------------------------------") with open("{}/Tensor_host_audio_raw_performance_log.txt".format(loggingFolder), "a") as log_file: - print(f"./Tensor_host_audio {srcPath} {case} {numRuns} {testType} {numRuns} {batchSize} ") + print(f"./Tensor_host_audio {srcPath} {case} {testType} {numRuns} {batchSize} ") process = subprocess.Popen([buildFolderPath + "/build/Tensor_host_audio", srcPath, str(case), str(testType), str(numRuns), str(batchSize), outFilePath, scriptPath], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) # nosec while True: output = process.stdout.readline() @@ -178,7 +178,7 @@ def rpp_test_suite_parser_and_validator(): subprocess.run(["make", "-j16"], cwd=".") # nosec # List of cases supported -supportedCaseList = ['0', '1', '2', '3', '4', '6'] +supportedCaseList = ['0', '1', '2', '3', '4', '6', '7'] if testType == 0: if batchSize != 3: diff --git a/utilities/test_suite/HOST/runTests.py b/utilities/test_suite/HOST/runTests.py index 8bedd5044..7fc946d46 100644 --- a/utilities/test_suite/HOST/runTests.py +++ b/utilities/test_suite/HOST/runTests.py @@ -21,7 +21,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import os import sys sys.dont_write_bytecode = True diff --git a/utilities/test_suite/REFERENCE_OUTPUTS_AUDIO/mel_filter_bank/mel_filter_bank.bin b/utilities/test_suite/REFERENCE_OUTPUTS_AUDIO/mel_filter_bank/mel_filter_bank.bin new file mode 100644 index 000000000..66cefe73f Binary files /dev/null and b/utilities/test_suite/REFERENCE_OUTPUTS_AUDIO/mel_filter_bank/mel_filter_bank.bin differ diff --git a/utilities/test_suite/rpp_test_suite_audio.h b/utilities/test_suite/rpp_test_suite_audio.h index 8ff5815c0..c70c6659f 100644 --- a/utilities/test_suite/rpp_test_suite_audio.h +++ b/utilities/test_suite/rpp_test_suite_audio.h @@ -42,7 +42,8 @@ std::map audioAugmentationMap = {3, "down_mixing"}, {4, "spectrogram"}, {5, "slice"}, - {6, "resample"} + {6, "resample"}, + {7, "mel_filter_bank"} }; // Golden outputs for Non Silent Region Detection @@ -149,6 +150,50 @@ void read_audio_batch_and_fill_dims(RpptDescPtr descPtr, Rpp32f *inputf32, vecto } } +void read_from_bin_file(Rpp32f *srcPtr, RpptDescPtr srcDescPtr, Rpp32s *srcDims, string testCase, string scriptPath) +{ + // read data from golden outputs + Rpp64u oBufferSize = srcDescPtr->n * srcDescPtr->strides.nStride; + Rpp32f *refInput = static_cast(malloc(oBufferSize * sizeof(float))); + string outFile = scriptPath + "/../REFERENCE_OUTPUTS_AUDIO/" + testCase + "/" + testCase + ".bin"; + std::fstream fin(outFile, std::ios::in | std::ios::binary); + if(fin.is_open()) + { + for(Rpp64u i = 0; i < oBufferSize; i++) + { + if(!fin.eof()) + fin.read(reinterpret_cast(&refInput[i]), sizeof(float)); + else + { + std::cout<<"\nUnable to read all data from golden outputs\n"; + return; + } + } + } + else + { + std::cout<<"\nCould not open the reference output. Please check the path specified\n"; + return; + } + for (int batchCount = 0; batchCount < srcDescPtr->n; batchCount++) + { + Rpp32f *srcPtrCurrent = srcPtr + batchCount * srcDescPtr->strides.nStride; + Rpp32f *refPtrCurrent = refInput + batchCount * srcDescPtr->strides.nStride; + Rpp32f *srcPtrRow = srcPtrCurrent; + Rpp32f *refPtrRow = refPtrCurrent; + for(int i = 0; i < srcDims[batchCount * 2]; i++) + { + Rpp32f *srcPtrTemp = srcPtrRow; + Rpp32f *refPtrTemp = refPtrRow; + for(int j = 0; j < srcDims[(batchCount * 2) + 1]; j++) + srcPtrTemp[j] = refPtrTemp[j]; + srcPtrRow += srcDescPtr->strides.hStride; + refPtrRow += srcDescPtr->strides.hStride; + } + } + free(refInput); +} + void verify_output(Rpp32f *dstPtr, RpptDescPtr dstDescPtr, RpptImagePatchPtr dstDims, string testCase, string dst, string scriptPath) { fstream refFile;