In [1]:
import cupy as cp
import numpy as np

import time

In [2]:
tensor_mul = cp.ElementwiseKernel(in_params = "float32 x, float32 y", out_params = "float32 z", operation = "z = x * y", name = "tensor_mul")

tensor_mul(cp.array([1, 2, 3], dtype=cp.float32), cp.array([4, 5, 6], dtype=cp.float32))

array([ 4., 10., 18.], dtype=float32)

In [3]:
batched_a = cp.random.rand(1000, 100, 100)
batched_b = cp.random.rand(1000, 100, 100)

start_time = time.time()
# Perform batched matrix multiplication
batched_result = cp.matmul(batched_a, batched_b)
end_time = time.time()

print(f"Batched Result Shape: {batched_result.shape}")
print(f"CuPy Batched Time: {end_time - start_time} seconds")


start_time = time.time()
divided_result = cp.array([cp.dot(batched_a[i, ::], batched_b[i, ::]) for i in range(1000)])
end_time = time.time()
print(f"Divided Result Shape: {divided_result.shape}")
print(f"CuPy Divided Time: {end_time - start_time} seconds")

print("Batched result is equal to divided result:", cp.allclose(batched_result, divided_result))


Batched Result Shape: (1000, 100, 100)
CuPy Batched Time: 0.05359697341918945 seconds
Divided Result Shape: (1000, 100, 100)
CuPy Divided Time: 0.0977480411529541 seconds
Batched result is equal to divided result: True


In [4]:
batched_a = cp.random.rand(10, 100)

def corr_batch(A:cp.ndarray) -> cp.ndarray:

    numerator = cp.dot(A, A.T)
    denominator = cp.sqrt(
        cp.dot(
            cp.dot(
                (cp.dot(A, A.T) * cp.eye(A.shape[0])), cp.ones((A.shape[0], A.shape[0]))
            ),
            (cp.dot(A, A.T) * cp.eye(A.shape[0]))
        )
    )

    return numerator / denominator

corr_batch(batched_a).shape


(10, 10)

In [5]:
batched_a = np.random.rand(3, 100)

def corr_batch_np(A:np.ndarray) -> np.ndarray:

    numerator = np.dot(A, A.T)
    denominator = np.sqrt(
        np.dot(
            np.dot(
                (np.dot(A, A.T) * np.eye(A.shape[0])), np.ones((A.shape[0], A.shape[0]))
            ),
            (np.dot(A, A.T) * np.eye(A.shape[0]))
        )
    )

    return numerator / denominator

corr_batch_np(batched_a).shape

(3, 3)

In [6]:
def legacy_corr(A:np.ndarray) -> np.ndarray:

    rho = np.array([[np.corrcoef(A[i, ::], A[j, ::])[0, 1] for j in range(A.shape[0])] for i in range(A.shape[0])])

    return rho

In [7]:
A = np.random.rand(10, 100)
np.testing.assert_allclose(legacy_corr(A), corr_batch_np(A - np.mean(A, axis=1, keepdims=True)))

In [8]:
A = np.random.rand(1000, 100)

start_time = time.time()
_ = legacy_corr(A)
end_time = time.time()
print(f"Legacy implementation: {end_time - start_time} seconds")

start_time = time.time()
_ = corr_batch_np(A - np.mean(A, axis=1, keepdims=True))
end_time = time.time()
print(f"numpy implementation: {end_time - start_time} seconds")

A = cp.array(A)

start_time = time.time()
_ = corr_batch(A - cp.mean(A, axis=1, keepdims=True))
end_time = time.time()
print(f"CuPy implementation: {end_time - start_time} seconds")

Legacy implementation: 49.12901020050049 seconds
numpy implementation: 0.06891942024230957 seconds
CuPy implementation: 0.03535175323486328 seconds


In [None]:
import jax
import jax.numpy as jnp
from functools import partial

@jax.jit
def foo(x: jnp.ndarray) -> jnp.ndarray:
    bar = jnp.sum(x) - 1.0
    return jnp.array([bar])

len(foo(jnp.array([1, 2, 3])))  # Should return 3

TypeError: iteration over a 0-d array