In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'

In [2]:
from tqdm.notebook import tqdm
from conformer.tokenizer import Tokenizer
from conformer.dataset import batch_fn, ProcessAudioData, unpack_speech_data
import grain
from functools import partial
from conformer.conformer_block import ConformerEncoder
from conformer.config import ConformerConfig, TrainingConfig, FeaturizerConfig
from flax import nnx
import jax.numpy as jnp
import optax
import jax

In [3]:
jax.config.update('jax_traceback_filtering', 'off')

In [4]:
conformer_config = ConformerConfig()
featurizer_config = FeaturizerConfig()
train_config = TrainingConfig()

In [5]:
tokenizer = Tokenizer.load('/home/penguin/data/tokenizer/tokenizer.json')

In [6]:
train_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/packed_dataset/train/data.array_record')
test_audio_source = grain.sources.ArrayRecordDataSource('/home/penguin/data/packed_dataset/test/data.array_record')
tokenizer_batch_fn = partial(batch_fn, tokenizer=tokenizer)

In [7]:
unpack_speech_data(train_audio_source[0])

({'label': 'აქვე გაწევრიანდა ქართველი ახალგაზრდობის პატრიოტული ორგანიზაცია „თეთრი გიორგის“ რიგებში',
  'frames': 91008},
 b'fLaC\x00\x00\x00"\x10\x00\x10\x00\x00\x01\xb8\x00\x14l\x03\xe8\x00\xf0\x00\x01c\x80;\xd8g\x15\xc64\x89\x91\x9b\xc1\xbb|\xc4\xb2W\xc4\x84\x00\x00( \x00\x00\x00reference libFLAC 1.4.3 20230623\x00\x00\x00\x00\xff\xf8\xc5\x08\x00oN\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xb4\xa0\xc5\xa8\x1e\x1f\x87}\xf0.\x80\xcf\xc7\x8a\x1f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xe1\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe\x1f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xe1\xff\xff\xff\xff\xff\xff\xff\xe4@\xa4\xe7\xff\xff\xfd\x02!\xc2\x81I\xcf\xa1\xa0p\xa1\x12\x10@1\x8a\x95biQE\x89\x94\x95(\x98N\x11\x84\xe2\x92\xa9`\x86\x17%\x95%bp\x98\xb2\xb2i2\xb0\xacP\xc0\x8e\x04d!\x824\xb8\xb4\xa4LT\x94\xb1\x0cTX\x9a#"ngm\xda)\xac\xd2\x7f%5\xa3>\x97V|\xcaD\xd9\xb4\xa6\x90\xd6Mj\x1a+/\xd4\xce\x84\xf0G\xe4_QR\xe4\x12]\x

In [8]:
map_train_audio_dataset = grain.MapDataset.source(train_audio_source)
map_test_audio_dataset = grain.MapDataset.source(test_audio_source)

processed_train_dataset = (
    map_train_audio_dataset
    .map(ProcessAudioData(tokenizer))
    .shuffle(seed=42)
    .repeat(2)
    .batch(batch_size=42, batch_fn=tokenizer_batch_fn)
)

processed_test_dataset = (
    map_test_audio_dataset
    .map(ProcessAudioData(tokenizer))
    .batch(batch_size=24, batch_fn=tokenizer_batch_fn)
)

In [9]:
def create_padding_mask(lengths: jnp.ndarray, max_len: int) -> jnp.ndarray:
    batch_size = lengths.shape[0]
    indices = jnp.arange(max_len).reshape(1, -1)
    mask = indices >= lengths.reshape(-1, 1)
    return mask.astype(jnp.float32)

In [10]:
model = ConformerEncoder(conformer_config, featurizer_config, num_classes=tokenizer.vocab_size,
                          rngs=nnx.Rngs(0))

In [11]:
batch = processed_train_dataset[0]
training = False

In [12]:
log_probs, output_lengths = model(
        batch["inputs"], batch["input_lengths"], training=training
    )

max_logit_len = log_probs.shape[1]
max_label_len = batch["labels"].shape[1]
logit_paddings = create_padding_mask(output_lengths, max_logit_len)

In [13]:
batch['input_lengths'][0]

np.int64(61056)

In [14]:
log_probs.shape

(42, 368, 42)

In [15]:
max_logit_len

368

In [16]:
output_lengths[0]

np.int64(15264)

In [17]:
logit_paddings[0]

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0.

In [18]:
nnx.display(model)

In [19]:
from conformer.train_utils import (
    create_learning_rate_fn,
    train_step,
    eval_step,
    loss_fn
)

In [20]:
from optax import warmup_cosine_decay_schedule
import optax

In [21]:
lr_schedule = warmup_cosine_decay_schedule(
    init_value=0.0,           
    peak_value=1e-3,          
    warmup_steps=len(processed_train_dataset) * 0.1,       
    decay_steps=len(processed_train_dataset) * 0.9,       
    end_value=1e-6            
)

In [22]:
adamw = optax.adamw(
        learning_rate=lr_schedule,
        b1=train_config.beta1,
        b2=train_config.beta2,
        weight_decay=train_config.weight_decay,
    )
# tx = optax.chain(optax.clip_by_global_norm(1.0), adamw)
tx = adamw

In [23]:
optimizer = nnx.Optimizer(
    model,
    tx,
    wrt=nnx.Param
)

In [24]:
graphdef, state = nnx.split((model, optimizer))

In [25]:
# loss_fn(model, processed_train_dataset[0], training=False)

In [26]:
# loss, grads = train_step(model, optimizer, processed_train_dataset[0])
loss, state = train_step(graphdef, state, processed_train_dataset[0])

In [27]:
# iterset = processed_train_dataset.to_iter_dataset()

In [28]:
def compute_gradient_metrics(grads: nnx.State):
    """Extract meaningful metrics from gradient State."""
    metrics = {}
    
    # Global gradient norm (L2 norm of all gradients)
    grad_norm = jnp.sqrt(sum(
        jnp.sum(jnp.square(g)) 
        for g in jax.tree.leaves(grads)
    ))
    metrics['global_grad_norm'] = float(grad_norm)
    
    # # Per-layer statistics
    # layer_norms = {}
    # layer_means = {}
    # layer_maxs = {}
    
    # # Iterate over the flat state correctly
    # flat_grads = grads.flat_state()
    # for i in range(len(flat_grads)):
    #     path = flat_grads._keys[i]
    #     value = flat_grads._values[i]
        
    #     # Extract layer name (customize based on your model structure)
    #     layer_name = '/'.join(path.split('/')[:-1]) if '/' in path else path
        
    #     if layer_name not in layer_norms:
    #         layer_norms[layer_name] = []
    #         layer_means[layer_name] = []
    #         layer_maxs[layer_name] = []
        
    #     # Compute metrics for this parameter
    #     layer_norms[layer_name].append(float(jnp.linalg.norm(value.ravel())))
    #     layer_means[layer_name].append(float(jnp.mean(jnp.abs(value))))
    #     layer_maxs[layer_name].append(float(jnp.max(jnp.abs(value))))
    
    # # Aggregate per-layer metrics
    # for layer_name in layer_norms:
    #     metrics[f'{layer_name}/grad_norm'] = sum(layer_norms[layer_name])
    #     metrics[f'{layer_name}/grad_mean'] = sum(layer_means[layer_name]) / len(layer_means[layer_name])
    #     metrics[f'{layer_name}/grad_max'] = max(layer_maxs[layer_name])
    
    return metrics

In [29]:
eval_loss = eval_step(graphdef, state, processed_test_dataset[0])

In [30]:
import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'GPU')

