In [19]:
from tqdm import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from functools import partial
from conformer.conformer_block import ConformerEncoder
from conformer.config import ConformerConfig, TrainingConfig
from flax import nnx
import jax.numpy as jnp
import optax
import jax


In [20]:
conformer_config = ConformerConfig()
train_config = TrainingConfig()

In [21]:
tokenizer = Tokenizer.load('/home/penguin/data/tokenizer/tokenizer.json')

In [22]:
audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/packed_dataset/data.array_record')
with_tokenizer_batch_fn = partial(batch_fn, tokenizer=tokenizer)

In [25]:
map_audio_dataset = grain.MapDataset.source(audio_source)

In [26]:
example_datasets = (
    map_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=24, batch_fn=with_tokenizer_batch_fn)
)

In [34]:
model = ConformerEncoder(conformer_config, num_classes=42, rngs=nnx.Rngs(0))

In [35]:
def create_padding_mask(lengths: jnp.ndarray, max_len: int) -> jnp.ndarray:
    batch_size = lengths.shape[0]
    indices = jnp.arange(max_len).reshape(1, -1)
    mask = indices >= lengths.reshape(-1, 1)
    return mask.astype(jnp.float32)

def create_learning_rate_fn(warmup_steps: int, model_size: int):
    def lr_fn(step):
        arg1 = 1 / jnp.sqrt(step + 1e-9)
        arg2 = step * (warmup_steps ** -1.5)
        return (1 / jnp.sqrt(model_size)) * jnp.minimum(arg1, arg2)
    return lr_fn

@nnx.jit(donate_argnums=0)
def train_step(model: ConformerEncoder, optimizer: nnx.Optimizer, batch: dict):

    def loss_fn(model: ConformerEncoder):
        log_probs, output_lengths = model(
            batch["inputs"], batch["input_lengths"], training=True
        )

        max_logit_len = log_probs.shape[1]
        max_label_len = batch["labels"].shape[1]
        logit_paddings = create_padding_mask(output_lengths, max_logit_len)
        label_paddings = create_padding_mask(batch["label_lengths"], max_label_len)
        
        loss = optax.ctc_loss(
            log_probs, logit_paddings, batch["labels"], label_paddings
        ).mean()
        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)
    return loss

In [39]:
lr_schedule = create_learning_rate_fn(train_config.warmup_steps, conformer_config.encoder_dim)
optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=train_config.beta1,
        b2=train_config.beta2,
        weight_decay=train_config.weight_decay,
    ),
    wrt=nnx.Param
)

In [40]:
loss = train_step(model, optimizer, example_datasets[0])

In [41]:
total_loss_accumulator = 0
step_count = 0
n_steps_to_print_avg_loss = 20


for batch in tqdm(example_datasets):
    loss = train_step(model, optimizer, batch)
    total_loss_accumulator += loss.item()
    step_count += 1

    if step_count % n_steps_to_print_avg_loss == 0:
        average_loss = total_loss_accumulator / n_steps_to_print_avg_loss
        print(f"Step {step_count}: Average Loss = {average_loss:.4f}")
        
        # Reset the accumulator and counter for the next interval
        total_loss_accumulator = 0.0
        step_count = 0

        break

  0%|          | 19/4858 [00:07<33:39,  2.40it/s] 

Step 20: Average Loss = 908.2094





In [42]:
log_probs, output_lengths = model(
    example_datasets[0]["inputs"], example_datasets[0]["input_lengths"], training=False
)

In [85]:
from jiwer import wer

In [88]:
import numpy as np

In [89]:
def eval_step(model, batch, tokenizer_decode):
    """
    Evaluate a batch of inputs using the model and compute WER.

    Args:
        model: The Flax model (nnx.Module) to evaluate.
        batch: Dictionary containing 'inputs', 'input_lengths', and 'labels'.
        tokenizer_decode: Callable that decodes a single sequence of token IDs to a string.

    Returns:
        WER score for the batch.
    """
    # Forward pass through the model
    log_probs, output_lengths = model(
        batch["inputs"], batch["input_lengths"], training=False
    )

    # Get predicted token IDs
    prediction_tokens = jnp.argmax(log_probs, axis=-1)

    # Vectorized decoding using vmap
    predictions = jax.vmap(lambda x: np.array(tokenizer_decode(x)))(prediction_tokens)
    real_values = jax.vmap(lambda x: np.array(tokenizer_decode(x)))(batch["labels"])

    # Convert JAX arrays to Python lists for WER computation
    predictions = predictions.tolist()
    real_values = real_values.tolist()

    return wer(predictions, real_values)

In [92]:
# @nnx.jit(donate_argnums=0, static_argnames=('tokenizer'))
def eval_step(model, batch, tokenizer):
    log_probs, output_lengths = model(
        batch["inputs"], batch["input_lengths"], training=False
    )

    prediction_tokens = jnp.argmax(log_probs, axis=-1)
    predictions = []
    for element in prediction_tokens:
        predictions.append(tokenizer.decode(element.tolist()))

    real_values = []
    for element in batch['labels']:
        real_values.append(tokenizer.decode(element.tolist()))

    return wer(predictions, real_values)

In [93]:
eval_step(model, batch, tokenizer)

133

In [59]:
predictions = jnp.argmax(log_probs, axis=-1)

In [67]:
for element in predictions:
    text = tokenizer.decode(element.tolist())

In [71]:
for element in example_datasets[0]['labels']:
    text = tokenizer.decode(element.tolist())

In [74]:
from jiwer import wer

In [75]:
wer(text, ' მდგომარეობა პოლონეთში გავრცელების არეალის ძირითადი მონაკვეთია')

0.5714285714285714

In [73]:
text

'ამ მხრივ უკეთესი მდგომარეობა პოლონეთში გავრცელების არეალის ძირითადი მონაკვეთია რუსეთი სკანდინავია მათ შორის ფინეთი'

In [None]:

log_probs[0][0]

Array([-4, -4.5625, -1.67969, -3.9375, -3.125, -4.375, -4.1875, -3.01562,
       -5.9375, -4.65625, -3.79688, -5.8125, -3, -3.5, -3.45312, -3.70312,
       -3.65625, -3.875, -3.01562, -4.40625, -3.84375, -4.8125, -3.96875,
       -4.65625, -4.5625, -5.1875, -4.375, -5, -5.71875, -5.375, -4.71875,
       -4.125, -3.09375, -4.875, -3.26562, -5.46875, -3.1875, -4.8125,
       -5.59375, -3.78125, -3.125, -3.8125], dtype=bfloat16)

In [None]:
def eval_step(model, batch, tokenizer):
    