diff --git a/pytorch_binding/warpctc_pytorch/__init__.py b/pytorch_binding/warpctc_pytorch/__init__.py index b5bd390..081197b 100644 --- a/pytorch_binding/warpctc_pytorch/__init__.py +++ b/pytorch_binding/warpctc_pytorch/__init__.py @@ -42,7 +42,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, size_average=False, if length_average: # Compute the avg. log-probability per batch sample and frame. - total_length = torch.prod(act_lens) + total_length = torch.sum(act_lens) grads = grads / total_length costs = costs / total_length elif size_average: