In [46]:
import torch
import torch.nn as nn
from transformers import ViTModel, BertModel, Wav2Vec2Model

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [47]:
class UniConNet(nn.Module):
    def __init__(self):
        super(UniConNet, self).__init__()
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        self.gate = nn.Linear(768 * 3, 3)
        self.fc = nn.Linear(768, 3)
    
    def forward(self, image, text_input_ids, text_attention_mask, audio_input_values):
        img_features = self.image_encoder(pixel_values=image).last_hidden_state[:, 0]
        text_features = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state[:, 0]
        audio_features = self.audio_encoder(input_values=audio_input_values).last_hidden_state[:, 0]
        all_features = torch.cat([img_features, text_features, audio_features], dim=1)
        gate_values = torch.softmax(self.gate(all_features), dim=-1)
        fused_features = (gate_values[:, 0:1] * img_features + 
                          gate_values[:, 1:2] * text_features + 
                          gate_values[:, 2:3] * audio_features)
        output = self.fc(fused_features)
        return output


In [48]:
dummy_image = torch.randn(1, 3, 224, 224)
dummy_text_input_ids = torch.randint(0, 30522, (1, 32))
dummy_text_attention_mask = torch.ones(1, 32)
dummy_audio_input_values = torch.randn(1, 16000)


In [49]:
model = UniConNet()

onnx_model_path = "../onnx_exports/UniConNet.onnx"

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model.eval()

# Export the model
torch.onnx.export(
    model, 
    (dummy_image, dummy_text_input_ids, dummy_text_attention_mask, dummy_audio_input_values),  # Inputs to the model
    onnx_model_path,  
    input_names=['image', 'text_input_ids', 'text_attention_mask', 'audio_input_values'],  # Input names
    output_names=['output'],  # Output name
    dynamic_axes={
        'image': {0: 'batch_size'},  # Batch size is dynamic for image input
        'text_input_ids': {0: 'batch_size'},  # Batch size is dynamic for text input
        'text_attention_mask': {0: 'batch_size'},  # Batch size is dynamic for text attention mask
        'audio_input_values': {0: 'batch_size'},  # Batch size is dynamic for audio input
        'output': {0: 'batch_size'}  # Batch size is dynamic for output
    },
    opset_version=14   
)

print(f"Model has been exported to {onnx_model_path}")

Model has been exported to ../onnx_exports/UniConNet.onnx
