In [1]:
from tqdm.notebook import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from pathlib import Path
from flax import nnx
import numpy as np
import jax
import jax.numpy as jnp
from conformer.model import ConformerEncoder
from tqdm import tqdm
import optax
import orbax.checkpoint as ocp
import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1"

In [3]:
checkpoint_path = Path('/home/penguin/Documents/TinyVoice/checkpoint/checkpoints_fixed')

checkpointer = ocp.CheckpointManager(
    checkpoint_path.absolute(),
    options=ocp.CheckpointManagerOptions(max_to_keep=5)
)

W0102 23:40:25.149051  105155 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0102 23:40:25.154219  104824 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


In [None]:
tokenizer = Tokenizer.load_tokenizer(Path('/home/penguin/data/ka/tokenizer/tokenizer.pkl'))

In [None]:
model = ConformerEncoder(token_count=len(tokenizer.id_to_char))

lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-7,
    peak_value=5e-4,
    warmup_steps=1000,
    decay_steps=10000,
    end_value=1e-6
)

optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=0.9,
        b2=0.98,
        weight_decay=1e-2
    ),
    wrt=nnx.Param
)

In [None]:
if checkpointer.latest_step() is not None:
    latest_step = checkpointer.latest_step()
    print(f"Restoring from checkpoint at step {latest_step}...")
    restored = checkpointer.restore(latest_step)
    nnx.update(model, restored['model'])
    # nnx.update(optimizer, restored['optimizer'])


In [None]:
train_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/ka/packed_dataset/train.array_record')
test_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/ka/packed_dataset/test.array_record')


map_train_audio_dataset = grain.MapDataset.source(train_audio_source)
map_test_audio_dataset = grain.MapDataset.source(test_audio_source)

In [None]:
processed_train_dataset = (
    map_train_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=48, batch_fn=batch_fn)
    .repeat(1)
)

processed_test_dataset = (
    map_test_audio_dataset
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=24, batch_fn=batch_fn)
)

In [None]:
def compute_mask(frames):
    # MelSpectrogram: hop_length=160, win_length=400, padded=False
    # T_mel = (T_audio - win_length) // hop_length + 1
    # Conv2dSubSampler: two layers of kernel=3, stride=2, padding='VALID'
    # T_out = (T_in - 3) // 2 + 1
    # T_final = (T_out - 3) // 2 + 1
    
    t_mel = (frames - 400) // 160 + 1
    t_conv1 = (t_mel - 3) // 2 + 1
    t_final = (t_conv1 - 3) // 2 + 1
    
    max_frames = 235008
    max_t_mel = (max_frames - 400) // 160 + 1
    max_t_conv1 = (max_t_mel - 3) // 2 + 1
    max_t_final = (max_t_conv1 - 3) // 2 + 1

    real_times = t_final
    
    # Square mask for attention
    mask = jnp.arange(max_t_final) < real_times[:, None]
    mask = jnp.expand_dims(mask, axis=1).repeat(max_t_final, axis=1)
    
    # MultiHeadAttention mask: (batch, num_heads, q_len, k_len)
    mask = jnp.expand_dims(mask, axis=1).repeat(4, axis=1)

    return mask, real_times

In [None]:
padded_audios, frames, padded_labels, label_lengths = processed_train_dataset[12]
mask, real_times = compute_mask(frames)

In [None]:
res = model(padded_audios, mask, training=False)

In [None]:
@nnx.jit
def jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths):
    def loss_fn(model):
        logits = model(padded_audios, mask=mask, training=True)
        
        audio_time_mask = jnp.arange(logits.shape[1]) >= real_times[:, None]
        label_mask = jnp.arange(padded_labels.shape[1]) >= label_lengths[:, None]
        
        loss = optax.ctc_loss(logits, audio_time_mask, padded_labels, label_mask, blank_id=tokenizer.blank_id).mean()

        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)

    return loss

In [None]:
padded_audios, frames, padded_labels, label_lengths = processed_train_dataset[43]

In [None]:
mask, real_times = compute_mask(frames)

In [None]:
z = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

In [None]:
# with jax.profiler.trace('./profiler/jax-trace') as profiler:
avg_loss = 0
for i, element in enumerate(tqdm(processed_train_dataset)):
    padded_audios, frames, padded_labels, label_lengths = element
    mask, real_times = compute_mask(frames)

    loss = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

    avg_loss += loss
    if (i + 1) % 20 == 0:
        print(f"avg loss: {avg_loss // 20}")
        avg_loss = 0

        # if i == 4:
        #     break

In [None]:
def decode(ids: list[int]) -> str:
    last_char_id = 0
    decoded_chars = []
    for char_id in ids:
        if char_id != 0 and char_id != last_char_id:
            decoded_chars.append(char_id)
        last_char_id = char_id
    
    return decoded_chars

In [None]:
dds = decode(output[4].argmax(axis=-1).tolist())

In [None]:
tokens = tokenizer.decode(dds)

In [None]:
# tokens = tokenizer.decode(output[10].argmax(axis=-1).tolist())
for tok in tokens:
    print(tok, end='')

In [None]:
z = tokenizer.decode(padded_labels[4].tolist())

In [None]:
for tok in z:
    if tok != '<BLANK>':
        print(tok, end='')