diff --git a/Test/OrochiUtils.cpp b/Test/OrochiUtils.cpp index e01b00c..957c872 100644 --- a/Test/OrochiUtils.cpp +++ b/Test/OrochiUtils.cpp @@ -2,7 +2,110 @@ #include #include #include +#include +#if defined( _WIN32 ) +#define NOMINMAX +#include +#else +#include +#include +#endif + +inline std::wstring utf8_to_wstring( const std::string& str ) +{ + std::wstring_convert> myconv; + std::wstring out1 = myconv.from_bytes( str ); + return out1; +} + +class FileStat +{ +#if defined( _WIN32 ) + public: + FileStat( const char* filePath ) + { + m_file = 0; + std::wstring filePathW = utf8_to_wstring( filePath ); + m_file = CreateFileW( filePathW.c_str(), GENERIC_READ, FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0 ); + if( m_file == INVALID_HANDLE_VALUE ) + { + DWORD errorCode; + errorCode = GetLastError(); + switch( errorCode ) + { + case ERROR_FILE_NOT_FOUND: + { +#ifdef _DEBUG + printf( "File not found %s\n", filePath ); +#endif + break; + } + case ERROR_PATH_NOT_FOUND: + { +#ifdef _DEBUG + printf( "File path not found %s\n", filePath ); +#endif + break; + } + default: + { + printf( "Failed reading file with errorCode = %d\n", static_cast( errorCode ) ); + printf( "%s\n", filePath ); + } + } + } + } + ~FileStat() + { + if( m_file != INVALID_HANDLE_VALUE ) CloseHandle( m_file ); + } + + bool found() const { return ( m_file != INVALID_HANDLE_VALUE ); } + + unsigned long long getTime() + { + if( m_file == INVALID_HANDLE_VALUE ) return 0; + + unsigned long long t = 0; + FILETIME exeTime; + if( GetFileTime( m_file, NULL, NULL, &exeTime ) == 0 ) + { + } + else + { + unsigned long long u = exeTime.dwHighDateTime; + t = ( u << 32 ) | exeTime.dwLowDateTime; + } + return t; + } + + private: + HANDLE m_file; +#else + public: + FileStat( const char* filePath ) { m_file = filePath; } + + bool found() const + { + struct stat binaryStat; + bool e = stat( m_file.c_str(), &binaryStat ); + return e == 0; + } + + unsigned long long getTime() + { + struct stat binaryStat; + bool e = stat( m_file.c_str(), &binaryStat ); + if( e != 0 ) return 0; + unsigned long long t = binaryStat.st_mtime; + return t; + } + + private: + std::string m_file; +#endif +}; struct OrochiUtilsImpl { @@ -48,14 +151,233 @@ struct OrochiUtilsImpl } } + static void getCacheFileName( oroDevice device, const char* moduleName, const char* functionName, const char* options, std::string& binFileName ) + { + auto hashBin = []( const char* s, const size_t size ) + { + unsigned int hash = 0; + + for( unsigned int i = 0; i < size; ++i ) + { + hash += *s++; + hash += ( hash << 10 ); + hash ^= ( hash >> 6 ); + } + + hash += ( hash << 3 ); + hash ^= ( hash >> 11 ); + hash += ( hash << 15 ); + + return hash; + }; + + auto hashString =[&]( const char* ss, const size_t size, char buf[9] ) + { + const unsigned int hash = hashBin( ss, size ); + + sprintf( buf, "%08x", hash ); + }; + + auto strip = []( const char* name, const char* pattern ) + { + size_t const patlen = strlen( pattern ); + size_t patcnt = 0; + const char* oriptr; + const char* patloc; + // find how many times the pattern occurs in the original string + for( oriptr = name; ( patloc = strstr( oriptr, pattern ) ); oriptr = patloc + patlen ) + { + patcnt++; + } + return oriptr; + }; + + oroDeviceProp props; + oroGetDeviceProperties( &props, device ); + int v; + oroDriverGetVersion( &v ); + std::string deviceName = props.name; + std::string driverVersion = std::to_string( v ); + char optionHash[9] = "0x0"; + + if( moduleName && options ) + { + std::string tmp = moduleName; + tmp += options; + + hashString( tmp.c_str(), strlen( tmp.c_str() ), optionHash ); + } + + char moduleHash[9] = "0x0"; + const char* strippedModuleName = strip( moduleName, "\\" ); + strippedModuleName = strip( strippedModuleName, "/" ); + hashString( strippedModuleName, strlen( strippedModuleName ), moduleHash ); + + using namespace std::string_literals; + + deviceName = deviceName.substr( 0, deviceName.find( ":" ) ); + binFileName = OrochiUtils::s_cacheDirectory + "/"s + moduleHash + "-"s + optionHash + ".v."s + deviceName + "."s + driverVersion + "_"s + std::to_string( 8 * sizeof( void* ) ) + ".bin"s; + } + static + bool isFileUpToDate( const char* binaryFileName, const char* srcFileName ) + { + FileStat b( binaryFileName ); + + if( !b.found() ) return false; + + FileStat s( srcFileName ); + + if( !s.found() ) + { + // todo. compare with exe time + return true; + } + + if( s.getTime() < b.getTime() ) return true; + + return false; + } + + static bool createDirectory( const char* cacheDirName ) + { +#if defined( WIN32 ) + std::wstring cacheDirNameW = utf8_to_wstring( cacheDirName ); + bool error = CreateDirectoryW( cacheDirNameW.c_str(), 0 ); + if( error == false && GetLastError() != ERROR_ALREADY_EXISTS ) + { + printf( "Cache folder path not found!\n" ); + return false; + } + return true; +#else + int error = mkdir( cacheDirName, 0775 ); + if( error == -1 && errno != EEXIST ) + { + printf( "Cache folder path not found!\n" ); + return false; + } + return true; +#endif + } + static std::string getCheckSumFileName( const std::string& binaryName ) + { + const std::string dst = binaryName + ".check"; + return dst; + } + static inline long long checksum( const char* data, long long size ) + { + unsigned int hash = 0; + + for( unsigned int i = 0; i < size; ++i ) + { + hash += *data++; + hash += ( hash << 10 ); + hash ^= ( hash >> 6 ); + } + + hash += ( hash << 3 ); + hash ^= ( hash >> 11 ); + hash += ( hash << 15 ); + + return hash; + } + + static int loadCacheFileToBinary( const std::string& cacheName, std::vector& binaryOut ) + { + long long checksumValue = 0; + { + const std::string csFileName = getCheckSumFileName( cacheName ); +#if defined( WIN32 ) + std::wstring csFileNameW = utf8_to_wstring( csFileName ); + FILE* csfile = _wfopen( csFileNameW.c_str(), L"rb" ); +#else + FILE* csfile = fopen( csFileName.c_str(), "rb" ); +#endif + if( csfile ) + { + fread( &checksumValue, sizeof( long long ), 1, csfile ); + fclose( csfile ); + } + } + + if( checksumValue == 0 ) return 0; + +#if defined( WIN32 ) + std::wstring binaryFileNameW = utf8_to_wstring( cacheName ); + FILE* file = _wfopen( binaryFileNameW.c_str(), L"rb" ); +#else + FILE* file = fopen( cacheName.c_str(), "rb" ); +#endif + if( file ) + { + fseek( file, 0L, SEEK_END ); + size_t binarySize = ftell( file ); + rewind( file ); + + binaryOut.resize( binarySize ); + size_t dummy = fread( const_cast( binaryOut.data() ), sizeof( char ), binarySize, file ); + fclose( file ); + + long long s = checksum( binaryOut.data(), binarySize ); + if( s != checksumValue ) + { + printf( "checksum doesn't match %llx : %llx\n", s, checksumValue ); + return 0; + } + } + return 0; + } + + static int cacheBinaryToFile( std::vector binary, const std::string& cacheName ) + { + const size_t binarySize = binary.size(); + { +#ifdef WIN32 + std::wstring binaryFileNameW = utf8_to_wstring( cacheName ); + FILE* file = _wfopen( binaryFileNameW.c_str(), L"wb" ); +#else + FILE* file = fopen( cacheName.c_str(), "wb" ); +#endif + + if( file ) + { +#ifdef _DEBUG + printf( "Cached file created %s\n", cacheName.c_str() ); +#endif + fwrite( binary.data(), sizeof( char ), binarySize, file ); + fclose( file ); + } + } + + long long s = checksum( const_cast( binary.data() ), binarySize ); + const std::string filename = getCheckSumFileName( cacheName ); + + { +#ifdef WIN32 + std::wstring filenameW = utf8_to_wstring( filename ); + FILE* file = _wfopen( filenameW.c_str(), L"wb" ); +#else + FILE* file = fopen( filename.c_str(), "wb" ); +#endif + + if( file ) + { + fwrite( &s, sizeof( long long ), 1, file ); + fclose( file ); + } + } + return 0; + } }; -oroFunction OrochiUtils::getFunctionFromFile( const char* path, const char* funcName, std::vector* optsIn ) +char* OrochiUtils::s_cacheDirectory = "./cache/"; + +oroFunction OrochiUtils::getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector* optsIn ) { std::string source; OrochiUtilsImpl::readSourceCode( path, source, 0 ); - return getFunction( source.c_str(), path, funcName, optsIn ); + return getFunction( device, source.c_str(), path, funcName, optsIn ); /* const char* code = source.c_str(); oroFunction function; @@ -93,37 +415,54 @@ oroFunction OrochiUtils::getFunctionFromFile( const char* path, const char* func */ } -oroFunction OrochiUtils::getFunction( const char* code, const char* path, const char* funcName, std::vector* optsIn ) +oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector* optsIn ) { - oroFunction function; - - orortcProgram prog; - orortcResult e; - e = orortcCreateProgram( &prog, code, path, 0, 0, 0 ); std::vector opts; opts.push_back( "-I ../" ); + opts.push_back( "-std=c++14" ); -// if( oroGetCurAPI(0) == ORO_API_CUDA ) -// opts.push_back( "-G" ); + // if( oroGetCurAPI(0) == ORO_API_CUDA ) + // opts.push_back( "-G" ); - e = orortcCompileProgram( prog, opts.size(), opts.data() ); - if( e != ORORTC_SUCCESS ) + oroFunction function; + std::vector codec; + + std::string cacheFile; + OrochiUtilsImpl::getCacheFileName( device, path, funcName, 0, cacheFile ); + if( OrochiUtilsImpl::isFileUpToDate( cacheFile.c_str(), path ) ) { - size_t logSize; - orortcGetProgramLogSize( prog, &logSize ); - if( logSize ) - { - std::string log( logSize, '\0' ); - orortcGetProgramLog( prog, &log[0] ); - std::cout << log << '\n'; - }; + //load cache + OrochiUtilsImpl::loadCacheFileToBinary( cacheFile, codec ); } - size_t codeSize; - e = orortcGetCodeSize( prog, &codeSize ); + else + { + orortcProgram prog; + orortcResult e; + e = orortcCreateProgram( &prog, code, path, 0, 0, 0 ); - std::vector codec( codeSize ); - e = orortcGetCode( prog, codec.data() ); - e = orortcDestroyProgram( &prog ); + e = orortcCompileProgram( prog, opts.size(), opts.data() ); + if( e != ORORTC_SUCCESS ) + { + size_t logSize; + orortcGetProgramLogSize( prog, &logSize ); + if( logSize ) + { + std::string log( logSize, '\0' ); + orortcGetProgramLog( prog, &log[0] ); + std::cout << log << '\n'; + }; + } + size_t codeSize; + e = orortcGetCodeSize( prog, &codeSize ); + + codec.resize( codeSize ); + e = orortcGetCode( prog, codec.data() ); + e = orortcDestroyProgram( &prog ); + + //store cache + OrochiUtilsImpl::createDirectory( s_cacheDirectory ); + OrochiUtilsImpl::cacheBinaryToFile( codec, cacheFile ); + } oroModule module; oroError ee = oroModuleLoadData( &module, codec.data() ); ee = oroModuleGetFunction( &function, module, funcName ); diff --git a/Test/OrochiUtils.h b/Test/OrochiUtils.h index 99127c0..4c4f387 100644 --- a/Test/OrochiUtils.h +++ b/Test/OrochiUtils.h @@ -12,8 +12,8 @@ class OrochiUtils int x, y, z, w; }; - static oroFunction getFunctionFromFile( const char* path, const char* funcName, std::vector* opts ); - static oroFunction getFunction( const char* code, const char* path, const char* funcName, std::vector* opts ); + static oroFunction getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector* opts ); + static oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector* opts ); static void launch1D( oroFunction func, int nx, const void** args, int wgSize = 64, unsigned int sharedMemBytes = 0 ); @@ -55,4 +55,7 @@ class OrochiUtils auto e = oroDeviceSynchronize(); OROASSERT( e == oroSuccess, 0 ); } + +public: + static char* s_cacheDirectory; }; diff --git a/Test/ParallelPrimitives/RadixSort.cpp b/Test/ParallelPrimitives/RadixSort.cpp index 8cced60..941cb74 100644 --- a/Test/ParallelPrimitives/RadixSort.cpp +++ b/Test/ParallelPrimitives/RadixSort.cpp @@ -67,8 +67,6 @@ RadixSort::RadixSort() { m_flags = (Flag)0; - compileKernels(); - if( selectedScanAlgo == ScanAlgo::SCAN_GPU_PARALLEL ) { OrochiUtils::malloc( m_partialSum, m_nWGsToExecute ); @@ -86,28 +84,28 @@ RadixSort::~RadixSort() } } -void RadixSort::compileKernels() +void RadixSort::compileKernels( oroDevice device ) { constexpr auto kernelPath{ "../Test/ParallelPrimitives/RadixSortKernels.h" }; printf( "compiling kernels ... \n" ); - oroFunctions[Kernel::COUNT] = OrochiUtils::getFunctionFromFile( kernelPath, "CountKernel", 0 ); + oroFunctions[Kernel::COUNT] = OrochiUtils::getFunctionFromFile( device, kernelPath, "CountKernel", 0 ); if( m_flags & FLAG_LOG ) RadixSortImpl::printKernelInfo( oroFunctions[Kernel::COUNT] ); - oroFunctions[Kernel::COUNT_REF] = OrochiUtils::getFunctionFromFile( kernelPath, "CountKernelReference", 0 ); + oroFunctions[Kernel::COUNT_REF] = OrochiUtils::getFunctionFromFile( device, kernelPath, "CountKernelReference", 0 ); if( m_flags & FLAG_LOG ) RadixSortImpl::printKernelInfo( oroFunctions[Kernel::COUNT_REF] ); - oroFunctions[Kernel::SCAN_SINGLE_WG] = OrochiUtils::getFunctionFromFile( kernelPath, "ParallelExclusiveScanSingleWG", 0 ); + oroFunctions[Kernel::SCAN_SINGLE_WG] = OrochiUtils::getFunctionFromFile( device, kernelPath, "ParallelExclusiveScanSingleWG", 0 ); if( m_flags & FLAG_LOG ) RadixSortImpl::printKernelInfo( oroFunctions[Kernel::SCAN_SINGLE_WG] ); - oroFunctions[Kernel::SCAN_PARALLEL] = OrochiUtils::getFunctionFromFile( kernelPath, "ParallelExclusiveScanAllWG", 0 ); + oroFunctions[Kernel::SCAN_PARALLEL] = OrochiUtils::getFunctionFromFile( device, kernelPath, "ParallelExclusiveScanAllWG", 0 ); if( m_flags & FLAG_LOG ) RadixSortImpl::printKernelInfo( oroFunctions[Kernel::SCAN_PARALLEL] ); - oroFunctions[Kernel::SORT] = OrochiUtils::getFunctionFromFile( kernelPath, "SortKernel2", 0 ); + oroFunctions[Kernel::SORT] = OrochiUtils::getFunctionFromFile( device, kernelPath, "SortKernel2", 0 ); if( m_flags & FLAG_LOG ) RadixSortImpl::printKernelInfo( oroFunctions[Kernel::SORT] ); - oroFunctions[Kernel::SORT_REF] = OrochiUtils::getFunctionFromFile( kernelPath, "SortKernelReference", 0 ); + oroFunctions[Kernel::SORT_REF] = OrochiUtils::getFunctionFromFile( device, kernelPath, "SortKernelReference", 0 ); if( m_flags & FLAG_LOG ) RadixSortImpl::printKernelInfo( oroFunctions[Kernel::SORT_REF] ); } @@ -130,6 +128,8 @@ void RadixSort::configure( oroDevice device, u32& tempBufferSizeOut ) m_nWGsToExecute = newWGsToExecute; tempBufferSizeOut = BIN_SIZE * m_nWGsToExecute; + + compileKernels( device ); } void RadixSort::setFlag( Flag flag ) { m_flags = flag; } diff --git a/Test/ParallelPrimitives/RadixSort.h b/Test/ParallelPrimitives/RadixSort.h index e9ddb24..54b71de 100644 --- a/Test/ParallelPrimitives/RadixSort.h +++ b/Test/ParallelPrimitives/RadixSort.h @@ -35,7 +35,7 @@ class RadixSort private: void sort1pass( u32* src, u32* dst, int n, int startBit, int endBit, int* tmps ); - void compileKernels(); + void compileKernels( oroDevice device ); private: int m_nWGsToExecute{ 4 };