# JAX vs PyTorch

# 1. JAX

In [None]:
import jax.numpy as jnp
import jax
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# jax memory allocation
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.95'

## Matrix multiplication: 

### Square matrices:

#### Flop:
`A (bs, N, N)` and `B (N, N)` are two square matrices of size `N x N`.

Then we have: 

$ bs \text{ matrix multiplications } * N \text{ rows } * N \text{ columns } * (N \text{ multiplications } + (N-1) \text{ additions})$

$= bs * N^2 * (2N-1) \text{ flop} $

#### Memory: 
When we use vmap, we actually do `bs` matrix multiplications in parallel which means that the memory usage passes from $N^2 * 3$ to $(2*bs + 1) N^2$ because the batched tensor is dispatched, every dispatch produces a result while the common tensor is shared between all the dispatches.

### Rectangular matrices:
`A (bs, N, M)` and `B (M, P)` 
we have: 
$ bs \text{ matrix multiplications } * N \text{ rows } * P \text{ columns } * (M \text{ multiplications } + (M-1) \text{ additions}) $

$= bs * N * P * (2M-1) \text{ flop} $

### Results for Square matrices:

| N     | bs    | Flop              | Time     | Flops        | Memory       |
|-------|-------|-------------------|----------|--------------|--------------|
| 10000 | 1     |   199'990'000'000 |    24 ms | 83.33 TFLOPS |  1.12 GB     |
| 5000  | 1     |   249'975'000'000 |  3.39 ms | 73.74 TFLOPS |   286 MB     |
| 2048  | 1     |    17'175'674'880 |   210 us | 81.79 TFLOPS |    48 MB     |
| 2000  | 1     |    15'996'000'000 |   205 us | 78.03 TFLOPS |    46 MB     |
| 1500  | 1     |     6'747'750'000 |   160 us | 41.17 TFLOPS |    26 MB     |
| 1024  | 1     |     2'146'435'072 |    37 us | 58.01 TFLOPS |    12 MB     |
| 1000  | 1     |     1'999'000'000 |  34.5 us | 58.79 TFLOPS |    11 MB     |
| 500   | 1     |       249'750'000 |    45 us |   5.5 TFLOPS |     3 MB     |
| 1000  | 100   |   199'900'000'000 |  2.53 ms | 79.01 TFLOPS |   766 MB     |
| 1000  | 1000  | 1'999'000'000'000 |  24.5 ms | 81.59 TFLOPS |  7.45 GB     |
| 1024  | 1024  | 2'197'949'513'728 |  26.1 ms | 84.12 TFLOPS |     8 GB     |

In [None]:
def flop_compute(N, bs):
    return bs * N**2 * (2*N - 1)
def memory(N, bs):
    return ((2*bs + 1) * N**2) * 4 
def format_nb(nb):
    return f'{nb:,}'.replace(',', "'")
def flops_compute(N, bs, time_in_s):
    a = flop_compute(N, bs)
    b = memory(N, bs)
    print(f'{N:<6} | {bs:<5} | {format_nb(a):>17} | {(a/time_in_s)/1e12:>6.4} | {b/(1024**3):.2f} GB - {b/(1024**2):.2f} MB - {b/(1024):.2f} KB')

flops_compute(10000,    1, 0.024)
flops_compute( 5000,    1, 0.00339)
flops_compute( 2048,    1, 0.000210)
flops_compute( 2000,    1, 0.000205)
flops_compute( 1500,    1, 0.000160)
flops_compute( 1024,    1, 0.000037)
flops_compute( 1000,    1, 0.000034)
flops_compute(  500,    1, 0.000045)
flops_compute( 1000,  100, 0.00253)
flops_compute( 1000, 1000, 0.0245)
flops_compute( 1024, 1024, 0.0261)



