In [1]:
from tqdm.notebook import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from pathlib import Path
from flax import nnx
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
tokenizer = Tokenizer.load_tokenizer(Path('/home/penguin/data/tinyvoice/tokenizer/tokenizer.pkl'))

In [3]:
train_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/tinyvoice/data/data.array_record')
# test_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/Data/processed/test.array_record')

In [4]:
map_train_audio_dataset = grain.MapDataset.source(train_audio_source)
# map_test_audio_dataset = grain.MapDataset.source(test_audio_source)

In [5]:
processed_train_dataset = (
    map_train_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=48, batch_fn=batch_fn)
    .repeat(1)
)

# processed_test_dataset = (
#     map_test_audio_dataset
#     .map(ProcessAudioData(tokenizer))
#     .batch(batch_size=8, batch_fn=batch_fn)
# )

In [6]:
from conformer.model import ConformerModel
from tqdm import tqdm

In [7]:
model = ConformerModel(token_count=len(tokenizer.id_to_char))

W1220 16:44:35.024277  107510 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1220 16:44:35.026201  107260 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


In [8]:
import optax

In [9]:
lr_schedule = optax.linear_schedule(
    init_value=1e-7, 
    end_value=5e-4, 
    transition_steps=300
)

optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=0.9,
        b2=0.98,
        weight_decay=1e-2
    ),
    wrt=nnx.Param
)

In [10]:
@nnx.jit
def jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths):
    def loss_fn(model):
        logits = model(padded_audios, mask=mask, training=True)
        log_probs = jax.nn.log_softmax(logits, axis=-1)
        
        audio_time_mask = jnp.arange(log_probs.shape[1]) >= real_times[:, None]
        label_mask = jnp.arange(padded_labels.shape[1]) >= label_lengths[:, None]
        
        loss = optax.ctc_loss(log_probs, audio_time_mask, padded_labels, label_mask).mean()

        return loss
    
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model=model, grads=grads)

    return loss

In [11]:
padded_audios, frames, padded_labels, label_lengths = processed_train_dataset[43]

In [12]:
def compute_mask(frames):
    # MelSpectrogram: hop_length=160, win_length=400, padded=False
    # T_mel = (T_audio - win_length) // hop_length + 1
    # Conv2dSubSampler: two layers of kernel=3, stride=2, padding='VALID'
    # T_out = (T_in - 3) // 2 + 1
    # T_final = (T_out - 3) // 2 + 1
    
    t_mel = (frames - 400) // 160 + 1
    t_conv1 = (t_mel - 3) // 2 + 1
    t_final = (t_conv1 - 3) // 2 + 1
    
    max_frames = 235008
    max_t_mel = (max_frames - 400) // 160 + 1
    max_t_conv1 = (max_t_mel - 3) // 2 + 1
    max_t_final = (max_t_conv1 - 3) // 2 + 1

    real_times = t_final
    
    # Square mask for attention
    mask = jnp.arange(max_t_final) < real_times[:, None]
    mask = jnp.expand_dims(mask, axis=1).repeat(max_t_final, axis=1)
    
    # MultiHeadAttention mask: (batch, num_heads, q_len, k_len)
    mask = jnp.expand_dims(mask, axis=1).repeat(4, axis=1)

    return mask, real_times

In [13]:
mask, real_times = compute_mask(frames)

In [14]:
z = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

In [39]:
avg_loss = 0
for i, element in enumerate(tqdm(processed_train_dataset)):
    padded_audios, frames, padded_labels, label_lengths = element
    mask, real_times = compute_mask(frames)

    loss = jitted_train(model, optimizer, padded_audios, padded_labels, mask, real_times, label_lengths)

    avg_loss += loss    
    if (i + 1) % 20 == 0:
        print(f"avg loss: {avg_loss // 20}")
        avg_loss = 0

  1%|          | 21/2444 [00:05<10:22,  3.89it/s]

avg loss: 42.0


  2%|▏         | 40/2444 [00:10<10:59,  3.64it/s]

avg loss: 41.0


  2%|▏         | 61/2444 [00:16<09:55,  4.00it/s]

avg loss: 45.0


  3%|▎         | 81/2444 [00:21<09:37,  4.09it/s]

avg loss: 41.0


  4%|▍         | 101/2444 [00:26<09:28,  4.12it/s]

avg loss: 40.0


  5%|▍         | 121/2444 [00:31<09:21,  4.14it/s]

avg loss: 43.0


  6%|▌         | 141/2444 [00:36<09:17,  4.13it/s]

avg loss: 42.0


  7%|▋         | 161/2444 [00:41<09:07,  4.17it/s]

