In [1]:
import os
import shutil
import torch
import torch.utils.data
import torch.nn as nn 

from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as T

from torchvision import models
from tqdm import tqdm

# Quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor


In [2]:
DATA_ROOT = "./data"

# if os.path.exists(DATA_ROOT):
#     print(f"# Remove {DATA_ROOT} directory")
#     shutil.rmtree(DATA_ROOT)

if not os.path.exists(DATA_ROOT):
    os.mkdir('./data')

In [3]:
training_data = datasets.FashionMNIST(
    root=DATA_ROOT,
    train=True,
    download=True,
    transform=T.ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root=DATA_ROOT,
    train=False,
    download=True,
    transform=T.ToTensor(),
)

In [4]:
BATCH_SIZE = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break


Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


In [5]:
# Creating Models

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using {DEVICE} device")

Using cuda device


In [6]:
class MyNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        # Change code MLP to Convolution Layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(),
            nn.Linear(1000, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        
    def forward(self, x:torch.Tensor):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.mlp(x)
        return x

# net = MyNetwork()
# rand_input = torch.randn(64, 1, 28, 28)
# rand_output = net(rand_input)
# print(rand_output.shape)


In [7]:
model = MyNetwork().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(DEVICE), y.to(DEVICE)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [8]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, criterion, optimizer)
    test(test_dataloader, model, criterion)
print("Done!")

Epoch 1
-------------------------------
loss: 2.432016  [   64/60000]
loss: 0.325746  [ 6464/60000]
loss: 0.599731  [12864/60000]
loss: 0.465818  [19264/60000]
loss: 0.550881  [25664/60000]
loss: 0.160811  [32064/60000]
loss: 0.348585  [38464/60000]
loss: 0.215534  [44864/60000]
loss: 0.174195  [51264/60000]
loss: 0.218845  [57664/60000]
Test Error: 
 Accuracy: 89.5%, Avg loss: 0.276772 

Epoch 2
-------------------------------
loss: 0.140370  [   64/60000]
loss: 0.310377  [ 6464/60000]
loss: 0.240193  [12864/60000]
loss: 0.273823  [19264/60000]
loss: 0.137869  [25664/60000]
loss: 0.242828  [32064/60000]
loss: 0.331275  [38464/60000]
loss: 0.184776  [44864/60000]
loss: 0.355628  [51264/60000]
loss: 0.160914  [57664/60000]
Test Error: 
 Accuracy: 90.2%, Avg loss: 0.260740 

Epoch 3
-------------------------------
loss: 0.119917  [   64/60000]
loss: 0.217225  [ 6464/60000]
loss: 0.105070  [12864/60000]
loss: 0.197803  [19264/60000]
loss: 0.255179  [25664/60000]
loss: 0.260698  [32064/600

In [9]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_data[0][0].to(DEVICE), test_data[0][1]
with torch.no_grad():
    pred = model(x.unsqueeze(dim=0))
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "Ankle boot", Actual: "Ankle boot"


In [10]:
print(model)

MyNetwork(
  (conv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv3): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (mlp): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=6272, out_features=1000, bias=True)
    (2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Linear(in_features=1000, out_features=128, bias=True)
    (5): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, tra

In [11]:
# save model weights...
model_weights_path = "mynet_weights.pth"
torch.save(model.state_dict(), model_weights_path)

In [12]:
# Adding quantized modules 
from pytorch_quantization import quant_modules
quant_modules.initialize()

quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)


In [13]:
new_network = MyNetwork()
print(new_network)

MyNetwork(
  (conv1): Sequential(
    (0): QuantConv2d(
      1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): Sequential(
    (0): QuantConv2d(
      64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv3): Sequential(
    (0): QuantConv2d(
      128, 128, k

In [14]:
# load pretrained weights
new_network.load_state_dict(torch.load(model_weights_path))
new_network = new_network.to(DEVICE)

In [15]:
def collect_stats(model, data_loader, num_batches):
    """
        Feed data to the network and collect statistics
    """
    for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    module.disable_quant()
                    module.enable_calib()
                else:
                    module.disable()

    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()
                
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                     module.load_calib_amax()
                else:
                     module.load_calib_amax(**kwargs)
            print(F"{name:40}: {module}")
    model.cuda()

with torch.no_grad():
    collect_stats(new_network, test_dataloader, 64)
    compute_amax(new_network, method="percentile", percentile=99.99) 

100%|██████████| 64/64 [00:07<00:00,  8.06it/s]
W0220 15:01:35.131776 139728225096832 tensor_quantizer.py:174] Disable HistogramCalibrator
W0220 15:01:35.132268 139728225096832 tensor_quantizer.py:174] Disable MaxCalibrator
W0220 15:01:35.132450 139728225096832 tensor_quantizer.py:174] Disable HistogramCalibrator
W0220 15:01:35.132585 139728225096832 tensor_quantizer.py:174] Disable MaxCalibrator
W0220 15:01:35.132717 139728225096832 tensor_quantizer.py:174] Disable HistogramCalibrator
W0220 15:01:35.132845 139728225096832 tensor_quantizer.py:174] Disable MaxCalibrator
W0220 15:01:35.132996 139728225096832 tensor_quantizer.py:174] Disable HistogramCalibrator
W0220 15:01:35.133119 139728225096832 tensor_quantizer.py:174] Disable MaxCalibrator
W0220 15:01:35.133255 139728225096832 tensor_quantizer.py:174] Disable HistogramCalibrator
W0220 15:01:35.133371 139728225096832 tensor_quantizer.py:174] Disable MaxCalibrator
W0220 15:01:35.133493 139728225096832 tensor_quantizer.py:174] Disable H

