# Pytorch

In [1]:
import torch


def create_torch_tensors(device):
    x = torch.rand((10000, 10000), dtype=torch.float32)
    y = torch.rand((10000, 10000), dtype=torch.float32)
    x = x.to(device)
    y = y.to(device)

    return x, y

In [2]:
device = torch.device("cpu")
x, y = create_torch_tensors(device)

In [3]:
%%timeit
x * y

27.5 ms ± 84.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
device = torch.device("mps")
x, y = create_torch_tensors(device)

In [5]:
%%timeit
x * y

20.3 ms ± 125 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Tensorflow

In [6]:
import tensorflow as tf


def create_tf_tensors():
    x = tf.random.uniform((10000, 10000), dtype=tf.float32)
    y = tf.random.uniform((10000, 10000), dtype=tf.float32)

    return x, y


x, y = create_tf_tensors()

2024-10-07 13:03:29.336045: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2024-10-07 13:03:29.336073: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 8.00 GB
2024-10-07 13:03:29.336076: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 2.67 GB
2024-10-07 13:03:29.336389: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-10-07 13:03:29.336399: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [7]:
%%timeit

with tf.device("/CPU:0"):
    x * y

35.2 ms ± 4.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit

with tf.device("/GPU:0"):
    x * y

20.5 ms ± 195 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
for device in tf.config.list_physical_devices():
    print(device)

PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')


# JAX

In [10]:
# import os

# os.environ["JAX_PLATFORMS"] = "cpu"

In [11]:
# import jax
# import jax.numpy as jnp


# def create_jax_tensors():
#     x = jax.random.uniform(jax.random.PRNGKey(0), (10000, 10000), dtype=jnp.float32)
#     y = jax.random.uniform(jax.random.PRNGKey(1), (10000, 10000), dtype=jnp.float32)

#     return x, y


# x, y = create_jax_tensors()

In [12]:
# %%timeit
# x * y