10000  | 1     | 1'999'900'000'000 |  83.33 | 1.12 GB - 1144.41 MB - 1171875.00 KB
5000   | 1     |   249'975'000'000 |  73.74 | 0.28 GB - 286.10 MB - 292968.75 KB
2048   | 1     |    17'175'674'880 |  81.79 | 0.05 GB - 48.00 MB - 49152.00 KB
2000   | 1     |    15'996'000'000 |  78.03 | 0.04 GB - 45.78 MB - 46875.00 KB
1500   | 1     |     6'747'750'000 |  42.17 | 0.03 GB - 25.75 MB - 26367.19 KB
1024   | 1     |     2'146'435'072 |  58.01 | 0.01 GB - 12.00 MB - 12288.00 KB
1000   | 1     |     1'999'000'000 |  58.79 | 0.01 GB - 11.44 MB - 11718.75 KB
500    | 1     |       249'750'000 |   5.55 | 0.00 GB - 2.86 MB - 2929.69 KB
1000   | 100   |   199'900'000'000 |  79.01 | 0.75 GB - 766.75 MB - 785156.25 KB
1000   | 1000  | 1'999'000'000'000 |  81.59 | 7.45 GB - 7633.21 MB - 7816406.25 KB
1024   | 1024  | 2'197'949'513'728 |  84.21 | 8.00 GB - 8196.00 MB - 8392704.00 KB


In [None]:
N = 1024
A = jnp.ones((N, N))
B = jnp.ones((N, N))

In [None]:
def matmul(A, B):
    return jnp.dot(A, B)
jit_matmul = jax.jit(matmul)

In [None]:
_ = jit_matmul(A, B)  # warmup

In [None]:
%timeit jit_matmul(A, B)  

36.3 μs ± 9.16 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


**<font color='green'>SOLVED: ```jax.lib.xla_bridge.get_backend()```</font>**

<font color='red'>I DON'T UNDERSTAND</font>

This code run:
```python
bs = 1024
N = 1024
B = jnp.ones((N, N))
C = jnp.ones((bs, N, N))
jit_vmap = jax.jit(jax.vmap(matmul, in_axes=(0, None), out_axes=0))
_ = jit_vmap(C, B)  # warmup
```

<font color='red'> But if I restart the kernel and run this code, I get OOM for the second part of the following code (which is exactly the same as the one above)</font>
```python
bs = 2048
N = 1024
B = jnp.ones((N, N))
C = jnp.ones((bs, N, N))
jit_vmap = jax.jit(jax.vmap(matmul, in_axes=(0, None), out_axes=0))
_ = jit_vmap(C, B)  # warmup
>>> RuntimeError: Resource exhausted: Out of memory while trying to allocate 8.00GiB

bs = 1024
N = 1024
B = jnp.ones((N, N))
C = jnp.ones((bs, N, N))
jit_vmap = jax.jit(jax.vmap(matmul, in_axes=(0, None), out_axes=0))
_ = jit_vmap(C, B)  # warmup
>>> RuntimeError: Resource exhausted: Out of memory while trying to allocate 4.00GiB
```

In [None]:
backend = jax.lib.xla_bridge.get_backend()
print(len(backend.live_buffers()))
for buf in backend.live_buffers(): buf.delete()

3


In [None]:
jax.clear_caches()

# if C exists delete it
if 'C' in locals():
    print('Deleting C', C.shape)
    del C, B

if 'jit_vmap' in locals():
    print('Deleting jit_vmap')
    del jit_vmap

bs = 1024 + 512
N = 1024
B = jnp.ones((N, N))
C = jnp.ones((bs, N, N))
jit_vmap = jax.jit(jax.vmap(matmul, in_axes=(0, None), out_axes=0))
_ = jit_vmap(C, B)  # warmup

Deleting C (1024, 1024, 1024)
Deleting jit_vmap


2025-03-01 12:28:09.618073: W external/tsl/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.02GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6459228160 bytes.

In [None]:
%timeit jit_vmap(C, B)

26.1 ms ± 4.47 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
vmap_jit = jax.vmap(jax.jit(matmul), in_axes=(0, None), out_axes=0)
_ = vmap_jit(C, B)  # warmup

In [None]:
%timeit vmap_jit(C, B)  

26.1 ms ± 3.21 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
def mvm(A, b):
    return jnp.dot(A, b)

jit_mvm = jax.jit(mvm)
vmap_jit_mvm = jax.vmap(jit_mvm, in_axes=(None, 0), out_axes=0)
_ = vmap_jit_mvm(A, B)  # warmup

