In [1]:
from tqdm.notebook import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from pathlib import Path
from flax import nnx
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
tokenizer = Tokenizer.load_tokenizer(Path('/home/penguin/data/tinyvoice/tokenizer/tokenizer.pkl'))

In [3]:
train_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/tinyvoice/data/data.array_record')
# test_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/Data/processed/test.array_record')

In [4]:
map_train_audio_dataset = grain.MapDataset.source(train_audio_source)
# map_test_audio_dataset = grain.MapDataset.source(test_audio_source)

In [5]:
processed_train_dataset = (
    map_train_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=48, batch_fn=batch_fn)
    .repeat(1)
)

# processed_test_dataset = (
#     map_test_audio_dataset
#     .map(ProcessAudioData(tokenizer))
#     .batch(batch_size=8, batch_fn=batch_fn)
# )

In [6]:
from conformer.model import ConformerModel
from tqdm import tqdm

In [7]:
model = ConformerModel(token_count=len(tokenizer.id_to_char))

W1220 18:17:10.568221  146291 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1220 18:17:10.569891  146143 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


In [8]:
import optax

In [9]:
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-7,
    peak_value=5e-4,
    warmup_steps=1000,
    decay_steps=10000,
    end_value=1e-6
)

optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=0.9,
        b2=0.98,
        weight_decay=1e-2
    ),
    wrt=nnx.Param
)

In [10]:
@nnx.jit
def jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths):
    def loss_fn(model):
        logits = model(padded_audios, mask=mask, training=True)
        
        audio_time_mask = jnp.arange(logits.shape[1]) >= real_times[:, None]
        label_mask = jnp.arange(padded_labels.shape[1]) >= label_lengths[:, None]
        
        loss = optax.ctc_loss(logits, audio_time_mask, padded_labels, label_mask).mean()

        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)

    return loss

In [11]:
padded_audios, frames, padded_labels, label_lengths = processed_train_dataset[43]

In [12]:
#
def compute_mask(frames):
    # MelSpectrogram: hop_length=160, win_length=400, padded=False
    # T_mel = (T_audio - win_length) // hop_length + 1
    # Conv2dSubSampler: two layers of kernel=3, stride=2, padding='VALID'
    # T_out = (T_in - 3) // 2 + 1
    # T_final = (T_out - 3) // 2 + 1
    
    t_mel = (frames - 400) // 160 + 1
    t_conv1 = (t_mel - 3) // 2 + 1
    t_final = (t_conv1 - 3) // 2 + 1
    
    max_frames = 235008
    max_t_mel = (max_frames - 400) // 160 + 1
    max_t_conv1 = (max_t_mel - 3) // 2 + 1
    max_t_final = (max_t_conv1 - 3) // 2 + 1

    real_times = t_final
    
    # Square mask for attention
    mask = jnp.arange(max_t_final) < real_times[:, None]
    mask = jnp.expand_dims(mask, axis=1).repeat(max_t_final, axis=1)
    
    # MultiHeadAttention mask: (batch, num_heads, q_len, k_len)
    mask = jnp.expand_dims(mask, axis=1).repeat(4, axis=1)

    return mask, real_times

In [13]:
mask, real_times = compute_mask(frames)

In [14]:
z = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

In [15]:
avg_loss = 0
for i, element in enumerate(tqdm(processed_train_dataset)):
    padded_audios, frames, padded_labels, label_lengths = element
    mask, real_times = compute_mask(frames)

    loss = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

    avg_loss += loss    
    if (i + 1) % 20 == 0:
        print(f"avg loss: {avg_loss // 20}")
        avg_loss = 0

  1%|          | 21/2444 [00:05<09:28,  4.26it/s]

avg loss: 355.0


  2%|▏         | 41/2444 [00:09<08:55,  4.49it/s]

avg loss: 237.0


  2%|▏         | 61/2444 [00:13<08:42,  4.56it/s]

avg loss: 201.0


  3%|▎         | 81/2444 [00:17<08:33,  4.60it/s]

avg loss: 183.0


  4%|▍         | 101/2444 [00:22<08:03,  4.85it/s]

avg loss: 168.0


  5%|▍         | 121/2444 [00:26<07:57,  4.87it/s]

avg loss: 174.0


  6%|▌         | 141/2444 [00:30<07:52,  4.87it/s]

avg loss: 171.0


  7%|▋         | 161/2444 [00:35<07:53,  4.82it/s]

avg loss: 168.0


  7%|▋         | 181/2444 [00:39<08:31,  4.43it/s]

