Skip to content

Conversation

ev-br
Copy link
Collaborator

@ev-br ev-br commented Apr 18, 2023

Try testing our torch_np wrapper on real-world examples. Several toy problems first:

Several examples from N. Rougier, From python to numpy, https://www.labri.fr/perso/nrougier/from-python-to-numpy/


The strategy is to replace import numpy as np with import torch_np as np (i.e., in eager mode). See e2e/tests.md for some notes.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat! The only comment I have (not for this PR) is that perhaps we can wire np.random.seed to torch.set_seed together with a torch.warn_once saying "careful, the random numbers are not the same!". This would allow us to not use _np at all, and still run these models. The point here is that here they just set a seed for reproducibility, but any seed (and hopefully any generator) will work just fine.

@ev-br
Copy link
Collaborator Author

ev-br commented Apr 18, 2023

_np.random is me, not them. It's only to check that the results are exactly the same.

@ev-br
Copy link
Collaborator Author

ev-br commented Apr 18, 2023

@honno could you give me a hand at excluding everything under ./e2e from pre-commit please?

EDIT: figured it out.

@ev-br
Copy link
Collaborator Author

ev-br commented Apr 18, 2023

OK, a first fix to ndarray.__setitem__: 5c8f568

Here's a minimal reproducer:

In [48]: ts = torch.empty(3, dtype=torch.float32)

In [49]: idx = [2, 0, 1]

In [50]: tf = torch.as_tensor([2, 2, 2], dtype=float)

In [51]: ts[idx] = tf
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [51], in <cell line: 1>()
----> 1 ts[idx] = tf

RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Double for the source.

Note that in PyTorch this is specific to advanced indexing:

In [52]: ts[...] = tf

In [53]: ts
Out[53]: tensor([2., 2., 2.], dtype=torch.float32)

So maybe this can be considered a bug in PyTorch @lezcano ?

@lezcano
Copy link
Collaborator

lezcano commented Apr 18, 2023

Not a bug, as it throws a clean error, so this was accounted for. The fix LGTM.

@ev-br
Copy link
Collaborator Author

ev-br commented Apr 19, 2023

