In [1]:
import sys
import os

# Add the project root directory to sys.path
project_root = os.path.abspath("..")  # adjust if needed
if project_root not in sys.path:
    sys.path.insert(-1, project_root)
print(f"Project root added to sys.path: {project_root}")
import torch
#from .ViTSubmodules import *

from models.quantized.quant_ready_LSTMNetVIT import QuantReadyLSTMNetViT
from models.quantized.quant_ready_ITAConformerLSTM import QuantReadyITALSTM
from models.ITAConformerLSTM import ITALSTM
from third_party.vitfly.models.model import LSTMNetVIT



Project root added to sys.path: /Users/denizonat/REPOS/neuroTUM/Drone-ViT-HW-Accelerator


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
def generate_dummy_input(traj_len=10):
    # Simulate a sequence of depth images (T, 1, 60, 90)
    depth_images = torch.randn(traj_len, 1, 60, 90)

    # Simulate desired velocities (T, 1)
    control_input = torch.rand(traj_len, 1)

    # Simulate quaternion orientations (T, 4)
    orientation = torch.tensor([[1.0, 0.0, 0.0, 0.0]] * traj_len).float()

    return [depth_images, control_input, orientation]

# 1 Prepering the pretrained weights
We need to remove the spectral_norm wrappers from the models since they arent supported by the torch QAT API

We start by loading the model 

In [None]:
model = ITALSTM()
model.load_state_dict(torch.load("../models/pretrained_models/ITALSTM.pth", map_location='cpu'))
model.eval()

In [4]:
ita_lstm = ITALSTM()
quant_ita_lstm = QuantReadyITALSTM()

In [5]:
X = generate_dummy_input()
print(f"Input shape: {[x.shape for x in X]}")

Input shape: [torch.Size([10, 1, 60, 90]), torch.Size([10, 1]), torch.Size([10, 4])]


In [6]:
ita_out = ita_lstm(X)
quant_ita_out = quant_ita_lstm(X)

In [7]:
print(f"ita_out shape: {ita_out[0].shape}, quant_ita_out shape: {quant_ita_out[0].shape}")

ita_out shape: torch.Size([10, 3]), quant_ita_out shape: torch.Size([10, 3])


In [8]:
# QAT configuration
quant_ita_lstm.train()
quant_ita_lstm.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
quant_ita_lstm.fuse_model()  
torch.quantization.prepare_qat(quant_ita_lstm, inplace=True)



QuantReadyITALSTM(
  (encoder_blocks): ModuleList(
    (0): QuantReadyITAEncoderLayer(
      (patchMerge): OverlapPatchMerging(
        (cn1): Conv2d(
          1, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3)
          (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
            fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_symmetric, reduce_range=False
            (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
          )
          (activation_post_process): FusedMovingAvgObsFakeQuantize(
            fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, reduce_range=False
            (activation_post_process): MovingAverageMinMaxObserver(

In [9]:
quant_ita_out = quant_ita_lstm(X)

In [10]:
print(quant_ita_out[0].shape)

torch.Size([10, 3])


Then we need to generate a dummy input for a single forward pass that normalises the weights

In [None]:
X = generate_dummy_input()
print(f"Input shape: {[x.shape for x in X]}")

In [None]:
x = generate_dummy_input()
# Perform a forward pass to normalize the weights
quantized_lstmnetvit.eval()
with torch.no_grad():
    x1 = quantized_lstmnetvit(x)

In [None]:
print(f"Output shape: {x1.shape}")

In [None]:
x = model(x)

In [None]:
print(len(x1))

print(x[1].shape)
print(x[0].shape)

Now we can remove the spectral norm wrappers and save the model

In [None]:
torch.nn.utils.remove_spectral_norm(model.decoder)  
torch.nn.utils.remove_spectral_norm(model.nn_fc2)


torch.save(model.state_dict(), "../models/pretrained_models/checkpoints_for_qat/ITALSTM.pth")

In [None]:
lstmnetvit = LSTMNetVIT()
italstm = ITALSTM()
quantized_lstmnetvit = QuantReadyLSTMNetViT()

In [None]:
lstmnetvit.eval()
with torch.no_grad():
    x1 = lstmnetvit(x)
print(f"Output shape: {x1.shape}")

In [None]:
quantized_lstmnetvit.eval()

In [None]:


pretrained_weights = torch.load("../models/pretrained_models/checkpoints_for_qat/ITALSTM.pth", map_location='cpu')
# Load the pretrained weights
quantized_lstmnetvit.load_state_dict(pretrained_weights, strict=False)
