Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Radix Sort update & Optimization - part 2 #62

Merged
merged 6 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 100 additions & 78 deletions ParallelPrimitives/RadixSort.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ParallelPrimitives/RadixSort.h>
#include <algorithm>
#include <array>
#include <cassert>
#include <iostream>
#include <numeric>

Expand All @@ -13,7 +14,7 @@
// clang-format on
#endif

#if defined(__GNUC__)
#if defined( __GNUC__ )
#include <dlfcn.h>
#endif

Expand All @@ -25,7 +26,7 @@ constexpr auto useBitCode = true;
constexpr auto useBitCode = false;
#endif

#if !defined(__GNUC__)
#if !defined( __GNUC__ )
const HMODULE GetCurrentModule()
{
HMODULE hModule = NULL;
Expand All @@ -34,55 +35,52 @@ const HMODULE GetCurrentModule()
return hModule;
}
#else
void GetCurrentModule1()
{
}
void GetCurrentModule1() {}
#endif



void printKernelInfo( oroFunction func )
void printKernelInfo( const std::string& name, oroFunction func )
{
std::cout << "Function: " << name;

int numReg{};
int sharedSizeBytes{};
int constSizeBytes{};
oroFuncGetAttribute( &numReg, ORO_FUNC_ATTRIBUTE_NUM_REGS, func );
oroFuncGetAttribute( &sharedSizeBytes, ORO_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func );
oroFuncGetAttribute( &constSizeBytes, ORO_FUNC_ATTRIBUTE_CONST_SIZE_BYTES, func );
std::cout << "vgpr : shared = " << numReg << " : "
<< " : " << sharedSizeBytes << " : " << constSizeBytes << '\n';
std::cout << ", vgpr : shared = " << numReg << " : " << sharedSizeBytes << " : " << constSizeBytes << '\n';
}

} // namespace

