## Conformer
[Original paper](https://arxiv.org/pdf/2005.08100)

This implementation is translated from PyTorch code, from lucidrains.

[PyTorch Repo](https://github.com/lucidrains/conformer/blob/master/conformer/conformer.py)

In [28]:
import jax
import jax.numpy as jnp
from flax import nnx

from einops import rearrange
from einops.layers.flax import Rearrange

Define helper functions/classes


In [29]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def calc_same_padding(kernel_size: int) -> tuple[int, int]:
    pad = kernel_size // 2
    return pad, pad - (kernel_size + 1) % 2

In [30]:
# GLU, Swish doesn't have to be manually implemented.
class DepthWiseConv1d(nnx.Module):
    def __init__(self, chan_in, chan_out, kernel_size, padding, rngs: nnx.Rngs):
        self.padding = padding
        self.conv = nnx.Conv(chan_in, chan_out, kernel_size, rngs=rngs, feature_group_count = chan_in)
    def __call__(self, x: jax.Array):

        # x = jnp.pad(x, self.padding) NNX Conv가 자동으로 사용하는 'SAME' 옵션 때문에 패딩 필요 X
        return self.conv(x)

In [31]:
# No-cache-RoPE
# We do not consider Key getting cached, neither the RoPE matrix calculation.
def _rotate_half(x: jnp.ndarray) -> jnp.ndarray:
    """Split """
    x1, x2 = x[..., ::2], x[..., 1::2] # Later check the performance of this
    return rearrange(jnp.stack((-x2, x1), axis=-1), '... d r -> ... (d r)', r=2)

def _apply_rope(q, k, sin, cos):
    """RoPE를 Q·K에 적용"""
    q = (q * cos) + (_rotate_half(q) * sin)
    k = (k * cos) + (_rotate_half(k) * sin)
    return q, k

class RotaryLookup(nnx.Module):
    """RoPE 사인·코사인 테이블을 만드는 보조 모듈"""
    def __init__(self, dim: int, theta: int = 10000):
        self.inv_freq = 1.0 / (theta ** (jnp.arange(0, dim, 2) / dim))

    def __call__(self, seq_len: int, dtype=jnp.float32):
        t = jnp.arange(seq_len, dtype=dtype)
        freqs = jnp.outer(t, self.inv_freq)
        sin, cos = jnp.sin(freqs), jnp.cos(freqs)
        # (1 1 n d)로 브로드캐스트
        return rearrange(sin, 'n d -> 1 1 n d'), rearrange(cos, 'n d -> 1 1 n d')

In [32]:
from typing import Callable

# attn, FFN, convolutional module to make Conformer
class Scale(nnx.Module):
    def __init__(self, scale, fn:Callable):
        self.fn = fn
        self.scale = scale
    def __call__(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

class PreNorm(nnx.Module):
    def __init__(self, dim, fn:Callable, rng: nnx.Rngs):
        self.norm = nnx.RMSNorm(dim, rngs=rng) # Let's not use LayerNorm, and improve to RMS
        self.fn = fn
    def __call__(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

class Attention(nnx.Module):
    """Now uses RoPE instead of lookup based positional embeddings"""
    def __init__(self,
                 dim: int,
                 rngs: nnx.Rngs,
                 heads: int = 8,
                 dim_head: int = 64,
                 dropout: float = 0.0):
        self.heads = heads
        self.dim_head = dim_head
        inner = heads * dim_head

        self.qkv = nnx.Linear(dim, inner * 3, use_bias=False, rngs=rngs)
        self.out_proj = nnx.Linear(inner, dim, use_bias=False, rngs=rngs)

        self.rope = RotaryLookup(dim_head)
        self.dropout = nnx.Dropout(dropout, rngs=rngs)

        # 스케일링 상수
        self.scale = dim_head ** -0.5

    def __call__(self,
                 x: jnp.ndarray,
                 mask: jnp.ndarray | None = None,
                 *,
                 rngs: nnx.Rngs | None = None,
                 deterministic: bool = True) -> jnp.ndarray:
        """
        x     : (batch, seq, dim)
        mask  : (batch, seq) or (batch, seq, seq)
        The original code did not use context nor context_mask, thus this is ignored
        """
        b, n, _ = x.shape

        # 1) QKV 생성 및 헤드 분할
        qkv = self.qkv(x)                               # (b n 3*inner)
        qkv = rearrange(qkv, 'b n (t h d) -> t b h n d',
                        t=3, h=self.heads, d=self.dim_head)
        q, k, v = qkv                                   # 각각 (b h n d)

        # 2) RoPE 적용
        sin, cos = self.rope(n, dtype=x.dtype)          # (1 1 n d)
        q, k = _apply_rope(q, k, jnp.repeat(sin, repeats=2, axis=-1), jnp.repeat(cos, repeats=2, axis=-1))

        # 3) 유사도 & 마스킹
        attn_logits = jnp.einsum('b h i d, b h j d -> b h i j', q, k)
        attn_logits *= self.scale

        if mask is not None:
            # (b 1 i j)로 브로드캐스트
            mask = mask if mask.ndim == 3 else mask[:, None, :, None] & mask[:, None, None, :]
            attn_logits = jnp.where(mask, attn_logits, jnp.finfo(x.dtype).min)

        # 4) 어텐션 가중치
        attn = jax.nn.softmax(attn_logits, axis=-1)                          # (b h i j)

        # 5) 값 집계 후 출력 투사
        out = jnp.einsum('b h i j, b h j d -> b h i d', attn, v)             # (b h n d)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.out_proj(out)
        if not deterministic:
            out = self.dropout(out, rngs=rngs)
        return out

class FeedForward(nnx.Module):
    def __init__(
        self,
        dim,
        rngs: nnx.Rngs,
        mult = 4,
        dropout = 0.0
    ):
        self.net = nnx.Sequential(
            nnx.Linear(dim, dim * mult, rngs = rngs),
            nnx.swish,
            nnx.Dropout(dropout, rngs=rngs),
            nnx.Linear(dim * mult, dim, rngs = rngs),
            nnx.Dropout(dropout, rngs=rngs)
        )
    def __call__(self, x: jax.Array):
        return self.net(x)

class ConformerConvModule(nnx.Module):
    def __init__(
        self,
        dim,
        rngs: nnx.Rngs,
        causal: bool = False,
        expansion_factor = 2,
        kernel_size = 31,
        dropout = 0.
    ):
        inner_dim = dim * expansion_factor
        padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
        # B, C, N으로 변환할 필요 없다.
        self.net = nnx.Sequential(
            nnx.RMSNorm(dim, rngs=rngs),
            nnx.Conv(dim, inner_dim * 2, 1, rngs=rngs),
            nnx.glu,
            DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding, rngs=rngs),
            nnx.BatchNorm(inner_dim, rngs=rngs) if not causal else jax.nn.identity,
            nnx.swish,
            nnx.Conv(inner_dim, dim, 1, rngs=rngs),
            nnx.Dropout(dropout, rngs=rngs)
        )

    def __call__(self, x):
        return self.net(x)

실제 모델 만들기

In [33]:
class ConformerBlock(nnx.Module):
    def __init__(
        self,
        *,
        dim,
        rngs: nnx.Rngs,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, rngs=rngs)
        self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rngs=rngs)
        self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout, rngs=rngs)
        self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, rngs=rngs)

        self.attn = PreNorm(dim, self.attn, rngs)
        self.ff1 = Scale(0.5, PreNorm(dim, self.ff1, rngs))
        self.ff2 = Scale(0.5, PreNorm(dim, self.ff2, rngs))

        self.post_norm = nnx.RMSNorm(dim, rngs=rngs)

    def __call__(self, x, mask = None):
        x = self.ff1(x) + x
        x = self.attn(x, mask = mask) + x
        x = self.conv(x) + x
        x = self.ff2(x) + x
        x = self.post_norm(x)
        return x


