In [None]:
!which python

In [4]:
import jax 
import jax.numpy as jnp
import flax.linen as nn
import einops
# force jax to use only device 0 with os library "CUDA_VISIBLE_DEVICES" 
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [7]:
!pip3 freeze | grep jax 
!pip3 freeze | grep optax
!pip3 freeze | grep flax

  pid, fd = os.forkpty()


jax==0.4.25
jaxlib==0.4.25+cuda11.cudnn86
optax==0.2.2
flax==0.8.4


In [8]:
shared_params = False
if shared_params:
    BatchDense = nn.vmap(nn.Dense, in_axes=0, out_axes=0, variable_axes={'params': None}, split_rngs={'params': False})
else:
    BatchDense = nn.vmap(nn.Dense, in_axes=0, out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True})

In [9]:
batch_size = 10
embedding_dim = 5
x = jnp.ones((batch_size, embedding_dim))
dense = BatchDense(features=3)
params = dense.init(jax.random.PRNGKey(0), x)

In [10]:
# display GPU used by jax 
jax.devices()

[cuda(id=0)]

In [11]:
# write a function that prints recursively the keys and shapes of a nested dictionary
def map_nested_fn(fn):
    """Recursively apply `fn to the key-value pairs of a nested dict / pytree."""

    def map_fn(k_in, nested_dict):
        if k_in != None: print(f'{k_in}', end=' ')
        return {
            k: (map_fn(k, v) if hasattr(v, "keys") else fn(k, v))
            for k, v in nested_dict.items()
        }
    
    return map_fn

map_nested_fn(lambda k, v: print(f'\n {k, v.shape}'))(None, params)

params 
 ('bias', (10, 3))

 ('kernel', (10, 5, 3))


{'params': {'bias': None, 'kernel': None}}

In [12]:
y = dense.apply(params, x)

In [13]:
y

Array([[ 1.9910171 ,  0.2838356 , -1.0479922 ],
       [ 1.3863921 ,  0.5475616 ,  1.2088562 ],
       [-1.3427643 , -0.5816939 , -0.02100414],
       [-1.6752064 , -0.6873529 , -1.4142421 ],
       [ 0.7120553 ,  0.13763508, -0.4718809 ],
       [ 0.5991874 , -0.1374764 ,  0.4084237 ],
       [-0.18032219, -0.48730594, -1.0494049 ],
       [-0.04323459, -0.40114638,  0.6318799 ],
       [ 0.50989735,  0.20760432,  0.13128321],
       [ 0.5380559 , -1.664613  ,  1.1062526 ]], dtype=float32)

In [14]:
def step(x_k_1, u_k):
    x_k = x_k_1 + u_k
    y_k = jnp.tanh(x_k)
    return x_k, y_k
x_last, y_entire = jax.lax.scan(step, 0, jnp.ones((5,)))

In [15]:
x_last

Array(5., dtype=float32)

In [16]:
y_entire

Array([0.7615942 , 0.9640275 , 0.9950547 , 0.9993292 , 0.99990916],      dtype=float32)

In [17]:
print([float(jnp.tanh(i)) for i in range(5)])

[0.0, 0.7615941762924194, 0.9640275239944458, 0.9950547218322754, 0.9993292093276978]
