In [None]:
%load_ext autoreload
%autoreload 2

## Post Training Quantization

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, BackboneFinetuning, EarlyStopping
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import gc
import json
import importlib
from pathlib import Path
import numpy as np
import glob
import timm
from ast import literal_eval
import pandas as pd
import torchaudio as ta

from modules.preprocess import preprocess,prepare_cfg
from modules.dataset import get_train_dataloader
from modules.model import load_model
import modules.inception_next_nano

In [None]:
# move to repo root
cur_dir = Path().resolve()

if not (cur_dir / "notebooks").exists():
    os.chdir(os.path.abspath("../"))
print(f"{Path().resolve()}")

# Config

Set the configuration name for the model to calibrate.

#### 2021-2nd CNN Model (seresnext26ts)
```python
model_name = "cnn_v1"
stage = "train_bce"
```

#### 2021-2nd CNN Model (rexnet_150)
```python
model_name = "cnn_v3_rexnet"
stage = "train_bce"
```

#### Simple CNN Model (inception_next_nano)
```python
model_name = "simple_cnn_v1"
stage = "train_bce"
```

In [None]:
model_name = "simple_cnn_v1"
stage = "train_bce"

cfg = importlib.import_module(f'configs.{model_name}').basic_cfg
cfg = prepare_cfg(cfg, stage)
cfg.batch_size = cfg.quant_batch_size
infer_len = cfg.SR * cfg.infer_duration


In [None]:
pl.seed_everything(cfg.seed[stage], workers=True)

df_train, df_valid, df_label_train, df_label_valid, transforms = preprocess(cfg, stage)
df_train["version"] = "2023"
df_valid["version"] = "2023"
len(df_train), len(df_valid)

In [None]:
class Compose:
    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, y: np.ndarray, sr):
        for trns in self.transforms:
            y = trns(y, sr)
        return y


class AudioTransform:
    def __init__(self, always_apply=False, p=0.5):
        self.always_apply = always_apply
        self.p = p

    def __call__(self, y: np.ndarray, sr):
        if self.always_apply:
            return self.apply(y, sr=sr)
        else:
            if np.random.rand() < self.p:
                return self.apply(y, sr=sr)
            else:
                return y

    def apply(self, y: np.ndarray, **params):
        raise NotImplementedError


class Normalize(AudioTransform):
    def __init__(self, always_apply=False, p=1):
        super().__init__(always_apply, p)

    def apply(self, y: np.ndarray, **params):
        max_vol = np.abs(y).max()
        y_vol = y * 1 / max_vol
        return np.asfortranarray(y_vol)

In [None]:
pseudo = None
dl_train, dl_val, ds_train, ds_val = get_train_dataloader(
        df_train,
        df_valid,
        df_label_train,
        df_label_valid,
        cfg,
        pseudo,
        transforms
    )

# torch model to use melspec transform
if model_name != "tmt_reshape":
    torch_model = load_model(cfg,stage,train=False).to("cpu")
    melspec_transform = torch_model.melspec_transform
    db_transform = torch_model.db_transform
else:
    wave_transform = Compose([Normalize(p=1),])

## Quantization

In [None]:
import nncf
import openvino as ov

# load ovn model
model = ov.Core().read_model(cfg.quant_ovn_model_path)

In [None]:
# prepare calibration dataset

def transform_fn(batch):
    x, _, _ = batch # batch ch seg len
    # x = x[:,:,0] # batch ch len
    x = x[:,:,:infer_len]

    if model_name != "tmt_reshape":
        x = melspec_transform(x)
        x = db_transform(x)

        if cfg.normal == 80:
            x = (x + 80) / 80
        elif cfg.normal == 255:
            x = x / 255
        else:
            raise NotImplementedError
        
        x = x.numpy()
    else:
        x = wave_transform(x.numpy(), sr=cfg.SR)
        x = x[:,0,::2]

    return x

calibration_dataset = nncf.Dataset(dl_train, transform_fn)

In [None]:
# patch for gap layer's op-set version issue
model_gap_has_AP = [n.friendly_name for n in model.get_ops() if "/global_pool/AveragePool" in n.friendly_name]
if len(model_gap_has_AP) > 0 and "/global_pool/GlobalAveragePool" in cfg.quant_ignore_layer_names:
    # replace target layer name: GlobalAveragePool -> AveragePool
    cfg_gap_idx = cfg.quant_ignore_layer_names.index("/global_pool/GlobalAveragePool")
    cfg.quant_ignore_layer_names[cfg_gap_idx] = "/global_pool/AveragePool"
    print(cfg.quant_ignore_layer_names)

In [None]:
quantized_model = nncf.quantize(
    model, calibration_dataset, subset_size=cfg.quant_subset_size,
    ignored_scope=nncf.IgnoredScope(names=cfg.quant_ignore_layer_names),
    fast_bias_correction=cfg.quant_fast_bias_correction ,
)

### Save model

In [None]:
os.makedirs(cfg.output_path["quantization"], exist_ok=True)

save_path = os.path.join(cfg.output_path["quantization"], "quant.xml")
ov.save_model(quantized_model, save_path)