In [34]:
class Conformer(nnx.Module):
    def __init__(
        self,
        dim,
        rngs: nnx.Rngs,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.dim = dim
        self.layers = [ConformerBlock(dim=dim, rngs=rngs, dim_head=dim_head,heads=heads, ff_mult=ff_mult, conv_expansion_factor=conv_expansion_factor,conv_kernel_size=conv_kernel_size, attn_dropout=attn_dropout, ff_dropout=ff_dropout,conv_dropout=conv_dropout, conv_causal=conv_causal) for _ in range(depth)]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [35]:
jax.devices()

[CudaDevice(id=0)]

In [36]:
key = jax.random.PRNGKey(0)
model_key = nnx.Rngs(42)
batch_size = 2
seq_len = 128
dim = 256

x = jax.random.normal(key, (batch_size, seq_len, dim))
model = Conformer(dim=dim, depth=4, rngs=model_key)
y = model(x)

print("Input shape:", x.shape)
print("Output shape:", y.shape)

2025-06-12 00:01:09.311631: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng3{k11=2} for conv %cudnn-conv.1 = (f32[2,512,128]{2,1,0}, u8[0]{0}) custom-call(%transpose, %bitcast.12), window={size=31 pad=15_15}, dim_labels=bf0_oi0->bf0, feature_group_count=512, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated" source_file="/usr/programming/uv/machine_learning/lib/python3.10/site-packages/flax/nnx/nn/linear.py" source_line=800}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]} is taking a while...
2025-06-11 23:58:58.541737: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 581.677005ms
Trying algorithm eng3{k11=2} for conv %cudnn-conv.1 = (f32[2,512,128]{2,1,0}, u8[0]{0}) cust

Input shape: (2, 128, 256)
Output shape: (2, 128, 256)