avg loss: 164.0


  8%|▊         | 201/2444 [00:44<08:27,  4.42it/s]

avg loss: 161.0


  9%|▉         | 221/2444 [00:49<08:47,  4.22it/s]

avg loss: 164.0


 10%|▉         | 241/2444 [00:53<08:28,  4.33it/s]

avg loss: 174.0


 11%|█         | 261/2444 [00:58<08:20,  4.36it/s]

avg loss: 167.0


 11%|█▏        | 281/2444 [01:03<08:02,  4.48it/s]

avg loss: 170.0


 12%|█▏        | 301/2444 [01:08<08:38,  4.14it/s]

avg loss: 164.0


 13%|█▎        | 321/2444 [01:12<07:46,  4.55it/s]

avg loss: 163.0


 14%|█▍        | 341/2444 [01:17<07:40,  4.57it/s]

avg loss: 168.0


 15%|█▍        | 361/2444 [01:21<07:36,  4.56it/s]

avg loss: 165.0


 16%|█▌        | 381/2444 [01:26<07:38,  4.50it/s]

avg loss: 160.0


 16%|█▋        | 401/2444 [01:31<07:36,  4.48it/s]

avg loss: 155.0


 17%|█▋        | 421/2444 [01:35<07:23,  4.56it/s]

avg loss: 158.0


 18%|█▊        | 441/2444 [01:40<07:52,  4.24it/s]

avg loss: 155.0


 19%|█▉        | 461/2444 [01:45<07:19,  4.51it/s]

avg loss: 145.0


 20%|█▉        | 481/2444 [01:49<07:13,  4.53it/s]

avg loss: 136.0


 20%|██        | 501/2444 [01:54<07:09,  4.53it/s]

avg loss: 139.0


 21%|██▏       | 521/2444 [01:58<07:23,  4.34it/s]

avg loss: 130.0


 22%|██▏       | 540/2444 [02:03<08:39,  3.67it/s]

avg loss: 126.0


 23%|██▎       | 561/2444 [02:08<07:05,  4.43it/s]

avg loss: 123.0


 24%|██▍       | 581/2444 [02:13<07:08,  4.35it/s]

avg loss: 118.0


 25%|██▍       | 601/2444 [02:18<07:29,  4.10it/s]

avg loss: 116.0


 25%|██▌       | 621/2444 [02:23<07:16,  4.18it/s]

avg loss: 114.0


 26%|██▌       | 641/2444 [02:28<07:13,  4.16it/s]

avg loss: 113.0


 27%|██▋       | 661/2444 [02:33<07:09,  4.15it/s]

avg loss: 111.0


 28%|██▊       | 681/2444 [02:37<06:57,  4.22it/s]

avg loss: 110.0


 29%|██▊       | 701/2444 [02:42<07:02,  4.12it/s]

avg loss: 106.0


 30%|██▉       | 721/2444 [02:47<06:51,  4.19it/s]

avg loss: 111.0


 30%|███       | 741/2444 [02:52<06:44,  4.21it/s]

avg loss: 103.0


 31%|███       | 761/2444 [02:57<06:46,  4.14it/s]

avg loss: 104.0


 32%|███▏      | 781/2444 [03:02<06:37,  4.18it/s]

avg loss: 99.0


 33%|███▎      | 801/2444 [03:07<06:24,  4.27it/s]

avg loss: 102.0


 34%|███▎      | 821/2444 [03:12<06:00,  4.50it/s]

avg loss: 100.0


 34%|███▍      | 841/2444 [03:17<06:20,  4.21it/s]

avg loss: 99.0


 35%|███▌      | 861/2444 [03:21<06:23,  4.13it/s]

avg loss: 96.0


 36%|███▌      | 881/2444 [03:26<05:51,  4.45it/s]

avg loss: 97.0


 37%|███▋      | 901/2444 [03:31<05:43,  4.49it/s]

avg loss: 99.0


 38%|███▊      | 921/2444 [03:35<05:40,  4.47it/s]

avg loss: 94.0


 39%|███▊      | 941/2444 [03:40<05:40,  4.41it/s]

avg loss: 96.0


 39%|███▉      | 961/2444 [03:45<05:31,  4.48it/s]

avg loss: 93.0


 40%|████      | 981/2444 [03:49<05:27,  4.47it/s]

avg loss: 94.0


 41%|████      | 1001/2444 [03:54<05:54,  4.07it/s]

avg loss: 91.0


 42%|████▏     | 1021/2444 [03:59<05:21,  4.43it/s]

avg loss: 91.0


 43%|████▎     | 1040/2444 [04:04<06:24,  3.65it/s]

