In [7]:
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from conformer.tokenizer import build_tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, create_array_record_dataset
import grain
from functools import partial

In [8]:
ROOT_PATH = '/home/penguin/Data/cv-corpus-22.0-2025-06-20-ka/cv-corpus-22.0-2025-06-20/ka'
df = pd.read_csv(Path(ROOT_PATH) / 'train_mod.tsv', delimiter='\t')
tokenizer = build_tokenizer(df['sentence'].values)

Vocabulary size: 41


In [9]:
tokenizer.save('tokenizer.json')

In [4]:
tokenizer.char_to_id = None

In [5]:
tokenizer.load('test.json')

In [None]:
create_array_record_dataset(df, 'data')

In [None]:
audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/Desktop/research/data/data.array_record')
with_tokenizer_batch_fn = partial(batch_fn, tokenizer=tokenizer)

In [None]:
df['label_token_count'].max(), df['duration'].max() * 16_000

In [None]:
map_audio_dataset = grain.MapDataset.source(audio_source)

In [None]:
example_datasets = (
    map_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=24, batch_fn=with_tokenizer_batch_fn)
)

In [None]:
from conformer.conformer_block import ConformerEncoder
from conformer.config import ConformerConfig, TrainingConfig
from flax import nnx
conformer_config = ConformerConfig()
train_config = TrainingConfig()
import jax.numpy as jnp
import optax
import jax


In [None]:
model = ConformerEncoder(conformer_config, num_classes=42, rngs=nnx.Rngs(0))

In [None]:
def create_padding_mask(lengths: jnp.ndarray, max_len: int) -> jnp.ndarray:
    batch_size = lengths.shape[0]
    indices = jnp.arange(max_len).reshape(1, -1)
    mask = indices >= lengths.reshape(-1, 1)
    return mask.astype(jnp.float32)

def create_learning_rate_fn(warmup_steps: int, model_size: int):
    def lr_fn(step):
        arg1 = 1 / jnp.sqrt(step + 1e-9)
        arg2 = step * (warmup_steps ** -1.5)
        return (1 / jnp.sqrt(model_size)) * jnp.minimum(arg1, arg2)
    return lr_fn

@nnx.jit(donate_argnums=0)
def train_step(model: ConformerEncoder, optimizer: nnx.Optimizer, batch: dict):

    def loss_fn(model: ConformerEncoder):
        log_probs, output_lengths = model(
            batch["inputs"], batch["input_lengths"], training=True
        )

        max_logit_len = log_probs.shape[1]
        max_label_len = batch["labels"].shape[1]
        logit_paddings = create_padding_mask(output_lengths, max_logit_len)
        label_paddings = create_padding_mask(batch["label_lengths"], max_label_len)
        
        loss = optax.ctc_loss(
            log_probs, logit_paddings, batch["labels"], label_paddings
        ).mean()
        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)
    return loss

In [None]:
lr_schedule = create_learning_rate_fn(train_config.warmup_steps, conformer_config.encoder_dim)
optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=train_config.beta1,
        b2=train_config.beta2,
        weight_decay=train_config.weight_decay,
    ),
    wrt=nnx.Param
)

In [None]:
loss = train_step(model, optimizer, example_datasets[0])

In [None]:
total_loss_accumulator = 0
step_count = 0
n_steps_to_print_avg_loss = 20


for batch in tqdm(example_datasets):
    loss = train_step(model, optimizer, batch)
    total_loss_accumulator += loss.item()
    step_count += 1

    if step_count % n_steps_to_print_avg_loss == 0:
        average_loss = total_loss_accumulator / n_steps_to_print_avg_loss
        print(f"Step {step_count}: Average Loss = {average_loss:.4f}")
        
        # Reset the accumulator and counter for the next interval
        total_loss_accumulator = 0.0
        step_count = 0

In [None]:
log_probs, output_lengths = model(
    example_datasets[0]["inputs"], example_datasets[0]["input_lengths"], training=False
)

In [None]:
log_probs.shape

In [None]:
jnp.argmax(log_probs, -1).shape

In [None]:
log_probs[0, :][-1]

In [None]:
tokenizer.decode(jnp.argmax(log_probs, -1)[7].tolist())

In [None]:
log_probs[0, :][4]

In [None]:
jnp.argmax

In [None]:
jnp.argmax(log_probs[10, :, ], -1)

In [None]:
with jax.profiler.trace("/tmp/profile-data"):
    i = 0
    for batch in example_datasets:
        loss = train_step(model, optimizer, batch)

        i += 1

        if i == 10:
            break