conv1.0._input_quantizer                : TensorQuantizer(8bit fake per-tensor amax=0.9995 calibrator=HistogramCalibrator scale=1.0 quant)
conv1.0._weight_quantizer               : TensorQuantizer(8bit fake axis=0 amax=[0.1704, 0.5058](64) calibrator=MaxCalibrator scale=1.0 quant)
conv2.0._input_quantizer                : TensorQuantizer(8bit fake per-tensor amax=6.9654 calibrator=HistogramCalibrator scale=1.0 quant)
conv2.0._weight_quantizer               : TensorQuantizer(8bit fake axis=0 amax=[0.1598, 0.4379](128) calibrator=MaxCalibrator scale=1.0 quant)
conv3.0._input_quantizer                : TensorQuantizer(8bit fake per-tensor amax=5.3452 calibrator=HistogramCalibrator scale=1.0 quant)
conv3.0._weight_quantizer               : TensorQuantizer(8bit fake axis=0 amax=[0.1721, 0.3495](128) calibrator=MaxCalibrator scale=1.0 quant)
mlp.1._input_quantizer                  : TensorQuantizer(8bit fake per-tensor amax=5.5694 calibrator=HistogramCalibrator scale=1.0 quant)
mlp.1._weight

In [16]:
new_criterion = nn.CrossEntropyLoss()
test(test_dataloader, new_network, new_criterion)

Test Error: 
 Accuracy: 92.7%, Avg loss: 0.224558 



In [17]:
print(new_network)

MyNetwork(
  (conv1): Sequential(
    (0): QuantConv2d(
      1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=0.9995 calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=[0.1704, 0.5058](64) calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): Sequential(
    (0): QuantConv2d(
      64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=6.9654 calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=[0.1598, 0.4379](128) calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv3): Sequential(
    (0): Quant

In [18]:
from pytorch_quantization import nn as quant_nn
quant_nn.TensorQuantizer.use_fb_fake_quant = True

In [20]:
dummy_input = torch.randn(1, 1, 28, 28, device='cuda')
input_names = ['actual_input_1']
output_names = ['outptu1']

torch.onnx.export(
    new_network,
    dummy_input,
    'quant_my_net.onnx',
    verbose=True,
)

W0220 15:07:15.062594 139728225096832 tensor_quantizer.py:281] Use Pytorch's native experimental fake quantization.


Exported graph: graph(%inputs.1 : Float(1, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=0, device=cuda:0),
      %conv1.0.weight : Float(64, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=1, device=cuda:0),
      %conv1.0.bias : Float(64, strides=[1], requires_grad=1, device=cuda:0),
      %conv1.1.weight : Float(64, strides=[1], requires_grad=1, device=cuda:0),
      %conv1.1.bias : Float(64, strides=[1], requires_grad=1, device=cuda:0),
      %conv1.1.running_mean : Float(64, strides=[1], requires_grad=0, device=cuda:0),
      %conv1.1.running_var : Float(64, strides=[1], requires_grad=0, device=cuda:0),
      %conv2.0.weight : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cuda:0),
      %conv2.0.bias : Float(128, strides=[1], requires_grad=1, device=cuda:0),
      %conv2.1.weight : Float(128, strides=[1], requires_grad=1, device=cuda:0),
      %conv2.1.bias : Float(128, strides=[1], requires_grad=1, device=cuda:0),
      %conv2.1.running_mean : Float(128,

  if amax.numel() == 1:
  inputs, amax.item() / bound, 0,
  quant_dim = list(amax.shape).index(list(amax_sequeeze.shape)[0])
