In [7]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn
import optax

In [24]:
# Load iris dataset from sklearn
from sklearn.datasets import load_iris
iris = load_iris()

# Extract features and labels
X = iris.data
y = iris.target

# Convert X and y into jnp arrays
X = jnp.array(X)
y = jnp.array(y)

# Convert y into one hot
y = jnp.eye(3)[y]

print(X.shape, y.shape)

(150, 4) (150, 3)


In [28]:
class MLP(nn.Module):
    hidden1_size: int
    hidden2_size: int
    hidden3_size: int
    output_size: int
    
    dropout_rate: float

    @nn.compact
    def __call__(self, x, train=False):
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = nn.Dense(features=self.hidden1_size)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = nn.Dense(features=self.hidden2_size)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = nn.Dense(features=self.hidden3_size)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_size)(x)
        x = nn.log_softmax(x)
        return x

In [33]:
key = jax.random.PRNGKey(0)
input_shape = (768,)

model = MLP(hidden1_size=64, hidden2_size=64, hidden3_size=16, output_size=3, dropout_rate=0.2)
params = model.init(key, jnp.ones(input_shape, jnp.float32))

print(model)

MLP(
    # attributes
    hidden1_size = 64
    hidden2_size = 64
    hidden3_size = 16
    output_size = 3
    dropout_rate = 0.2
)


In [34]:
output = model.apply(params, X, train=True)
print(output.shape)

InvalidRngError: Dropout_0 needs PRNG for "dropout" (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError)