Skip to content

Commit

Permalink
[Trainer] update distributed dataloader (#8426)
Browse files Browse the repository at this point in the history
* [DistDataloader] Update implementation, add nested.py (#8380)
* fix distdataloader, fix eval with dp group (#8420)
  • Loading branch information
DesmonDay committed May 13, 2024
1 parent 9e4a4f4 commit debb2ad
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 222 deletions.
193 changes: 68 additions & 125 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.distributed import fleet

from paddlenlp.utils.log import logger

_MAX_DATA_DIM = 64
from paddlenlp.utils.nested import (
nested_broadcast_tensor,
nested_copy_place,
nested_empty_tensor,
nested_reduce_tensor,
)


class DummyDataset(paddle.io.Dataset):
Expand Down Expand Up @@ -53,6 +56,7 @@ def __init__(
timeout=0,
worker_init_fn=None,
persistent_workers=False,
eval=False,
):

if dataset is None:
Expand All @@ -62,12 +66,15 @@ def __init__(
super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)

self._hcg = fleet.get_hybrid_communicate_group()
self.eval = eval

# Init pp data comm group.
if self._hcg.get_pipe_parallel_world_size() > 1:
self._pp_data_group = self._init_dataloader_comm_group()
self._pp_group = self._hcg.get_pipe_parallel_group()
else:
self._pp_data_group = None
self._pp_group = None

self.mp_group = self._hcg.get_model_parallel_group()
self.mp_rank = self._hcg.get_model_parallel_rank()
Expand All @@ -78,10 +85,6 @@ def __init__(
sharding_rank = self._hcg.get_sharding_parallel_rank()
self._need_data = (self.mp_rank == 0) and (self.pp_rank == 0)

# When needed other data types, we can modify dtype_list.
self.dtype_list = [paddle.int64, paddle.float32, paddle.int32]
self._data_keys_list, self._data_keys_size = None, None

if self._need_data:
self._dataloader = paddle.io.DataLoader(
dataset,
Expand Down Expand Up @@ -127,7 +130,6 @@ def _init_dataloader_comm_group(self):
parallel_groups = topo.get_comm_list("pipe")

for group in parallel_groups:
# only first rank and last rank
ranks = [group[0], group[-1]]
comm_group = paddle.distributed.new_group(ranks=ranks)
if paddle.distributed.get_rank() in ranks:
Expand All @@ -137,127 +139,68 @@ def _init_dataloader_comm_group(self):
def __iter__(self):
return self

def __next__(self):
data_keys_size = [0 for i in range(len(self.dtype_list))]
if self._need_data:
data = next(self._dataloader_iter)
data_keys = list(data.keys())

for key in data_keys:
if data[key].dtype not in self.dtype_list:
raise ValueError(
f"Dist dataloader requires dtype as `int64`, `float32` or `int32` currently, but got: {data[key].dtype}"
def _broadcast_data(self, data):
process_rank = paddle.distributed.get_rank()
if self.mp_group.nranks > 1:
if process_rank == self.mp_src_rank:
fake_data = [nested_reduce_tensor(data)]
else:
if data is not None:
logger.warning(
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)

data_list, data_keys_list = [], []
for i, dtype in enumerate(self.dtype_list):
data_list.append([data[key] for key in data_keys if data[key].dtype == dtype])
data_keys_list.append([key for key in data_keys if data[key].dtype == dtype])
data_keys_size = [len(keys) for keys in data_keys_list]

# Broadcast data keys size.
if self._data_keys_size is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_size, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_size, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_size = data_keys_size

if not self._need_data:
data_keys_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

# Broadcast data keys name.
if self._data_keys_list is None:
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(data_keys_list, src=self.mp_src_rank, group=self.mp_group)
if self._pp_data_group is not None:
paddle.distributed.broadcast_object_list(
data_keys_list, src=self._pp_data_group.ranks[0], group=self._pp_data_group
)
self._data_keys_list = data_keys_list

# Broadcast data.
if not self._need_data:
data_list = [[None for i in range(keys_size)] for keys_size in self._data_keys_size]

if self.mp_group.nranks > 1 and self.pp_rank == 0:
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i], dtype, self.mp_rank, self.mp_group, self.mp_src_rank
fake_data = [None]
if self._pp_group is not None:
if process_rank == self._pp_group.ranks[0]:
fake_data = [nested_reduce_tensor(data)]
else:
if data is not None:
logger.warning(
f"Your local rank {paddle.distributed.get_rank()} are forbidden to have a state_dict."
)
fake_data = [None]
if self.mp_group.nranks > 1 and self.pp_rank == 0:
paddle.distributed.broadcast_object_list(
fake_data,
src=self.mp_src_rank,
group=self.mp_group,
)
if self._pp_group is not None:
paddle.distributed.broadcast_object_list(
fake_data,
src=self._pp_group.ranks[0],
group=self._pp_group,
)

if self._pp_data_group is not None:
# Note(daisimng): In last stage of pp, we don't need input_ids.
# It will be removed in future.
for i, dtype in enumerate(self.dtype_list):
if self._data_keys_size[i] > 0:
data_list[i] = broadcast_data_list(
data_list[i],
dtype,
self.pp_rank,
self._pp_data_group,
self._pp_data_group.ranks[0],
)

out_data = {}
for keys, datas in zip(self._data_keys_list, data_list):
out_data.update([(k, d) for k, d in zip(keys, datas)])

return out_data


def broadcast_data_list(data_list, datatype, comm_rank=0, comm_group=None, src_rank=0):
"""
Broadcast data from src_rank to all ranks in comm_group.
"""
# Move to GPU and broadcast.
size_cpu = []
if comm_rank == 0:
for data in data_list:
size_cpu.append(len(data.shape))
size_cpu += data.shape
size_cpu = size_cpu + [0] * (_MAX_DATA_DIM - len(size_cpu))
size_cuda = paddle.to_tensor(size_cpu)
paddle.distributed.broadcast(size_cuda, src_rank, group=comm_group).wait()

size_cpu = size_cuda.tolist()
i = 0
numel = 0
sizes = []
while size_cpu[i] > 0:
rank = size_cpu[i]
this_size = size_cpu[i + 1 : i + 1 + rank]
numel += int(np.prod(this_size))
sizes.append(this_size)
i += rank + 1

if comm_rank == 0:
assert data.dtype == datatype, "input has data type {} which " "is different than {}".format(
data.dtype, datatype
)
if paddle.is_compiled_with_cuda():
data_b = paddle.concat([d.cuda().reshape([-1]) for d in data_list], 0)
else:
data_b = paddle.concat([d.reshape([-1]) for d in data_list], 0)
fake_data = fake_data[0]
if fake_data is None:
raise StopIteration

assert numel == sum([d.numel().item() for d in data_list]), (numel, [d.numel().item() for d in data_list])
else:
if paddle.is_compiled_with_cuda():
data_b = paddle.empty([numel], dtype=datatype).cuda()
else:
data_b = paddle.empty([numel], dtype=datatype)
dst_pp_group = self._pp_group if self.eval else self._pp_data_group
if self.mp_group.nranks > 1:
if process_rank != self.mp_src_rank:
data = nested_empty_tensor(fake_data)
if dst_pp_group is not None:
if process_rank != dst_pp_group.ranks[0]:
data = nested_empty_tensor(fake_data)

# Broadcast
paddle.distributed.broadcast(data_b, src_rank, group=comm_group).wait()
if self.mp_group.nranks > 1 and self.pp_rank == 0:
data = nested_broadcast_tensor(data, src=self.mp_src_rank, group=self.mp_group)
if dst_pp_group is not None:
data = nested_broadcast_tensor(data, src=dst_pp_group.ranks[0], group=dst_pp_group)
# for pp1 - pp_{n-1}, Paddle need to recevie empty dict for pipeline parallel.
if data is None:
data = {}

ret = []
offset = 0
for size in sizes:
numel = int(np.prod(size))
ret.append(data_b[offset : offset + numel].reshape(size))
offset += numel
return data

return ret
def __next__(self):
data = None
if self._need_data:
try:
data = next(self._dataloader_iter)
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
except:
pass
data = self._broadcast_data(data)
return data
21 changes: 1 addition & 20 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
SAFE_WEIGHTS_NAME,
)
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy, nested_copy_place

if is_safetensors_available():
from safetensors import safe_open
Expand Down Expand Up @@ -1876,26 +1877,6 @@ def mapping_optimizer_tp_actions(tp_actions, optimizer_loaded_keys):
return new_actions


def nested_copy(inputs):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy(inputs[key])
return outputs
return inputs


def nested_copy_place(inputs, place=None, blocking=False):
if isinstance(inputs, dict):
outputs = {}
for key in list(inputs.keys()):
outputs[key] = nested_copy_place(inputs[key], place, blocking)
return outputs
if isinstance(inputs, paddle.Tensor):
inputs = inputs if inputs.place == place else inputs._copy_to(place, blocking)
return inputs


def flatten_list(nested_list):
flattened_list = []
for item in nested_list:
Expand Down
Loading

0 comments on commit debb2ad

Please sign in to comment.