In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel, BertModel, Wav2Vec2Model
from torch.nn.functional import normalize

class AdaptiveModalityEncoder(nn.Module):
    def __init__(self, input_dim=768):
        super(AdaptiveModalityEncoder, self).__init__()
        self.transformer_block = nn.TransformerEncoderLayer(d_model=input_dim, nhead=8)
        self.gate = nn.Linear(input_dim, 1)  # Gating mechanism to assign relevance to each modality
    
    def forward(self, modality_features):
        # modality_features is a list of tensors (img_features, text_features, audio_features)
        stacked_features = torch.stack(modality_features, dim=1)
        transformed = self.transformer_block(stacked_features)  # (batch, modalities, dim)
        gated_weights = torch.softmax(self.gate(transformed).squeeze(-1), dim=1)
        weighted_features = (gated_weights.unsqueeze(-1) * transformed).sum(dim=1)
        return weighted_features, gated_weights

class HierarchicalCrossModalityAttention(nn.Module):
    def __init__(self, input_dim=768):
        super(HierarchicalCrossModalityAttention, self).__init__()
        self.cross_attention1 = nn.MultiheadAttention(embed_dim=input_dim, num_heads=8)
        self.cross_attention2 = nn.MultiheadAttention(embed_dim=input_dim, num_heads=8)
    
    def forward(self, img_features, text_features, audio_features):
        # First level of cross-modality attention with text as the pivot modality
        text_img_attn, _ = self.cross_attention1(text_features, img_features, img_features)
        text_audio_attn, _ = self.cross_attention1(text_features, audio_features, audio_features)
        fused_text = (text_img_attn + text_audio_attn) / 2
        
        # Second level of cross-modality attention with image as the pivot modality
        img_text_attn, _ = self.cross_attention2(img_features, fused_text, fused_text)
        img_audio_attn, _ = self.cross_attention2(img_features, audio_features, audio_features)
        fused_img = (img_text_attn + img_audio_attn) / 2
        
        # Final fusion
        fused_features = (fused_text + fused_img + audio_features) / 3
        return fused_features

class MACAST(nn.Module):
    def __init__(self):
        super(MACAST, self).__init__()
        
        # Modality encoders
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

        
        
        # Adaptive Modality Encoder
        self.adaptive_modality_encoder = AdaptiveModalityEncoder()
        
        # Hierarchical Cross-Modality Attention
        self.hierarchical_attention = HierarchicalCrossModalityAttention()
        
        # Sentiment classification head
        self.classification_head = nn.Linear(768, 3)  # 3 sentiment classes
        
        # Auxiliary losses for self-supervised modality contrastive alignment
        self.contrastive_temp = 0.07
        self.contrastive_loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, image, text, audio):
        # Encode each modality
        img_features = self.image_encoder(pixel_values=image).last_hidden_state[:, 0]
        text_features = self.text_encoder(input_ids=text["input_ids"], attention_mask=text["attention_mask"]).last_hidden_state[:, 0]
        audio_features = audio.view(audio.size(0), -1)  # Flatten everything except the batch dimension

        
        # Adaptive Modality Encoding
        weighted_features, gated_weights = self.adaptive_modality_encoder([img_features, text_features, audio_features])
        
        # Hierarchical Cross-Modality Attention Fusion
        fused_features = self.hierarchical_attention(img_features, text_features, audio_features)
        
        # Classification head
        logits = self.classification_head(fused_features)
        
        # Contrastive alignment for modality embeddings
        img_proj, text_proj, audio_proj = map(normalize, [img_features, text_features, audio_features])
        contrastive_loss = self.modality_contrastive_loss(img_proj, text_proj, audio_proj)
        
        return logits, contrastive_loss

    def modality_contrastive_loss(self, img_proj, text_proj, audio_proj):
        # Implementing contrastive loss for cross-modal alignment
        batch_size = img_proj.size(0)
        labels = torch.arange(batch_size).to(img_proj.device)
        
        logits_img_text = torch.mm(img_proj, text_proj.T) / self.contrastive_temp
        logits_img_audio = torch.mm(img_proj, audio_proj.T) / self.contrastive_temp
        logits_text_audio = torch.mm(text_proj, audio_proj.T) / self.contrastive_temp
        
        loss_img_text = self.contrastive_loss_fn(logits_img_text, labels)
        loss_img_audio = self.contrastive_loss_fn(logits_img_audio, labels)
        loss_text_audio = self.contrastive_loss_fn(logits_text_audio, labels)
        
        return (loss_img_text + loss_img_audio + loss_text_audio) / 3