In [None]:
%timeit vmap_jit_mvm(A, B)

428 μs ± 12.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
jit_vmap_mvm = jax.jit(jax.vmap(mvm, in_axes=(None, 0), out_axes=0))
_ = jit_vmap_mvm(A, B)  # warmup

In [None]:
%timeit jit_vmap_mvm(A, B)  

34.5 μs ± 31.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Benchmarking the softmax function of a vector
- x of size N
- exp(x) = N flop
- sum(exp(x)) = N-1 flop (can be highly optimized)
- exp(x) / sum(exp(x)) = N flop
- Total flop = 3N flop

In [None]:
jit_softmax = jax.jit(jax.nn.softmax)
x = jax.random.normal(jax.random.PRNGKey(0), (500000,))
a = jit_softmax(x)  # warmup
a.shape

(500000,)

In [None]:
%timeit jax.nn.softmax(x)
%timeit jit_softmax(x)  # 1.5 s

227 μs ± 45.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
t = 3 * 500*1e3 / (227 * 1e-6)
print(f'{t:.2e} FLOPS')

6.61e+09 FLOPS


: 

In [None]:
x = jax.random.normal(jax.random.PRNGKey(0), (500000,))
cpu_softmax = jax.jit(jax.nn.softmax, device=jax.devices("cpu")[0])
result = cpu_softmax(x)
gpu_softmax = jax.jit(jax.nn.softmax, device=jax.devices("gpu")[0])
result = gpu_softmax(x)


In [None]:
%timeit cpu_softmax(x)  # 
%timeit gpu_softmax(x)  # 

645 μs ± 60.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
30.9 μs ± 1.54 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
def hand_softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x))
jit_hand_softmax = jax.jit(hand_softmax)
result = jit_hand_softmax(x)  # warmup

In [None]:
%timeit hand_softmax(x)  
%timeit jit_hand_softmax(x)  # always 10% faster than jax.nn.softmax, surprisingly

98.5 μs ± 2.77 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.9 μs ± 1.89 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


: 

# 2. PyTorch

In [None]:
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

if torch.cuda.is_available():
    print("CUDA is available")
    print(torch.cuda.get_device_name())
    device = torch.device("cuda")
    print(device)
    print(torch.cuda.current_device())
    print(torch.cuda.device_count())

else:
    print("CUDA is not available")
    device = torch.device("cpu")

In [None]:
N = 1024
A = torch.ones((N, N)).cuda()
B = torch.ones((N, N)).cuda()

def matmul(A, B):
    return torch.matmul(A, B)

jit_matmul = torch.jit.script(matmul)
_ = jit_matmul(A, B)  # warmup


CUDA is available
NVIDIA GeForce RTX 4090
cuda
0
1


In [None]:
%timeit jit_matmul(A, B)

47 μs ± 44.2 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%timeit matmul(A, B)

47 μs ± 14.6 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
import time 
start = time.time()
n = 10000
for _ in range(n):
    matmul(A, B)
end = time.time()
print(f'{((end - start) / n)*1000:.3f}') # result in ms

0.04579520225524902


In [None]:
bs = 1024
C = torch.ones((bs, N, N)).cuda()


vmap_jit = torch.func.vmap(jit_matmul, in_dims=(0, None), out_dims=0)
vmap = torch.func.vmap(matmul, in_dims=(0, None), out_dims=0)
_ = vmap_jit(C, B)  # warmup
_ = vmap(C, B)  # warmup


In [None]:
import time 
start = time.time()
n = 400
for _ in range(n):
    vmap(C, B)
end = time.time()
print(((end - start) / n)*1000) # result in ms

13.908361792564392


In [None]:
jit_vmap = torch.jit.script(vmap)
_ = jit_vmap(C, B)  # warmup

NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/home/tristan/miniconda3/envs/.jax_conda_env_LearningJAX/lib/python3.12/site-packages/torch/_functorch/apis.py", line 187
    def wrapped(*args, **kwargs):
                        ~~~~~~~ <--- HERE
        return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
