# Source:
https://www.youtube.com/watch?v=cGEIEnekmRM&t=301s

# Pytorch

In [1]:
import torch

def create_torch_tensors(device):
    x = torch.rand((10000, 10000), dtype=torch.float32)
    y = torch.rand((10000, 10000), dtype=torch.float32)
    x = x.to(device)
    y = y.to(device)
    
    return x, y

In [2]:
device = torch.device("cpu")
x, y = create_torch_tensors(device)

In [3]:
%%timeit
x * y

36.7 ms ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
device = torch.device("mps")
x, y = create_torch_tensors(device)

In [5]:
%%timeit
x * y

121 μs ± 51.1 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Tensorflow

In [6]:
import tensorflow as tf

def create_tf_tensors():
    x = tf.random.uniform((10000, 10000), dtype=tf.float32)
    y = tf.random.uniform((10000, 10000), dtype=tf.float32)
    
    return x, y

x, y = create_tf_tensors()

In [7]:
%%timeit

with tf.device("/CPU:0"):
    x * y

The slowest run took 6.68 times longer than the fastest. This could mean that an intermediate result is being cached.
53.4 ms ± 53.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit

with tf.device("/GPU:0"):
    x * y

14.1 ms ± 317 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# JAX

In [9]:
import os
# Commenting the following line out enables use of the GPU
# os.environ['JAX_PLATFORMS'] = 'cpu'

In [10]:
import jax
import jax.numpy as jnp

def create_jax_tensors():
    x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
    y = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
    
    return x, y

x, y = create_jax_tensors()

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1720390904.054307 5528513 service.cc:145] XLA service 0x146b46b50 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1720390904.054324 5528513 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1720390904.055371 5528513 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1720390904.055388 5528513 mps_client.cc:384] XLA backend will use up to 2378711040 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M2

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



In [11]:
%%timeit
x * y

The slowest run took 4.38 times longer than the fastest. This could mean that an intermediate result is being cached.
272 μs ± 156 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
