In [None]:
from flax import nnx
import jax.numpy as jnp
import jax

In [None]:
class ConvSubsampling(nnx.Module):
    def __init__(self, output_dim: int, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(in_features=1, out_features=output_dim, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)
        self.conv2 = nnx.Conv(in_features=output_dim, out_features=output_dim, kernel_size=(3, 3), strides=(2, 2), rngs=rngs)
        # D * F/4 (F = mel_bins 80, hence 20)
        self.linear = nnx.Linear(in_features=20 * output_dim, out_features=output_dim, rngs=rngs)
        self.dropout = nnx.Dropout(0.1, rngs=rngs)

    def __call__(self, x: jax.Array, *, is_train: bool):
        x = nnx.relu(self.conv1(x)) # (B, T/2, F/2, D)
        x = nnx.relu(self.conv2(x)) # (B, T/4, F/4, D)

        B, T = x.shape[0], x.shape[1]
        x = x.reshape(B, T, -1) # (B, T/4, F/4 * D)

        x = self.linear(x)
        x = self.dropout(x, deterministic=not is_train)

        return x


In [None]:
class FeedForwardModule(nnx.Module):
    def __init__(self, embed_dim, expand_factor, *, rngs):
        intermediate_dim = embed_dim * expand_factor
        self.norm = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.linear1 = nnx.Linear(embed_dim, intermediate_dim, rngs=rngs)
        self.linear2 = nnx.Linear(intermediate_dim, embed_dim, rngs=rngs)
        self.dropout1 = nnx.Dropout(0.1, rngs=rngs)
        self.dropout2 = nnx.Dropout(0.1, rngs=rngs)

    def __call__(self, x: jax.Array, *, train: bool):
        x = self.norm(x)
        x = self.linear1(x)
        x = nnx.swish(x)
        x = self.dropout1(x, deterministic=not train)
        x = self.linear2(x)
        x = self.dropout2(x, deterministic=not train)

        return x

In [None]:
class MultiHeadSelfAttention(nnx.Module):
    def __init__(self, embed_dim, num_heads, rngs: nnx.Rngs):
        self.norm = nnx.LayerNorm(embed_dim, rngs=rngs)
        self.attention = nnx.MultiHeadAttention(num_heads, embed)

In [None]:
# class ConformerEncoder(nnx.Module):
#     def __init__(self, num_layers: int):
#         self.blocks = [
#             ConformerBlock(

#             ) for _ in range(num_layers)
#         ]

In [None]:
subsampling = ConvSubsampling(80, rngs=nnx.Rngs(0))

In [None]:
data = jnp.ones((1, 142, 80, 1))
data.shape

In [None]:
res = subsampling(data, is_train=True)

In [None]:
res.shape

In [None]:
ffn = FeedForwardModule(80, 4, rngs=nnx.Rngs(1))

In [None]:
ffn(res, train=True).shape