Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 6th No.34】support return micro batch loss for dygraph train_batch #64218

Merged
merged 6 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,11 @@ def _flush_records(self):
self._records = []

def forward_backward_pipeline(
self, data, scaler=None, static_scheduler=False
self,
data,
scaler=None,
static_scheduler=False,
return_micro_batch_loss=False,
):
# use the 1f1b scheduling strategy.
# this strategy is inspired by:
Expand Down Expand Up @@ -625,7 +629,7 @@ def forward_backward_pipeline(
self.timers("allreduce_shared_weight_gradients").stop()
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()

Expand Down Expand Up @@ -691,7 +695,13 @@ def _wrap_data(self, data):
return micro_dataset

def train_batch(
self, data, optimizer, lr_scheduler=None, scaler=None, loss_fn_idx=0
self,
data,
optimizer,
lr_scheduler=None,
scaler=None,
loss_fn_idx=0,
return_micro_batch_loss=False,
):
data = self._prepare_training(data, optimizer, lr_scheduler)

Expand All @@ -703,7 +713,9 @@ def train_batch(
self.loss_fn_idx = loss_fn_idx

# 1f1b scheduler for pipeline parallel
train_loss = self.forward_backward_pipeline(data, scaler)
train_loss = self.forward_backward_pipeline(
data, scaler, return_micro_batch_loss=return_micro_batch_loss
)

# optimizer
with paddle.amp.auto_cast(enable=False):
Expand Down Expand Up @@ -823,10 +835,8 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
self.total_loss = []
# when self.total_loss length is less than idx, append a new tensor
if len(self.total_loss) <= idx:
self.total_loss.append(
paddle.zeros_like(loss_tensor)
)
self.total_loss[idx] += loss_tensor.detach()
self.total_loss.append([])
self.total_loss[idx].append(loss_tensor.detach())

if idx == self.loss_fn_idx:
backward_loss_tensor = loss_tensor
Expand Down Expand Up @@ -883,19 +893,26 @@ def _check_micro_batch_data_valid(self, micro_batch_data):
elif micro_batch_data is not None:
assert isinstance(micro_batch_data, paddle.Tensor)

def _broadcast_final_loss(self):
def _broadcast_final_loss(self, return_micro_batch_loss=False):
# Since the last backward run in interleave will set the virtual rank to 0,
# here we need to check last stage ignoring virtual stage.
if self.is_pipeline_last_stage(ignore_virtual=True):
assert (
self.total_loss is not None
), "train_batch() in last stage should obtain valid loss"
losses = [
self.total_loss[idx].detach()
if not self._delay_scale_loss
else self.total_loss[idx] / self.accumulate_steps
for idx in range(len(self._layers._loss_fn))
]
losses = []
for idx in range(len(self._layers._loss_fn)):
self.total_loss[idx] = paddle.to_tensor(self.total_loss[idx])
if not return_micro_batch_loss:
if not self._delay_scale_loss:
losses.append(paddle.sum(self.total_loss[idx]).detach())
else:
losses.append(
paddle.mean(self.total_loss[idx]).detach()
)
else:
losses.append(self.total_loss[idx].detach())

for idx in range(len(self._layers._loss_fn)):
is_fp32 = (
paddle.full([], 1, 'int64')
Expand Down Expand Up @@ -924,10 +941,14 @@ def _broadcast_final_loss(self):
sync_op=True,
group=self.pp_group,
)
if return_micro_batch_loss:
loss_shape = [self.accumulate_steps]
else:
loss_shape = [1]
losses.append(
paddle.zeros(shape=[1], dtype="float32")
paddle.zeros(shape=loss_shape, dtype="float32")
if is_fp32.item()
else paddle.zeros(shape=[1], dtype="float16")
else paddle.zeros(shape=loss_shape, dtype="float16")
)
paddle.distributed.broadcast(
losses[idx],
Expand Down Expand Up @@ -1202,6 +1223,7 @@ def forward_backward_pipeline(
forward_only=False,
compute_loss=True,
static_scheduler=False,
return_micro_batch_loss=False,
):
# use interleave scheduling strategy.
# this strategy is inspired by:
Expand Down Expand Up @@ -1740,7 +1762,7 @@ def _process_bwd_buffer(step_id, tensor):
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
Expand All @@ -1754,7 +1776,13 @@ def _process_bwd_buffer(step_id, tensor):
return train_loss

def train_batch(
self, data, optimizer, lr_scheduler=None, scaler=None, loss_fn_idx=0
self,
data,
optimizer,
lr_scheduler=None,
scaler=None,
loss_fn_idx=0,
return_micro_batch_loss=False,
):
data = self._prepare_training(data, optimizer, lr_scheduler)

Expand All @@ -1766,7 +1794,9 @@ def train_batch(
self.loss_fn_idx = loss_fn_idx

# interleave scheduler for pipeline parallel
train_loss = self.forward_backward_pipeline(data, scaler)
train_loss = self.forward_backward_pipeline(
data, scaler, return_micro_batch_loss=return_micro_batch_loss
)

# optimizer
with paddle.amp.auto_cast(enable=False):
Expand Down Expand Up @@ -1857,7 +1887,12 @@ def _sync_overlap_grads(self):
buffer.scale_grads()

def forward_backward_pipeline(
self, data, scaler, forward_only=False, compute_loss=True
self,
data,
scaler,
forward_only=False,
compute_loss=True,
return_micro_batch_loss=False,
):
if not compute_loss:
assert (
Expand Down Expand Up @@ -2009,7 +2044,7 @@ def forward_backward_pipeline(
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
Expand Down
Loading