# Torch

In [1]:
import torch


def create_torch_tensors(device):
    x = torch.rand((10000, 10000), dtype=torch.float32)
    y = torch.rand((10000, 10000), dtype=torch.float32)
    x = x.to(device)
    y = y.to(device)

    return x, y
device = torch.device("cpu")
x, y = create_torch_tensors(device)


In [2]:
%%timeit
x * y

33.9 ms ± 7.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [3]:
device = torch.device("mps")
x, y = create_torch_tensors(device)

In [4]:
%%timeit
x * y

1.66 ms ± 3.87 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Tensorflow

In [6]:
import tensorflow as tf
def create_tf_tensors():
    x = tf.random.uniform((10000, 10000), dtype=tf.float32)
    y = tf.random.uniform((10000, 10000), dtype=tf.float32)
    # x = x.device(device)
    # y = y.device(device)

    return x, y
x, y = create_tf_tensors()

In [7]:
%%timeit
with tf.device("/cpu:0"):
    x * y

27.8 ms ± 360 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%%timeit
with tf.device("/gpu:0"):
    x * y

1.64 ms ± 3.38 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# JAX

In [14]:
import os
# os.environ["JAX_PLATFORMS"] = "gpu"

In [1]:
import jax
import jax.numpy as jnp
def create_jax_tensors():
    x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype = jnp.float32)
    y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype = jnp.float32)

    return x, y
x, y = create_jax_tensors()

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1716594468.706163  562218 service.cc:145] XLA service 0x600003534300 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1716594468.706185  562218 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1716594468.707329  562218 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1716594468.707339  562218 mps_client.cc:384] XLA backend will use up to 103078821888 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M2 Ultra


In [2]:
%%timeit
x * y

1.64 ms ± 928 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
10000 * 10000

100000000