In [1]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [2]:
import jax
import jax.numpy as jnp
from jax import random
from jax import device_put
from jax.experimental import sparse

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [41]:
import os
from tensorflow.python.profiler import profiler_client

tpu_profile_service_address = os.environ['COLAB_TPU_ADDR'].replace('8470', '8466')
print(profiler_client.monitor(tpu_profile_service_address, 100, 2))

  Timestamp: 17:30:40
  TPU type: TPU v2
  Utilization of TPU Matrix Units (higher is better): 0.000%




In [38]:
size = 8

In [40]:
x = random.uniform(random.PRNGKey(0), (100, 100))
w = random.uniform(random.PRNGKey(0), (100, 100))

def mul(x, w):
  return jnp.matmul(x, w)

%timeit -o y = mul(x,w).block_until_ready()

The slowest run took 21.41 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 1.78 ms per loop


<TimeitResult : 1000 loops, best of 5: 1.78 ms per loop>

In [37]:
x = random.uniform(random.PRNGKey(0), (size, 100, 100))
w = random.uniform(random.PRNGKey(0), (size, 100, 100))

def mul(x, w):
  return jnp.matmul(x, w)

%timeit -o y = jax.pmap(mul)(x,w).block_until_ready()

The slowest run took 12.57 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 11.6 ms per loop


<TimeitResult : 100 loops, best of 5: 11.6 ms per loop>

In [5]:
@jax.jit
def f(w,x):
    return w @ x

data = random.uniform(random.PRNGKey(0), (size*3,))
idxs = random.randint(random.PRNGKey(0), (size*3,2), 0, size)
w = sparse.BCOO((data, idxs), shape=(size, size))
%timeit -o f(w,x).block_until_ready()

The slowest run took 2538.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 75.3 µs per loop


<TimeitResult : 10000 loops, best of 5: 75.3 µs per loop>

In [36]:
x = random.uniform(random.PRNGKey(0), (size, 800))

data = random.uniform(random.PRNGKey(0), (size*3,))
idxs = random.randint(random.PRNGKey(0), (size*3,2), 0, size)
w = sparse.BCOO((data, idxs), shape=(size, size))

my_sp_mat_list_jax = [sparse.BCOO((data, idxs), shape=(size, size)) for i in range(800)]

from jax import lax
from jax import vmap

@jax.jit
def matvec(mat, vec): 
    return mat @ vec 

def concat(arr_list, axis=0):

  data = lax.concatenate([a.data[None] for a in arr_list], dimension=0)
  idxs = lax.concatenate([a.idxs[None] for a in arr_list], dimension=0)

  return sparse.BCOO((data, idxs), shape=(len(arr_list), *arr_list[0].shape))

stacked = concat(my_sp_mat_list_jax)

@jax.jit
def vmap_matvec(stacked, x):
    return vmap(matvec, in_axes=[0, 1], out_axes=1)(stacked, x)

%timeit y = vmap_matvec(stacked, x).block_until_ready()

The slowest run took 41.15 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 8.04 ms per loop


In [None]:
@jax.jit
def f(w,x):
    return w @ x

data = random.uniform(random.PRNGKey(0), (size*3,))
idxs = random.randint(random.PRNGKey(0), (size*3,2), 0, size)
w = sparse.BCOO((data, idxs), shape=(size, size))
f(w,x).block_until_ready()
%timeit f(w,x).block_until_ready()

10 loops, best of 5: 78.7 ms per loop


In [None]:
@jax.jit
def f(w,x):
    return w @ x

data = random.uniform(random.PRNGKey(0), (size*3,))
idxs = random.randint(random.PRNGKey(0), (size*3,2), 0, size)
w = sparse.BCOO((data, idxs), shape=(size, size))
f(w,x).block_until_ready()
%time f(w,x).block_until_ready()

CPU times: user 818 µs, sys: 0 ns, total: 818 µs
Wall time: 81 ms


DeviceArray([[0.7951084],
             [1.2279414],
             [0.       ],
             ...,
             [1.3993609],
             [1.9278833],
             [1.6789727]], dtype=float32)

In [None]:
x = random.uniform(random.PRNGKey(0), (size, 1))

data = random.uniform(random.PRNGKey(0), (size*3,))
idxs = random.randint(random.PRNGKey(0), (size*3,2), 0, size)
w = sparse.BCOO((data, idxs), shape=(size, size))

my_sp_mat_list_jax = [sparse.BCOO((data, idxs), shape=(size, size)) for i in range(2)]

from jax import lax
from jax import vmap

@jax.jit
def matvec(mat, vec): 
    return mat @ vec 

def concat(arr_list, axis=0):

  data = lax.concatenate([a.data[None] for a in arr_list], dimension=0)
  idxs = lax.concatenate([a.idxs[None] for a in arr_list], dimension=0)

  return sparse.BCOO((data, idxs), shape=(len(arr_list), *arr_list[0].shape))

stacked = concat(my_sp_mat_list_jax)

@jax.jit
def vmap_matvec(stacked, x):
    return vmap(matvec, in_axes=(0, None))(stacked, x)

vmap_matvec(stacked, x).block_until_ready()
%timeit -o vmap_matvec(stacked, x).block_until_ready()

KeyboardInterrupt: ignored

In [None]:
import torch
import jax
import jax.numpy as jnp
import jax.dlpack

x = jnp.zeros((1000, 1000))
dlpack = jax.dlpack.to_dlpack(x)
y = torch.utils.dlpack.from_dlpack(dlpack)

x = torch.zeros((1000, 1000)).cuda()
dlpack = torch.utils.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)