In [67]:
from flax import linen as nn
from flax.linen import initializers
from jax.typing import ArrayLike
import jax
import matplotlib.pyplot as plt

class NN(nn.Module):
    @nn.compact
    def __call__(self, x: ArrayLike) -> ArrayLike:
        x = nn.Dense(features=2)(x)
        x = nn.relu(x)
        x = nn.Dense(features=3)(x)

        return x


class NNArgMax(nn.Module):
    @nn.compact
    def __call__(self, x: ArrayLike) -> ArrayLike:
        x = nn.Dense(features=2)(x)
        x = nn.relu(x)
        x = nn.Dense(features=3)(x)

        argmax_idx = jnp.argmax(x, axis=-1)

        return jnp.eye(x.shape[-1])[argmax_idx]


class NNSoftMax(nn.Module):
    @nn.compact
    def __call__(self, x: ArrayLike) -> ArrayLike:
        x = nn.Dense(features=2)(x)
        x = nn.relu(x)
        x = nn.Dense(features=3)(x)

        return nn.softmax(x)

In [68]:
from jax import random
from jax import numpy as jnp

random_state = random.PRNGKey(44)
model = NN()
model_argmax = NNArgMax()
model_softmax = NNSoftMax()

variables = model.init(random_state, jnp.ones((2,)))
variables

{'params': {'Dense_0': {'kernel': Array([[0.353222  , 0.6076666 ],
          [0.35691372, 0.29986838]], dtype=float32),
   'bias': Array([0., 0.], dtype=float32)},
  'Dense_1': {'kernel': Array([[-0.18841058, -0.87364864,  0.29732683],
          [ 0.13652378, -0.17848821,  0.55810565]], dtype=float32),
   'bias': Array([0., 0., 0.], dtype=float32)}}}

In [69]:
final_res =  {
    'Dense_0': {
        'kernel': jnp.array([
            [-2.5, -1.5 ],
            [0.6, 0.4]
        ]),
        'bias': jnp.array([1.6, 0.7])
    },
    'Dense_1': {
        'kernel': jnp.array([
            [-0.1, 2.4,  -2.2],
            [ 1.5, -5.2,  3.7]
        ]),
        'bias': jnp.array([0.0, 0.0, 1.0]),
    }
}

In [70]:
print(model.apply({ 'params': final_res }, jnp.array([0.5, 0.37])))
print(model_argmax.apply({ 'params': final_res }, jnp.array([0.5, 0.37])))
print(model_softmax.apply({ 'params': final_res }, jnp.array([0.5, 0.37])))

[0.08979999 0.8632002  0.10419989]
[0. 1. 0.]
[0.2391414  0.5182487  0.24260993]
