<a href="https://colab.research.google.com/github/AmritSDutta/colab_ml/blob/main/model_optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Following TensorFlow model optimization , https://www.tensorflow.org/model_optimization/guide


In [2]:
import tempfile
import os

import tensorflow as tf


In [3]:
!pip install tensorflow_model_optimization



In [4]:
import tensorflow_model_optimization as tfmot


In [5]:

from tensorflow_model_optimization.python.core.keras.compat import keras

In [6]:
# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

In [7]:
# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

In [8]:
# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=4, validation_split=0.1)

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model, keras_file, include_optimizer=False)
print('Saved baseline model to:', keras_file)

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4
Baseline test accuracy: 0.9771000146865845
Saved baseline model to: /tmp/tmpm9p7u5wa.h5


  keras.models.save_model(model, keras_file, include_optimizer=False)


In [9]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 12)        0         
 D)                                                              
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20410 (79.73 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 0 (0.00 Byte)
____________________

In [12]:
!pip install tf2onnx

Collecting tf2onnx
  Downloading tf2onnx-1.16.1-py3-none-any.whl.metadata (1.3 kB)
Collecting onnx>=1.4.1 (from tf2onnx)
  Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting protobuf~=3.20 (from tf2onnx)
  Downloading protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Downloading tf2onnx-1.16.1-py3-none-any.whl (455 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m455.8/455.8 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m91.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.1/162.1 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf, onnx, tf2onnx
  Attempting uninstall: pr

ONNX quatization

In [17]:
import tf2onnx
import onnx
import numpy as np

spec = (tf.TensorSpec((1, 28, 28, 1), np.float32, name="input"),)
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path="tf_mnist.onnx")

onnx_model = onnx.load("tf_mnist.onnx")
onnx.checker.check_model(onnx_model)

In [12]:
!pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m58.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected pack

In [20]:
import onnxruntime as ort

ort_session = ort.InferenceSession("tf_mnist.onnx")
outputs = ort_session.run(None, {"input": np.random.randn(1, 28, 28,1).astype(np.float32)})
print("Output shape:", outputs[0].shape)


Output shape: (1, 10)


In [25]:
def get_gzipped_model_size(file)-> int:
    # Returns size of gzipped model, in bytes.
    import os
    import zipfile

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
      f.write(file)

    return os.path.getsize(zipped_file)

In [27]:
print(f'Size of gzipped pruned model without stripping:{get_gzipped_model_size(keras_file)} bytes')
print(f'Size of gzipped pruned model with stripping:{get_gzipped_model_size("tf_mnist.onnx")} bytes')


Size of gzipped pruned model without stripping:78177 bytes
Size of gzipped pruned model with stripping:76907 bytes


In [34]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude


In [22]:
import numpy as np
# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1  # 10% of training set will be used for validation set.

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step=end_step)
}

In [23]:
model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
                          loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                          metrics=['accuracy'])

model_for_pruning.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                        

In [24]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

Epoch 1/2
  5/422 [..............................] - ETA: 12s - loss: 0.0631 - accuracy: 0.9797



Epoch 2/2


<tf_keras.src.callbacks.History at 0x784ee064f410>

In [35]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)

Baseline test accuracy: 0.9678999781608582
Pruned test accuracy: 0.9678999781608582


In [36]:
def get_gzipped_model_size(file)-> int:
    # Returns size of gzipped model, in bytes.
    import os
    import zipfile

    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
      f.write(file)

    return os.path.getsize(zipped_file)

strip_pruning reduces the model size, not prior steps


In [37]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

print("final model")
model_for_export.summary()
_, pruned_keras_file = tempfile.mkstemp('.h5')
keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

print("\n")
print(f'Size of gzipped pruned model without stripping:{get_gzipped_model_size(keras_file)} bytes')
print(f'Size of gzipped pruned model with stripping:{get_gzipped_model_size(pruned_keras_file)} bytes')


final model
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 12)        0         
 D)                                                              
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20410 (79.73 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 0 (0.00 Byte)
________

  keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)


Saved pruned Keras model to: /tmp/tmpyjjncs0a.h5


Size of gzipped pruned model without stripping:78177 bytes
Size of gzipped pruned model with stripping:25798 bytes


In [38]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)

print("\n")
print(f'Size of gzipped keras model without stripping:{get_gzipped_model_size(keras_file)} bytes')
print(f'Size of gzipped pruned model with stripping:{get_gzipped_model_size(pruned_keras_file)} bytes')
print(f'Size of gzipped pruned TFlite  model:{get_gzipped_model_size(pruned_tflite_file)} bytes')

