Skip to content

Commit

Permalink
Update sequence_parallel for predict (#8551)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Jun 6, 2024
1 parent 5a6d0fa commit 66ad506
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
6 changes: 5 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
import paddle.nn as nn
from packaging import version
from paddle import framework
from paddle.base import core

try:
from paddle.base import core
except:
core = None
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@

import paddle.distributed.fleet.meta_parallel as mpu
from paddle import nn
from paddle.distributed.fleet.utils import sequence_parallel_utils

try:
from paddle.distributed.fleet.utils import sequence_parallel_utils
except:
sequence_parallel_utils = None

from paddlenlp.transformers.mc2_parallel_linear import (
MC2ColumnSeqParallelLinear,
Expand All @@ -29,8 +33,12 @@
Linear = nn.Linear
ColumnParallelLinear = mpu.ColumnParallelLinear
RowParallelLinear = mpu.RowParallelLinear
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
try:
ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear
RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear
except:
ColumnSequenceParallelLinear = None
RowSequenceParallelLinear = None

if get_env_device() == "npu":
if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None:
Expand Down

0 comments on commit 66ad506

Please sign in to comment.