From 66ad506a67dc45bbc3a9fc9d3d426764a14f155b Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Thu, 6 Jun 2024 13:10:26 +0800 Subject: [PATCH] Update sequence_parallel for predict (#8551) --- paddlenlp/trainer/trainer.py | 6 +++++- paddlenlp/transformers/linear_utils.py | 14 +++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index ff6b71cae4d..2f3aad5e712 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -40,7 +40,11 @@ import paddle.nn as nn from packaging import version from paddle import framework -from paddle.base import core + +try: + from paddle.base import core +except: + core = None from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import ( HybridParallelOptimizer, diff --git a/paddlenlp/transformers/linear_utils.py b/paddlenlp/transformers/linear_utils.py index de1a0f886b7..469e7c45985 100644 --- a/paddlenlp/transformers/linear_utils.py +++ b/paddlenlp/transformers/linear_utils.py @@ -18,7 +18,11 @@ import paddle.distributed.fleet.meta_parallel as mpu from paddle import nn -from paddle.distributed.fleet.utils import sequence_parallel_utils + +try: + from paddle.distributed.fleet.utils import sequence_parallel_utils +except: + sequence_parallel_utils = None from paddlenlp.transformers.mc2_parallel_linear import ( MC2ColumnSeqParallelLinear, @@ -29,8 +33,12 @@ Linear = nn.Linear ColumnParallelLinear = mpu.ColumnParallelLinear RowParallelLinear = mpu.RowParallelLinear -ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear -RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear +try: + ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear + RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear +except: + ColumnSequenceParallelLinear = None + RowSequenceParallelLinear = None if get_env_device() == "npu": if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: