In [None]:
#| default_exp quantize.quantize_callback

In [None]:
#| include: false
import warnings
warnings.filterwarnings('ignore')
from nbdev.showdoc import *
from fastai.vision.all import *

In [None]:
#| export
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastai.vision.all import *
from fastai.callback.all import *

from torch.ao.quantization import get_default_qat_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx

In [None]:
#| include: false
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

In [None]:
#| export
class QuantizeCallback(Callback):
    def __init__(self, qconfig_mapping=None, backend='x86'):
        self.qconfig_mapping = qconfig_mapping or get_default_qat_qconfig_mapping(backend)

    def before_fit(self):
        example_inputs, _ = next(iter(self.dls.train))
        self.learn.model = quantize_fx.prepare_qat_fx(self.learn.model, self.qconfig_mapping, example_inputs)

    def after_fit(self):
        self.learn.model.eval()
        self.learn.model = quantize_fx.convert_fx(self.learn.model.to('cpu'))

In [None]:
show_doc(QuantizeCallback)

---

### QuantizeCallback

>      QuantizeCallback (qconfig_mapping=None, backend='x86')

Basic class handling tweaks of the training loop by changing a `Learner` in various events

In [None]:
learn = vision_learner(dls, resnet18, metrics=accuracy)
learn.unfreeze()
learn.fit(5, cbs=QuantizeCallback())

epoch,train_loss,valid_loss,accuracy,time
0,0.571546,0.42405,0.789581,00:06
1,0.469406,0.363851,0.843708,00:06
2,0.407703,0.399239,0.817997,00:06
3,0.375065,0.309377,0.865359,00:06
4,0.323774,0.331475,0.873478,00:06


In [None]:
learn.model

GraphModule(
  (0): Module(
    (0): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.029317768290638924, zero_point=0, padding=(3, 3))
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.017887497320771217, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.0466480627655983, zero_point=66, padding=(1, 1))
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.017889995127916336, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.07470479607582092, zero_point=66, padding=(1, 1))
      )
    )
    (5): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.01743861660361