In [None]:
"""
This notebook defines and exports a multimodal classification model that combines visual and textual features into an ONNX format. 
It first implements three components: `VisualCNN`, a MobileNetV3-based image feature extractor; `TextBERT`, a BERT-based text 
feature extractor with a projection layer; and `SEClassifier`, which merges visual and text embeddings to perform classification. 
A trained model is then loaded from disk, set to evaluation mode, and tested with dummy image and text inputs matching the expected 
dimensions. Finally, the script uses `torch.onnx.export` to convert the model into an ONNX file (`model.onnx`) with dynamic batch 
size support for future deployment or interoperability.
"""


In [1]:
import os
import time
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

import onnx

In [2]:
# Visual feature extractor from orignal model
class VisualCNN(nn.Module):
    def __init__(self):
        super(VisualCNN, self).__init__()
        self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v3_small', pretrained=True)
        self.backbone.classifier = nn.Identity()
    
    def forward(self, x):
        return self.backbone(x)

In [3]:
# Text feature extractor from orignal model
class TextBERT(nn.Module):
    def __init__(self, model_name="prajjwal1/bert-mini"):
        super(TextBERT, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.fc = nn.Linear(self.bert.config.hidden_size, 128)
    
    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return self.fc(output.pooler_output)

In [4]:
# Classifier from orignal model
class SEClassifier(nn.Module):
    def __init__(self, visual_feat_dim=576, text_feat_dim=128, num_classes=2):
        super(SEClassifier, self).__init__()
        self.visual_cnn = VisualCNN()
        self.text_bert = TextBERT()
        self.fc = nn.Sequential(
            nn.Linear(visual_feat_dim + text_feat_dim, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, image, input_ids, attention_mask):
        visual_features = self.visual_cnn(image)
        text_features = self.text_bert(input_ids, attention_mask)
        combined = torch.cat((visual_features, text_features), dim=1)
        return self.fc(combined)

In [5]:
# Load your trained model
model = torch.load("../models/m33_ep4.pth", weights_only=False)
model.eval()
model.to("cpu")

  from .autonotebook import tqdm as notebook_tqdm


SEClassifier(
  (visual_cnn): VisualCNN(
    (backbone): MobileNetV3(
      (features): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): InvertedResidual(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
              (activation): ReLU()
          

In [6]:
# # Create dummy inputs - OG
# dummy_image = torch.randn(1, 3, 224, 224)  # Dummy image tensor
# dummy_input_ids = torch.randint(0, 30522, (1, 128))  # Random tokenized text
# dummy_attention_mask = torch.ones((1, 128))  # All ones (no padding)

# Create dummy inputs - To match test
dummy_image = torch.randn(1, 3, 960, 540)  # Dummy image tensor
dummy_input_ids = torch.randint(0, 30522, (1, 512))  # Random tokenized text
dummy_attention_mask = torch.ones((1, 512))  # All ones (no padding)

# Move inputs to CPU
dummy_image = dummy_image.to("cpu")
dummy_input_ids = dummy_input_ids.to("cpu")
dummy_attention_mask = dummy_attention_mask.to("cpu")


In [7]:
output = model(dummy_image, dummy_input_ids, dummy_attention_mask)

In [8]:
# Export the model
torch.onnx.export(
    model,
    (dummy_image, dummy_input_ids, dummy_attention_mask),  # Tuple of inputs
    "model.onnx",
    input_names=["image", "input_ids", "attention_mask"],
    output_names=["output"],
    dynamic_axes={
        "image": {0: "batch_size"},
        "input_ids": {0: "batch_size"},
        "attention_mask": {0: "batch_size"},
        "output": {0: "batch_size"}
    }
)

print("Model exported to model.onnx successfully!")

Model exported to model.onnx successfully!
