In [None]:
from tqdm.notebook import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from pathlib import Path
from flax import nnx
import numpy as np

In [None]:
tokenizer = Tokenizer.load_tokenizer(Path('/home/penguin/data/tinyvoice/tokenizer/tokenizer.pkl'))

In [None]:
train_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/tinyvoice/data/data.array_record')
# test_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/Data/processed/test.array_record')

In [None]:
map_train_audio_dataset = grain.MapDataset.source(train_audio_source)
# map_test_audio_dataset = grain.MapDataset.source(test_audio_source)

In [None]:
processed_train_dataset = (
    map_train_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=48, batch_fn=batch_fn)
    .repeat(1)
)

# processed_test_dataset = (
#     map_test_audio_dataset
#     .map(ProcessAudioData(tokenizer))
#     .batch(batch_size=8, batch_fn=batch_fn)
# )

In [None]:
from conformer.model import ConformerModel
from tqdm import tqdm

In [None]:
model = ConformerModel(token_count=len(tokenizer.id_to_char))

In [None]:
import optax

In [None]:
lr_schedule = optax.linear_schedule(
    init_value=1e-7, 
    end_value=5e-4, 
    transition_steps=300
)

optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=0.9,
        b2=0.98,
        weight_decay=1e-2
    ),
    wrt=nnx.Param
)

In [None]:
@nnx.jit
def jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths):
    def loss_fn(model):
        outputs = model(padded_audios, mask=mask, training=True)
        audio_time_mask = np.arange(mask.shape[-1]) > real_times[:, None]
        label_mask = np.arange(padded_labels.shape[-1]) > label_lengths[:, None]
        loss = optax.ctc_loss(outputs, audio_time_mask, padded_labels, label_mask).mean()

        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)

    return loss

In [None]:
padded_audios, frames, padded_labels, label_lengths = processed_train_dataset[43]

In [None]:
def compute_mask(frames):
    real_times = ((frames // 160) + 1)
    max_mask = (235008 // 160) + 1
    mask = np.arange(max_mask) < real_times[:, None]
    mask = np.expand_dims(mask, axis=1).repeat(max_mask, axis=1)
    mask = mask[:, :-2:2, :-2:2]
    mask = mask[:, :-2:2, :-2:2]
    mask = np.expand_dims(mask, axis=1).repeat(4, axis=1)

    return mask, real_times // 4

In [None]:
mask, real_times = compute_mask(frames)

In [None]:
z = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

In [None]:
avg_loss = 0
for i, element in enumerate(tqdm(processed_train_dataset)):
    padded_audios, frames, padded_labels, label_lengths = element
    mask, real_times = compute_mask(frames)

    loss = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

    avg_loss += loss    
    if (i + 1) % 20 == 0:
        print(f"avg loss: {avg_loss // 20}")
        avg_loss = 0