In [32]:
import torch
import torch.onnx
from transformers import ViTFeatureExtractor, BertTokenizer, Wav2Vec2FeatureExtractor

# Initialize the MACAST model
model = MACAST()
model.eval()

# Define ONNX opset version (14 or above to support all operations)
opset_version = 14

# Prepare dummy inputs
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

# Dummy inputs for each modality

dummy_image = torch.rand(1, 3, 224, 224)  # Generate a random tensor within [0, 1]
dummy_image = feature_extractor(images=dummy_image, return_tensors="pt")["pixel_values"]

dummy_text = tokenizer("This is a test sentence.", return_tensors="pt")
dummy_audio = audio_feature_extractor(torch.randn(1, 16000), return_tensors="pt")

# ONNX output path
onnx_model_path = "MACAST.onnx"

# Processed inputs, extracting only the tensor values from the BatchEncoding objects
dummy_image_tensor = dummy_image
dummy_text_ids = dummy_text["input_ids"]
dummy_text_attention_mask = dummy_text["attention_mask"]
dummy_audio_input_values = dummy_audio["input_values"]

# Export the model to ONNX
torch.onnx.export(
    model,
    (dummy_image_tensor, 
     {"input_ids": dummy_text_ids, "attention_mask": dummy_text_attention_mask}, 
     dummy_audio_input_values),
    onnx_model_path,
    export_params=True,
    opset_version=14,
    input_names=["image", "text_input", "audio_input_values"],  # Changed input names to reflect the dictionary structure
    output_names=["logits", "contrastive_loss"],
    dynamic_axes={
        "image": {0: "batch_size"},
        "text_input": {0: "batch_size", 1: "seq_len"},
        "audio_input_values": {0: "batch_size", 1: "audio_len"},
        "logits": {0: "batch_size"},
        "contrastive_loss": {0: "batch_size"}
    }
)

print(f"Model successfully exported to {onnx_model_path}")



Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
  if num_channels != self.num_channels:
  if height != self.image_size[0] or width != self.image_size[1]:


RuntimeError: stack expects each tensor to be equal size, but got [1, 768] at entry 0 and [1, 16000] at entry 2

In [33]:
import torch
import torch.nn as nn
from transformers import ViTModel, BertModel, Wav2Vec2Model
from torch.nn.functional import normalize

# Adaptive Modality Encoder: Modality features are weighted
class AdaptiveModalityEncoder(nn.Module):
    def __init__(self, input_dim=768):
        super(AdaptiveModalityEncoder, self).__init__()
        self.transformer_block = nn.TransformerEncoderLayer(d_model=input_dim, nhead=8)
        self.gate = nn.Linear(input_dim, 1)  # Gating mechanism to assign relevance to each modality
    
    def forward(self, modality_features):
        # modality_features is a list of tensors (img_features, text_features, audio_features)
        stacked_features = torch.stack(modality_features, dim=1)
        transformed = self.transformer_block(stacked_features)  # (batch, modalities, dim)
        gated_weights = torch.softmax(self.gate(transformed).squeeze(-1), dim=1)
        weighted_features = (gated_weights.unsqueeze(-1) * transformed).sum(dim=1)
        return weighted_features, gated_weights


# Hierarchical Cross-Modality Attention Fusion Layer
class HierarchicalCrossModalityAttention(nn.Module):
    def __init__(self, input_dim=768):
        super(HierarchicalCrossModalityAttention, self).__init__()
        self.cross_attention1 = nn.MultiheadAttention(embed_dim=input_dim, num_heads=8)
        self.cross_attention2 = nn.MultiheadAttention(embed_dim=input_dim, num_heads=8)
    
    def forward(self, img_features, text_features, audio_features):
        # First level of cross-modality attention with text as the pivot modality
        text_img_attn, _ = self.cross_attention1(text_features, img_features, img_features)
        text_audio_attn, _ = self.cross_attention1(text_features, audio_features, audio_features)
        fused_text = (text_img_attn + text_audio_attn) / 2
        
        # Second level of cross-modality attention with image as the pivot modality
        img_text_attn, _ = self.cross_attention2(img_features, fused_text, fused_text)
        img_audio_attn, _ = self.cross_attention2(img_features, audio_features, audio_features)
        fused_img = (img_text_attn + img_audio_attn) / 2
        
        # Final fusion
        fused_features = (fused_text + fused_img + audio_features) / 3
        return fused_features


