In [1]:
from tqdm.notebook import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from pathlib import Path
from flax import nnx
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
tokenizer = Tokenizer.load_tokenizer(Path('/home/penguin/data/tinyvoice/tokenizer/tokenizer.pkl'))

In [3]:
train_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/tinyvoice/data/data.array_record')
# test_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/Data/processed/test.array_record')

In [4]:
map_train_audio_dataset = grain.MapDataset.source(train_audio_source)
# map_test_audio_dataset = grain.MapDataset.source(test_audio_source)

In [5]:
processed_train_dataset = (
    map_train_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=48, batch_fn=batch_fn)
    .repeat(1)
)

# processed_test_dataset = (
#     map_test_audio_dataset
#     .map(ProcessAudioData(tokenizer))
#     .batch(batch_size=8, batch_fn=batch_fn)
# )

In [6]:
from conformer.model import ConformerModel
from tqdm import tqdm

In [7]:
model = ConformerModel(token_count=len(tokenizer.id_to_char))

W1220 17:22:29.814147  121998 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1220 17:22:29.816246  121823 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


In [8]:
import optax

In [9]:
lr_schedule = optax.linear_schedule(
    init_value=1e-7, 
    end_value=5e-4, 
    transition_steps=300
)

optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=0.9,
        b2=0.98,
        weight_decay=1e-2
    ),
    wrt=nnx.Param
)

In [10]:
@nnx.jit
def jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths):
    def loss_fn(model):
        logits = model(padded_audios, mask=mask, training=True)
        log_probs = jax.nn.log_softmax(logits, axis=-1)
        
        audio_time_mask = jnp.arange(log_probs.shape[1]) >= real_times[:, None]
        label_mask = jnp.arange(padded_labels.shape[1]) >= label_lengths[:, None]
        
        loss = optax.ctc_loss(log_probs, audio_time_mask, padded_labels, label_mask).mean()

        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)

    return loss

In [11]:
padded_audios, frames, padded_labels, label_lengths = processed_train_dataset[43]

In [12]:
def compute_mask(frames):
    # MelSpectrogram: hop_length=160, win_length=400, padded=False
    # T_mel = (T_audio - win_length) // hop_length + 1
    # Conv2dSubSampler: two layers of kernel=3, stride=2, padding='VALID'
    # T_out = (T_in - 3) // 2 + 1
    # T_final = (T_out - 3) // 2 + 1
    
    t_mel = (frames - 400) // 160 + 1
    t_conv1 = (t_mel - 3) // 2 + 1
    t_final = (t_conv1 - 3) // 2 + 1
    
    max_frames = 235008
    max_t_mel = (max_frames - 400) // 160 + 1
    max_t_conv1 = (max_t_mel - 3) // 2 + 1
    max_t_final = (max_t_conv1 - 3) // 2 + 1

    real_times = t_final
    
    # Square mask for attention
    mask = jnp.arange(max_t_final) < real_times[:, None]
    mask = jnp.expand_dims(mask, axis=1).repeat(max_t_final, axis=1)
    
    # MultiHeadAttention mask: (batch, num_heads, q_len, k_len)
    mask = jnp.expand_dims(mask, axis=1).repeat(4, axis=1)

    return mask, real_times

In [13]:
mask, real_times = compute_mask(frames)

In [14]:
z = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

In [15]:
avg_loss = 0
for i, element in enumerate(tqdm(processed_train_dataset)):
    padded_audios, frames, padded_labels, label_lengths = element
    mask, real_times = compute_mask(frames)

    loss = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

    avg_loss += loss    
    if (i + 1) % 20 == 0:
        print(f"avg loss: {avg_loss // 20}")
        avg_loss = 0

  1%|          | 21/2444 [00:05<10:58,  3.68it/s]

avg loss: 297.0


  2%|▏         | 41/2444 [00:11<10:31,  3.81it/s]

avg loss: 231.0


  2%|▏         | 61/2444 [00:16<10:19,  3.85it/s]

avg loss: 223.0


  3%|▎         | 81/2444 [00:21<09:37,  4.09it/s]

avg loss: 214.0


  4%|▍         | 101/2444 [00:26<09:31,  4.10it/s]

avg loss: 202.0


  5%|▍         | 121/2444 [00:31<09:23,  4.12it/s]

avg loss: 205.0


  6%|▌         | 141/2444 [00:36<09:20,  4.11it/s]

avg loss: 194.0


  7%|▋         | 161/2444 [00:41<09:13,  4.12it/s]