In [31]:
# with jax.profiler.trace('trace_bf'):
#     for i, idx in enumerate(index_sampler):
#         if i == 3:
#             break
#         loss = train_step(model, optimizer, processed_train_dataset[idx.record_key])

In [32]:
from clu import metric_writers
logdir = './metrics_v17'

writer = metric_writers.create_default_writer(logdir)
total_loss_accumulator = 0
# total_grad_norm_accumulator = 0
lr_rate_accumulator = 0


n_steps_to_save_avg_train_loss = 20
n_steps_for_eval = 500

for step_count, batch in tqdm(enumerate(processed_train_dataset, 1),
                               total=len(processed_train_dataset),
                               desc="training loop",
                               colour="green"):
    # loss, grads = train_step(model, optimizer, batch)
    loss, state = train_step(graphdef, state, batch)
    total_loss_accumulator += loss.item()
    lr_rate_accumulator += lr_schedule(step_count)
    
    # total_grad_norm_accumulator += compute_gradient_metrics(grads)['global_grad_norm']

    if step_count % n_steps_to_save_avg_train_loss == 0:
        avg_loss = total_loss_accumulator / n_steps_to_save_avg_train_loss
        # avg_grad_norm = total_grad_norm_accumulator / n_steps_to_save_avg_train_loss
        avg_lr = lr_rate_accumulator / n_steps_to_save_avg_train_loss

        writer.write_scalars(step_count, {'train_loss': avg_loss})
        # writer.write_scalars(step_count, {'train_grad_norm': avg_grad_norm})
        writer.write_scalars(step_count, {'train_lr': avg_lr})

        
        total_loss_accumulator = 0
        # total_grad_norm_accumulator = 0
        lr_rate_accumulator = 0

    if step_count % n_steps_for_eval == 0:
        total_eval_loss_accumulator = 0
        for eval_batch in tqdm(processed_test_dataset, desc='eval loop', colour='blue', leave=False):
            eval_loss = eval_step(graphdef, state, eval_batch)
            total_eval_loss_accumulator += eval_loss

        avg_eval_loss = total_eval_loss_accumulator / len(processed_test_dataset)
        writer.write_scalars(step_count, {'eval_loss': avg_eval_loss})
        writer.flush()


