In [1]:
from conformer.tokenizer import build_tokenizer
from conformer.dataset import AudioDataSource, batch_fn
import grain
from flax import nnx
from conformer.config import AudioConfig, ConformerConfig, TrainingConfig
from functools import partial
import jax.numpy as jnp
from conformer.conformer_block import ConformerEncoder
from conformer.config import ConformerConfig
from tqdm import tqdm
import jax
import optax
from pathlib import Path
import pandas as pd

In [2]:
ROOT_PATH = '/home/penguin/Data/cv-corpus-22.0-2025-06-20-ka/cv-corpus-22.0-2025-06-20/ka'

df = pd.read_csv(Path(ROOT_PATH) / 'train_mod.tsv', delimiter='\t')
tokenizer = build_tokenizer(df['sentence'].values)
audio_config = AudioConfig()
conformer_config = ConformerConfig()
train_config = TrainingConfig()

Vocabulary size: 41


In [3]:
audio_source = AudioDataSource(df, tokenizer)

In [4]:
batch_size=64
dataset = (
    grain.MapDataset.source(audio_source)
    .shuffle(seed=42)
    .batch(batch_size=batch_size, batch_fn=batch_fn)
)

In [5]:
iter_dataset = dataset.to_iter_dataset(
    grain.ReadOptions(num_threads=4, prefetch_buffer_size=64)
)

In [6]:
model = ConformerEncoder(conformer_config, num_classes=42, rngs=nnx.Rngs(0))

In [7]:
def create_padding_mask(lengths: jnp.ndarray, max_len: int) -> jnp.ndarray:
    batch_size = lengths.shape[0]
    indices = jnp.arange(max_len).reshape(1, -1)
    mask = indices >= lengths.reshape(-1, 1)
    return mask.astype(jnp.float32)

In [8]:
def create_learning_rate_fn(warmup_steps: int, model_size: int):
    def lr_fn(step):
        arg1 = 1 / jnp.sqrt(step + 1e-9)
        arg2 = step * (warmup_steps ** -1.5)
        return (1 / jnp.sqrt(model_size)) * jnp.minimum(arg1, arg2)
    return lr_fn

@nnx.jit(donate_argnums=0)
def train_step(model: ConformerEncoder, optimizer: nnx.Optimizer, batch: dict):

    def loss_fn(model: ConformerEncoder):
        log_probs, output_lengths = model(
            batch["inputs"], batch["input_lengths"], training=True
        )

        max_logit_len = log_probs.shape[1]
        max_label_len = batch["labels"].shape[1]
        logit_paddings = create_padding_mask(output_lengths, max_logit_len)
        label_paddings = create_padding_mask(batch["label_lengths"], max_label_len)
        
        loss = optax.ctc_loss(
            log_probs, logit_paddings, batch["labels"], label_paddings
        ).mean()
        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)
    return loss

In [9]:
lr_schedule = create_learning_rate_fn(train_config.warmup_steps, conformer_config.encoder_dim)
optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=train_config.beta1,
        b2=train_config.beta2,
        weight_decay=train_config.weight_decay,
    ),
    wrt=nnx.Param
)

In [10]:
# running once for jax to actually jit train_step
train_step(model, optimizer, dataset[12])

Array(35469.72, dtype=float32)

In [None]:
with tqdm(iter_dataset, unit="batch") as pbar:
    for element in pbar:
        loss = train_step(model, optimizer, element)
        pbar.set_postfix(loss=f"{loss:.4f}")

In [None]:
with jax.profiler.trace("/tmp/profile-data"):
    i = 0
    for element in tqdm(iter_dataset):
        loss = train_step(model, optimizer, element)
        print("Loss:", loss)
        if i == 50:
            break
        i += 1


In [None]:
for _ in tqdm(range(1000)):
    loss = train_step(model, optimizer, element)