<a href="https://colab.research.google.com/github/HarounH/smol/blob/main/rl/learning_jax_flax_etc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# @title Part 1 JAX Foundations
import jax.numpy as jnp
import jax

def f(x):
    return jnp.sin(x) + x**2

x = jnp.arange(5.)
y = f(x)
print(y)


[ 0.         1.841471   4.9092975  9.14112   15.243197 ]


In [2]:
import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)
params = jnp.array([0.1, 0.5])
x = jnp.array([[1., 2.], [2., 3.]])
y = jnp.array([1., 2.])

o = loss_fn(params, x, y)
print(o)

0.04999999


In [3]:
grad_fn = jax.grad(loss_fn)
grads = grad_fn(params, x, y)
print(grads)


[-0.49999988 -0.6999998 ]


In [4]:
jitted_loss = jax.jit(loss_fn)
print(jitted_loss)

@jax.jit
def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

print(loss_fn.lower(params, x, y).as_text())

<PjitFunction of <function loss_fn at 0x79d714921940>>
module @jit_loss_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2xf32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.dot_general %arg1, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2xf32>) -> tensor<2xf32>
    %1 = stablehlo.subtract %0, %arg2 : tensor<2xf32>
    %2 = stablehlo.multiply %1, %1 : tensor<2xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = stablehlo.reduce(%2 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
    %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<f32>
    %4 = stablehlo.divide %3, %cst_0 : tensor<f32>
    return %4 : tensor<f32>
  }
}



In [5]:
import jax
from functools import partial

@jax.jit
@partial(jax.vmap, in_axes=(0,1))
def dot(x, y):
    return jnp.dot(x, y)


x = jnp.array([[1., 2., 3.0], [2., 3., 4.]])
y = x.T
print(x.shape)
print(y.shape)
print(dot(x, y))

(2, 3)
(3, 2)
[14. 29.]


In [6]:
# @title pmap requres multiple devices :/
jax.devices("cpu")

[CpuDevice(id=0)]

In [14]:
# @title Part 2 FLAX: 1/3: Linen
import jax
import jax.numpy as jnp
from flax import linen as nn

class MLP(nn.Module):
    features: tuple[int, ...] = (32, 10)

    def setup(self):
        self.dense1 = nn.Dense(self.features[0], use_bias=False)
        self.dense2 = nn.Dense(self.features[1], use_bias=False)

    # alternative use nn.compact here
    def __call__(self, x):
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        return x

model = MLP()
x = jnp.ones((1, 5))
y_true = jnp.ones((1, 10))
params = model.init(jax.random.PRNGKey(0), x)
y = model.apply(params, x)
print(y.shape)


(1, 10)


In [16]:
import optax

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

def loss_fn(params, x, y):
    pred = model.apply(params, x)
    return jnp.mean((pred - y) ** 2)

def update(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

new_params, new_opt_state, loss = update(params, opt_state, x, y_true)

In [22]:
print([type(x) for x in new_opt_state])
print(list(new_opt_state[0].mu["params"]["dense1"]["kernel"].shape))

[<class 'optax._src.transform.ScaleByAdamState'>, <class 'optax._src.base.EmptyState'>]
[5, 32]


In [8]:
for k, v in params["params"].items():
    print(f"{k} -> {v['kernel'].shape}")

dense1 -> (5, 32)
dense2 -> (32, 10)


In [13]:
opt_state[1]

optax._src.base.EmptyState

In [8]:
# @title Part 2 FLAX: 2/3: NNX