avg loss: 179.0


  7%|▋         | 181/2444 [00:46<09:32,  3.95it/s]

avg loss: 165.0


  8%|▊         | 201/2444 [00:51<09:23,  3.98it/s]

avg loss: 161.0


  9%|▉         | 221/2444 [00:56<08:59,  4.12it/s]

avg loss: 163.0


 10%|▉         | 241/2444 [01:01<08:55,  4.12it/s]

avg loss: 173.0


 11%|█         | 261/2444 [01:06<08:59,  4.04it/s]

avg loss: 166.0


 11%|█▏        | 281/2444 [01:11<08:59,  4.01it/s]

avg loss: 168.0


 12%|█▏        | 301/2444 [01:17<08:55,  4.00it/s]

avg loss: 161.0


 13%|█▎        | 321/2444 [01:22<08:44,  4.05it/s]

avg loss: 154.0


 14%|█▍        | 341/2444 [01:27<08:31,  4.11it/s]

avg loss: 151.0


 15%|█▍        | 361/2444 [01:32<08:30,  4.08it/s]

avg loss: 140.0


 16%|█▌        | 381/2444 [01:37<08:20,  4.12it/s]

avg loss: 128.0


 16%|█▋        | 400/2444 [01:42<10:05,  3.38it/s]

avg loss: 115.0


 17%|█▋        | 421/2444 [01:47<08:16,  4.07it/s]

avg loss: 114.0


 18%|█▊        | 441/2444 [01:52<08:18,  4.02it/s]

avg loss: 112.0


 19%|█▉        | 461/2444 [01:57<08:12,  4.03it/s]

avg loss: 104.0


 20%|█▉        | 481/2444 [02:02<08:11,  3.99it/s]

avg loss: 97.0


 20%|██        | 501/2444 [02:08<08:06,  4.00it/s]

avg loss: 100.0


 21%|██▏       | 521/2444 [02:13<07:56,  4.03it/s]

avg loss: 95.0


 22%|██▏       | 541/2444 [02:18<07:45,  4.09it/s]

avg loss: 93.0


 23%|██▎       | 561/2444 [02:23<07:40,  4.09it/s]

avg loss: 92.0


 24%|██▍       | 581/2444 [02:28<07:43,  4.02it/s]

avg loss: 88.0


 25%|██▍       | 601/2444 [02:33<07:37,  4.03it/s]

avg loss: 88.0


 25%|██▌       | 621/2444 [02:38<07:32,  4.03it/s]

avg loss: 86.0


 26%|██▌       | 641/2444 [02:43<07:28,  4.02it/s]

avg loss: 85.0


 27%|██▋       | 661/2444 [02:48<07:21,  4.03it/s]

avg loss: 84.0


 28%|██▊       | 681/2444 [02:53<07:17,  4.03it/s]

avg loss: 84.0


 29%|██▊       | 701/2444 [02:59<07:15,  4.00it/s]

avg loss: 81.0


 30%|██▉       | 721/2444 [03:04<07:06,  4.04it/s]

avg loss: 83.0


 30%|███       | 741/2444 [03:09<07:00,  4.05it/s]

avg loss: 78.0


 31%|███       | 761/2444 [03:14<06:55,  4.05it/s]

avg loss: 78.0


 32%|███▏      | 781/2444 [03:19<06:50,  4.05it/s]

avg loss: 74.0


 33%|███▎      | 801/2444 [03:24<06:46,  4.04it/s]

avg loss: 76.0


 34%|███▎      | 821/2444 [03:29<06:40,  4.05it/s]

avg loss: 75.0


 34%|███▍      | 841/2444 [03:34<06:34,  4.07it/s]

avg loss: 74.0


 35%|███▌      | 861/2444 [03:39<06:33,  4.03it/s]

avg loss: 72.0


 36%|███▌      | 881/2444 [03:44<06:26,  4.04it/s]

avg loss: 72.0


 37%|███▋      | 901/2444 [03:50<06:19,  4.06it/s]

avg loss: 73.0


 38%|███▊      | 921/2444 [03:55<06:17,  4.03it/s]

avg loss: 69.0


 39%|███▊      | 941/2444 [04:00<06:12,  4.04it/s]

avg loss: 71.0


 39%|███▉      | 961/2444 [04:05<06:04,  4.06it/s]

avg loss: 68.0


 40%|████      | 981/2444 [04:10<06:01,  4.04it/s]

avg loss: 70.0


 41%|████      | 1001/2444 [04:15<06:00,  4.00it/s]

