In [1]:
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from torch.quantization import quantize_dynamic, get_default_qconfig, prepare, convert, quantize_fx
import torch.nn.utils.prune as prune



In [2]:
# 1. Load the pre-trained model and tokenizer
model_name = './mobilenet_v2_affectnethq-fer2013_model_fixed_labels'
output_dir = "./mobilenet_v2_affectnethq-fer2013_quantized_pruned" #Replace with your output directory
original_model = AutoModelForImageClassification.from_pretrained(model_name, num_labels=7, ignore_mismatched_sizes=True)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)


from torch.quantization import quantize_dynamic

# Apply dynamic quantization
quantized_model = quantize_dynamic(
    original_model,  # The original model
    {torch.nn.Linear, torch.nn.Conv2d},  # Layers to quantize
    dtype=torch.qint8  # Quantized data type
)

Some weights of MobileNetV2ForImageClassification were not initialized from the model checkpoint at ./mobilenet_v2_affectnethq-fer2013_model_fixed_labels and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1001]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1001, 1280]) in the checkpoint and torch.Size([7, 1280]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
models =  []
for model in quantized_model.named_modules():
    models.append(model)

print(models[0])

('', MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): Mob

In [4]:
pruning_amount = 0.25  # Adjust as needed
for name, module in quantized_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=pruning_amount)
        prune.remove(module, "weight")

In [5]:
torch.save(quantized_model.state_dict(), f"{output_dir}/quantized_model.pth")
original_model.save_pretrained(output_dir)
feature_extractor.save_pretrained(output_dir)

['./mobilenet_v2_affectnethq-fer2013_quantized_pruned\\preprocessor_config.json']