avg loss: 42.0


  7%|▋         | 181/2444 [00:46<09:04,  4.16it/s]

avg loss: 39.0


  8%|▊         | 201/2444 [00:51<09:01,  4.15it/s]

avg loss: 39.0


  9%|▉         | 221/2444 [00:56<08:58,  4.13it/s]

avg loss: 40.0


 10%|▉         | 241/2444 [01:01<08:56,  4.11it/s]

avg loss: 43.0


 11%|█         | 261/2444 [01:06<08:53,  4.09it/s]

avg loss: 41.0


 11%|█▏        | 281/2444 [01:11<08:43,  4.13it/s]

avg loss: 41.0


 12%|█▏        | 301/2444 [01:16<08:38,  4.13it/s]

avg loss: 39.0


 13%|█▎        | 321/2444 [01:21<08:35,  4.12it/s]

avg loss: 39.0


 14%|█▍        | 341/2444 [01:26<08:27,  4.14it/s]

avg loss: 41.0


 15%|█▍        | 361/2444 [01:31<08:34,  4.05it/s]

avg loss: 39.0


 16%|█▌        | 381/2444 [01:36<08:19,  4.13it/s]

avg loss: 39.0


 16%|█▋        | 401/2444 [01:41<08:24,  4.05it/s]

avg loss: 37.0


 17%|█▋        | 421/2444 [01:46<08:23,  4.02it/s]

avg loss: 39.0


 18%|█▊        | 441/2444 [01:51<08:17,  4.03it/s]

avg loss: 40.0


 19%|█▉        | 461/2444 [01:56<08:27,  3.91it/s]

avg loss: 39.0


 20%|█▉        | 481/2444 [02:01<07:57,  4.11it/s]

avg loss: 37.0


 20%|██        | 501/2444 [02:06<07:53,  4.10it/s]

avg loss: 40.0


 21%|██▏       | 521/2444 [02:11<07:47,  4.11it/s]

avg loss: 38.0


 22%|██▏       | 541/2444 [02:16<07:45,  4.09it/s]

avg loss: 39.0


 23%|██▎       | 561/2444 [02:21<07:37,  4.12it/s]

avg loss: 39.0


 24%|██▍       | 581/2444 [02:26<07:30,  4.14it/s]

avg loss: 38.0


 25%|██▍       | 601/2444 [02:31<07:28,  4.11it/s]

avg loss: 39.0


 25%|██▌       | 621/2444 [02:37<07:19,  4.15it/s]

avg loss: 39.0


 26%|██▌       | 641/2444 [02:42<07:16,  4.13it/s]

avg loss: 38.0


 27%|██▋       | 661/2444 [02:47<07:12,  4.13it/s]

avg loss: 39.0


 28%|██▊       | 681/2444 [02:52<07:08,  4.11it/s]

avg loss: 40.0


 29%|██▊       | 701/2444 [02:57<07:07,  4.07it/s]

avg loss: 38.0


 30%|██▉       | 721/2444 [03:02<07:01,  4.09it/s]

avg loss: 39.0


 30%|███       | 741/2444 [03:07<06:58,  4.06it/s]

avg loss: 37.0


 31%|███       | 761/2444 [03:12<06:51,  4.09it/s]

avg loss: 38.0


 32%|███▏      | 781/2444 [03:17<06:45,  4.10it/s]

avg loss: 36.0


 33%|███▎      | 801/2444 [03:22<06:31,  4.19it/s]

avg loss: 37.0


 34%|███▎      | 821/2444 [03:27<06:32,  4.13it/s]

avg loss: 38.0


 34%|███▍      | 841/2444 [03:32<06:34,  4.06it/s]

avg loss: 38.0


 35%|███▌      | 861/2444 [03:37<06:33,  4.02it/s]

avg loss: 36.0


 36%|███▌      | 881/2444 [03:42<06:29,  4.02it/s]

avg loss: 37.0


 37%|███▋      | 901/2444 [03:47<06:21,  4.04it/s]

avg loss: 39.0


 38%|███▊      | 921/2444 [03:52<06:15,  4.06it/s]

avg loss: 36.0


 39%|███▊      | 941/2444 [03:58<06:12,  4.03it/s]

avg loss: 37.0


 39%|███▉      | 961/2444 [04:03<06:11,  3.99it/s]

avg loss: 36.0


 40%|████      | 981/2444 [04:08<05:56,  4.11it/s]

avg loss: 38.0


 41%|████      | 1001/2444 [04:13<05:56,  4.05it/s]

