In [10]:
from typing import Dict, Tuple, List, Any
import jax
import jax.numpy as jnp
import flax.linen as nn
import flax
# import torch
# import torch.nn as nn
import math

In [16]:
class Model(nn.Module):
    def setup(self):
        self.linear = nn.Dense(64)
        self.layernorm = nn.LayerNorm()
        
    @nn.compact
    def __call__(self, x):
        x = self.linear(x)
        x = self.layernorm(x)
        return x

In [19]:
model = Model()

N = 8
D = 128

rnd_key = jax.random.PRNGKey(42)
x1 = jax.random.normal(rnd_key, shape=(N,D))

params = model.init(rnd_key, x1)
y = model.apply(params, x1)

In [20]:
print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output shape:\n', y.shape)

initialized parameter shapes:
 {'params': {'layernorm': {'bias': (64,), 'scale': (64,)}, 'linear': {'bias': (64,), 'kernel': (128, 64)}}}
output shape:
 (8, 64)
