In [13]:
import copy
import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx
import pytorch_lightning as pl
from torch import nn

In [14]:
class MPJPELoss(nn.Module):
    @staticmethod
    def forward(pred_joints, y_joints):
        x = torch.sum(pred_joints - y_joints, dim=-1)
        distance_per_image = torch.mean(x.pow(2), dim=1)
        return torch.mean(distance_per_image)

class HandKeypointDetector(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.backbone = MogaNet(
            in_channels=3,
            out_indices=[1, 2, 3],
            depths=[3, 3, 10, 2],
            widths=[32, 64, 96, 192],
            stem_act_type='GELU',
            stem_norm_type='LN',
            moga_ratio=[1, 3, 4],
            moga_dilations=[1, 2, 3],
            drop_path_rate=0.1,
            drop_rate=0.1,
            ffn_scales=[8, 8, 4, 4],
            ffn_act_type='GELU',
            fd_act_type='GELU',
            moga_act_type='SiLU'
        )
        self.head = TransformerFCN(
            in_channels_layers=[[64, 96, 192], [108, 152, 144]],
            fused_channels_layers=[[216, 304, 288], [-1, 368, 480]],
            out_channels_layers=[[108, 152, 144], [108, 184, 240]],
            depths_layers=[[3, 3, 3], [2, 3, 3]],
            mlp_ratio_layers=[[4, 4, 4], [2, 2, 2]],
            transformer_norm_type='LN',
            mlp_drop_rate=0.1,
            mlp_act_type='GELU',
            attn_proj_act_type='ReLU',
            attn_norm_type='LN',
            drop_path_rate=0.1,
            avg_pool_outputs=[2, 4, 6],
            num_joints=21,
            num_classes=0,
        )
        self.init_parameters()
        self.config = config
        self.save_hyperparameters()
        self.criterion = torch.nn.L1Loss()
        self.metric = MPJPELoss()

    def init_parameters(self):
        for name, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.00001, 1.0 / m.weight.shape[1])
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.00001)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0001)
                nn.init.constant_(m.running_mean, 0)

    def forward(self, images):
        assert not torch.isnan(images).any(), 'Input creates nan'
        img_features = self.backbone(images)
        assert not torch.isnan(img_features[-1]).any(), 'Backbones creates nan'
        outputs = self.head(img_features)
        assert not torch.isnan(outputs).any(), 'Neck or Head creates nan'
        return outputs

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), **self.config['optimizer'])
        scheduler = CosineAnnealingWarmRestarts(optimizer, **self.config['scheduler'])
        return [optimizer], [scheduler]

    def _step(self, batch):
        x, y = batch
        assert not torch.isnan(y).any(), 'Input creates nan'
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        metric = self.metric(outputs, y)
        assert not torch.isnan(loss).any(), 'Loss calculates nan'
        return loss, metric

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss, metric = self._step(batch)
        self.log("train_L1", loss, sync_dist=True)
        self.log("train_MPJPE", metric, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss, metric = self._step(batch)
        self.log("val_L1", loss, sync_dist=True)
        self.log("val_MPJPE", metric, sync_dist=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss, metric = self._step(batch)
        self.log("test_L1", loss, sync_dist=True)
        self.log("test_MPJPE", metric, sync_dist=True)

In [19]:
fp32_model = torch.load('model_last.pt')

NotImplementedError: cannot instantiate 'PosixPath' on your system

In [None]:
# `qconfig` means quantization configuration, it specifies how should we
# observe the activation and weight of an operator
# `qconfig_dict`, specifies the `qconfig` for each operator in the model
# we can specify `qconfig` for certain types of modules
# we can specify `qconfig` for a specific submodule in the model
# we can specify `qconfig` for some functioanl calls in the model
# we can also set `qconfig` to None to skip quantization for some operators
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
# `prepare_fx` inserts observers in the model based on the configuration in `qconfig_dict`
model_prepared = prepare_fx(model, qconfig_dict)

In [None]:
# calibration runs the model with some sample data, which allows observers to record the statistics of
# the activation and weigths of the operators
calibration_data = [torch.randn(1, 3, 224, 224) for _ in range(100)]
for i in range(len(calibration_data)):
   model_prepared(calibration_data[i])
# `convert_fx` converts a calibrated model to a quantized model, this includes inserting
# quantize, dequantize operators to the model and swap floating point operators with quantized operators
model_quantized = convert_fx(copy.deepcopy(model_prepared))

In [None]:
# benchmark
x = torch.randn(1, 3, 224, 224)
%timeit fp32_model(x)
%timeit model_quantized(x)