Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible numerical error in log-norm computation #13

Open
maxwellzh opened this issue Feb 17, 2023 · 2 comments
Open

Possible numerical error in log-norm computation #13

maxwellzh opened this issue Feb 17, 2023 · 2 comments

Comments

@maxwellzh
Copy link

In current implementation, emissions and the predictions subtract their own maximum values respectively. But consider this case

emission[0, 0] = [0, -1000]
prediction[0, 0] = [-1000, 0]
->
# current impl
logNorm[0, 0, 0] = log(exp(emission[0, 0]-maxEs) @ exp(prediction[0, 0]-maxPs)) + maxEs + maxPs
                             = log(exp([0, -1000]) @ exp([-1000, 0]))
                             = log([1, exp(-1000)] @ [exp(-1000), 1])  <-- exp(-1000) would give 0 in FP32 precision
                             = log(0)
                             = -inf

# correct result
logNorm[0, 0, 0] = log(2) - 1000

I also tried convert emission and prediction into FP64 before calculating the logNorm, but it still didn't work in my asr experiment.

The broadcast-sum way is more numerical stable, but would consume O(B*T*U*V) memory.

logNorm = torch.log_softmax(emission.unsqueeze(2) + prediction.unsqueeze(1), dim=-1)

maxEs = emissions.max(dim=2, keepdim=True)[0]
maxPs = predictions.max(dim=2, keepdim=True)[0]
log_norms = torch.log(torch.bmm(
torch.exp(emissions - maxEs),
torch.exp((predictions - maxPs)).transpose(1, 2)))
log_norms = log_norms + maxEs + maxPs.transpose(1, 2)

@maxwellzh
Copy link
Author

There is a similar loss function impl from K2

https://github.com/danpovey/fast_rnnt/blob/2c2dc4b96a6b9a8c0dbedada94cdee53a9337402/fast_rnnt/python/fast_rnnt/rnnt_loss.py#L159-L162

It seems they just add a small value to avoid log(0), which would also introduce errors in calculation.

@pkufool @csukuangfj Could you take a look at this? I believe the implementation from k2 would also faces this issue.

@maxwellzh
Copy link
Author

maxwellzh commented Feb 17, 2023

The only way I could figure out is to implement a custom functionlogmmexp(a, b), where we need to compute a_k + b_k twice (Since we don't want to store the super large intermedia tensor).

At first time, reduce a_k + b_k to obtain max values, then we can get maxSum in shape (B, T, U);
At second time, compute logsumexp(a_k + b_k - maxSum)+maxSum at each position.

Update:
Just found people working at it pytorch/pytorch#54064

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant