Skip to content

Commit

Permalink
fix position bias in tensor parallel (#1714)
Browse files Browse the repository at this point in the history
* fix position bias in tensor parallel

* add symbol ncclCommFinalize
  • Loading branch information
minhthuc2502 committed May 30, 2024
1 parent 3b248f1 commit 5eb5d5a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/cuda/nccl_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ extern "C" {
return func(comm);
}

ncclResult_t ncclCommAbort(ncclComm_t comm) {
ncclResult_t ncclCommFinalize(ncclComm_t comm) {
using Signature = ncclResult_t(*)(ncclComm_t comm);
static auto func = ctranslate2::load_symbol<Signature>("ncclCommAbort");
static auto func = ctranslate2::load_symbol<Signature>("ncclCommFinalize");
return func(comm);
}

Expand Down
2 changes: 1 addition & 1 deletion src/devices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ namespace ctranslate2 {
for (auto* comm : _nccl_comms) {
//finalizing NCCL
if (*comm) {
NCCL_CHECK(ncclCommAbort(*comm));
NCCL_CHECK(ncclCommFinalize(*comm));
NCCL_CHECK(ncclCommDestroy(*comm));
}
}
Expand Down
15 changes: 13 additions & 2 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ctranslate2/layers/attention.h"
#include "ctranslate2/ops/split.h"
#include "ctranslate2/utils.h"


#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -210,11 +212,20 @@ namespace ctranslate2 {
is_decoder,
with_cache ? key_length - 1 : 0);
}
StorageView* position_bias_per_gpu = position_bias;
StorageView position_bias_tmp(position_bias->dtype(), position_bias->device());
if (ScopedMPISetter::getCurRank() != 0) {
const dim_t num_head_per_gpu = SAFE_DIVIDE(position_bias->dim(0), ScopedMPISetter::getNRanks());
ops::Slide slide_ops(0, num_head_per_gpu * ScopedMPISetter::getCurRank(),
num_head_per_gpu, true);
slide_ops(*position_bias, position_bias_tmp);
position_bias_per_gpu = &position_bias_tmp;
}

DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(),
primitives<D>::add_batch_broadcast(position_bias->data<T>(),
primitives<D>::add_batch_broadcast(position_bias_per_gpu->data<T>(),
output.data<T>(),
position_bias->size(),
position_bias_per_gpu->size(),
output.size()));
}

Expand Down

0 comments on commit 5eb5d5a

Please sign in to comment.