avg loss: 35.0


 42%|████▏     | 1021/2444 [04:18<05:53,  4.02it/s]

avg loss: 36.0


 43%|████▎     | 1040/2444 [04:23<07:01,  3.33it/s]

avg loss: 35.0


 43%|████▎     | 1061/2444 [04:28<05:42,  4.04it/s]

avg loss: 37.0


 44%|████▍     | 1081/2444 [04:34<05:35,  4.06it/s]

avg loss: 36.0


 45%|████▌     | 1101/2444 [04:39<05:29,  4.07it/s]

avg loss: 35.0


 46%|████▌     | 1121/2444 [04:44<05:28,  4.03it/s]

avg loss: 35.0


 47%|████▋     | 1141/2444 [04:49<05:21,  4.05it/s]

avg loss: 35.0


 47%|████▋     | 1160/2444 [04:54<06:25,  3.33it/s]

avg loss: 36.0


 48%|████▊     | 1181/2444 [04:59<05:13,  4.02it/s]

avg loss: 35.0


 49%|████▉     | 1201/2444 [05:04<05:07,  4.05it/s]

avg loss: 33.0


 50%|████▉     | 1221/2444 [05:09<04:59,  4.08it/s]

avg loss: 34.0


 51%|█████     | 1241/2444 [05:14<04:54,  4.09it/s]

avg loss: 35.0


 52%|█████▏    | 1261/2444 [05:19<04:50,  4.07it/s]

avg loss: 33.0


 52%|█████▏    | 1281/2444 [05:25<04:48,  4.03it/s]

avg loss: 34.0


 53%|█████▎    | 1301/2444 [05:30<04:42,  4.05it/s]

avg loss: 34.0


 54%|█████▍    | 1321/2444 [05:35<04:34,  4.09it/s]

avg loss: 34.0


 55%|█████▍    | 1341/2444 [05:40<04:27,  4.12it/s]

avg loss: 34.0


 56%|█████▌    | 1361/2444 [05:45<04:28,  4.03it/s]

avg loss: 36.0


 57%|█████▋    | 1381/2444 [05:50<04:20,  4.08it/s]

avg loss: 34.0


 57%|█████▋    | 1401/2444 [05:55<04:19,  4.02it/s]

avg loss: 34.0


 58%|█████▊    | 1421/2444 [06:00<04:08,  4.11it/s]

avg loss: 34.0


 59%|█████▉    | 1441/2444 [06:05<04:07,  4.05it/s]

avg loss: 33.0


 60%|█████▉    | 1460/2444 [06:10<04:55,  3.33it/s]

avg loss: 35.0


 61%|██████    | 1481/2444 [06:15<03:57,  4.06it/s]

avg loss: 34.0


 61%|██████▏   | 1501/2444 [06:21<03:55,  4.00it/s]

avg loss: 136.0


 62%|██████▏   | 1521/2444 [06:26<03:47,  4.05it/s]

avg loss: 35.0


 63%|██████▎   | 1541/2444 [06:31<03:43,  4.05it/s]

avg loss: 32.0


 64%|██████▍   | 1561/2444 [06:36<03:40,  4.01it/s]

avg loss: 35.0


 65%|██████▍   | 1581/2444 [06:41<03:32,  4.07it/s]

avg loss: 32.0


 66%|██████▌   | 1601/2444 [06:46<03:27,  4.06it/s]

avg loss: 33.0


 66%|██████▋   | 1621/2444 [06:51<03:21,  4.08it/s]

avg loss: 32.0


 67%|██████▋   | 1641/2444 [06:56<03:18,  4.04it/s]

avg loss: 33.0


 68%|██████▊   | 1661/2444 [07:01<03:13,  4.05it/s]

avg loss: 33.0


 69%|██████▉   | 1681/2444 [07:07<03:10,  4.00it/s]

avg loss: 32.0


 70%|██████▉   | 1701/2444 [07:12<03:04,  4.02it/s]

avg loss: 32.0


 70%|███████   | 1721/2444 [07:17<02:57,  4.07it/s]

avg loss: 31.0


 71%|███████   | 1741/2444 [07:22<02:53,  4.04it/s]

avg loss: 30.0


 72%|███████▏  | 1761/2444 [07:27<02:47,  4.07it/s]

avg loss: 33.0


 73%|███████▎  | 1781/2444 [07:32<02:45,  4.01it/s]

avg loss: 33.0


 74%|███████▎  | 1801/2444 [07:37<02:33,  4.19it/s]

avg loss: 34.0


 75%|███████▍  | 1821/2444 [07:42<02:31,  4.13it/s]

