<a href="https://colab.research.google.com/github/Rumeysakeskin/ASR-Quantization/blob/main/quantization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# If you're using Google Colab and not running locally, run this cell.
## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install text-unidecode
!pip install matplotlib>=3.3.2

## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]
!apt-get update && apt-get install -y libsndfile1 ffmpeg
!pip install Cython tensorflow==2.11.0 Pygments==2.6.1 pynini==2.1.5 nemo_toolkit[all]

In [None]:
import nemo.collections.asr as nemo_asr
from ruamel.yaml import YAML
import pytorch_lightning as pl
from omegaconf import DictConfig, OmegaConf, open_dict
import copy
from pytorch_lightning.callbacks import ModelCheckpoint
import os

if not os.path.exists("configs/config.yaml"):
   !wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml

def training_quartznet():

    EPOCHS = 100
   
    config_path = "configs/config.yaml"

    yaml = YAML(typ='safe')
    with open(config_path) as f:
        params = yaml.load(f)

    params['model']['train_ds']['manifest_filepath'] = "data/train_manifest.jsonl"
    params['model']['validation_ds']['manifest_filepath'] = "data/val_manifest.jsonl"


    first_asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained("QuartzNet15x5Base-En")

    # Insert the QuantStub() before the first layer of the model
    first_asr_model.quant = torch.quantization.QuantStub()
    first_asr_model.encoder.quant = torch.quantization.QuantStub()

    # Insert a DeQuantStub() at the end of the model
    first_asr_model.dequant = torch.quantization.DeQuantStub()
    first_asr_model.decoder.dequant = torch.quantization.DeQuantStub()

    
    first_asr_model.change_vocabulary(
        new_vocabulary=[" ", "a", "b", "c", "ç", "d", "e", "f", "g", "ğ", "h", "ı", "i", "j", "k", "l", "m",
                        "n", "o", "ö", "p", "q", "r", "s", "ş", "t", "u", "ü", "v", "w", "x", "y", "z", "'"])

    new_opt = copy.deepcopy(params['model']['optim'])

    new_opt['lr'] = 0.001
    # Point to the data we'll use for fine-tuning as the training set
    first_asr_model.setup_training_data(train_data_config=params['model']['train_ds'])
    # Point to the new validation data for fine-tuning
    first_asr_model.setup_validation_data(val_data_config=params['model']['validation_ds'])
    # assign optimizer config
    first_asr_model.setup_optimization(optim_config=DictConfig(new_opt))


    # used for saving models
    save_path = os.path.join(os.getcwd(),"Quartznet15x5_models")
    checkpoint_callback = ModelCheckpoint(
        dirpath=save_path,
        save_top_k= -1,
        verbose=True,
        monitor='val_loss',
        mode='min',
    )
    
    trainer = pl.Trainer(accelerator='cpu',
                         max_epochs=EPOCHS,
                         logger=wandb_logger, log_every_n_steps=1, amp_level='O1',
                         val_check_interval=1.0, enable_checkpointing=checkpoint_callback)

    first_asr_model.set_trainer(trainer)

    trainer.fit(first_asr_model)

if __name__ == '__main__':
    training_quartznet()

In [None]:
first_asr_model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare(first_asr_model, inplace=True)
torch.quantization.convert(first_asr_model, inplace=True)

model_filepath = os.path.join("/Quartznet15x5_models", "quantized_model.ckpt")
# print("first_asr_model.state_dict() name: ",first_asr_model.state_dict())
torch.save(first_asr_model.state_dict(), model_filepath)

In [None]:
model_to_load = os.path.join("/Quartznet15x5_models/quantized_model.ckpt")
config_path = "configs/config.yaml"

yaml = YAML(typ='safe')
with open(config_path) as f:
    params = yaml.load(f)

params['model']['train_ds']['manifest_filepath'] = "data/train_manifest.jsonl"
params['model']['validation_ds']['manifest_filepath'] = "data/val_manifest.jsonl"

first_asr_model = nemo_asr.models.EncDecCTCModel(cfg=DictConfig(params['model']))
checkpoint = torch.load(model_to_load)

for key in list(checkpoint.keys()):
    if 'module.encoder.encoder' in key:
        checkpoint[key.replace('module.', '')] = checkpoint.pop(key)
for key in list(checkpoint.keys()):
    if '.qconfig.' in key:
        checkpoint.pop(key)

# Dequantize the tensor before loading it
for key, value in checkpoint.items():
    if value.is_quantized:
        checkpoint[key] = value.dequantize()

first_asr_model.load_state_dict(checkpoint, strict=False)
first_asr_model.eval()
audio_filepath = 'test_audio.wav'

print(first_asr_model.transcribe([audio_filepath], batch_size=1)[0])