In [1]:
from tqdm.notebook import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from functools import partial
from conformer.conformer_block import ConformerEncoder
from conformer.config import ConformerConfig, TrainingConfig
from flax import nnx
import jax.numpy as jnp
import optax
import jax

In [2]:
jax.config.update('jax_traceback_filtering', 'off')

In [3]:
conformer_config = ConformerConfig()
train_config = TrainingConfig()

In [4]:
tokenizer = Tokenizer.load('/home/penguin/data/tokenizer/tokenizer.json')

In [5]:
train_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/packed_dataset/train/data.array_record')
test_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/packed_dataset/test/data.array_record')
tokenizer_batch_fn = partial(batch_fn, tokenizer=tokenizer)

In [6]:
map_train_audio_dataset = grain.MapDataset.source(train_audio_source)
map_test_audio_dataset = grain.MapDataset.source(test_audio_source)

processed_train_dataset = (
    map_train_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=24, batch_fn=tokenizer_batch_fn)
)

processed_test_dataset = (
    map_test_audio_dataset
    .shuffle(seed=42)
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=24, batch_fn=tokenizer_batch_fn)
)

In [7]:
model = ConformerEncoder(conformer_config, num_classes=42, rngs=nnx.Rngs(0))

In [8]:
nnx.display(model)

In [9]:
from conformer.train_utils import (
    create_learning_rate_fn,
    train_step,
    eval_step,
    loss_fn
)

In [10]:
lr_schedule = create_learning_rate_fn(train_config.warmup_steps, conformer_config.encoder_dim)
optimizer = nnx.Optimizer(
    model,
    optax.adamw(
        learning_rate=lr_schedule,
        b1=train_config.beta1,
        b2=train_config.beta2,
        weight_decay=train_config.weight_decay,
    ),
    wrt=nnx.Param
)

In [11]:
processed_train_dataset[0]['inputs'].shape

(24, 235008)

In [None]:
loss_fn(model, processed_train_dataset[0], training=True)

In [12]:
loss = train_step(model, optimizer, processed_train_dataset[0])

W0927 13:55:31.310143  157952 bfc_allocator.cc:310] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.49GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
W0927 13:55:35.296677  157952 bfc_allocator.cc:310] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.21GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
E0927 13:55:38.324361  158175 slow_operation_alarm.cc:73] Trying algorithm eng28{k2=3,k3=0} for conv (f32[24,256,735,64]{3,2,1,0}, u8[0]{0}) custom-call(f32[24,256,737,66]{3,2,1,0}, f32[256,256,3,3]{3,2,1,0}), window={size=3x3}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale"

In [13]:
loss

Array(1021.5361, dtype=float32)

In [14]:
import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'GPU')

E0000 00:00:1758988566.269431  157952 cuda_executor.cc:1309] INTERNAL: CUDA Runtime error: Failed call to cudaGetRuntimeVersion: Error loading CUDA libraries. GPU will not be used.: Error loading CUDA libraries. GPU will not be used.
W0000 00:00:1758988566.275887  157952 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [15]:
from clu import metric_writers
logdir = './metrics'

writer = metric_writers.create_default_writer(logdir)
total_loss_accumulator = 0


n_steps_to_save_avg_train_loss = 20
n_steps_for_eval = 1000

for step_count, batch in tqdm(enumerate(processed_train_dataset, 1),
                               total=len(processed_train_dataset),
                               desc="training loop",
                               colour="green"):
    
    loss = train_step(model, optimizer, batch)
    total_loss_accumulator += loss.item()

    if step_count % n_steps_to_save_avg_train_loss == 0:
        avg_loss = total_loss_accumulator / n_steps_to_save_avg_train_loss
        writer.write_scalars(step_count, {'train_loss': avg_loss})
        total_loss_accumulator = 0

    if step_count % n_steps_for_eval == 0:
        total_eval_loss_accumulator = 0
        for eval_batch in tqdm(processed_test_dataset, desc='eval loop', colour='blue', leave=False):
            eval_loss = eval_step(model, batch)
            total_eval_loss_accumulator += eval_loss

        avg_eval_loss = total_eval_loss_accumulator / len(processed_test_dataset)
        writer.write_scalars(step_count, {'eval_loss': avg_eval_loss})

training loop:   0%|          | 0/4615 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

KeyboardInterrupt: 