In [43]:
import functools
from typing import Any, Callable

from flax import linen as nn
import gin
from internal import mip, utils  # pylint: disable=g-multiple-import
import jax
from jax import random
import jax.numpy as jnp
from flax.core import freeze, unfreeze
from typing import Any, Callable, Sequence
from jax import lax



In [44]:
model = nn.Dense(features=5)


rng = random.PRNGKey(20200823)
# key1, key2 = random.split(rng)
x = random.normal(rng, (10,)) # Dummy input data
params = model.init(rng, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes

FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [45]:
try:
    params['new_key'] = jnp.ones((2,2))
except ValueError as e:
    print("Error: ", e)

Error:  FrozenDict is immutable.


In [46]:
model.apply(params, x)

DeviceArray([-0.18553361, -0.1401707 , -0.3689404 , -0.11251645,
             -2.8037481 ], dtype=float32)

In [47]:
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [48]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  # loss function here, personal defined
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

In [49]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023639798
Loss step 0:  31.742035
Loss step 10:  0.5414006
Loss step 20:  0.12205205
Loss step 30:  0.040281765
Loss step 40:  0.019850202
Loss step 50:  0.014116081
Loss step 60:  0.0123820845
Loss step 70:  0.011833681
Loss step 80:  0.011655833
Loss step 90:  0.011597362
Loss step 100:  0.011578011


In [50]:
import optax
tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)


for i in range(101):
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.011577001
Loss step 10:  0.24471128
Loss step 20:  0.072736286
Loss step 30:  0.039242655
Loss step 40:  0.023648137
Loss step 50:  0.01604911
Loss step 60:  0.012986958
Loss step 70:  0.012020969
Loss step 80:  0.01174228
Loss step 90:  0.011646231
Loss step 100:  0.011585994


In [51]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292755e-02 -4.3807048e-02  2.9323749e-02  6.5492429e-03
  -1.7147159e-02]
 [ 1.2967803e-01 -1.4551786e-01  9.4432145e-02  1.2521381e-02
  -4.5417286e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024079e-04  2.7864418e-05  2.4478836e-04  8.1344321e-04
  -1.0110774e-03]]


In [52]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292755e-02 -4.3807048e-02  2.9323749e-02  6.5492429e-03
  -1.7147159e-02]
 [ 1.2967803e-01 -1.4551786e-01  9.4432145e-02  1.2521381e-02
  -4.5417286e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024079e-04  2.7864418e-05  2.4478836e-04  8.1344321e-04
  -1.0110774e-03]]


In [53]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros_init()

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)

AttributeError: module 'flax.linen.initializers' has no attribute 'zeros_init'

In [55]:
import functools

def multiply(x, y):
    return x * y

# Create a new function using partial application
double = functools.partial(multiply, y=2)

# Call the new function
result = double(5)

print(result)  # Output: 10



10