avg loss: 89.0


 43%|████▎     | 1061/2444 [04:09<05:19,  4.33it/s]

avg loss: 88.0


 44%|████▍     | 1081/2444 [04:13<05:04,  4.48it/s]

avg loss: 88.0


 45%|████▌     | 1101/2444 [04:18<05:00,  4.47it/s]

avg loss: 87.0


 46%|████▌     | 1121/2444 [04:23<04:59,  4.42it/s]

avg loss: 85.0


 47%|████▋     | 1141/2444 [04:27<04:50,  4.49it/s]

avg loss: 86.0


 48%|████▊     | 1161/2444 [04:32<04:46,  4.48it/s]

avg loss: 87.0


 48%|████▊     | 1181/2444 [04:36<04:47,  4.39it/s]

avg loss: 83.0


 49%|████▉     | 1201/2444 [04:41<04:48,  4.30it/s]

avg loss: 79.0


 50%|████▉     | 1221/2444 [04:46<04:45,  4.29it/s]

avg loss: 82.0


 51%|█████     | 1241/2444 [04:51<04:31,  4.43it/s]

avg loss: 85.0


 52%|█████▏    | 1261/2444 [04:55<04:30,  4.38it/s]

avg loss: 81.0


 52%|█████▏    | 1281/2444 [05:00<04:23,  4.42it/s]

avg loss: 82.0


 53%|█████▎    | 1300/2444 [05:05<05:10,  3.69it/s]

avg loss: 81.0


 54%|█████▍    | 1321/2444 [05:09<04:14,  4.41it/s]

avg loss: 80.0


 55%|█████▍    | 1341/2444 [05:14<04:23,  4.19it/s]

avg loss: 81.0


 56%|█████▌    | 1361/2444 [05:19<04:02,  4.46it/s]

avg loss: 82.0


 57%|█████▋    | 1381/2444 [05:24<03:57,  4.47it/s]

avg loss: 78.0


 57%|█████▋    | 1401/2444 [05:28<03:52,  4.49it/s]

avg loss: 79.0


 58%|█████▊    | 1421/2444 [05:33<04:08,  4.11it/s]

avg loss: 80.0


 59%|█████▉    | 1441/2444 [05:38<03:47,  4.42it/s]

avg loss: 78.0


 60%|█████▉    | 1461/2444 [05:43<03:58,  4.12it/s]

avg loss: 79.0


 61%|██████    | 1481/2444 [05:48<03:51,  4.17it/s]

avg loss: 79.0


 61%|██████▏   | 1501/2444 [05:53<03:40,  4.28it/s]

avg loss: 178.0


 62%|██████▏   | 1521/2444 [05:57<03:34,  4.30it/s]

avg loss: 77.0


 63%|██████▎   | 1541/2444 [06:02<03:34,  4.20it/s]

avg loss: 74.0


 64%|██████▍   | 1561/2444 [06:07<03:17,  4.47it/s]

avg loss: 77.0


 65%|██████▍   | 1581/2444 [06:12<03:12,  4.48it/s]

avg loss: 72.0


 66%|██████▌   | 1601/2444 [06:16<03:08,  4.47it/s]

avg loss: 73.0


 66%|██████▋   | 1621/2444 [06:21<03:09,  4.35it/s]

avg loss: 73.0


 67%|██████▋   | 1641/2444 [06:26<02:58,  4.50it/s]

avg loss: 73.0


 68%|██████▊   | 1661/2444 [06:31<03:09,  4.12it/s]

avg loss: 73.0


 69%|██████▊   | 1680/2444 [06:35<03:32,  3.59it/s]

avg loss: 72.0


 70%|██████▉   | 1701/2444 [06:40<02:57,  4.17it/s]

avg loss: 72.0


 70%|███████   | 1721/2444 [06:45<02:57,  4.08it/s]

avg loss: 71.0


 71%|███████   | 1741/2444 [06:50<02:50,  4.12it/s]

avg loss: 67.0


 72%|███████▏  | 1761/2444 [06:55<02:48,  4.05it/s]

avg loss: 73.0


 73%|███████▎  | 1781/2444 [07:01<02:39,  4.17it/s]

avg loss: 73.0


 74%|███████▎  | 1801/2444 [07:06<02:28,  4.32it/s]

avg loss: 73.0


 75%|███████▍  | 1821/2444 [07:10<02:30,  4.13it/s]

avg loss: 73.0


 75%|███████▌  | 1841/2444 [07:15<02:19,  4.33it/s]

