Skip to content

Commit

Permalink
バッチサイズ1の場合並列化しない
Browse files Browse the repository at this point in the history
  • Loading branch information
TadaoYamaoka committed Sep 9, 2023
1 parent 5384f9f commit 300f538
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions cppshogi/python_module.cpp
Expand Up @@ -32,12 +32,12 @@ inline T is_nyugyoku(const uint8_t result) {
}

void __hcpe_decode_with_value(const size_t len, char* ndhcpe, char* ndfeatures1, char* ndfeatures2, char* ndmove, char* ndresult, char* ndvalue) {
HuffmanCodedPosAndEval *hcpe = reinterpret_cast<HuffmanCodedPosAndEval *>(ndhcpe);
HuffmanCodedPosAndEval* hcpe = reinterpret_cast<HuffmanCodedPosAndEval*>(ndhcpe);
features1_t* features1 = reinterpret_cast<features1_t*>(ndfeatures1);
features2_t* features2 = reinterpret_cast<features2_t*>(ndfeatures2);
int64_t* move = reinterpret_cast<int64_t*>(ndmove);
float *result = reinterpret_cast<float *>(ndresult);
float *value = reinterpret_cast<float *>(ndvalue);
float* result = reinterpret_cast<float*>(ndresult);
float* value = reinterpret_cast<float*>(ndvalue);

// set all zero
std::fill_n((float*)features1, sizeof(features1_t) / sizeof(float) * len, 0.0f);
Expand All @@ -62,7 +62,7 @@ void __hcpe_decode_with_value(const size_t len, char* ndhcpe, char* ndfeatures1,
}

void __hcpe2_decode_with_value(const size_t len, char* ndhcpe2, char* ndfeatures1, char* ndfeatures2, char* ndmove, char* ndresult, char* ndvalue, char* ndaux) {
HuffmanCodedPosAndEval2 *hcpe = reinterpret_cast<HuffmanCodedPosAndEval2 *>(ndhcpe2);
HuffmanCodedPosAndEval2* hcpe = reinterpret_cast<HuffmanCodedPosAndEval2*>(ndhcpe2);
features1_t* features1 = reinterpret_cast<features1_t*>(ndfeatures1);
features2_t* features2 = reinterpret_cast<features2_t*>(ndfeatures2);
int64_t* move = reinterpret_cast<int64_t*>(ndmove);
Expand Down Expand Up @@ -144,17 +144,16 @@ void __hcpe3_create_cache(const std::string& filepath) {
}

// hcpe3キャッシュ
std::ifstream cache;
std::ifstream* cache;
std::vector<size_t> cache_pos;
std::mutex cache_mutex;
size_t __hcpe3_load_cache(const std::string& filepath) {
cache.open(filepath, std::ios::binary);
cache = new std::ifstream(filepath, std::ios::binary);
size_t num;
cache.read((char*)&num, sizeof(num));
cache->read((char*)&num, sizeof(num));
cache_pos.resize(num + 1);
cache.read((char*)cache_pos.data(), sizeof(size_t) * num);
cache.seekg(0, std::ios_base::end);
cache_pos[num] = cache.tellg();
cache->read((char*)cache_pos.data(), sizeof(size_t) * num);
cache->seekg(0, std::ios_base::end);
cache_pos[num] = cache->tellg();
return num;
}

Expand All @@ -165,14 +164,27 @@ size_t __hcpe3_get_cache_num() {
TrainingData get_cache(const size_t i) {
const size_t pos = cache_pos[i];
const size_t candidateNum = ((cache_pos[i + 1] - pos) - sizeof(Hcpe3CacheBody)) / sizeof(Hcpe3CacheCandidate);
cache_mutex.lock();
cache.seekg(pos, std::ios_base::beg);
struct Hcpe3CacheBuf {
Hcpe3CacheBody body;
Hcpe3CacheCandidate candidates[MaxLegalMoves];
} buf;
cache.read((char*)&buf, sizeof(Hcpe3CacheBody) + sizeof(Hcpe3CacheCandidate) * candidateNum);
cache_mutex.unlock();
cache->seekg(pos, std::ios_base::beg);
cache->read((char*)&buf, sizeof(Hcpe3CacheBody) + sizeof(Hcpe3CacheCandidate) * candidateNum);
return TrainingData(buf.body, buf.candidates, candidateNum);
}

TrainingData get_cache_with_lock(const size_t i) {
const size_t pos = cache_pos[i];
const size_t candidateNum = ((cache_pos[i + 1] - pos) - sizeof(Hcpe3CacheBody)) / sizeof(Hcpe3CacheCandidate);
struct Hcpe3CacheBuf {
Hcpe3CacheBody body;
Hcpe3CacheCandidate candidates[MaxLegalMoves];
} buf;
#pragma omp critical
{
cache->seekg(pos, std::ios_base::beg);
cache->read((char*)&buf, sizeof(Hcpe3CacheBody) + sizeof(Hcpe3CacheCandidate) * candidateNum);
}
return TrainingData(buf.body, buf.candidates, candidateNum);
}

Expand Down Expand Up @@ -418,9 +430,9 @@ void __hcpe3_decode_with_value(const size_t len, char* ndindex, char* ndfeatures
std::fill_n((float*)features2, sizeof(features2_t) / sizeof(float) * len, 0.0f);
std::fill_n((float*)probability, 9 * 9 * MAX_MOVE_LABEL_NUM * len, 0.0f);

#pragma omp parallel for num_threads(2)
for (int64_t i = 0; i < len; i++) {
const auto& hcpe3 = cache.is_open() ? get_cache(index[i]) : trainingData[index[i]];
#pragma omp parallel for num_threads(2) if (len > 1)
for (size_t i = 0; i < len; i++) {
const auto& hcpe3 = cache ? (len > 1 ? get_cache_with_lock(index[i]) : get_cache(index[i])) : trainingData[index[i]];

Position position;
position.set(hcpe3.hcp);
Expand All @@ -447,10 +459,10 @@ void __hcpe3_decode_with_value(const size_t len, char* ndindex, char* ndfeatures
void __hcpe3_get_hcpe(const size_t index, char* ndhcpe) {
HuffmanCodedPosAndEval* hcpe = reinterpret_cast<HuffmanCodedPosAndEval*>(ndhcpe);

const auto& hcpe3 = cache.is_open() ? get_cache(index) : trainingData[index];
const auto& hcpe3 = cache ? get_cache(index) : trainingData[index];

hcpe->hcp = hcpe3.hcp;
float max_prob = FLT_MIN ;
float max_prob = FLT_MIN;
for (const auto kv : hcpe3.candidates) {
const auto& move16 = kv.first;
const auto& prob = kv.second;
Expand Down

0 comments on commit 300f538

Please sign in to comment.