In [1]:
import torch

In [2]:
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import os
from ipywidgets import FloatProgress

In [3]:
class Args(object):
    def __init__(self,dict):
        for key in dict:
            setattr(self,key, dict[key]) 

### Training & Testing functions

In [47]:
def train(args, model, device, train_loader, optimizer, epoch,max_iter = float('inf')):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx > max_iter: break
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
#         print (data.shape)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break

In [48]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
#             print (data.shape)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


### Data sets

In [6]:
args = Args(dict( batch_size =64,
                      test_batch_size= 1000,
                        epochs = 2, 
                       lr =1.0,
                    gamma=0.7,
                    no_cuda = True,
                    no_mps = True,
                    dry_run=False,
                     seed = 1,
                    log_interval =10,
                    save_model =True,
                    model_dir=  './float_model/')
               )
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST('../data', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('../data', train=False,transform=transform)
train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
 



train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### Models

In [74]:
class Model(nn.Module):
    def __init__(self, quant=False):
        # insert quant in conv2 and fc1
        super(Model, self).__init__()
        self.quant = quant
        self.conv1 = nn.Conv2d(1,32,3,1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32,64,3,1)
        self.bn = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.fc = nn.Linear(9216, 128)
        self.fc_relu = nn.ReLU()
        self.fc2 = nn.Linear(128,10)
        
        if self.quant:
            self.quant_conv = torch.quantization.QuantStub() 

            self.dequant_conv = torch.quantization.DeQuantStub()

            self.quant_fc = torch.quantization.QuantStub() 

            self.dequant_fc = torch.quantization.DeQuantStub()
    def forward(self,x):
       
        x = self.conv1 (x)
        x = self.relu1(x)
        
        # insert quant in conv2
        if  self.quant:
            x = self.quant_conv(x)
        x = self.conv2(x)
        x = self.bn(x)
        x = self.relu2(x)
        if self.quant:
            x = self.dequant_conv(x)
        
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x,1)
       
        # insert quant in fc
        if self.quant:
            x = self.quant_fc(x)
        x  = self.fc(x)
        x = self.fc_relu(x)
        if self.quant:
            x = self.dequant_fc(x)
        
        
        x = self.fc2 (x)
        output = F.log_softmax(x, dim=1)

        return output

### Train Float Model

In [75]:
def float_main():
    # Training settings
    
     
    
    if not os.path.exists(args.model_dir):
         
        os.system('mkdir -p {}'.format(args.model_dir))
    
    
    device = torch.device("cpu")

    

    model =  Model(quant=False).to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), os.path.join(args.model_dir,"mnist_cnn.pt"))
    


In [76]:
float_main()


Test set: Average loss: 0.0509, Accuracy: 9839/10000 (98%)




Test set: Average loss: 0.0369, Accuracy: 9880/10000 (99%)



### Quantization

#### init quant model with float model's params

In [77]:

    
def copy_float2quant(m_float,m_quant):
    m_quant_dict = m_quant.state_dict()
    for name , param in m_float.state_dict().items():
        m_quant_dict[name].copy_(param)

#### Post  training quant

In [91]:
def post_training_quant(m_float,m_quant):
    copy_float2quant(m_float,m_quant)
 
    torch.quantization.fuse_modules(m_quant, ['conv1','relu1'], inplace=True)

    torch.quantization.fuse_modules(m_quant, ['fc','fc_relu'], inplace=True)
 
 

    # """Prepare"""
    qconfig = torch.quantization.get_default_qconfig('qnnpack')
    for module_name, module in m_quant.named_children():
#         print (module_name)
        if module_name not in ['conv1','relu1','fc2']:
       
            module.qconfig = qconfig

    torch.quantization.prepare(m_quant, inplace=True)

    # """Calibrate
    # - This example uses random data for convenience. Use representative (validation) data instead.
    # """
    print ('calibrating...')
    with torch.inference_mode():
        for batch_idx, (data, target) in enumerate(train_loader):
            if batch_idx  > 100: break
            m_quant(data)


    # """Convert"""
    torch.quantization.convert(m_quant, inplace=True,remove_qconfig = False)
#     print (m_quant)
    print ('m_quant on test data:')
    test(m_quant, torch.device("cpu"), test_loader)