avg loss: 33.0


 75%|███████▌  | 1841/2444 [07:47<02:31,  3.98it/s]

avg loss: 34.0


 76%|███████▌  | 1861/2444 [07:52<02:28,  3.91it/s]

avg loss: 33.0


 77%|███████▋  | 1881/2444 [07:58<02:19,  4.05it/s]

avg loss: 32.0


 78%|███████▊  | 1901/2444 [08:03<02:14,  4.04it/s]

avg loss: 31.0


 79%|███████▊  | 1921/2444 [08:08<02:09,  4.05it/s]

avg loss: 32.0


 79%|███████▉  | 1941/2444 [08:13<02:03,  4.08it/s]

avg loss: 31.0


 80%|████████  | 1961/2444 [08:18<01:57,  4.12it/s]

avg loss: 32.0


 81%|████████  | 1981/2444 [08:23<01:52,  4.12it/s]

avg loss: 32.0


 82%|████████▏ | 2001/2444 [08:28<01:50,  4.02it/s]

avg loss: 31.0


 83%|████████▎ | 2021/2444 [08:33<01:39,  4.26it/s]

avg loss: 32.0


 84%|████████▎ | 2041/2444 [08:38<01:38,  4.10it/s]

avg loss: 34.0


 84%|████████▍ | 2061/2444 [08:42<01:28,  4.32it/s]

avg loss: 33.0


 85%|████████▌ | 2081/2444 [08:47<01:19,  4.58it/s]

avg loss: 33.0


 86%|████████▌ | 2101/2444 [08:52<01:15,  4.54it/s]

avg loss: 32.0


 87%|████████▋ | 2121/2444 [08:56<01:11,  4.50it/s]

avg loss: 30.0


 88%|████████▊ | 2141/2444 [09:01<01:07,  4.52it/s]

avg loss: 32.0


 88%|████████▊ | 2161/2444 [09:06<01:07,  4.19it/s]

avg loss: 31.0


 89%|████████▉ | 2181/2444 [09:11<00:59,  4.40it/s]

avg loss: 31.0


 90%|█████████ | 2201/2444 [09:15<00:57,  4.25it/s]

avg loss: 32.0


 91%|█████████ | 2221/2444 [09:20<00:50,  4.38it/s]

avg loss: 31.0


 92%|█████████▏| 2241/2444 [09:25<00:44,  4.54it/s]

avg loss: 32.0


 93%|█████████▎| 2261/2444 [09:29<00:40,  4.55it/s]

avg loss: 31.0


 93%|█████████▎| 2281/2444 [09:34<00:38,  4.25it/s]

avg loss: 30.0


 94%|█████████▍| 2301/2444 [09:39<00:34,  4.20it/s]

avg loss: 31.0


 95%|█████████▍| 2321/2444 [09:44<00:27,  4.43it/s]

avg loss: 30.0


 96%|█████████▌| 2341/2444 [09:48<00:23,  4.39it/s]

avg loss: 31.0


 97%|█████████▋| 2361/2444 [09:53<00:18,  4.44it/s]

avg loss: 31.0


 97%|█████████▋| 2381/2444 [09:58<00:14,  4.47it/s]

avg loss: 32.0


 98%|█████████▊| 2401/2444 [10:02<00:10,  4.29it/s]

avg loss: 30.0


 99%|█████████▉| 2421/2444 [10:07<00:05,  4.21it/s]

avg loss: 30.0


100%|█████████▉| 2441/2444 [10:12<00:00,  4.41it/s]

avg loss: 32.0


100%|██████████| 2444/2444 [10:13<00:00,  3.99it/s]


In [42]:
output = model(padded_audios, mask=mask, training=False)

In [43]:
output.shape

(46, 366, 44)

In [59]:
def decode(ids: list[int]) -> str:
    last_char_id = 0
    decoded_chars = []
    for char_id in ids:
        if char_id != 0 and char_id != last_char_id:
            decoded_chars.append(char_id)
        last_char_id = char_id
    
    return decoded_chars

In [60]:
dds = decode(output[18].argmax(axis=-1).tolist())

In [61]:
tokens = tokenizer.decode(dds)

In [62]:
# tokens = tokenizer.decode(output[10].argmax(axis=-1).tolist())

for tok in tokens:
    if tok != '<BLANK>':
        print(tok, end='')

იგი იყო ყუბაძლო მეომარიდა მხედართ მთავარი

In [67]:
z = tokenizer.decode(padded_labels[18].tolist())

In [68]:
for tok in z:
    if tok != '<BLANK>':
        print(tok, end='')

იგი იყო უბადლო მეომარი და მხედართმთავარი