In [10]:
from typing import Dict, Tuple, List, Any
import jax
import jax.numpy as jnp
import flax.linen as nn
import flax
# import torch
# import torch.nn as nn
import math

In [16]:
class Model(nn.Module):
    def setup(self):
        self.linear = nn.Dense(64)
        self.layernorm = nn.LayerNorm()
        
    @nn.compact
    def __call__(self, x):
        x = self.linear(x)
        x = self.layernorm(x)
        return x

In [20]:
model = Model()

N = 8
D = 128

rnd_key = jax.random.PRNGKey(42)
x1 = jax.random.normal(rnd_key, shape=(N,D))

params = model.init(rnd_key, x1)
y = model.apply(params, x1)
print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output shape:\n', y.shape)

initialized parameter shapes:
 {'params': {'layernorm': {'bias': (64,), 'scale': (64,)}, 'linear': {'bias': (64,), 'kernel': (128, 64)}}}
output shape:
 (8, 64)


In [39]:
x = jax.random.normal(rnd_key, shape=(5, 17))
a,b = jnp.array_split(x, 2, axis=-1)

In [40]:
a.shape

(5, 9)

In [41]:
b.shape

(5, 8)

In [44]:
c = jnp.einsum("ng,nh->gh", a, b)

In [45]:
c.shape

(9, 8)

In [51]:
# jax.lax.top_k(c)
jax.lax.top_k(c, 3)

[DeviceArray([[2.8488002 , 2.7218099 , 2.0955062 ],
              [4.0057507 , 1.7198598 , 0.49998915],
              [2.44787   , 1.9903433 , 0.9540967 ],
              [2.2788448 , 1.2575146 , 0.20737867],
              [2.8847215 , 1.7629296 , 1.4312001 ],
              [2.188009  , 2.0436435 , 1.9311047 ],
              [4.0429177 , 1.6797887 , 0.5984382 ],
              [2.5068424 , 2.1373606 , 2.1111934 ],
              [2.5799189 , 1.8154304 , 1.7440134 ]], dtype=float32),
 DeviceArray([[5, 1, 4],
              [2, 0, 4],
              [4, 6, 5],
              [0, 2, 4],
              [7, 2, 0],
              [2, 0, 3],
              [7, 5, 2],
              [1, 7, 2],
              [1, 3, 5]], dtype=int32)]

In [53]:
out = nn.softmax(c, axis=1)
out.shape

(9, 8)

In [54]:
out.sum(1)

DeviceArray([1.0000001 , 0.99999994, 1.        , 0.99999994, 1.        ,
             1.        , 1.        , 1.0000002 , 1.        ],            dtype=float32)

In [59]:
variance_init = jax.nn.initializers.variance_scaling(scale=1, mode="fan_in", distribution="normal")
variance_init(rnd_key, (9,8))

DeviceArray([[ 0.01799773,  0.3240893 ,  0.16049889, -0.39378947,
               0.03847735, -0.05789944,  0.14026228, -0.01109238],
             [-0.31371105, -0.0454727 , -0.28461683,  0.09464972,
              -0.2718889 , -0.34081125, -0.06649639,  0.07113177],
             [-0.21927719, -0.0969271 , -0.31637388, -0.11878338,
               0.26330873, -0.03214517, -0.00101734,  0.05531672],
             [ 0.09329666, -0.315701  ,  0.23672032, -0.39212704,
               0.6521596 ,  0.00694316, -0.10352549, -0.07599135],
             [ 0.3715672 ,  0.09549747,  0.52732676, -0.22619566,
              -0.32916254,  0.44663152,  0.26371616, -0.08919866],
             [ 0.26806706, -0.08028515,  0.1099793 ,  0.15923974,
               0.24198954, -0.14503574,  0.22950482, -0.01157813],
             [ 0.39442417,  0.6945726 , -0.07789575,  0.12051268,
               0.09601855,  0.739609  , -0.20942426, -0.20421992],
             [ 0.22844616,  0.08924569,  0.24862719,  0.62727344,
   

In [63]:
layer = nn.Embed(4, 8, embedding_init=variance_init)

In [64]:
layer

Embed(
    # attributes
    num_embeddings = 4
    features = 8
    dtype = float32
    param_dtype = float32
    embedding_init = init
    embedding = None
)

In [62]:

variance_init(layer)

TypeError: init() missing 1 required positional argument: 'shape'