training loop:   0%|          | 0/5275 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

eval loop:   0%|          | 0/243 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# total_eval_loss_accumulator = 0
# for eval_batch in tqdm(processed_test_dataset, desc='eval loop', colour='blue', leave=False):
#     eval_loss = eval_step(model, eval_batch)
#     total_eval_loss_accumulator += eval_loss

# avg_eval_loss = total_eval_loss_accumulator / len(processed_test_dataset)
# writer.write_scalars(step_count, {'eval_loss': avg_eval_loss})

In [None]:
nnx.update((model, optimizer), state)

In [33]:
test = eval_step(graphdef, state, processed_test_dataset[6])

In [34]:
test

Array(189.10191, dtype=float32)

In [35]:
inputs = processed_test_dataset[0]['inputs']
input_lengths = processed_test_dataset[0]['input_lengths']
labels = processed_test_dataset[0]['labels']

In [36]:
model, _ = nnx.merge(graphdef, state)

In [37]:
transcribed = model(inputs, input_lengths, training=False)

In [38]:
logits = jnp.argmax(transcribed[0], -1)

In [39]:
for i in range(len(transcribed[0])):
    print("MODEL PREDICTED: ", tokenizer.decode(logits[i].tolist()))
    print("TRU LABEL: ", tokenizer.decode(labels[i].tolist()))
    print("=======")


MODEL PREDICTED:  ი
TRU LABEL:  არა
MODEL PREDICTED:  ი
TRU LABEL:  არა
MODEL PREDICTED:  ი
TRU LABEL:  დიახ
MODEL PREDICTED:  ი
TRU LABEL:  ხუთი
MODEL PREDICTED:    ი ი
TRU LABEL:  უფრო ადრე ცნობილი იყო როგორც შოუმენი კომიკოსი რეჟისორი პროდიუსერი და სცენარისტი
MODEL PREDICTED:  ი
TRU LABEL:  რვა
MODEL PREDICTED:  ი
TRU LABEL:  დასაფლავებულია წითელ მოედანზე კრემლის კედელთან
MODEL PREDICTED:  ი
TRU LABEL:  არა
MODEL PREDICTED:  ი
TRU LABEL:  ხუთი
MODEL PREDICTED:  ი
TRU LABEL:  არა
MODEL PREDICTED:  ი
TRU LABEL:  ერთი
MODEL PREDICTED:  ი
TRU LABEL:  ოთხი
MODEL PREDICTED:  ი
TRU LABEL:  არა
MODEL PREDICTED:  ი
TRU LABEL:  დიახ
MODEL PREDICTED:  ი
TRU LABEL:  ხუთი
MODEL PREDICTED:  ი
TRU LABEL:  ოთხი
MODEL PREDICTED:  ი
TRU LABEL:  სამი
MODEL PREDICTED:  ი
TRU LABEL:  ქორწინება არის მონოგამური
MODEL PREDICTED:  ი
TRU LABEL:  ნული
MODEL PREDICTED:  ი
TRU LABEL:  არა
MODEL PREDICTED:  ი
TRU LABEL:  შერაცხულია წმინდანად რომის კათოლიკური ეკლესის მიერ
MODEL PREDICTED:  ი
TRU LABEL:  ორი
MODEL 

In [None]:
tokenizer.decode(labels[3].tolist())

In [None]:
tokenizer.decode(logits[3].tolist())

In [None]:
processed_test_dataset[0]

In [None]:
tokenizer.decode()