In [11]:
import torch
import torch.nn as nn
from transformers import ViTModel, BertModel, Wav2Vec2Model

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [12]:
class UniConNet(nn.Module):
    def __init__(self):
        super(UniConNet, self).__init__()
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        self.gate = nn.Linear(768 * 3, 3)
        self.fc = nn.Linear(768, 3)
    
    def forward(self, image, text_input_ids, text_attention_mask, audio_input_values):
        img_features = self.image_encoder(pixel_values=image).last_hidden_state[:, 0]
        text_features = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state[:, 0]
        audio_features = self.audio_encoder(input_values=audio_input_values).last_hidden_state[:, 0]
        all_features = torch.cat([img_features, text_features, audio_features], dim=1)
        gate_values = torch.softmax(self.gate(all_features), dim=-1)
        fused_features = (gate_values[:, 0:1] * img_features + 
                          gate_values[:, 1:2] * text_features + 
                          gate_values[:, 2:3] * audio_features)
        output = self.fc(fused_features)
        return output


In [13]:
dummy_image = torch.randn(1, 3, 224, 224)
dummy_text_input_ids = torch.randint(0, 30522, (1, 32))
dummy_text_attention_mask = torch.ones(1, 32)
dummy_audio_input_values = torch.randn(1, 16000)


In [14]:
model = UniConNet()

onnx_model_path = "../UniConNet.onnx"

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
torch.onnx.export(
    model,
    (dummy_image, dummy_text_input_ids, dummy_text_attention_mask, dummy_audio_input_values),
    onnx_model_path,
    export_params=True,
    opset_version=14,  # Change from 12 to 14
    input_names=['image', 'text_input_ids', 'text_attention_mask', 'audio_input_values'],
    output_names=['output'],
    dynamic_axes={
        'image': {0: 'batch_size'},
        'text_input_ids': {0: 'batch_size', 1: 'seq_len'},
        'text_attention_mask': {0: 'batch_size', 1: 'seq_len'},
        'audio_input_values': {0: 'batch_size', 1: 'audio_len'},
        'output': {0: 'batch_size'}
    }
)


  if num_channels != self.num_channels:
  if height != self.image_size[0] or width != self.image_size[1]:
  if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):


In [27]:
class m2(nn.Module):
    def __init__(self):
        super(m2, self).__init__()
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        self.gate = nn.Linear(768 * 3, 3)
        self.fc1 = nn.Linear(2304, 768)  # Added layer to reduce the features to 768
        self.fc2 = nn.Linear(768, 3)  # Final output layer
    
    def forward(self, image, text_input_ids, text_attention_mask, audio_input_values):
        img_features = self.image_encoder(pixel_values=image).last_hidden_state[:, 0]
        text_features = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state[:, 0]
        audio_features = self.audio_encoder(input_values=audio_input_values).last_hidden_state[:, 0]
        
        # Concatenate the features
        all_features = torch.cat([img_features, text_features, audio_features], dim=1)
        
        # Pass through the gate layer to get importance values for each modality
        gate_values = torch.softmax(self.gate(all_features), dim=-1)
        
        # Fuse the features based on the gate values
        fused_features = (gate_values[:, 0:1] * img_features + 
                          gate_values[:, 1:2] * text_features + 
                          gate_values[:, 2:3] * audio_features)
        
        # Pass through the first fully connected layer to reduce the size
        reduced_features = self.fc1(fused_features)
        
        # Final output prediction
        output = self.fc2(reduced_features)
        
        return output


In [29]:
model = m2()

onnx_model_path = "../m2.onnx"

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:
import torch
import torch.onnx

# Assuming your UniConNet model is already defined as `model`
model = UniConNet()  # Replace with your model initialization if necessary

# Set the model to evaluation mode (important for some layers like dropout)
model.eval()

# Prepare example inputs that match the expected input shapes
# Example input shapes based on the model
# image: [batch_size, 3, 224, 224] (assuming RGB image)
# text_input_ids: [batch_size, max_seq_len] (for BERT model)
# text_attention_mask: [batch_size, max_seq_len]
# audio_input_values: [batch_size, sequence_length] (for Wav2Vec2)

batch_size = 1  # You can adjust this based on your use case
image = torch.randn(batch_size, 3, 224, 224)  # Random image tensor
text_input_ids = torch.randint(0, 1000, (batch_size, 128))  # Random token ids (max_seq_len=128)
text_attention_mask = torch.ones(batch_size, 128)  # Attention mask
audio_input_values = torch.randn(batch_size, 16000)  # Random audio input (assuming 16kHz audio)

# Specify the path where you want to save the ONNX model
onnx_model_path = "../m2.onnx"

# Export the model to ONNX
torch.onnx.export(
    model, 
    (image, text_input_ids, text_attention_mask, audio_input_values),  # Example inputs
    onnx_model_path,  # Output path
    input_names=['image', 'text_input_ids', 'text_attention_mask', 'audio_input_values'],  # Names of input nodes
    output_names=['output'],  # Name of output node
    dynamic_axes={
        'image': {0: 'batch_size'},  # Allow dynamic batch size for image input
        'text_input_ids': {0: 'batch_size'},  # Allow dynamic batch size for text input
        'text_attention_mask': {0: 'batch_size'},  # Allow dynamic batch size for text input
        'audio_input_values': {0: 'batch_size'},  # Allow dynamic batch size for audio input
        'output': {0: 'batch_size'}  # Allow dynamic batch size for output
    },
    opset_version=12  # The ONNX opset version to use, 12 should work for most use cases
)

print(f"Model successfully exported to {onnx_model_path}")


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  if num_channels != self.num_channels:
  if height != self.image_size[0] or width != self.image_size[1]:
  if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x768 and 2304x768)