avg loss: 66.0


 42%|████▏     | 1020/2444 [04:20<07:05,  3.34it/s]

avg loss: 66.0


 43%|████▎     | 1041/2444 [04:25<05:51,  3.99it/s]

avg loss: 65.0


 43%|████▎     | 1061/2444 [04:30<05:40,  4.06it/s]

avg loss: 64.0


 44%|████▍     | 1081/2444 [04:35<05:32,  4.10it/s]

avg loss: 64.0


 45%|████▌     | 1101/2444 [04:40<05:29,  4.07it/s]

avg loss: 63.0


 46%|████▌     | 1121/2444 [04:45<05:24,  4.08it/s]

avg loss: 61.0


 47%|████▋     | 1141/2444 [04:51<05:20,  4.06it/s]

avg loss: 62.0


 48%|████▊     | 1161/2444 [04:56<05:18,  4.02it/s]

avg loss: 64.0


 48%|████▊     | 1181/2444 [05:01<05:15,  4.00it/s]

avg loss: 60.0


 49%|████▉     | 1201/2444 [05:06<05:04,  4.09it/s]

avg loss: 58.0


 50%|████▉     | 1221/2444 [05:11<05:01,  4.05it/s]

avg loss: 60.0


 51%|█████     | 1241/2444 [05:16<04:57,  4.05it/s]

avg loss: 61.0


 52%|█████▏    | 1261/2444 [05:21<04:46,  4.13it/s]

avg loss: 57.0


 52%|█████▏    | 1281/2444 [05:26<04:42,  4.12it/s]

avg loss: 58.0


 53%|█████▎    | 1301/2444 [05:31<04:37,  4.11it/s]

avg loss: 58.0


 54%|█████▍    | 1321/2444 [05:36<04:33,  4.10it/s]

avg loss: 57.0


 55%|█████▍    | 1341/2444 [05:41<04:28,  4.10it/s]

avg loss: 57.0


 56%|█████▌    | 1361/2444 [05:46<04:26,  4.06it/s]

avg loss: 58.0


 57%|█████▋    | 1381/2444 [05:51<04:18,  4.12it/s]

avg loss: 56.0


 57%|█████▋    | 1401/2444 [05:56<04:13,  4.12it/s]

avg loss: 56.0


 58%|█████▊    | 1421/2444 [06:01<04:09,  4.10it/s]

avg loss: 57.0


 59%|█████▉    | 1440/2444 [06:06<04:57,  3.38it/s]

avg loss: 55.0


 60%|█████▉    | 1461/2444 [06:11<04:02,  4.05it/s]

avg loss: 57.0


 61%|██████    | 1481/2444 [06:16<03:57,  4.06it/s]

avg loss: 56.0


 61%|██████▏   | 1501/2444 [06:21<03:54,  4.02it/s]

avg loss: 156.0


 62%|██████▏   | 1521/2444 [06:27<03:52,  3.97it/s]

avg loss: 55.0


 63%|██████▎   | 1541/2444 [06:32<03:41,  4.08it/s]

avg loss: 51.0


 64%|██████▍   | 1561/2444 [06:37<03:36,  4.08it/s]

avg loss: 55.0


 65%|██████▍   | 1581/2444 [06:42<03:32,  4.05it/s]

avg loss: 51.0


 66%|██████▌   | 1601/2444 [06:47<03:30,  4.00it/s]

avg loss: 52.0


 66%|██████▋   | 1621/2444 [06:52<03:22,  4.07it/s]

avg loss: 51.0


 67%|██████▋   | 1641/2444 [06:57<02:55,  4.58it/s]

avg loss: 51.0


 68%|██████▊   | 1660/2444 [07:01<03:44,  3.49it/s]

avg loss: 51.0


 69%|██████▉   | 1681/2444 [07:06<02:51,  4.44it/s]

avg loss: 50.0


 70%|██████▉   | 1701/2444 [07:11<02:47,  4.43it/s]

avg loss: 50.0


 70%|███████   | 1721/2444 [07:16<02:42,  4.45it/s]

avg loss: 48.0


 71%|███████   | 1741/2444 [07:20<02:37,  4.47it/s]

avg loss: 47.0


 72%|███████▏  | 1761/2444 [07:25<02:48,  4.06it/s]

avg loss: 51.0


 73%|███████▎  | 1781/2444 [07:30<02:45,  4.00it/s]

avg loss: 50.0


 74%|███████▎  | 1801/2444 [07:35<02:37,  4.07it/s]

