From 300f538fc10216b95697ee19a9743a9b815d54a5 Mon Sep 17 00:00:00 2001 From: TadaoYamaoka Date: Sat, 9 Sep 2023 14:04:50 +0900 Subject: [PATCH] =?UTF-8?q?=E3=83=90=E3=83=83=E3=83=81=E3=82=B5=E3=82=A4?= =?UTF-8?q?=E3=82=BA1=E3=81=AE=E5=A0=B4=E5=90=88=E4=B8=A6=E5=88=97?= =?UTF-8?q?=E5=8C=96=E3=81=97=E3=81=AA=E3=81=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cppshogi/python_module.cpp | 52 +++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/cppshogi/python_module.cpp b/cppshogi/python_module.cpp index 99fc2f6b..58bb73e2 100644 --- a/cppshogi/python_module.cpp +++ b/cppshogi/python_module.cpp @@ -32,12 +32,12 @@ inline T is_nyugyoku(const uint8_t result) { } void __hcpe_decode_with_value(const size_t len, char* ndhcpe, char* ndfeatures1, char* ndfeatures2, char* ndmove, char* ndresult, char* ndvalue) { - HuffmanCodedPosAndEval *hcpe = reinterpret_cast(ndhcpe); + HuffmanCodedPosAndEval* hcpe = reinterpret_cast(ndhcpe); features1_t* features1 = reinterpret_cast(ndfeatures1); features2_t* features2 = reinterpret_cast(ndfeatures2); int64_t* move = reinterpret_cast(ndmove); - float *result = reinterpret_cast(ndresult); - float *value = reinterpret_cast(ndvalue); + float* result = reinterpret_cast(ndresult); + float* value = reinterpret_cast(ndvalue); // set all zero std::fill_n((float*)features1, sizeof(features1_t) / sizeof(float) * len, 0.0f); @@ -62,7 +62,7 @@ void __hcpe_decode_with_value(const size_t len, char* ndhcpe, char* ndfeatures1, } void __hcpe2_decode_with_value(const size_t len, char* ndhcpe2, char* ndfeatures1, char* ndfeatures2, char* ndmove, char* ndresult, char* ndvalue, char* ndaux) { - HuffmanCodedPosAndEval2 *hcpe = reinterpret_cast(ndhcpe2); + HuffmanCodedPosAndEval2* hcpe = reinterpret_cast(ndhcpe2); features1_t* features1 = reinterpret_cast(ndfeatures1); features2_t* features2 = reinterpret_cast(ndfeatures2); int64_t* move = reinterpret_cast(ndmove); @@ -144,17 +144,16 @@ void __hcpe3_create_cache(const std::string& filepath) { } // hcpe3キャッシュ -std::ifstream cache; +std::ifstream* cache; std::vector cache_pos; -std::mutex cache_mutex; size_t __hcpe3_load_cache(const std::string& filepath) { - cache.open(filepath, std::ios::binary); + cache = new std::ifstream(filepath, std::ios::binary); size_t num; - cache.read((char*)&num, sizeof(num)); + cache->read((char*)&num, sizeof(num)); cache_pos.resize(num + 1); - cache.read((char*)cache_pos.data(), sizeof(size_t) * num); - cache.seekg(0, std::ios_base::end); - cache_pos[num] = cache.tellg(); + cache->read((char*)cache_pos.data(), sizeof(size_t) * num); + cache->seekg(0, std::ios_base::end); + cache_pos[num] = cache->tellg(); return num; } @@ -165,14 +164,27 @@ size_t __hcpe3_get_cache_num() { TrainingData get_cache(const size_t i) { const size_t pos = cache_pos[i]; const size_t candidateNum = ((cache_pos[i + 1] - pos) - sizeof(Hcpe3CacheBody)) / sizeof(Hcpe3CacheCandidate); - cache_mutex.lock(); - cache.seekg(pos, std::ios_base::beg); struct Hcpe3CacheBuf { Hcpe3CacheBody body; Hcpe3CacheCandidate candidates[MaxLegalMoves]; } buf; - cache.read((char*)&buf, sizeof(Hcpe3CacheBody) + sizeof(Hcpe3CacheCandidate) * candidateNum); - cache_mutex.unlock(); + cache->seekg(pos, std::ios_base::beg); + cache->read((char*)&buf, sizeof(Hcpe3CacheBody) + sizeof(Hcpe3CacheCandidate) * candidateNum); + return TrainingData(buf.body, buf.candidates, candidateNum); +} + +TrainingData get_cache_with_lock(const size_t i) { + const size_t pos = cache_pos[i]; + const size_t candidateNum = ((cache_pos[i + 1] - pos) - sizeof(Hcpe3CacheBody)) / sizeof(Hcpe3CacheCandidate); + struct Hcpe3CacheBuf { + Hcpe3CacheBody body; + Hcpe3CacheCandidate candidates[MaxLegalMoves]; + } buf; + #pragma omp critical + { + cache->seekg(pos, std::ios_base::beg); + cache->read((char*)&buf, sizeof(Hcpe3CacheBody) + sizeof(Hcpe3CacheCandidate) * candidateNum); + } return TrainingData(buf.body, buf.candidates, candidateNum); } @@ -418,9 +430,9 @@ void __hcpe3_decode_with_value(const size_t len, char* ndindex, char* ndfeatures std::fill_n((float*)features2, sizeof(features2_t) / sizeof(float) * len, 0.0f); std::fill_n((float*)probability, 9 * 9 * MAX_MOVE_LABEL_NUM * len, 0.0f); - #pragma omp parallel for num_threads(2) - for (int64_t i = 0; i < len; i++) { - const auto& hcpe3 = cache.is_open() ? get_cache(index[i]) : trainingData[index[i]]; + #pragma omp parallel for num_threads(2) if (len > 1) + for (size_t i = 0; i < len; i++) { + const auto& hcpe3 = cache ? (len > 1 ? get_cache_with_lock(index[i]) : get_cache(index[i])) : trainingData[index[i]]; Position position; position.set(hcpe3.hcp); @@ -447,10 +459,10 @@ void __hcpe3_decode_with_value(const size_t len, char* ndindex, char* ndfeatures void __hcpe3_get_hcpe(const size_t index, char* ndhcpe) { HuffmanCodedPosAndEval* hcpe = reinterpret_cast(ndhcpe); - const auto& hcpe3 = cache.is_open() ? get_cache(index) : trainingData[index]; + const auto& hcpe3 = cache ? get_cache(index) : trainingData[index]; hcpe->hcp = hcpe3.hcp; - float max_prob = FLT_MIN ; + float max_prob = FLT_MIN; for (const auto kv : hcpe3.candidates) { const auto& move16 = kv.first; const auto& prob = kv.second;