In [92]:
m_float = Model(quant=False)
m_float.load_state_dict(torch.load(os.path.join(args.model_dir,'mnist_cnn.pt'),map_location='cpu'))
m_float.eval()
print ('m_float on test data:')
test(m_float, torch.device("cpu"), test_loader)

m_quant = Model(quant=True)
m_quant.eval()
post_training_quant(m_float,m_quant)

m_float on test data:

Test set: Average loss: 0.0369, Accuracy: 9880/10000 (99%)

calibrating...
m_quant on test data:

Test set: Average loss: 0.0369, Accuracy: 9878/10000 (99%)



#### QAT
points to note: https://pytorch.org/blog/quantization-in-practice/
![Screen Shot 2023-02-23 at 18 36 20](https://user-images.githubusercontent.com/20760190/220883386-1e85bf80-284c-4fac-8933-2323fc7e12f1.png)


In [95]:
def qat_main(model):
    # Training settings



    if not os.path.exists(args.model_dir):

        os.system('mkdir -p {}'.format(args.model_dir))


    device = torch.device("cpu")



    model.to(device)
    model.train()
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr*0.01) # qat typicall requires using 1% of original learning rate

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, 2):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), os.path.join(args.model_dir,"mnist_cnn_qat.pt"))


def QAT(m_float,m_quant,init_with_float = True):
    
    if init_with_float:
        print ("init with pre-trained float model")
        copy_float2quant(m_float,m_quant)
    else:
        print ('train qat from scratch...')

#    fuse model
    torch.quantization.fuse_modules(m_quant, ['conv1','relu1'], inplace=True)
#     torch.quantization.fuse_modules(m_quant, ['conv2','bn','relu2'], inplace=True)

    torch.quantization.fuse_modules(m_quant, ['fc','fc_relu'], inplace=True)

 

    # """Prepare"""
    qconfig  = torch.quantization.get_default_qat_qconfig('qnnpack')
    for module_name, module in m_quant.named_children():
        print (module_name)
        if module_name not in ['conv1','relu1','fc2']:
            print (module,'set config')
            module.qconfig = qconfig

    torch.quantization.prepare_qat(m_quant, inplace=True)
    
    qat_main(m_quant)
    

#     # """Convert"""
    torch.quantization.convert(m_quant, inplace=True,remove_qconfig = False)
    
    print ('QAT m_quant on test data:')
    test(m_quant, torch.device("cpu"), test_loader)

#### QAT training
with pre-trained float model initialization

In [96]:
m_float = Model(quant=False)
m_float.load_state_dict(torch.load(os.path.join(args.model_dir,'mnist_cnn.pt'),map_location='cpu'))
m_float.eval()
print ('m_float on test data:')
test(m_float, torch.device("cpu"), test_loader)

m_quant = Model(quant=True)
m_quant.eval()

QAT(m_float,m_quant)


m_float on test data:

Test set: Average loss: 0.0369, Accuracy: 9880/10000 (99%)

init with pre-trained float model
conv1
relu1
conv2
Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) set config
bn
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) set config
relu2
ReLU() set config
fc
LinearReLU(
  (0): Linear(in_features=9216, out_features=128, bias=True)
  (1): ReLU()
) set config
fc_relu
Identity() set config
fc2
quant_conv
QuantStub() set config
dequant_conv
DeQuantStub() set config
quant_fc
QuantStub() set config
dequant_fc
DeQuantStub() set config

Test set: Average loss: 0.0305, Accuracy: 9904/10000 (99%)

QAT m_quant on test data:

Test set: Average loss: 0.0304, Accuracy: 9904/10000 (99%)



without float model initialization, train from scratch

In [97]:
m_quant = Model(quant=True)
m_quant.eval()

QAT(None,m_quant,False)

train qat from scratch...
conv1
relu1
conv2
Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) set config
bn
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) set config
relu2
ReLU() set config
fc
LinearReLU(
  (0): Linear(in_features=9216, out_features=128, bias=True)
  (1): ReLU()
) set config
fc_relu
Identity() set config
fc2
quant_conv
QuantStub() set config
dequant_conv
DeQuantStub() set config
quant_fc
QuantStub() set config
dequant_fc
DeQuantStub() set config

Test set: Average loss: 0.1305, Accuracy: 9645/10000 (96%)

QAT m_quant on test data:

Test set: Average loss: 0.1307, Accuracy: 9645/10000 (96%)

