-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RPP Tensor Audio Support - Pre Emphasis Filter #259
Merged
kiritigowda
merged 112 commits into
ROCm:develop
from
r-abishek:ar/audio_support_3_pre_emphasis_filter
Dec 21, 2023
Merged
Changes from 92 commits
Commits
Show all changes
112 commits
Select commit
Hold shift + click to select a range
6627464
Initial commit - Non slient region detection
snehaa8 dcd9833
Initial commit - To Decibels
snehaa8 eb9e6eb
Intial commit - pre_emphasis_filter
HazarathKumarM 3720c8f
Replace vectors with arrays
snehaa8 ed1e425
Cleanup
snehaa8 055fd59
Minor cleanup
snehaa8 f7f51c8
Replace Rpp64s with Rpp32s
snehaa8 676de7c
Optimize and precompute cutOff
snehaa8 46b8fb5
Fix buffer used
snehaa8 11fb3ab
Fix buffer used
snehaa8 8e05043
Additional Cleanup
snehaa8 195ccfc
Optimize post increment operation
snehaa8 79d9e6f
Update testsuite for Audio
snehaa8 5719db9
code cleanup
HazarathKumarM 4aad4d6
Add Readme file for Audio test suite
HazarathKumarM abf15d8
changes based on review comments
HazarathKumarM 71ef2f5
minor change
HazarathKumarM adad92b
Remove unittest folders and updated README.md
HazarathKumarM 8e01e0e
Remove unit tests
HazarathKumarM 69fba3e
minor change
HazarathKumarM a7b1b22
code cleanup
sampath1117 441bfac
added common header file for audio helper functions
sampath1117 c8803c2
Merge remote-tracking branch 'abishek_rpp/master' into sn/nsr_host_te…
sampath1117 e7169fe
removed unncessary audio wav files
sampath1117 11b4709
removed log file
sampath1117 d64bc7b
added doxygen support for audio
sampath1117 23ba5a5
Merge branch 'sn/nsr_host_tensor' into sn/to_decibels
sampath1117 337ccc1
added doxygen changes for to_decibels
sampath1117 d1b5b41
updated test suite support for to_decibels
sampath1117 223de61
minor change
sampath1117 bf0a4e0
added doxygen changes for preemphasis filter
sampath1117 f7d7589
Merge branch 'sn/to_decibels' into sn/pre_emphasis_filter
sampath1117 3626fc0
updated changes for preemphasis filter in test suite
sampath1117 512a79f
removed the usage of getMax function and used std::max_element
sampath1117 6a18957
modularized code in test suite
sampath1117 c413836
merge with latest changes
sampath1117 9f6b6d4
minor change
sampath1117 2247696
minor change
sampath1117 1ba334f
Merge branch 'sn/to_decibels' into sn/pre_emphasis_filter
sampath1117 527ed18
minor change
sampath1117 87b0138
Merge pull request #149 from snehaa8/sn/nsr_host_tensor
r-abishek 8e2975d
resolved codacy warnings
sampath1117 b9f8c12
Merge pull request #174 from snehaa8/sn/nsr_host_tensor
r-abishek b22bf93
Codacy fix - Remove unused cpuTime
r-abishek 4a5c357
CMakeLists - Version Update
kiritigowda 2c8a78b
CHANGELOG Updates
kiritigowda f07786a
merge with latest changes
sampath1117 54cfa29
resolved issue with file_system dependency in test suite
sampath1117 c8bd726
Doxygen changes
sampath1117 92f8ed7
Merge pull request #182 from sampath1117/sr/nsr_pr_changes
r-abishek c4d8f3d
Merge branch 'develop' into ar/audio_support_1_non_silent_region
r-abishek 1db1425
RPP RICAP Tensor for HOST and HIP (#213)
r-abishek 1d3e7ce
Documentation - Readme & changelog updates (#251)
LisaDelaney 13995ad
Merge branch 'ar/audio_support_1_non_silent_region' into sn/nsr_host_…
sampath1117 e85f581
added ctests for audio test suite for CI
sampath1117 296ed72
Merge pull request #187 from snehaa8/sn/nsr_host_tensor
r-abishek 0aec6e1
Cmake mods for ctest
r-abishek 3140717
HOST-only build error bugfix
r-abishek fe4ef51
Merge branch 'ar/audio_support_1_non_silent_region' into sn/nsr_host_…
sampath1117 0713890
added qa mode paramter to python audio script
sampath1117 e98a4e8
minor change
sampath1117 a1f7366
Documentation - Bump rocm-docs-core[api_reference] from 0.26.0 to 0.2…
dependabot[bot] 749a552
RPP Resize Mirror Normalize Bugfix (#252)
r-abishek 3f5aec6
Merge pull request #189 from snehaa8/sn/nsr_host_tensor
r-abishek 1b466bc
added example for MMS calculation in comments for better understanding
sampath1117 38119f3
Sphinx - updates (#257)
kiritigowda b98bb99
updated info used to for running audio test suite
sampath1117 3795f37
removed bitdepth variable from audio test suite
sampath1117 58e1ff5
added more information on computing NSR outputs in the example added
sampath1117 64c52cd
Merge pull request #191 from snehaa8/sn/nsr_host_tensor
r-abishek 6b2add5
Merge branch 'ar/audio_support_1_non_silent_region' of https://github…
r-abishek 7753fda
Merge branch 'ar/audio_support_2_to_decibels' into sn/to_decibels
r-abishek 072cc1e
Merge pull request #150 from snehaa8/sn/to_decibels
r-abishek e04371d
Merge branch 'ar/audio_support_1_non_silent_region' of https://github…
r-abishek 1f25169
Merge branch 'ar/audio_support_2_to_decibels' of https://github.com/r…
r-abishek 8e16be8
Merge branch 'ar/audio_support_3_pre_emphasis_filter' into sn/pre_emp…
r-abishek e261ce3
Merge pull request #151 from snehaa8/sn/pre_emphasis_filter
r-abishek d205055
Fix doxygen for decibels
snehaa8 f9c66a6
Merge pull request #195 from snehaa8/sn/to_decibels
r-abishek da528d3
Fix build errors and qa tests in Audio Test suite
snehaa8 ce13b82
Fix build errors and qa tests in Audio Test suite
snehaa8 a90e280
Merge pull request #197 from snehaa8/sn/to_decibels
r-abishek 5713d1d
Merge branch 'ar/audio_support_2_to_decibels' of https://github.com/r…
r-abishek 8ae8673
Merge branch 'ar/audio_support_3_pre_emphasis_filter' into sn/pre_emp…
r-abishek 9c8ac7f
Merge pull request #198 from snehaa8/sn/pre_emphasis_filter
r-abishek 8227726
Merge branch 'master' of https://github.com/GPUOpen-ProfessionalCompu…
r-abishek 783ee98
Remove auto-merge repeated funcs
r-abishek 900672a
Merge branch 'ar/audio_support_2_to_decibels' of https://github.com/r…
r-abishek 7e91a91
Improve clarity of header docs
r-abishek 3b03aad
Remove blank line
r-abishek 2be82b1
Improve clarity on header docs
r-abishek 934276a
Merge branch 'ar/audio_support_2_to_decibels' of https://github.com/r…
r-abishek b8bcb04
modified the branch statements used in kernel with ternary operator
sampath1117 1875d85
made changes based on review comments
sampath1117 132eaa6
stored golden outputs of to_decibels in binary file
sampath1117 9322cd0
removed unused parameter in verify_output function
sampath1117 22a78f2
Merge branch 'sr/to_decibels_pr_changes' into sr/pre_emphasis_pr_changes
sampath1117 e5886e3
stored pre emphasis filter golden output to binary file
sampath1117 777a9f9
updated list of cases supported in python script
sampath1117 b650f02
updated list of cases supported in python script
sampath1117 f3d0a7d
added error handling for opening golden output file
sampath1117 7790cf6
added error handling for opening golden output file
sampath1117 1135cbc
Merge pull request #202 from sampath1117/sr/to_decibels_pr_changes
r-abishek 40f2cc9
Codacy fix and tests warning fix
r-abishek 3b6d418
Merge pull request #201 from sampath1117/sr/pre_emphasis_pr_changes
r-abishek 5aa53cb
Codacy fix
r-abishek 8dc04bb
Codacy fix trial
r-abishek 331f160
merge with latest changes
sampath1117 4ec92c6
codacy fix for checking boundaries of fstream
sampath1117 b7cf24f
Merge pull request #205 from sampath1117/sr/to_decibels_pr_changes
r-abishek 8e1b374
Merge branch 'ar/audio_support_2_to_decibels' of https://github.com/r…
r-abishek 8f2775d
Merge branch 'develop' into ar/audio_support_3_pre_emphasis_filter
r-abishek File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
Copyright (c) 2019 - 2023 Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. | ||
*/ | ||
|
||
#include "rppdefs.h" | ||
#include "rpp_cpu_simd.hpp" | ||
#include "rpp_cpu_common.hpp" | ||
|
||
RppStatus pre_emphasis_filter_host_tensor(Rpp32f *srcPtr, | ||
RpptDescPtr srcDescPtr, | ||
Rpp32f *dstPtr, | ||
RpptDescPtr dstDescPtr, | ||
Rpp32s *srcLengthTensor, | ||
Rpp32f *coeffTensor, | ||
Rpp32u borderType, | ||
rpp::Handle& handle) | ||
{ | ||
Rpp32u numThreads = handle.GetNumThreads(); | ||
|
||
omp_set_dynamic(0); | ||
#pragma omp parallel for num_threads(numThreads) | ||
for(int batchCount = 0; batchCount < srcDescPtr->n; batchCount++) | ||
{ | ||
Rpp32f *srcPtrTemp = srcPtr + batchCount * srcDescPtr->strides.nStride; | ||
Rpp32f *dstPtrTemp = dstPtr + batchCount * dstDescPtr->strides.nStride; | ||
Rpp32s bufferLength = srcLengthTensor[batchCount]; | ||
Rpp32f coeff = coeffTensor[batchCount]; | ||
|
||
if(borderType == RpptAudioBorderType::ZERO) | ||
dstPtrTemp[0] = srcPtrTemp[0]; | ||
else if(borderType == RpptAudioBorderType::CLAMP) | ||
{ | ||
Rpp32f border = srcPtrTemp[0]; | ||
dstPtrTemp[0] = srcPtrTemp[0] - coeff * border; | ||
} | ||
else if(borderType == RpptAudioBorderType::REFLECT) | ||
{ | ||
Rpp32f border = srcPtrTemp[1]; | ||
dstPtrTemp[0] = srcPtrTemp[0] - coeff * border; | ||
} | ||
|
||
Rpp32s vectorIncrement = 8; | ||
Rpp32s alignedLength = (bufferLength / 8) * 8; | ||
__m256 pCoeff = _mm256_set1_ps(coeff); | ||
|
||
Rpp32s vectorLoopCount = 1; | ||
dstPtrTemp++; | ||
srcPtrTemp++; | ||
for(; vectorLoopCount < alignedLength; vectorLoopCount += vectorIncrement) | ||
{ | ||
__m256 pSrc[2]; | ||
pSrc[0] = _mm256_loadu_ps(srcPtrTemp); | ||
pSrc[1] = _mm256_loadu_ps(srcPtrTemp - 1); | ||
pSrc[1] = _mm256_sub_ps(pSrc[0], _mm256_mul_ps(pSrc[1], pCoeff)); | ||
_mm256_storeu_ps(dstPtrTemp, pSrc[1]); | ||
srcPtrTemp += vectorIncrement; | ||
dstPtrTemp += vectorIncrement; | ||
} | ||
|
||
for(; vectorLoopCount < bufferLength; vectorLoopCount++) | ||
{ | ||
*dstPtrTemp++ = *srcPtrTemp - coeff * (*(srcPtrTemp - 1)); | ||
srcPtrTemp++; | ||
} | ||
|
||
} | ||
|
||
return RPP_SUCCESS; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
Copyright (c) 2019 - 2023 Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. | ||
*/ | ||
|
||
#include "rppdefs.h" | ||
#include <omp.h> | ||
|
||
RppStatus to_decibels_host_tensor(Rpp32f *srcPtr, | ||
RpptDescPtr srcDescPtr, | ||
Rpp32f *dstPtr, | ||
RpptDescPtr dstDescPtr, | ||
RpptImagePatchPtr srcDims, | ||
Rpp32f cutOffDB, | ||
Rpp32f multiplier, | ||
Rpp32f referenceMagnitude, | ||
rpp::Handle& handle) | ||
{ | ||
Rpp32u numThreads = handle.GetNumThreads(); | ||
|
||
// Calculate the intermediate values needed for DB conversion | ||
Rpp32f minRatio = std::pow(10, cutOffDB / multiplier); | ||
if(minRatio == 0.0f) | ||
minRatio = std::nextafter(0.0f, 1.0f); | ||
|
||
const Rpp32f log10Factor = 0.3010299956639812; //1 / std::log(10); | ||
multiplier *= log10Factor; | ||
|
||
omp_set_dynamic(0); | ||
#pragma omp parallel for num_threads(numThreads) | ||
for(int batchCount = 0; batchCount < srcDescPtr->n; batchCount++) | ||
{ | ||
Rpp32f *srcPtrCurrent = srcPtr + batchCount * srcDescPtr->strides.nStride; | ||
Rpp32f *dstPtrCurrent = dstPtr + batchCount * dstDescPtr->strides.nStride; | ||
|
||
Rpp32u height = srcDims[batchCount].height; | ||
Rpp32u width = srcDims[batchCount].width; | ||
Rpp32f refMag = referenceMagnitude; | ||
|
||
// Compute maximum value in the input buffer | ||
if(!referenceMagnitude) | ||
{ | ||
refMag = -std::numeric_limits<Rpp32f>::max(); | ||
Rpp32f *srcPtrTemp = srcPtrCurrent; | ||
if(width == 1) | ||
refMag = std::max(refMag, *(std::max_element(srcPtrTemp, srcPtrTemp + height))); | ||
else | ||
{ | ||
for(int i = 0; i < height; i++) | ||
{ | ||
refMag = std::max(refMag, *(std::max_element(srcPtrTemp, srcPtrTemp + width))); | ||
srcPtrTemp += srcDescPtr->strides.hStride; | ||
} | ||
} | ||
} | ||
|
||
// Avoid division by zero | ||
if(!refMag) | ||
refMag = 1.0f; | ||
|
||
Rpp32f invReferenceMagnitude = 1.f / refMag; | ||
// Interpret as 1D array | ||
if(width == 1) | ||
{ | ||
for(Rpp32s vectorLoopCount = 0; vectorLoopCount < height; vectorLoopCount++) | ||
*dstPtrCurrent++ = multiplier * std::log2(std::max(minRatio, (*srcPtrCurrent++) * invReferenceMagnitude)); | ||
} | ||
else | ||
{ | ||
Rpp32f *srcPtrRow, *dstPtrRow; | ||
srcPtrRow = srcPtrCurrent; | ||
dstPtrRow = dstPtrCurrent; | ||
for(int i = 0; i < height; i++) | ||
{ | ||
Rpp32f *srcPtrTemp, *dstPtrTemp; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can get rid of Temp variables and directly use srcPtrRow and dstPtrRow. For stride increment on line 99 &100 just add (stride-w) |
||
srcPtrTemp = srcPtrRow; | ||
dstPtrTemp = dstPtrRow; | ||
Rpp32s vectorLoopCount = 0; | ||
for(; vectorLoopCount < width; vectorLoopCount++) | ||
*dstPtrTemp++ = multiplier * std::log2(std::max(minRatio, (*srcPtrTemp++) * invReferenceMagnitude)); | ||
|
||
srcPtrRow += srcDescPtr->strides.hStride; | ||
dstPtrRow += dstDescPtr->strides.hStride; | ||
} | ||
} | ||
} | ||
|
||
return RPP_SUCCESS; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please all the if else conditions with better code
border = (borderType = RpptAudioBorderType::CLAMP) ? SrcptrTemp[0] : (borderType = RpptAudioBorderType::REFLECT) ? SrcptrTemp[1] : 0;
dstPtrTemp[0] = srcPtrTemp[0] - coeff * border;