@lezcano I wonder if you could clarify something for me? calling abs(complex64 tensor) overflows or not depending on the tensor size, is this expected?
(found when playing with the Mandelbrot set example, https://github.com/rougier/from-python-to-numpy/blob/master/code/mandelbrot_numpy_1.py)

In [61]: z2
Out[61]: 
tensor([[ 1.6183e+21-2.6955e+21j, -3.5169e+19-9.3859e+18j, -2.3807e+17+4.3302e+17j,
          2.8027e+15+7.9634e+15j,  1.2818e+14+1.3907e+14j,  3.8502e+12+3.9319e+12j,
         -2.2755e+22+2.8512e+22j, -3.6647e+19-2.8135e+19j],
        [ 3.6003e+19+5.4754e+18j,  3.4639e+14-2.2696e+17j, -1.5240e+15-7.9048e+14j,
         -1.7409e+13+9.2506e+11j,  6.3255e+22-1.4787e+22j,  2.8390e+19+8.0909e+18j,
          9.4602e+15+1.7841e+16j, -6.7203e+12+8.3905e+12j],
        [ 3.2646e+17+2.1984e+17j,  6.3642e+14-9.3251e+14j, -7.6484e+11-4.1867e+12j,
         -4.6086e+20+3.2900e+20j, -4.4375e+16-1.3281e+16j,  4.9347e+12-1.5831e+13j,
          1.8095e+20-1.7608e+20j,  2.7710e+13-1.5209e+13j],
        [ 4.0589e+15+2.4785e+14j,  1.1455e+12-3.8486e+12j, -4.3807e+19-2.2136e+18j,
         -3.1675e+14-2.6598e+14j,  1.8239e+19-2.0438e+19j, -4.6712e+12-2.3371e+11j,
         -8.8752e+17-2.3184e+18j,  4.4035e+16+1.5998e+16j],
        [-1.5219e+13-4.2478e+13j,  4.7741e+19+2.2545e+19j, -1.4828e+13+1.1509e+13j,
         -5.3467e+16+2.6890e+16j,  3.9061e+13+4.2480e+13j, -1.3268e+18-1.1204e+19j,
          4.2014e+14-1.1959e+14j,  8.6712e+15-7.6094e+15j],
        [ 1.1673e+12+2.4728e+10j, -4.2480e+12-1.2609e+12j, -3.7508e+18+5.3464e+17j,
          3.3215e+18-1.6886e+18j, -1.1585e+14+3.4563e+13j,  1.6029e-01+8.6051e-02j,
         -2.2068e-02+1.1888e-01j, -3.5691e+16+2.6905e+16j],
        [ 1.1673e+12-2.4728e+10j, -4.2480e+12+1.2609e+12j, -3.7508e+18-5.3464e+17j,
          3.3215e+18+1.6886e+18j, -1.1585e+14-3.4563e+13j,  1.6029e-01-8.6051e-02j,
         -2.2068e-02-1.1888e-01j, -3.5691e+16-2.6905e+16j],
        [-1.5219e+13+4.2478e+13j,  4.7741e+19-2.2545e+19j, -1.4828e+13-1.1509e+13j,
         -5.3467e+16-2.6890e+16j,  3.9061e+13-4.2480e+13j, -1.3268e+18+1.1204e+19j,
          4.2014e+14+1.1959e+14j,  8.6712e+15+7.6094e+15j]])

In [62]: z2.dtype
Out[62]: torch.complex64

In [63]: abs(z2)
Out[63]: 
tensor([[       inf,        inf, 4.9415e+17, 8.4422e+15, 1.8913e+14, 5.5031e+12,        inf,
                inf],
        [       inf, 2.2696e+17, 1.7168e+15, 1.7434e+13,        inf,        inf, 2.0194e+16,
         1.0750e+13],
        [3.9358e+17, 1.1290e+15, 4.2560e+12,        inf, 4.6320e+16, 1.6582e+13,        inf,
         3.1609e+13],
        [4.0665e+15, 4.0155e+12,        inf, 4.1361e+14,        inf, 4.6770e+12, 2.4825e+18,
         4.6851e+16],
        [4.5122e+13,        inf, 1.8770e+13, 5.9848e+16, 5.7709e+13, 1.1282e+19, 4.3683e+14,
         1.1537e+16],
        [1.1676e+12, 4.4312e+12, 3.7887e+18, 3.7261e+18, 1.2090e+14, 1.8193e-01, 1.2091e-01,
         4.4696e+16],
        [1.1676e+12, 4.4312e+12, 3.7887e+18, 3.7261e+18, 1.2090e+14, 1.8193e-01, 1.2091e-01,
         4.4696e+16],
        [4.5122e+13,        inf, 1.8770e+13, 5.9848e+16, 5.7709e+13, 1.1282e+19, 4.3683e+14,
         1.1537e+16]])

In [64]: abs(z2[:7, :7])
Out[64]: 
tensor([[3.1440e+21, 3.6400e+19, 4.9415e+17, 8.4422e+15, 1.8913e+14, 5.5031e+12, 3.6479e+22],
        [3.6417e+19, 2.2696e+17, 1.7168e+15, 1.7434e+13, 6.4960e+22, 2.9520e+19, 2.0194e+16],
        [3.9358e+17, 1.1290e+15, 4.2560e+12, 5.6624e+20, 4.6320e+16, 1.6582e+13, 2.5248e+20],
        [4.0665e+15, 4.0155e+12, 4.3863e+19, 4.1361e+14, 2.7393e+19, 4.6770e+12, 2.4825e+18],
        [4.5122e+13, 5.2797e+19, 1.8770e+13, 5.9848e+16, 5.7709e+13, 1.1282e+19, 4.3683e+14],
        [1.1676e+12, 4.4312e+12, 3.7887e+18, 3.7261e+18, 1.2090e+14, 1.8193e-01, 1.2091e-01],
        [1.1676e+12, 4.4312e+12, 3.7887e+18, 3.7261e+18, 1.2090e+14, 1.8193e-01, 1.2091e-01]])

@lezcano
Copy link
Collaborator

lezcano commented Apr 19, 2023

Ugh, yeah, that's possible. On CPU, when you have 8 or more elements, vectorisation kicks in. It looks like the vectorisation implementation of abs(complex64) is not as stable as it could be:
https://github.com/pytorch/pytorch/blob/e605b5df744cf5b694b99a206baaccb8eb0d4e0b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h#L682-L683
(It's naïvely computing sqrt(a*a+b*b)).

lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 19, 2023
@ev-br found in
Quansight-Labs/numpy_pytorch_interop#117 (comment)
that the precision of `abs()` for large values in the vectorised case is less-than-good.
This PR fixes this issue. While doing that, we are able to comment out a
few tests on extremal values.

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 19, 2023
ev-br found in
Quansight-Labs/numpy_pytorch_interop#117 (comment)
that the precision of `abs()` for large values in the vectorised case is less-than-good.
This PR fixes this issue. While doing that, we are able to comment out a
few tests on extremal values.

ghstack-source-id: 96a8b77
Pull Request resolved: #99550
@ev-br
Copy link
Collaborator Author

ev-br commented Apr 20, 2023

Ok, we have four examples, which all basically work (in eager mode). Let's merge it, to have a link to add to the RFC.

@ev-br ev-br merged commit f7bc826 into main Apr 20, 2023
@ev-br ev-br deleted the e2e_tests branch April 20, 2023 07:41
@lezcano
Copy link
Collaborator

lezcano commented Apr 20, 2023

This is great! Could you add a short readme, similar to the OP, in the folder?

@ev-br
Copy link
Collaborator Author

ev-br commented Apr 20, 2023

Is this what you meant: #120

lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 20, 2023
ev-br found in
Quansight-Labs/numpy_pytorch_interop#117 (comment)
that the precision of `abs()` for large values in the vectorised case is less-than-good.
This PR fixes this issue. While doing that, we are able to comment out a
few tests on extremal values.

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 20, 2023
ev-br found in
Quansight-Labs/numpy_pytorch_interop#117 (comment)
that the precision of `abs()` for large values in the vectorised case is less-than-good.
This PR fixes this issue. While doing that, we are able to comment out a
few tests on extremal values.

Fixes #53958 #48486

cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10

[ghstack-poisoned]
lezcano added a commit to pytorch/pytorch that referenced this pull request Apr 20, 2023
ev-br found in
Quansight-Labs/numpy_pytorch_interop#117 (comment)
that the precision of `abs()` for large values in the vectorised case is less-than-good.
This PR fixes this issue. While doing that, we are able to comment out a
few tests on extremal values.

ghstack-source-id: c3ca417
Pull Request resolved: #99550
@jisaacso
Copy link

I noticed your Mandelbrot example has some timeit benchmarking, was there anything interesting to report on performance differences?

@ev-br
Copy link
Collaborator Author

ev-br commented Apr 20, 2023

Nothing really interesting beyond "sloth slow in eager mode". How slow exactly I did not time, would be more interesting when integrated with torch.dynamo, I guess.

And it's not mine really, all credit goes to Nicholas Rougier :-).

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Apr 24, 2023
@ev-br found in
Quansight-Labs/numpy_pytorch_interop#117 (comment)
that the precision of `abs()` for large values in the vectorised case is less-than-good.
This PR fixes this issue. While doing that, we are able to comment out a
few tests on extremal values.

Fixes #53958 #48486

Pull Request resolved: #99550
Approved by: https://github.com/ngimel, https://github.com/peterbell10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants