Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] update distributed dataloader #8426

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 68 additions & 125 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.distributed import fleet

from paddlenlp.utils.log import logger

_MAX_DATA_DIM = 64
from paddlenlp.utils.nested import (
nested_broadcast_tensor,
nested_copy_place,
nested_empty_tensor,
nested_reduce_tensor,
)


class DummyDataset(paddle.io.Dataset):
Expand Down Expand Up @@ -53,6 +56,7 @@ def __init__(
timeout=0,
worker_init_fn=None,
persistent_workers=False,
eval=False,
):

if dataset is None:
Expand All @@ -62,12 +66,15 @@ def __init__(
super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)

self._hcg = fleet.get_hybrid_communicate_group()
self.eval = eval

# Init pp data comm group.
if self._hcg.get_pipe_parallel_world_size() > 1:
self._pp_data_group = self._init_dataloader_comm_group()
self._pp_group = self._hcg.get_pipe_parallel_group()
else:
self._pp_data_group = None
self._pp_group = None

self.mp_group = self._hcg.get_model_parallel_group()
self.mp_rank = self._hcg.get_model_parallel_rank()
Expand All @@ -78,10 +85,6 @@ def __init__(
sharding_rank = self._hcg.get_sharding_parallel_rank()
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)

# When needed other data types, we can modify dtype_list.
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
self._data_keys_list, self._data_keys_size = None, None

if self._need_data:
self._dataloader = paddle.io.DataLoader(
dataset,
Expand Down Expand Up @@ -127,7 +130,6 @@ def _init_dataloader_comm_group(self):
parallel_groups = topo.get_comm_list("pipe")

for group in parallel_groups:
# only first rank and last rank
ranks = [group[0], group[-1]]
comm_group = paddle.distributed.new_group(ranks=ranks)
if paddle.distributed.get_rank() in ranks:
Expand All @@ -137,127 +139,68 @@ def _init_dataloader_comm_group(self):
def __iter__(self):
return self

def __next__(self):
data_keys_size = [0 for i in range(len(self.dtype_list))]
if self._need_data:
data = next(self._dataloader_iter)
data_keys = list(data.keys())

for key in data_keys:
if data[key].dtype not in self.dtype_list:
raise ValueError(
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
def _broadcast_data(self, data):
process_rank = paddle.distributed.get_rank()
if self.mp_group.nranks > 1:
if process_rank == self.mp_src_rank:
fake_data = [nested_reduce_tensor(data)]
else:
if data is not None:
logger.warning(
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)

data_list, data_keys_list = [], []
for i, dtype in enumerate(self.dtype_list):
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
data_keys_size = [len(keys) for keys in data_keys_list]

# Broadcast data keys size.
if self._data_keys_size is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_size = data_keys_size

if not self._need_data:
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

# Broadcast data keys name.
if self._data_keys_list is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_list = data_keys_list

# Broadcast data.
if not self._need_data:
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

if self.mp_group.nranks > 1 and self.pp_rank == 0:
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
fake_data = [None]
if self._pp_group is not None:
if process_rank == self._pp_group.ranks[0]:
fake_data = [nested_reduce_tensor(data)]
else:
if data is not None:
logger.warning(
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)
fake_data = [None]
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(
fake_data,
src=self.mp_src_rank,
group=self.mp_group,
)
if self._pp_group is not None:
paddle.distributed.broadcast_object_list(
fake_data,
src=self._pp_group.ranks[0],
group=self._pp_group,
)

if self._pp_data_group is not None:
# Note(daisimng): In last stage of pp, we don't need input_ids.
# It will be removed in future.
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i],
dtype,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)

out_data = {}
for keys, datas in zip(self._data_keys_list, data_list):
out_data.update([(k, d) for k, d in zip(keys, datas)])

return out_data


def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
"""
Broadcast data from src_rank to all ranks in comm_group.
"""
# Move to GPU and broadcast.
size_cpu = []
if comm_rank == 0:
for data in data_list:
size_cpu.append(len(data.shape))
size_cpu += data.shape
size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu))
size_cuda = paddle.to_tensor(size_cpu)
paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait()

size_cpu = size_cuda.tolist()
i = 0
numel = 0
sizes = []
while size_cpu[i] > 0:
rank = size_cpu[i]
this_size = size_cpu[i + 1 : i + 1 + rank]
numel += int(np.prod(this_size))
sizes.append(this_size)
i += rank + 1

if comm_rank == 0:
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
data.dtype, datatype
)
if paddle.is_compiled_with_cuda():
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
else:
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)
fake_data = fake_data[0]
if fake_data is None:
raise StopIteration

assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
else:
if paddle.is_compiled_with_cuda():
data_b = paddle.empty([numel], dtype=datatype).cuda()
else:
data_b = paddle.empty([numel], dtype=datatype)
dst_pp_group = self._pp_group if self.eval else self._pp_data_group
if self.mp_group.nranks > 1:
if process_rank != self.mp_src_rank:
data = nested_empty_tensor(fake_data)
if dst_pp_group is not None:
if process_rank != dst_pp_group.ranks[0]:
data = nested_empty_tensor(fake_data)

# Broadcast
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()
if self.mp_group.nranks > 1 and self.pp_rank == 0:
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
if dst_pp_group is not None:
data = nested_broadcast_tensor(data, src=dst_pp_group.ranks[0], group=dst_pp_group)
# for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
if data is None:
data = {}

ret = []
offset = 0
for size in sizes:
numel = int(np.prod(size))
ret.append(data_b[offset : offset + numel].reshape(size))
offset += numel
return data

return ret
def __next__(self):
data = None
if self._need_data:
try:
data = next(self._dataloader_iter)
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
except:
pass
data = self._broadcast_data(data)
return data
21 changes: 1 addition & 20 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
SAFE_WEIGHTS_NAME,
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy, nested_copy_place

if is_safetensors_available():
from safetensors import safe_open
Expand Down Expand Up @@ -1876,26 +1877,6 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys):
return new_actions


def nested_copy(inputs):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy(inputs[key])
return outputs
return inputs


def nested_copy_place(inputs, place=None, blocking=False):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy_place(inputs[key], place, blocking)
return outputs
if isinstance(inputs, paddle.Tensor):
inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking)
return inputs


def flatten_list(nested_list):
flattened_list = []
for item in nested_list:
Expand Down
Loading
Loading