In [1]:
from conformer.tokenizer import build_tokenizer
from conformer.dataset import AudioDataSource, batch_fn
import grain
from flax import nnx
from conformer.config import AudioConfig, ConformerConfig
from conformer.conv_subsampler import ConvolutionSubsampling
from functools import partial
import jax.numpy as jnp

In [2]:
ROOT_PATH = '/home/penguin/Data/cv-corpus-22.0-2025-06-20-ka/cv-corpus-22.0-2025-06-20/ka'
tokenizer = build_tokenizer(ROOT_PATH)
audio_config = AudioConfig()
audio_source = AudioDataSource(ROOT_PATH, tokenizer)

Vocabulary size: 49


In [3]:
train_batch_fn = partial(batch_fn, tokenizer=tokenizer, audio_config=audio_config)

In [4]:
batch_size=32
dataset = (
    grain.MapDataset.source(audio_source)
    .shuffle(seed=42)
    .batch(batch_size=batch_size, batch_fn=train_batch_fn)
)

In [5]:
iter_dataset = dataset.to_iter_dataset(
    grain.ReadOptions(num_threads=8, prefetch_buffer_size=64)
)

In [6]:
conformer_config = ConformerConfig()

In [7]:
subsampler = ConvolutionSubsampling(output_dim=conformer_config.encoder_dim, rngs=nnx.Rngs(0))

In [8]:
x = subsampler(jnp.ones((4, 720, 80)), training=False)
x.shape

(4, 180, 144)

In [9]:
from conformer.conformer_block import ConformerEncoder
from conformer.config import ConformerConfig

In [10]:
conformer_encoder = ConformerEncoder(conformer_config, num_classes=42, rngs=nnx.Rngs(0))

In [11]:
inputs, inputs_length = dataset[0]['inputs'], dataset[0]['input_lengths']

In [12]:
inputs, inputs_length = dataset[0]['inputs'][0], dataset[0]['input_lengths'][0]
x = jnp.expand_dims(inputs, axis=0)
x2 = jnp.expand_dims(inputs_length, axis=0)
conformer_encoder(x, x2, training=False)[0].shape

(1, 206, 42)

In [13]:
dataset[0]['inputs'].devices()

{CudaDevice(id=0)}

In [14]:
from tqdm import tqdm

In [15]:
x.shape, x2.shape

((1, 821, 80), (1,))

In [16]:
@nnx.jit
def train_step(model: ConformerEncoder, batch: dict):
    
    def loss_fn(model: ConformerEncoder):
        log_probs, output_lengths = model(
            batch["inputs"], batch["input_lengths"], training=False
        )
        
        return log_probs
    
    log_probs = loss_fn(model)
    return log_probs

In [17]:
train_step(conformer_encoder, dataset[0])

Array([[[-3.7503004, -4.005907 , -1.8450456, ..., -3.0749722,
         -4.797942 , -3.0506206],
        [-3.4508882, -4.191522 , -1.7587198, ..., -3.4552402,
         -5.19321  , -3.713553 ],
        [-3.8005195, -3.28516  , -1.4204216, ..., -3.8842711,
         -4.4585037, -4.411375 ],
        ...,
        [-3.2474875, -4.156693 , -1.4352329, ..., -3.5550833,
         -5.009884 , -3.6632829],
        [-3.237942 , -4.154594 , -1.3675549, ..., -3.5617847,
         -5.072281 , -3.671265 ],
        [-3.15376  , -4.308777 , -1.2650965, ..., -3.7040615,
         -5.1331043, -4.6156125]],

       [[-4.000767 , -5.05459  , -2.4097311, ..., -2.891604 ,
         -4.9167185, -3.0327392],
        [-3.8264103, -4.7299333, -1.9516429, ..., -3.6300182,
         -5.3756022, -4.0153537],
        [-3.9146795, -4.1548595, -1.6208906, ..., -3.225204 ,
         -4.5990534, -3.8018863],
        ...,
        [-3.270958 , -4.8632526, -1.6055392, ..., -3.4922915,
         -5.0999346, -3.8166304],
        [-3.

In [None]:
for element in tqdm(iter_dataset):
    train_step(conformer_encoder, element)

In [None]:
inputs.shape

In [None]:
y = conformer_encoder(inputs, inputs_length, training=False)

In [None]:
nnx.display(conformer_encoder)