Skip to content

Commit

Permalink
【Hackathon 6th No.34】support return micro batch loss for dygraph trai…
Browse files Browse the repository at this point in the history
…n_batch (#64218)

* support return micro batch loss

* fix codestyle

* recover some code
  • Loading branch information
AndSonder committed May 16, 2024
1 parent 57c7b0d commit 84fb07d
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 22 deletions.
79 changes: 57 additions & 22 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,11 @@ def _flush_records(self):
self._records = []

def forward_backward_pipeline(
self, data, scaler=None, static_scheduler=False
self,
data,
scaler=None,
static_scheduler=False,
return_micro_batch_loss=False,
):
# use the 1f1b scheduling strategy.
# this strategy is inspired by:
Expand Down Expand Up @@ -625,7 +629,7 @@ def forward_backward_pipeline(
self.timers("allreduce_shared_weight_gradients").stop()
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()

Expand Down Expand Up @@ -691,7 +695,13 @@ def _wrap_data(self, data):
return micro_dataset

def train_batch(
self, data, optimizer, lr_scheduler=None, scaler=None, loss_fn_idx=0
self,
data,
optimizer,
lr_scheduler=None,
scaler=None,
loss_fn_idx=0,
return_micro_batch_loss=False,
):
data = self._prepare_training(data, optimizer, lr_scheduler)

Expand All @@ -703,7 +713,9 @@ def train_batch(
self.loss_fn_idx = loss_fn_idx

# 1f1b scheduler for pipeline parallel
train_loss = self.forward_backward_pipeline(data, scaler)
train_loss = self.forward_backward_pipeline(
data, scaler, return_micro_batch_loss=return_micro_batch_loss
)

# optimizer
with paddle.amp.auto_cast(enable=False):
Expand Down Expand Up @@ -823,10 +835,8 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
self.total_loss = []
# when self.total_loss length is less than idx, append a new tensor
if len(self.total_loss) <= idx:
self.total_loss.append(
paddle.zeros_like(loss_tensor)
)
self.total_loss[idx] += loss_tensor.detach()
self.total_loss.append([])
self.total_loss[idx].append(loss_tensor.detach())

if idx == self.loss_fn_idx:
backward_loss_tensor = loss_tensor
Expand Down Expand Up @@ -883,19 +893,26 @@ def _check_micro_batch_data_valid(self, micro_batch_data):
elif micro_batch_data is not None:
assert isinstance(micro_batch_data, paddle.Tensor)

def _broadcast_final_loss(self):
def _broadcast_final_loss(self, return_micro_batch_loss=False):
# Since the last backward run in interleave will set the virtual rank to 0,
# here we need to check last stage ignoring virtual stage.
if self.is_pipeline_last_stage(ignore_virtual=True):
assert (
self.total_loss is not None
), "train_batch() in last stage should obtain valid loss"
losses = [
self.total_loss[idx].detach()
if not self._delay_scale_loss
else self.total_loss[idx] / self.accumulate_steps
for idx in range(len(self._layers._loss_fn))
]
losses = []
for idx in range(len(self._layers._loss_fn)):
self.total_loss[idx] = paddle.to_tensor(self.total_loss[idx])
if not return_micro_batch_loss:
if not self._delay_scale_loss:
losses.append(paddle.sum(self.total_loss[idx]).detach())
else:
losses.append(
paddle.mean(self.total_loss[idx]).detach()
)
else:
losses.append(self.total_loss[idx].detach())

for idx in range(len(self._layers._loss_fn)):
is_fp32 = (
paddle.full([], 1, 'int64')
Expand Down Expand Up @@ -924,10 +941,14 @@ def _broadcast_final_loss(self):
sync_op=True,
group=self.pp_group,
)
if return_micro_batch_loss:
loss_shape = [self.accumulate_steps]
else:
loss_shape = [1]
losses.append(
paddle.zeros(shape=[1], dtype="float32")
paddle.zeros(shape=loss_shape, dtype="float32")
if is_fp32.item()
else paddle.zeros(shape=[1], dtype="float16")
else paddle.zeros(shape=loss_shape, dtype="float16")
)
paddle.distributed.broadcast(
losses[idx],
Expand Down Expand Up @@ -1202,6 +1223,7 @@ def forward_backward_pipeline(
forward_only=False,
compute_loss=True,
static_scheduler=False,
return_micro_batch_loss=False,
):
# use interleave scheduling strategy.
# this strategy is inspired by:
Expand Down Expand Up @@ -1740,7 +1762,7 @@ def _process_bwd_buffer(step_id, tensor):
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
Expand All @@ -1754,7 +1776,13 @@ def _process_bwd_buffer(step_id, tensor):
return train_loss

def train_batch(
self, data, optimizer, lr_scheduler=None, scaler=None, loss_fn_idx=0
self,
data,
optimizer,
lr_scheduler=None,
scaler=None,
loss_fn_idx=0,
return_micro_batch_loss=False,
):
data = self._prepare_training(data, optimizer, lr_scheduler)

Expand All @@ -1766,7 +1794,9 @@ def train_batch(
self.loss_fn_idx = loss_fn_idx

# interleave scheduler for pipeline parallel
train_loss = self.forward_backward_pipeline(data, scaler)
train_loss = self.forward_backward_pipeline(
data, scaler, return_micro_batch_loss=return_micro_batch_loss
)

# optimizer
with paddle.amp.auto_cast(enable=False):
Expand Down Expand Up @@ -1857,7 +1887,12 @@ def _sync_overlap_grads(self):
buffer.scale_grads()

def forward_backward_pipeline(
self, data, scaler, forward_only=False, compute_loss=True
self,
data,
scaler,
forward_only=False,
compute_loss=True,
return_micro_batch_loss=False,
):
if not compute_loss:
assert (
Expand Down Expand Up @@ -2009,7 +2044,7 @@ def forward_backward_pipeline(
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
Expand Down

0 comments on commit 84fb07d

Please sign in to comment.