In [33]:
import sys
import os
import sys
import os
import torch
import numpy as np
import onnx
import onnxruntime

from third_party.vitfly.models.ViTsubmodules import *
from third_party.vitfly.models.model import *

print("Successfully imported from ViTsubmodules.py and model.py.")

Successfully imported from ViTsubmodules.py and model.py.


In [34]:
# --- Quantization Specific Wrapper ---
class QuantizableLSTMNetVIT(LSTMNetVIT):
    """
    A wrapper class that prepares the LSTMNetVIT model for static quantization.
    It adds Quant/DeQuant stubs and handles the model's control flow for ONNX export.
    """
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, X):
        # Unpack the input tuple and apply the QuantStub to the image tensor
        img, des_vel, quat, hidden_state = X
        img = self.quant(img)

        # Execute the original model's logic, but follow a single, traceable path
        # by assuming the hidden state is always provided.
        embeds = [img]
        for block in self.encoder_blocks:
            embeds.append(block(embeds[-1]))
        
        out = embeds[1:]
        out = torch.cat([self.pxShuffle(out[1]), self.up_sample(out[0])], dim=1)
        out = self.down_sample(out)
        out = self.decoder(out.flatten(1))

        out = torch.cat([out, des_vel / 10, quat], dim=1).float()
        
        # Explicitly use the execution path with a hidden state for ONNX compatibility
        out, h = self.lstm(out, hidden_state)
        
        out = self.nn_fc2(out)

        # Apply the DeQuantStub before returning the final output
        out = self.dequant(out)
        
        return out, h

In [35]:
# --- Helper Functions ---
def get_dummy_input(batch_size=1):
    """Generates a valid dummy input tuple for the model."""
    image = torch.randn(batch_size, 1, 60, 90)
    des_vel = torch.randn(batch_size, 3)
    quat = torch.randn(batch_size, 2)
    # LSTM hidden state: (num_layers, batch_size, hidden_size)
    hidden = (torch.randn(3, batch_size, 128), torch.randn(3, batch_size, 128))
    return (image, des_vel, quat, hidden)

def print_model_size(model, label):
    """Prints the size of a model's state_dict in MB."""
    torch.save(model.state_dict(), "temp.p")
    size_mb = os.path.getsize("temp.p") / 1.e6
    print(f"{label}: {size_mb:.2f} MB")
    os.remove("temp.p")

In [36]:
# --- Step 1: Prepare the Model ---
print("🚀 Step 1: Preparing the model...")
# Instantiate the quantizable wrapper
model_fp32 = QuantizableLSTMNetVIT()

# IMPORTANT: Load your pre-trained weights here
#model_fp32.load_state_dict(torch.load('/Projects/Drone-ViT-HW-Accelerator/weights/lstm_vit.pth'))

# Remove spectral normalization, which interferes with quantization
torch.nn.utils.remove_spectral_norm(model_fp32.decoder)
torch.nn.utils.remove_spectral_norm(model_fp32.nn_fc2)

model_fp32.eval()
print("✅ Model prepared.")
print_model_size(model_fp32, "Original FP32 model size")

🚀 Step 1: Preparing the model...
✅ Model prepared.
Original FP32 model size: 14.28 MB


In [37]:
# --- Step 2: Configure Quantization ---
print("\n⚙️ Step 2: Configuring quantization...")
# Use 'fbgemm' for x86 CPUs. Use 'qnnpack' for ARM.
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Fuse modules for better performance (Conv-BN-ReLU, etc.)
# Note: Your model doesn't have obvious patterns like Conv-BN, but this is best practice.
#model_fp32_fused = torch.quantization.fuse_modules(model_fp32, []) # No modules to fuse in this case

# Insert observers to collect activation statistics
model_fp32_prepared = torch.quantization.prepare(model_fp32)
print("✅ Quantization configured.")


⚙️ Step 2: Configuring quantization...
✅ Quantization configured.


In [38]:
# --- Step 3: Calibrate the Model ---
print("\n🔬 Step 3: Calibrating the model...")
# Pass representative data through the model. For this example, we use dummy data.
with torch.no_grad():
    for _ in range(10):
        dummy_input = get_dummy_input()
        model_fp32_prepared(dummy_input)
print("✅ Calibration complete.")


🔬 Step 3: Calibrating the model...


RuntimeError: shape '[3, 1, 517, 128]' is invalid for input of size 384