Saved pruned TFLite model to: /tmp/tmpf5iaft3t.tflite


Size of gzipped keras model without stripping:78177 bytes
Size of gzipped pruned model with stripping:25798 bytes
Size of gzipped pruned TFlite  model:24759 bytes


In [40]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)

print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)

print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print(f'Size of gzipped pruned model with stripping:{get_gzipped_model_size(pruned_keras_file)} bytes')
print(f'Size of gzipped pruned TFlite  model:{get_gzipped_model_size(pruned_tflite_file)} bytes')
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))

Saved quantized and pruned TFLite model to: /tmp/tmpj24zbqwl.tflite
Size of gzipped baseline Keras model: 78177.00 bytes
Size of gzipped pruned model with stripping:25798 bytes
Size of gzipped pruned TFlite  model:24759 bytes
Size of gzipped pruned and quantized TFlite model: 8473.00 bytes


In [41]:
import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on ever y image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

In [42]:
interpreter = tf.lite.Interpreter(model_content=quantized_and_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', test_accuracy)
print('Pruned TF test accuracy:', model_for_pruning_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy: 0.9681
Pruned TF test accuracy: 0.9678999781608582


In [43]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms
import torchvision.transforms.functional as F

In [44]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [45]:
model_torch = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # [B, 1, 28, 28] -> [B, 32, 28, 28]
    nn.ReLU(),
    nn.MaxPool2d(2, 2),                                   # -> [B, 32, 14, 14]

    nn.Conv2d(32, 64, kernel_size=3),                     # -> [B, 64, 12, 12]
    nn.ReLU(),
    nn.MaxPool2d(2, 2),                                   # -> [B, 64, 6, 6]

    nn.Flatten(),                                         # -> [B, 64*6*6 = 2304]
    nn.Linear(64 * 6 * 6, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(dim=1)
)
model_torch = model_torch.to(device)

In [46]:
# Load data
from torch.utils.data import DataLoader
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)

100%|██████████| 9.91M/9.91M [00:00<00:00, 15.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.41MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.28MB/s]


In [47]:
import torch.fx


traced = torch.fx.symbolic_trace(model_torch)

print(traced.graph)  # Prints the ops and tensors flowing through

graph():
    %input_1 : [num_users=1] = placeholder[target=input]
    %_0 : [num_users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
    %_1 : [num_users=1] = call_module[target=1](args = (%_0,), kwargs = {})
    %_2 : [num_users=1] = call_module[target=2](args = (%_1,), kwargs = {})
    %_3 : [num_users=1] = call_module[target=3](args = (%_2,), kwargs = {})
    %_4 : [num_users=1] = call_module[target=4](args = (%_3,), kwargs = {})
    %_5 : [num_users=1] = call_module[target=5](args = (%_4,), kwargs = {})
    %_6 : [num_users=1] = call_module[target=6](args = (%_5,), kwargs = {})
    %_7 : [num_users=1] = call_module[target=7](args = (%_6,), kwargs = {})
    %_8 : [num_users=1] = call_module[target=8](args = (%_7,), kwargs = {})
    %_9 : [num_users=1] = call_module[target=9](args = (%_8,), kwargs = {})
    %_10 : [num_users=1] = call_module[target=10](args = (%_9,), kwargs = {})
    return _10


In [48]:
optimizer = torch.optim.Adam(model_torch.parameters(), lr=0.001)
criterion = nn.NLLLoss()

# Training loop
def train(model_t, loader):
    model_t.train()
    epoch_loss = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model_t(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
    epoch_loss += loss.item()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch Loss: {avg_loss:.4f}")
for epochs in range(4):
  train(model_torch, train_loader)
torch.save(model_torch.state_dict(), "mnist_torch_model.pth")

Epoch Loss: 0.0001
Epoch Loss: 0.0000
Epoch Loss: 0.0001
Epoch Loss: 0.0000


In [49]:
import os

file_path = "mnist_torch_model.pth"
file_size = os.path.getsize(file_path) / 1e6  # size in MB
print(f"Model size: {file_size:.2f} MB")

Model size: 1.26 MB


In [50]:
# 5. Evaluation
# Recreate model architecture as only weights were saved
t_model = nn.Sequential(
    nn.Conv2d(1, 32, 3, 1, 1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, 3),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 6 * 6, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(dim=1)
)

# Load weights
t_model.load_state_dict(torch.load("mnist_torch_model.pth"))
t_model = t_model.to(device)
t_model.eval()
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        output = t_model(images)
        pred = output.argmax(dim=1)
        correct += pred.eq(labels).sum().item()

print(f"Test Accuracy: {correct / len(test_loader.dataset):.4f}")

Test Accuracy: 0.9879


In [51]:
import torch.nn.utils.prune as prune

# Apply pruning to specific layers
prune.l1_unstructured(t_model[0], name='weight', amount=0.5)  # Conv2d(1, 32, ...)
prune.l1_unstructured(t_model[7], name='weight', amount=0.5)  # Linear(2304, 128)

# Make pruning permanent (remove hooks)
prune.remove(t_model[0], 'weight')
prune.remove(t_model[7], 'weight')


torch.save(t_model.state_dict(), "mnist_torch_model_pruned.pth")

**prunning does not decrease filesize automatically. it reduces active weights rather**


In [52]:
org_file_path = "mnist_torch_model.pth"
org_file_size = os.path.getsize(org_file_path) / 1e6  # size in MB
print(f"Model size: {org_file_size:.2f} MB")
prunned_file_path = 'mnist_torch_model_pruned.pth'
prunned_file_size = os.path.getsize(prunned_file_path) / 1e6  # size in MB
print(f"Model size: {prunned_file_size:.2f} MB")

Model size: 1.26 MB
Model size: 1.26 MB


In [53]:
quantized_model = torch.quantization.quantize_dynamic(
    t_model, {nn.Linear}, dtype=torch.qint8
)


In [54]:
torch.save(quantized_model.state_dict(), "mnist_pruned_quantized.pth")

# Check size
org_file_path = "mnist_torch_model.pth"
org_file_size = os.path.getsize(org_file_path) / 1e6  # size in MB
print(f"Model size: {org_file_size:.2f} MB")
prunned_file_path = 'mnist_torch_model_pruned.pth'
prunned_file_size = os.path.getsize(prunned_file_path) / 1e6  # size in MB
print(f"Model size: {prunned_file_size:.2f} MB")
size = os.path.getsize("mnist_pruned_quantized.pth") / 1e6
print(f"Pruned + Quantized model size: {size:.2f} MB")

Model size: 1.26 MB
Model size: 1.26 MB
Pruned + Quantized model size: 0.38 MB


In [55]:
model_clean = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(32, 64, kernel_size=3),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Flatten(),
    nn.Linear(64 * 6 * 6, 128),
    nn.ReLU(),
    nn.Linear(128, 10),
    nn.LogSoftmax(dim=1)
)

# Copy state_dict from pruned model
model_clean.load_state_dict(model_torch.state_dict())
model_clean.eval()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=2304, out_features=128, bias=True)
  (8): ReLU()
  (9): Linear(in_features=128, out_features=10, bias=True)
  (10): LogSoftmax(dim=1)
)

