In [None]:
import copy
import torch
import torch.nn as nn
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

In [None]:
def tabular_fp16(model: nn.Module) -> nn.Module:
    m = copy.deepcopy(model)
    m.eval()
    return m.half()

def tabular_int8_dynamic(model: nn.Module) -> nn.Module:
    m = copy.deepcopy(model)
    m = m.eval().cpu()
    qmodel = torch.ao.quantization.quantize_dynamic(
        m,
        qconfig_spec={nn.Linear, nn.GRU},
        dtype=torch.qint8
    )
    return qmodel

def tabular_int8_static(model: nn.Module, calibration_loader, num_calib_batches: int = 20, backend: str = "fbgemm") -> nn.Module:
    m = copy.deepcopy(model)
    m = m.eval().cpu()
    if hasattr(m, "conv") and hasattr(m, "relu"):
        try: torch.ao.quantization.fuse_modules(m, [["conv", "relu"]], inplace=True)
        except: pass
    if hasattr(m, "net") and isinstance(m.net, nn.Sequential):
        try: torch.ao.quantization.fuse_modules(m.net, [["0", "1"], ["2", "3"]], inplace=True)
        except: pass
    m = torch.ao.quantization.QuantWrapper(m)
    torch.backends.quantized.engine = backend
    m.qconfig = torch.ao.quantization.get_default_qconfig(backend)
    prepared = torch.ao.quantization.prepare(m, inplace=False)
    with torch.no_grad():
        for i, batch in enumerate(calibration_loader):
            if i >= num_calib_batches: break
            x = batch[0] if isinstance(batch, (list, tuple)) else batch
            prepared(x.cpu().float())
    quantized = torch.ao.quantization.convert(prepared, inplace=False)
    return quantized

In [None]:
def audio_fp16(model: nn.Module) -> nn.Module:
    model = model.eval()
    return model.half()

def audio_int8_dynamic(model):
    m = copy.deepcopy(model).cpu()
    q_model = torch.quantization.quantize_dynamic(m, {nn.Linear, nn.LSTM, nn.GRU, nn.RNN}, dtype=torch.qint8)
    return q_model

def audio_int8_static(model, calibration_loader, backend="fbgemm"):
    if backend == "qnnpack":
        torch.backends.quantized.engine = "qnnpack"
    else:
        torch.backends.quantized.engine = "fbgemm"
    
    m = copy.deepcopy(model).cpu()
    m.eval()
    
    num_calib_batches = 8
    def get_x(batch):
        if isinstance(batch, (list, tuple)): return batch[0]
        if isinstance(batch, dict): return next(iter(batch.values()))
        return batch

    x0 = get_x(next(iter(calibration_loader))).cpu().float()
    example_inputs = (x0,)

    qconfig_mapping = get_default_qconfig_mapping(backend)
    prepared = prepare_fx(m, qconfig_mapping, example_inputs)

    with torch.no_grad():
        for i, batch in enumerate(calibration_loader):
            if i >= num_calib_batches:
                break
            prepared(get_x(batch).cpu().float())

    return convert_fx(prepared)

In [None]:
def image_fp16(model: nn.Module) -> nn.Module:
    model = model.eval()
    return model.half()

def image_int8_dynamic(model):
    m = copy.deepcopy(model).cpu()
    q_model = torch.quantization.quantize_dynamic(m, {nn.Linear, nn.LSTM, nn.GRU, nn.RNN}, dtype=torch.qint8)
    return q_model

def image_int8_static(model, calibration_loader, backend="fbgemm"):
    if backend == "qnnpack":
        torch.backends.quantized.engine = "qnnpack"
    else:
        torch.backends.quantized.engine = "fbgemm"
    
    m = copy.deepcopy(model).cpu()
    m.eval()
    
    num_calib_batches = 8
    def get_x(batch):
        if isinstance(batch, (list, tuple)): return batch[0]
        if isinstance(batch, dict): return next(iter(batch.values()))
        return batch

    x0 = get_x(next(iter(calibration_loader))).cpu().float()
    example_inputs = (x0,)

    qconfig_mapping = get_default_qconfig_mapping(backend)
    prepared = prepare_fx(m, qconfig_mapping, example_inputs)

    with torch.no_grad():
        for i, batch in enumerate(calibration_loader):
            if i >= num_calib_batches:
                break
            prepared(get_x(batch).cpu().float())

    return convert_fx(prepared)