avg loss: 50.0


 75%|███████▍  | 1821/2444 [07:40<02:32,  4.07it/s]

avg loss: 50.0


 75%|███████▌  | 1841/2444 [07:45<02:24,  4.18it/s]

avg loss: 50.0


 76%|███████▌  | 1861/2444 [07:50<02:21,  4.11it/s]

avg loss: 50.0


 77%|███████▋  | 1881/2444 [07:55<02:17,  4.11it/s]

avg loss: 49.0


 78%|███████▊  | 1901/2444 [08:01<02:13,  4.06it/s]

avg loss: 47.0


 79%|███████▊  | 1921/2444 [08:06<02:09,  4.03it/s]

avg loss: 49.0


 79%|███████▉  | 1941/2444 [08:11<02:03,  4.06it/s]

avg loss: 46.0


 80%|████████  | 1961/2444 [08:16<01:59,  4.04it/s]

avg loss: 47.0


 81%|████████  | 1981/2444 [08:21<01:55,  4.01it/s]

avg loss: 48.0


 82%|████████▏ | 2001/2444 [08:26<01:49,  4.05it/s]

avg loss: 47.0


 83%|████████▎ | 2021/2444 [08:31<01:44,  4.04it/s]

avg loss: 47.0


 84%|████████▎ | 2041/2444 [08:36<01:38,  4.07it/s]

avg loss: 50.0


 84%|████████▍ | 2061/2444 [08:41<01:34,  4.06it/s]

avg loss: 48.0


 85%|████████▌ | 2081/2444 [08:47<01:29,  4.06it/s]

avg loss: 48.0


 86%|████████▌ | 2101/2444 [08:52<01:24,  4.04it/s]

avg loss: 47.0


 87%|████████▋ | 2121/2444 [08:57<01:19,  4.07it/s]

avg loss: 44.0


 88%|████████▊ | 2141/2444 [09:02<01:13,  4.10it/s]

avg loss: 46.0


 88%|████████▊ | 2161/2444 [09:07<01:09,  4.06it/s]

avg loss: 45.0


 89%|████████▉ | 2181/2444 [09:12<01:05,  4.04it/s]

avg loss: 45.0


 90%|█████████ | 2201/2444 [09:17<00:59,  4.07it/s]

avg loss: 45.0


 91%|█████████ | 2221/2444 [09:22<00:54,  4.09it/s]

avg loss: 44.0


 92%|█████████▏| 2241/2444 [09:27<00:49,  4.12it/s]

avg loss: 45.0


 93%|█████████▎| 2261/2444 [09:32<00:46,  3.97it/s]

avg loss: 45.0


 93%|█████████▎| 2281/2444 [09:37<00:40,  4.06it/s]

avg loss: 43.0


 94%|█████████▍| 2301/2444 [09:42<00:32,  4.40it/s]

avg loss: 45.0


 95%|█████████▍| 2321/2444 [09:47<00:30,  4.00it/s]

avg loss: 43.0


 96%|█████████▌| 2341/2444 [09:52<00:25,  4.03it/s]

avg loss: 45.0


 97%|█████████▋| 2361/2444 [09:58<00:20,  4.01it/s]

avg loss: 44.0


 97%|█████████▋| 2380/2444 [10:03<00:19,  3.30it/s]

avg loss: 44.0


 98%|█████████▊| 2401/2444 [10:08<00:10,  4.02it/s]

avg loss: 42.0


 99%|█████████▉| 2421/2444 [10:13<00:05,  4.01it/s]

avg loss: 42.0


100%|█████████▉| 2441/2444 [10:18<00:00,  4.03it/s]

avg loss: 45.0


100%|██████████| 2444/2444 [10:43<00:00,  3.80it/s]


In [None]:
output = model(padded_audios, mask=mask, training=False)

In [None]:
output.shape

In [None]:
def decode(ids: list[int]) -> str:
    last_char_id = 0
    decoded_chars = []
    for char_id in ids:
        if char_id != 0 and char_id != last_char_id:
            decoded_chars.append(char_id)
        last_char_id = char_id
    
    return decoded_chars

In [None]:
dds = decode(output[18].argmax(axis=-1).tolist())

In [None]:
tokens = tokenizer.decode(dds)

In [None]:
# tokens = tokenizer.decode(output[10].argmax(axis=-1).tolist())

for tok in tokens:
    if tok != '<BLANK>':
        print(tok, end='')

In [None]:
z = tokenizer.decode(padded_labels[18].tolist())

In [None]:
for tok in z:
    if tok != '<BLANK>':
        print(tok, end='')