In [2]:
import time
import torch
import numpy as np
import jax
import jax.numpy as jnp
import tensorflow as tf

2024-08-09 18:22:22.701763: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-09 18:22:22.775704: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-09 18:22:22.796005: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Check device availability
if tf.config.list_physical_devices('GPU'):
    print("TensorFlow GPU")
else:
    print("TensorFlow CPU")

if torch.cuda.is_available():
    print("PyTorch GPU")
else:
    print("PyTorch CPU")

if jax.devices()[0].device_kind == "gpu":
    print("JAX GPU")
else:
    print("JAX CPU")

I0000 00:00:1723207945.456681   21733 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-09 18:22:25.666668: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


TensorFlow CPU
PyTorch CPU
JAX CPU


In [4]:
def measure_time(func, *args, avg_number=100):
    start_time = time.time()
    for _ in range(avg_number):
        func(*args)
    end_time = time.time()
    return ((end_time - start_time) / avg_number) * 1000

In [5]:
def comparsion_operations(size=1000):
    np_a = np.random.rand(size, size).astype(np.float32)
    np_b = np.random.rand(size, size).astype(np.float32)
    pt_a = torch.from_numpy(np_a).to(torch.float32)
    pt_b = torch.from_numpy(np_b).to(torch.float32)
    jax_a = jnp.array(np_a, dtype=jnp.float32)
    jax_b = jnp.array(np_b, dtype=jnp.float32)
    tf_a = tf.convert_to_tensor(np_a)
    tf_b = tf.convert_to_tensor(np_b)

    operations = {
        "Matrix multiplication": [
            (lambda: np.dot(np_a, np_b)),
            (lambda: torch.matmul(pt_a, pt_b)),
            (lambda: jax.numpy.dot(jax_a, jax_b)),
            (lambda: tf.matmul(tf_a, tf_b))
        ],
        "Element wise addition": [
            (lambda: np_a + np_b),
            (lambda: pt_a + pt_b),
            (lambda: jax_a + jax_b),
            (lambda: tf_a + tf_b)
        ],
        "Element wise multiplication": [
            (lambda: np_a * np_b),
            (lambda: pt_a * pt_b),
            (lambda: jax_a * jax_b),
            (lambda: tf_a * tf_b)
        ]
    }

    results = []

    for operation_name, funcs in operations.items():
        np_time = measure_time(funcs[0])
        pt_time = measure_time(funcs[1])
        jax_time = measure_time(funcs[2])
        tf_time = measure_time(funcs[3])

        results.append(f"{operation_name}: NUMPY: {np_time:.3f}ms, PYTORCH: {pt_time:.3f}ms, JAX: {jax_time:.3f}ms, TENSORFLOW: {tf_time:.3f}ms")
    
    return results

In [6]:
comparison = comparsion_operations()
for result in comparison:
    print(result)

Matrix multiplication: NUMPY: 19.012ms, PYTORCH: 13.998ms, JAX: 16.316ms, TENSORFLOW: 17.444ms
Element wise addition: NUMPY: 1.980ms, PYTORCH: 1.739ms, JAX: 2.656ms, TENSORFLOW: 2.040ms
Element wise multiplication: NUMPY: 2.128ms, PYTORCH: 1.522ms, JAX: 2.015ms, TENSORFLOW: 2.091ms


In [11]:
np_a = np.random.rand(10, 10).astype(np.float32)
np_b = np.random.rand(10, 10).astype(np.float32)
pt_a = torch.from_numpy(np_a).to(torch.float32)
pt_b = torch.from_numpy(np_b).to(torch.float32)
jax_a = jnp.array(np_a, dtype=jnp.float32)
jax_b = jnp.array(np_b, dtype=jnp.float32)
tf_a = tf.convert_to_tensor(np_a)
tf_b = tf.convert_to_tensor(np_b)


print(f"DataType of matrix :\n NumPy: {type(np_a)}\n PyTorch: {type(pt_a)}\n JAX: {type(jax_a)}\n TensorFlow: {type(tf_a)}")


DataType of matrix :
 NumPy: <class 'numpy.ndarray'>
 PyTorch: <class 'torch.Tensor'>
 JAX: <class 'jaxlib.xla_extension.ArrayImpl'>
 TensorFlow: <class 'tensorflow.python.framework.ops.EagerTensor'>