In [56]:
quantized_model = torch.quantization.quantize_dynamic(
    model_clean, {nn.Linear}, dtype=torch.qint8
)

In [57]:
torch.save(quantized_model.state_dict(), "mnist_pruned_quantized.pth")

# Check size
org_file_path = "mnist_torch_model.pth"
org_file_size = os.path.getsize(org_file_path) / 1e6  # size in MB
print(f"Model size: {org_file_size:.2f} MB")
prunned_file_path = 'mnist_torch_model_pruned.pth'
prunned_file_size = os.path.getsize(prunned_file_path) / 1e6  # size in MB
print(f"Model size: {prunned_file_size:.2f} MB")
size = os.path.getsize("mnist_pruned_quantized.pth") / 1e6
print(f"Pruned + Quantized model size: {size:.2f} MB")

Model size: 1.26 MB
Model size: 1.26 MB
Pruned + Quantized model size: 0.38 MB


In [58]:
quantized_model = quantized_model.to('cpu')
quantized_model.eval()
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        #images, labels = images.to(device), labels.to(device)
        output = quantized_model(images)
        pred = output.argmax(dim=1)
        correct += pred.eq(labels).sum().item()

print(f"Quantized model accuracy: {correct / len(test_loader.dataset):.4f}")

Quantized model accuracy: 0.9876


In [63]:
dummy_input = torch.randn(1, 1, 28, 28)

torch.onnx.export(
    model_torch,
    dummy_input,
    "model_torch_mnist_onnx.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=11
)

In [64]:
onnx_model = onnx.load("model_torch_mnist_onnx.onnx")
onnx.checker.check_model(onnx_model)

ort_session = ort.InferenceSession("model_torch_mnist_onnx")
outputs = ort_session.run(None, {"input": np.random.randn(1, 1, 28, 28).astype(np.float32)})
print("Output shape:", outputs[0].shape)

Output shape: (1, 10)
