<a href="https://colab.research.google.com/github/artiomka/flax/blob/main/Learning_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install -q --upgrade jax jaxlib
!pip install -q flax

[?25l[K     |▌                               | 10 kB 25.8 MB/s eta 0:00:01[K     |█                               | 20 kB 13.5 MB/s eta 0:00:01[K     |█▍                              | 30 kB 10.1 MB/s eta 0:00:01[K     |█▉                              | 40 kB 9.2 MB/s eta 0:00:01[K     |██▎                             | 51 kB 5.3 MB/s eta 0:00:01[K     |██▊                             | 61 kB 5.8 MB/s eta 0:00:01[K     |███▎                            | 71 kB 5.8 MB/s eta 0:00:01[K     |███▊                            | 81 kB 6.5 MB/s eta 0:00:01[K     |████▏                           | 92 kB 6.3 MB/s eta 0:00:01[K     |████▋                           | 102 kB 5.4 MB/s eta 0:00:01[K     |█████                           | 112 kB 5.4 MB/s eta 0:00:01[K     |█████▌                          | 122 kB 5.4 MB/s eta 0:00:01[K     |██████                          | 133 kB 5.4 MB/s eta 0:00:01[K     |██████▌                         | 143 kB 5.4 MB/s eta 0:00:01[K  

In [1]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [4]:
import flax.linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence, List, Tuple, Dict
from jax import random

In [3]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [5]:
class MLP(nn.Module):
  features: Sequence[int]
  @nn.compact
  def __call__(self, x):
    for f in self.features[:-1]:
      x = nn.relu(nn.Dense(f)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

In [6]:
model = MLP([12, 8, 4])

In [7]:
x = jnp.ones((32, 10))

In [8]:
variables = model.init(jax.random.PRNGKey(0), x)

In [9]:
list(variables.keys())

['params']

In [10]:
output = model.apply(variables, x)

In [11]:
output.shape

(32, 4)

In [12]:
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(32, (3,3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
    x = nn.Conv(64, (3,3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(10)(x)
    x = nn.log_softmax(x)
    return x




In [13]:
model2 = CNN()

In [14]:
x = jnp.ones((32, 128, 128, 3))

In [15]:
variables = model2.init(jax.random.PRNGKey(0), x)

In [16]:
output = model2.apply(variables, x)

In [17]:
output.shape

(32, 10)

In [18]:
from jax import grad

In [19]:
def f(x):
  if x > 0:
    return 2*x ** 3
  else:
    return 3 *x 

In [20]:
key = random.PRNGKey(0)
x = random.normal(key, ())

In [21]:
print(grad(f)(x))
print(grad(f)(-x))

3.0
0.25422648


# JIT

In [22]:
from jax import jit

In [23]:
key = random.PRNGKey(0)
x = random.normal(key, (500, 500))


# VMAP

In [24]:
from jax import vmap

In [25]:
print(vmap(lambda x: x**2)(jnp.arange(8)))

[ 0  1  4  9 16 25 36 49]


In [26]:
from jax import make_jaxpr

In [27]:
make_jaxpr(jnp.dot)(jnp.ones(8), jnp.ones(8))

{ lambda  ; a b.
  let c = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] a b
  in (c,) }

In [28]:
make_jaxpr(vmap(jnp.dot))(jnp.ones((10, 8)), jnp.ones((10, 8)))

{ lambda  ; a b.
  let c = dot_general[ dimension_numbers=(((1,), (1,)), ((0,), (0,)))
                       precision=None
                       preferred_element_type=None ] a b
  in (c,) }

# PMAP

In [29]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [30]:
from jax import pmap

In [31]:
y = pmap(lambda x: x**2)(jnp.arange(8))

In [32]:
keys = random.split(random.PRNGKey(0), 8)

In [33]:
mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)

In [34]:
mats.shape

(8, 5000, 5000)

In [35]:
result = pmap(jnp.dot)(mats, mats)

In [36]:
result.shape

(8, 5000, 5000)

## Communication operations

In [37]:
from functools import partial
from jax.lax import psum

In [38]:
@partial(pmap, axis_name='i')
def f(x):
  total = psum(x, 'i')
  return x/total, total

In [39]:
normalized, total = f((jnp.arange(8.)))

In [40]:
print('normalized', normalized)

normalized [0.         0.03571429 0.07142857 0.10714287 0.14285715 0.17857143
 0.21428573 0.25      ]


In [41]:
print(total)

[28. 28. 28. 28. 28. 28. 28. 28.]


## PJIT

In [44]:
from jax.experimental.pjit import pjit, PartitionSpec as P
from jax.experimental.maps import mesh
from jax import lax

In [45]:
conv = lambda image, kernel: lax.conv(image, kernel, (1,1), 'SAME')

In [48]:
import numpy as np
image = jnp.ones((1, 16, 2000, 1000)).astype(np.float32)
kernel = jnp.array(np.random.random((8, 16, 5, 5)).astype(np.float32))
np.set_printoptions(edgeitems=1)

In [49]:
out = conv(image, kernel)
out.shape

(1, 8, 2000, 1000)

In [50]:
mesh_shape = (4,2)
mesh_devices = np.array(jax.devices()).reshape(mesh_shape)

In [51]:
mesh_axis_names = ('x', 'y')

In [52]:
image_partitions = P(None, None, 'x', 'y')
parallel_conv = pjit(conv, in_axis_resources=(image_partitions, None), 
                     out_axis_resources=image_partitions)

  warn("pjit is an experimental feature and probably has bugs!")


In [56]:
with mesh(mesh_devices, mesh_axis_names):
  r = parallel_conv(image, kernel)
r.shape

RuntimeError: ignored