namespace Oro
{

RadixSort::RadixSort()
RadixSort::RadixSort( oroDevice device, OrochiUtils& oroutils ) : m_device{ device }, m_oroutils{ oroutils }
{
if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
{
m_partialSum.resize( m_nWGsToExecute );
OrochiUtils::malloc( m_isReady, m_nWGsToExecute );
OrochiUtils::memset( m_isReady, false, m_nWGsToExecute * sizeof( bool ) );
}
}
oroGetDeviceProperties( &m_props, device );

RadixSort::~RadixSort()
{
if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
{
OrochiUtils::free( m_isReady );
}
m_num_threads_per_block_for_count = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_COUNT_BLOCK_SIZE;
m_num_threads_per_block_for_scan = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SCAN_BLOCK_SIZE;
m_num_threads_per_block_for_sort = m_props.maxThreadsPerBlock > 0 ? m_props.maxThreadsPerBlock : DEFAULT_SORT_BLOCK_SIZE;

const auto warp_size = ( m_props.warpSize != 0 ) ? m_props.warpSize : DEFAULT_WARP_SIZE;

m_num_warps_per_block_for_sort = m_num_threads_per_block_for_sort / warp_size;

assert( m_num_threads_per_block_for_count % warp_size == 0 );
assert( m_num_threads_per_block_for_scan % warp_size == 0 );
assert( m_num_threads_per_block_for_sort % warp_size == 0 );

configure();
}

void RadixSort::exclusiveScanCpu( int* countsGpu, int* offsetsGpu, const int nWGsToExecute, oroStream stream ) noexcept
void RadixSort::exclusiveScanCpu( const Oro::GpuMemory<int>& countsGpu, Oro::GpuMemory<int>& offsetsGpu, oroStream stream ) const noexcept
{
std::vector<int> counts( Oro::BIN_SIZE * nWGsToExecute );
OrochiUtils::copyDtoHAsync( counts.data(), countsGpu, Oro::BIN_SIZE * nWGsToExecute, stream );
OrochiUtils::waitForCompletion( stream );
const auto buffer_size = countsGpu.size();

std::vector<int> offsets( Oro::BIN_SIZE * nWGsToExecute );
std::vector<int> counts = countsGpu.getData();
std::vector<int> offsets( buffer_size );

int sum = 0;
for( int i = 0; i < counts.size(); ++i )
Expand All @@ -91,56 +89,63 @@ void RadixSort::exclusiveScanCpu( int* countsGpu, int* offsetsGpu, const int nWG
sum += counts[i];
}

OrochiUtils::copyHtoDAsync( offsetsGpu, offsets.data(), Oro::BIN_SIZE * nWGsToExecute, stream );
OrochiUtils::waitForCompletion( stream );
offsetsGpu.copyFromHost( offsets.data(), std::size( offsets ) );
}

void RadixSort::compileKernels( oroDevice device, OrochiUtils& oroutils, const std::string& kernelPath, const std::string& includeDir ) noexcept
void RadixSort::compileKernels( const std::string& kernelPath, const std::string& includeDir ) noexcept
{
constexpr auto defaultKernelPath{ "../ParallelPrimitives/RadixSortKernels.h" };
constexpr auto defaultIncludeDir{ "../" };
static constexpr auto defaultKernelPath{ "../ParallelPrimitives/RadixSortKernels.h" };
static constexpr auto defaultIncludeDir{ "../" };

const auto currentKernelPath{ ( kernelPath == "" ) ? defaultKernelPath : kernelPath };
const auto currentIncludeDir{ ( includeDir == "" ) ? defaultIncludeDir : includeDir };

auto getCurrentDir = []()
const auto getCurrentDir = []() noexcept
{
#if !defined(__GNUC__)
#if !defined( __GNUC__ )
HMODULE hm = GetCurrentModule();
char buff[MAX_PATH];
GetModuleFileName( hm, buff, MAX_PATH );
#else
Dl_info info;
dladdr( (const void*)GetCurrentModule1, &info );
const char* buff = info.dli_fname;
#endif
#endif
std::string::size_type position = std::string( buff ).find_last_of( "\\/" );
return std::string( buff ).substr( 0, position ) + "/";
};

std::string binaryPath{};
std::string log{};
if constexpr( useBitCode )
{
const bool isAmd = oroGetCurAPI( 0 ) == ORO_API_HIP;
binaryPath = getCurrentDir();
binaryPath += isAmd ? "oro_compiled_kernels.hipfb" : "oro_compiled_kernels.fatbin";
if( m_flags == Flag::LOG )
{
std::cout << "loading pre-compiled kernels at path : " << binaryPath << '\n';
}
log = "loading pre-compiled kernels at path : " + binaryPath;
}
else
{
if( m_flags == Flag::LOG )
{
std::cout << "compiling kernels at path : " << currentKernelPath << " in : " << currentIncludeDir << '\n';
}
log = "compiling kernels at path : " + currentKernelPath + " in : " + currentIncludeDir;
}

if( m_flags == Flag::LOG )
{
std::cout << log << std::endl;
}

const auto includeArg{ "-I" + currentIncludeDir };
const auto count_block_size_param = "-DCOUNT_WG_SIZE=" + std::to_string( m_num_threads_per_block_for_count );
const auto scan_block_size_param = "-DSCAN_WG_SIZE=" + std::to_string( m_num_threads_per_block_for_scan );
const auto sort_block_size_param = "-DSORT_WG_SIZE=" + std::to_string( m_num_threads_per_block_for_sort );
const auto sort_num_warps_param = "-DSORT_NUM_WARPS_PER_BLOCK=" + std::to_string( m_num_warps_per_block_for_sort );

std::vector<const char*> opts;
opts.push_back( includeArg.c_str() );
opts.push_back( count_block_size_param.c_str() );
opts.push_back( scan_block_size_param.c_str() );
opts.push_back( sort_block_size_param.c_str() );
opts.push_back( sort_num_warps_param.c_str() );

struct Record
{
Expand All @@ -149,8 +154,8 @@ void RadixSort::compileKernels( oroDevice device, OrochiUtils& oroutils, const s
};

const std::vector<Record> records{
{ "CountKernel", Kernel::COUNT }, { "CountKernelReference", Kernel::COUNT_REF }, { "ParallelExclusiveScanSingleWG", Kernel::SCAN_SINGLE_WG }, { "ParallelExclusiveScanAllWG", Kernel::SCAN_PARALLEL },
{ "SortKernel", Kernel::SORT }, { "SortKVKernel", Kernel::SORT_KV }, { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
{ "CountKernel", Kernel::COUNT }, { "ParallelExclusiveScanSingleWG", Kernel::SCAN_SINGLE_WG }, { "ParallelExclusiveScanAllWG", Kernel::SCAN_PARALLEL }, { "SortKernel", Kernel::SORT },
{ "SortKVKernel", Kernel::SORT_KV }, { "SortSinglePassKernel", Kernel::SORT_SINGLE_PASS }, { "SortSinglePassKVKernel", Kernel::SORT_SINGLE_PASS_KV },
};

for( const auto& record : records )
Expand All @@ -161,60 +166,77 @@ void RadixSort::compileKernels( oroDevice device, OrochiUtils& oroutils, const s

if constexpr( useBitCode )
{
oroFunctions[record.kernelType] = oroutils.getFunctionFromPrecompiledBinary( binaryPath.c_str(), record.kernelName.c_str() );
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary( binaryPath.c_str(), record.kernelName.c_str() );
}
else
{

oroFunctions[record.kernelType] = oroutils.getFunctionFromFile( device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts );
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromFile( m_device, currentKernelPath.c_str(), record.kernelName.c_str(), &opts );
}

#endif
if( m_flags == Flag::LOG )
{
printKernelInfo( oroFunctions[record.kernelType] );
printKernelInfo( record.kernelName, oroFunctions[record.kernelType] );
}
}
}

int RadixSort::calculateWGsToExecute( oroDevice device ) noexcept
int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept
{
oroDeviceProp props{};
oroGetDeviceProperties( &props, device );

constexpr auto maxWGSize = std::max( { COUNT_WG_SIZE, SCAN_WG_SIZE, SORT_WG_SIZE } );
const int warpSize = ( props.warpSize != 0 ) ? props.warpSize : 32;
const int warpPerWG = maxWGSize / warpSize;
const int warpPerWGP = props.maxThreadsPerMultiProcessor / warpSize;
const int warpSize = ( m_props.warpSize != 0 ) ? m_props.warpSize : DEFAULT_WARP_SIZE;
const int warpPerWG = blockSize / warpSize;
const int warpPerWGP = m_props.maxThreadsPerMultiProcessor / warpSize;
const int occupancyFromWarp = ( warpPerWGP > 0 ) ? ( warpPerWGP / warpPerWG ) : 1;

// From the runtime measurements this yields better result.
const int occupancy = std::max( 1, occupancyFromWarp / 2 );
const int occupancy = std::max( 1, occupancyFromWarp );

if( m_flags == Flag::LOG )
{
std::cout << "Occupancy: " << occupancy << '\n';
}

static constexpr auto min_num_blocks = 16;
auto number_of_blocks = m_props.multiProcessorCount > 0 ? m_props.multiProcessorCount * occupancy : min_num_blocks;

if( m_flags == Flag::LOG ) std::cout << "Occupancy: " << occupancy << '\n';
if( m_num_threads_per_block_for_scan > BIN_SIZE )
{
// Note: both are divisible by 2
const auto base = m_num_threads_per_block_for_scan / BIN_SIZE;

// Floor
number_of_blocks = ( number_of_blocks / base ) * base;
}

return props.multiProcessorCount * occupancy;
return number_of_blocks;
}

RadixSort::u32 RadixSort::configure( oroDevice device, OrochiUtils& oroutils, const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept
void RadixSort::configure( const std::string& kernelPath, const std::string& includeDir, oroStream stream ) noexcept
{
compileKernels( device, oroutils, kernelPath, includeDir );
const auto newWGsToExecute{ calculateWGsToExecute( device ) };
compileKernels( kernelPath, includeDir );

m_num_blocks_for_count = calculateWGsToExecute( m_num_threads_per_block_for_count );

/// The tmp buffer size of the count kernel and the scan kernel.

const auto tmp_buffer_size = BIN_SIZE * m_num_blocks_for_count;

if( newWGsToExecute != m_nWGsToExecute && selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
/// @c tmp_buffer_size must be divisible by @c m_num_threads_per_block_for_scan
/// This is guaranteed since @c m_num_blocks_for_count will be adjusted accordingly

m_num_blocks_for_scan = tmp_buffer_size / m_num_threads_per_block_for_scan;

m_tmp_buffer.resize( tmp_buffer_size );

if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL )
{
m_partialSum.resize( newWGsToExecute );
OrochiUtils::free( m_isReady );
OrochiUtils::malloc( m_isReady, newWGsToExecute );
OrochiUtils::memsetAsync( m_isReady, false, newWGsToExecute * sizeof( bool ), stream );
// These are for the scan kernel
m_partial_sum.resize( m_num_blocks_for_scan );
m_is_ready.resize( m_num_blocks_for_scan );
}

m_nWGsToExecute = newWGsToExecute;
return static_cast<u32>( BIN_SIZE * m_nWGsToExecute );
}
void RadixSort::setFlag( Flag flag ) noexcept { m_flags = flag; }

void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, u32* tempBuffer, oroStream stream ) noexcept
void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int startBit, int endBit, oroStream stream ) noexcept
{
// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
// right now, setting this as large as possible is faster than multi pass sorting
Expand All @@ -231,7 +253,7 @@ void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int s

for( int i = startBit; i < endBit; i += N_RADIX )
{
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), (int*)tempBuffer, stream );
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );

std::swap( s, d );
}
Expand All @@ -243,7 +265,7 @@ void RadixSort::sort( const KeyValueSoA src, const KeyValueSoA dst, int n, int s
}
}

void RadixSort::sort( const u32* src, const u32* dst, int n, int startBit, int endBit, u32* tempBuffer, oroStream stream ) noexcept
void RadixSort::sort( const u32* src, const u32* dst, int n, int startBit, int endBit, oroStream stream ) noexcept
{
// todo. better to compute SINGLE_SORT_N_ITEMS_PER_WI which we use in the kernel dynamically rather than hard coding it to distribute the work evenly
// right now, setting this as large as possible is faster than multi pass sorting
Expand All @@ -260,7 +282,7 @@ void RadixSort::sort( const u32* src, const u32* dst, int n, int startBit, int e

for( int i = startBit; i < endBit; i += N_RADIX )
{
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), (int*)tempBuffer, stream );
sort1pass( *s, *d, n, i, i + std::min( N_RADIX, endBit - i ), stream );

std::swap( s, d );
}
Expand Down
Loading