# MACAST Model
class MACAST(nn.Module):
    def __init__(self):
        super(MACAST, self).__init__()
        
        # Modality encoders
        self.image_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.text_encoder = BertModel.from_pretrained("bert-base-uncased")
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
        
        # Adaptive Modality Encoder
        self.adaptive_modality_encoder = AdaptiveModalityEncoder()
        
        # Hierarchical Cross-Modality Attention
        self.hierarchical_attention = HierarchicalCrossModalityAttention()
        
        # Sentiment classification head
        self.classification_head = nn.Linear(768, 3)  # 3 sentiment classes
        
        # Auxiliary losses for self-supervised modality contrastive alignment
        self.contrastive_temp = 0.07
        self.contrastive_loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, image, text, audio):
        # Encode each modality
        img_features = self.image_encoder(pixel_values=image).last_hidden_state[:, 0]
        text_features = self.text_encoder(input_ids=text["input_ids"], attention_mask=text["attention_mask"]).last_hidden_state[:, 0]
        audio_features = audio.view(audio.size(0), -1)  # Flatten everything except the batch dimension
        
        # Adaptive Modality Encoding
        weighted_features, gated_weights = self.adaptive_modality_encoder([img_features, text_features, audio_features])
        
        # Hierarchical Cross-Modality Attention Fusion
        fused_features = self.hierarchical_attention(img_features, text_features, audio_features)
        
        # Classification head
        logits = self.classification_head(fused_features)
        
        # Contrastive alignment for modality embeddings
        img_proj, text_proj, audio_proj = map(normalize, [img_features, text_features, audio_features])
        contrastive_loss = self.modality_contrastive_loss(img_proj, text_proj, audio_proj)
        
        return logits, contrastive_loss

    def modality_contrastive_loss(self, img_proj, text_proj, audio_proj):
        # Implementing contrastive loss for cross-modal alignment
        batch_size = img_proj.size(0)
        labels = torch.arange(batch_size).to(img_proj.device)
        
        logits_img_text = torch.mm(img_proj, text_proj.T) / self.contrastive_temp
        logits_img_audio = torch.mm(img_proj, audio_proj.T) / self.contrastive_temp
        logits_text_audio = torch.mm(text_proj, audio_proj.T) / self.contrastive_temp
        
        loss_img_text = self.contrastive_loss_fn(logits_img_text, labels)
        loss_img_audio = self.contrastive_loss_fn(logits_img_audio, labels)
        loss_text_audio = self.contrastive_loss_fn(logits_text_audio, labels)
        
        return (loss_img_text + loss_img_audio + loss_text_audio) / 3


# ONNX Export Code
import torch.onnx
from transformers import ViTFeatureExtractor, BertTokenizer, Wav2Vec2FeatureExtractor

# Initialize the MACAST model
model = MACAST()
model.eval()

# Define ONNX opset version (14 or above to support all operations)
opset_version = 14

# Prepare dummy inputs
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
audio_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

# Dummy inputs for each modality

dummy_image = torch.rand(1, 3, 224, 224)  # Generate a random tensor within [0, 1]
dummy_image = feature_extractor(images=dummy_image, return_tensors="pt")["pixel_values"]

dummy_text = tokenizer("This is a test sentence.", return_tensors="pt")
dummy_audio = audio_feature_extractor(torch.randn(1, 16000), return_tensors="pt")

# ONNX output path
onnx_model_path = "MACAST.onnx"

# Processed inputs, extracting only the tensor values from the BatchEncoding objects
dummy_image_tensor = dummy_image
dummy_text_ids = dummy_text["input_ids"]
dummy_text_attention_mask = dummy_text["attention_mask"]
dummy_audio_input_values = dummy_audio["input_values"]

# Export the model to ONNX
torch.onnx.export(
    model,
    (dummy_image_tensor, 
     {"input_ids": dummy_text_ids, "attention_mask": dummy_text_attention_mask}, 
     dummy_audio_input_values),
    onnx_model_path,
    export_params=True,
    opset_version=14,
    input_names=["image", "text_input", "audio_input_values"],  # Changed input names to reflect the dictionary structure
    output_names=["logits", "contrastive_loss"],
    dynamic_axes={
        "image": {0: "batch_size"},
        "text_input": {0: "batch_size", 1: "seq_len"},
        "audio_input_values": {0: "batch_size", 1: "audio_len"},
        "logits": {0: "batch_size"},
        "contrastive_loss": {0: "batch_size"}
    }
)

print(f"Model successfully exported to {onnx_model_path}")


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


RuntimeError: stack expects each tensor to be equal size, but got [1, 768] at entry 0 and [1, 16000] at entry 2