In [1]:
import torch
import os
import json
import librosa
import glob
import subprocess
import nemo.collections.asr as nemo_asr
from omegaconf import OmegaConf, DictConfig, open_dict
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

import lightning.pytorch as pl
# Correct import
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger

print("PyTorch CUDA available:", torch.cuda.is_available(), "CUDA version:", torch.version.cuda)

# Download and prepare AN4 data (same as before)
DATA_DIR = os.getcwd() + "/files"
# DATA_DIR = os.getcwd() + "/asr/notebooks/files/"
os.environ["DATA_DIR"] = DATA_DIR

  from .autonotebook import tqdm as notebook_tqdm


PyTorch CUDA available: True CUDA version: 12.8


In [2]:
# maxi, mini = (73.038375, 1.33225)
maxi, mini = (45, 1.33225)

# Create training configuration
train_config = DictConfig({
    'manifest_filepath': f'{DATA_DIR}/train_manifest.json',
    'sample_rate': 16000,
    'batch_size': 2,  # Reduced batch size for stability
    'shuffle': True,
    'num_workers': 2,  # Reduced for stability
    'pin_memory': True,
    'trim_silence': True,
    'max_duration': maxi,
    'min_duration': mini,
    "trim": True,
})

val_config = DictConfig({
    'manifest_filepath': f'{DATA_DIR}/test_manifest.json',
    'sample_rate': 16000,
    'batch_size': 4,
    'shuffle': False,
    'num_workers': 2,
    'pin_memory': True,
     "trim": True,
})

In [3]:
# Load the pre-trained model .from_pretrained
# model = nemo_asr.models.EncDecRNNTBPEModel.restore_from("parakeet-tdt-0.6b-v2.nemo")
# model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("parakeet-tdt-0.6b-v2.nemo")

# model = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/parakeet-tdt-0.6b-v2")
model = nemo_asr.models.ASRModel.restore_from("parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo")
# model.to(torch.bfloat16)

[NeMo I 2025-06-05 10:33:00 nemo_logging:393] Tokenizer SentencePieceTokenizer initialized with 1024 tokens


[NeMo W 2025-06-05 10:33:01 nemo_logging:405] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    use_lhotse: true
    skip_missing_manifest_entries: true
    input_cfg: null
    tarred_audio_filepaths: null
    manifest_filepath: null
    sample_rate: 16000
    shuffle: true
    num_workers: 2
    pin_memory: true
    max_duration: 40.0
    min_duration: 0.1
    text_field: answer
    batch_duration: null
    use_bucketing: true
    bucket_duration_bins: null
    bucket_batch_size: null
    num_buckets: 30
    bucket_buffer_size: 20000
    shuffle_buffer_size: 10000
    
[NeMo W 2025-06-05 10:33:01 nemo_logging:405] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config :

[NeMo I 2025-06-05 10:33:01 nemo_logging:393] PADDING: 0
[NeMo I 2025-06-05 10:33:05 nemo_logging:393] Using RNNT Loss : tdt
    Loss tdt_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0, 'durations': [0, 1, 2, 3, 4], 'sigma': 0.02, 'omega': 0.1}
[NeMo I 2025-06-05 10:33:05 nemo_logging:393] Using RNNT Loss : tdt
    Loss tdt_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0, 'durations': [0, 1, 2, 3, 4], 'sigma': 0.02, 'omega': 0.1}
[NeMo I 2025-06-05 10:33:05 nemo_logging:393] Using RNNT Loss : tdt
    Loss tdt_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0, 'durations': [0, 1, 2, 3, 4], 'sigma': 0.02, 'omega': 0.1}
[NeMo I 2025-06-05 10:33:08 nemo_logging:393] Model EncDecRNNTBPEModel was successfully restored from /home/raid/cognition/til/asr/notebooks/parakeet-tdt-0.6b-v2/parakeet-tdt-0.6b-v2.nemo.


EncDecRNNTBPEModel(
  (preprocessor): AudioToMelSpectrogramPreprocessor(
    (featurizer): FilterbankFeatures()
  )
  (encoder): ConformerEncoder(
    (pre_encode): ConvSubsampling(
      (out): Linear(in_features=4096, out_features=1024, bias=True)
      (conv): Sequential(
        (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
        (3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (4): ReLU(inplace=True)
        (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
        (6): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (7): ReLU(inplace=True)
      )
    )
    (pos_enc): RelPositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-23): 24 x ConformerLayer(
        (norm_feed_forward1): LayerNorm((1024,), eps=1e-05, element

In [4]:
print(type(model))
from lightning.pytorch import LightningModule
print(isinstance(model, LightningModule))

<class 'nemo.collections.asr.models.rnnt_bpe_models.EncDecRNNTBPEModel'>
True


In [5]:
# Set up training and validation data
model.setup_training_data(train_config)
model.setup_validation_data(val_config)

# Optional but recommended: prepare the model
# model.prepare_for_training()

[NeMo I 2025-06-05 10:33:08 nemo_logging:393] Dataset loaded with 4500 files totalling 31.91 hours
[NeMo I 2025-06-05 10:33:08 nemo_logging:393] 0 files were filtered totalling 0.00 hours
[NeMo I 2025-06-05 10:33:09 nemo_logging:393] Dataset loaded with 8 files totalling 0.06 hours
[NeMo I 2025-06-05 10:33:09 nemo_logging:393] 0 files were filtered totalling 0.00 hours


In [7]:
with open_dict(model.cfg.optim):
    model.cfg.optim.lr = 1e-4
    model.cfg.optim.sched.warmup_steps = 200
    model.cfg.optim.sched.warmup_ratio = None

In [8]:
# Set up logger
tb_logger = TensorBoardLogger(save_dir="./tb_logs", name="parakeet_finetune")

# Checkpoint and early stopping on val_wer
checkpoint_callback = ModelCheckpoint(
    monitor="val_wer", mode="min", save_top_k=1,
    dirpath="./checkpoints", filename="best_val_wer"
)
early_stop_callback = EarlyStopping(
    monitor="val_wer", mode="min", patience=5
)

# Initialize Trainer
trainer = Trainer(
    precision="bf16",
    max_epochs=10,
    accelerator="gpu", 
    devices=[1],
    accumulate_grad_batches=16,
    logger=tb_logger,
    callbacks=[checkpoint_callback, early_stop_callback],
)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# Train the model
trainer.fit(model)
# trainer.fit(model, ckpt_path="last")

In [10]:
model.save_to('ft-parakeet-tdt-0.6b-v2-e20.nemo')

In [14]:
model.summarize()

  | Name              | Type                              | Params | Mode 
--------------------------------------------------------------------------------
0 | preprocessor      | AudioToMelSpectrogramPreprocessor | 0      | train
1 | encoder           | ConformerEncoder                  | 608 M  | train
2 | decoder           | RNNTDecoder                       | 7.2 M  | train
3 | joint             | RNNTJoint                         | 1.7 M  | train
4 | loss              | RNNTLoss                          | 0      | train
5 | spec_augmentation | SpectrogramAugmentation           | 0      | train
6 | wer               | WER                               | 0      | train
--------------------------------------------------------------------------------
617 M     Trainable params
0         Non-trainable params
617 M     Total params
2,471.304 Total estimated model params size (MB)
706       Modules in train mode
0         Modules in eval mode