avg loss: 73.0


 76%|███████▌  | 1861/2444 [07:20<02:16,  4.26it/s]

avg loss: 72.0


 77%|███████▋  | 1881/2444 [07:25<02:09,  4.36it/s]

avg loss: 70.0


 78%|███████▊  | 1901/2444 [07:30<02:04,  4.35it/s]

avg loss: 69.0


 79%|███████▊  | 1921/2444 [07:35<02:05,  4.17it/s]

avg loss: 70.0


 79%|███████▉  | 1940/2444 [07:40<02:26,  3.44it/s]

avg loss: 68.0


 80%|████████  | 1961/2444 [07:45<01:51,  4.32it/s]

avg loss: 69.0


 81%|████████  | 1981/2444 [07:49<01:46,  4.34it/s]

avg loss: 70.0


 82%|████████▏ | 2001/2444 [07:54<01:49,  4.06it/s]

avg loss: 69.0


 83%|████████▎ | 2021/2444 [07:59<01:38,  4.28it/s]

avg loss: 69.0


 84%|████████▎ | 2041/2444 [08:04<01:36,  4.17it/s]

avg loss: 73.0


 84%|████████▍ | 2061/2444 [08:09<01:36,  3.98it/s]

avg loss: 70.0


 85%|████████▌ | 2081/2444 [08:14<01:25,  4.25it/s]

avg loss: 70.0


 86%|████████▌ | 2101/2444 [08:19<01:19,  4.31it/s]

avg loss: 69.0


 87%|████████▋ | 2121/2444 [08:24<01:19,  4.06it/s]

avg loss: 66.0


 88%|████████▊ | 2141/2444 [08:29<01:10,  4.33it/s]

avg loss: 66.0


 88%|████████▊ | 2161/2444 [08:34<01:06,  4.24it/s]

avg loss: 65.0


 89%|████████▉ | 2181/2444 [08:38<01:03,  4.11it/s]

avg loss: 66.0


 90%|█████████ | 2201/2444 [08:43<00:56,  4.32it/s]

avg loss: 67.0


 91%|█████████ | 2221/2444 [08:48<00:51,  4.30it/s]

avg loss: 65.0


 92%|█████████▏| 2241/2444 [08:53<00:46,  4.33it/s]

avg loss: 67.0


 93%|█████████▎| 2261/2444 [08:58<00:41,  4.37it/s]

avg loss: 66.0


 93%|█████████▎| 2281/2444 [09:02<00:38,  4.26it/s]

avg loss: 65.0


 94%|█████████▍| 2301/2444 [09:07<00:33,  4.27it/s]

avg loss: 66.0


 95%|█████████▍| 2321/2444 [09:12<00:28,  4.31it/s]

avg loss: 63.0


 96%|█████████▌| 2341/2444 [09:17<00:24,  4.28it/s]

avg loss: 65.0


 97%|█████████▋| 2361/2444 [09:22<00:19,  4.35it/s]

avg loss: 64.0


 97%|█████████▋| 2381/2444 [09:27<00:15,  4.18it/s]

avg loss: 65.0


 98%|█████████▊| 2401/2444 [09:32<00:10,  4.27it/s]

avg loss: 63.0


 99%|█████████▉| 2421/2444 [09:37<00:05,  4.17it/s]

avg loss: 64.0


100%|█████████▉| 2441/2444 [09:41<00:00,  4.18it/s]

avg loss: 67.0


100%|██████████| 2444/2444 [10:09<00:00,  4.01it/s]


In [16]:
padded_audios, frames, padded_labels, label_lengths = processed_train_dataset[12]
mask, real_times = compute_mask(frames)

In [17]:
output = model(padded_audios, mask=mask, training=False)

In [18]:
def decode(ids: list[int]) -> str:
    last_char_id = 0
    decoded_chars = []
    for char_id in ids:
        if char_id != 0 and char_id != last_char_id:
            decoded_chars.append(char_id)
        last_char_id = char_id
    
    return decoded_chars

In [19]:
dds = decode(output[4].argmax(axis=-1).tolist())

In [25]:
tokens = tokenizer.decode(dds)

In [26]:
# tokens = tokenizer.decode(output[10].argmax(axis=-1).tolist())
for tok in tokens:
    print(tok, end='')

მდებარეობსქეყი სახეთ ამოსავლუბშიკიიიიკააი კკკკგიგკკ 

In [22]:
z = tokenizer.decode(padded_labels[4].tolist())

In [23]:
for tok in z:
    if tok != '<BLANK>':
        print(tok, end='')

მდებარეობს ქვეყნის სამხრეთაღმოსავლეთში