In [1]:
import torch
import torch.nn as nn
import os
from DataLoader import prepare_data_loaders
import QATMobileNetV2

[INFO] loading the training and validation dataset...
[INFO] training dataset contains 19998 samples...
[INFO] validation dataset contains 5000 samples...


In [2]:
def load_model(num_classes = 2):
    model = QATMobileNetV2.MobileNetV2(num_classes = 2)
    model.to('cpu')
    return model

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [3]:
data_loader, data_loader_test = prepare_data_loaders()

In [4]:
def train_epoch(model, criterion, optimizer, data_loader, device):
    model.train()
    
    epoch_loss = 0.0
    num_batches = len(data_loader)
    
    for batch_idx, (image, target) in enumerate(data_loader):
        image, target = image.to(device), target.to(device)
        
        # Forward pass
        output = model(image)
        
        # Calculate loss
        loss = criterion(output, target)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accumulate batch loss
        epoch_loss += loss.item()
        
        # Print batch loss (optional)
        # print(f"Batch [{batch_idx + 1}/{num_batches}], Loss: {loss.item():.4f}")
    
    # Calculate average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    print(f"Epoch Training Loss: {avg_epoch_loss:.4f}")
    

In [5]:
def evaluate(model, criterion, data_loader, device):
    
    epoch_loss = 0.0
    
    correct_predictions = 0
    total_predictions = 0
    
    num_batches = len(data_loader)
    
    with torch.no_grad():
       
        for image, target in data_loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = criterion(output, target)
            # Accumulate batch loss
            epoch_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(output, 1)  # Get the predicted class index
            correct_predictions += (predicted == target).sum().item()
            total_predictions += target.size(0)
            
    # Calculate average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    print(f"Epoch Test Loss: {avg_epoch_loss:.4f}")
    
    accuracy = correct_predictions / total_predictions
    print(f"Epoch Accuracy: {accuracy:.4f}")
    

In [6]:
qat_model = load_model()
qat_model.fuse_model(is_qat=True)
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qat_model.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')
torch.ao.quantization.prepare_qat(qat_model, inplace=True)
print('Inverted Residual Block: After preparation for QAT, note fake-quantization modules \n',qat_model.features[1].conv)



Inverted Residual Block: After preparation for QAT, note fake-quantization modules 
 Sequential(
  (0): ConvBNReLU(
    (0): ConvBnReLU2d(
      32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
        (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
      )
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_t

In [7]:
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.AdamW(qat_model.parameters(), lr = 0.001)

In [8]:

for nepoch in range(20):
    train_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'))
    if nepoch > 3:
        # Freeze quantizer parameters
        qat_model.apply(torch.ao.quantization.disable_observer)
    if nepoch > 2:
        # Freeze batch norm mean and variance estimates
        qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    # Check the accuracy after each epoch
    quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)
    quantized_model.eval()
    evaluate(quantized_model,criterion, data_loader_test,torch.device('cpu'))




Epoch Training Loss: 0.7080




Epoch Test Loss: 0.6988
Epoch Accuracy: 0.5362
Epoch Training Loss: 0.7015
Epoch Test Loss: 0.7043
Epoch Accuracy: 0.5064
Epoch Training Loss: 0.7071
Epoch Test Loss: 0.7027
Epoch Accuracy: 0.5454
Epoch Training Loss: 0.7152
Epoch Test Loss: 0.6955
Epoch Accuracy: 0.5068
Epoch Training Loss: 0.7046
Epoch Test Loss: 1.1232
Epoch Accuracy: 0.5014
Epoch Training Loss: 0.6939
Epoch Test Loss: 0.6931
Epoch Accuracy: 0.5000


KeyboardInterrupt: 

: 