In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_

class ModalityEncoder(nn.Module):
    """
    Feedforward subnet used for Audio and Video before fusion.
    """

    def __init__(self, input_dimenstion, hidden_dimenstion, dropout_probability):
        super().__init__()
        self.batch_norm = nn.BatchNorm1d(input_dimenstion)
        self.dropout    = nn.Dropout(p=dropout_probability)
        self.fc_layer1 = nn.Linear(input_dim, hidden_dimenstion)
        self.fc_layer2 = nn.Linear(hidden_dim, hidden_dimenstion)
        self.fc_layer3 = nn.Linear(hidden_dim, hidden_dimenstion)

    def forward(self, inputs):
        normalized = self.batch_norm(inputs)
        dropped    = self.dropout(normalized)
        hidden1    = F.relu(self.fc_layer1(dropped))
        hidden2    = F.relu(self.fc_layer2(hidden1))
        hidden3    = F.relu(self.fc_layer3(hidden2))
        
        return hidden3



class TextEncoder(nn.Module):
    """
    LSTM-based Text processing subnet.
    """

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.2, bidirectional=False):

        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            batch_first=True
        )

        self.dropout = nn.Dropout(dropout)
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, text_sequence):
        _, (hidden_states, _) = self.lstm(text_sequence)

        last_hidden = hidden_states[-1]                     # last LSTM layer output
        last_hidden = self.dropout(last_hidden)

        output_features = self.output_layer(last_hidden)
        return output_features



class LMF(nn.Module):
    """
    Low-rank Multimodal Fusion (LMF)
    """

    def __init__(self, input_dimenstion, hidden_dimenstion, text_output_dimenstion, dropouts, output_dimenstion, rank):

        super().__init__()
        audio_raw_input, video_raw_input, text_raw_input = input_dimenstion
        audio_hidden, video_hidden, text_hidden = hidden_dimenstion

        self.rank = rank
        self.output_dimenstion = output_dimenstion
        

        self.audio_encoder = ModalityEncoder(audio_raw_input, audio_hidden, dropouts[0])
        self.video_encoder = ModalityEncoder(video_raw_input, video_hidden, dropouts[1])
        self.text_encoder  = TextEncoder    (text_raw_input, text_hidden, text_output_dim, dropout=dropouts[2])

        self.post_fusion_dropout = nn.Dropout(p=dropouts[3])

        bias_term = 1

        self.audio_factor = nn.Parameter(
            torch.Tensor(rank, audio_hidden + bias_term, output_dim)
        )
        self.video_factor = nn.Parameter(
            torch.Tensor(rank, video_hidden + bias_term, output_dim)
        )
        self.text_factor = nn.Parameter(
            torch.Tensor(rank, text_output_dim + bias_term, output_dim)
        )

        self.rank_weights = nn.Parameter(torch.Tensor(rank))
        self.output_bias  = nn.Parameter(torch.zeros(1, output_dim))

        xavier_normal_(self.audio_factor)
        xavier_normal_(self.video_factor)
        xavier_normal_(self.text_factor)
        xavier_normal_(self.rank_weights.view(1, -1))
    
    def forward(self, audio_raw_input, video_raw_input, text_raw_input):
        
        batch_size = audio_input.size(0)
        device     = audio_input.device
        
        audio_extracted_features = self.audio_subnet(audio_raw_input)
        video_extracted_features = self.video_subnet(video_raw_input)
        text_extracted_features  = self.text_subnet(text_raw_input)
        
        bias_column = torch.ones(batch_size, 1, device=device)
        audio_with_bias = torch.cat([bias_column, audio_extracted_features], dim=1)
        video_with_bias = torch.cat([bias_column, video_extracted_features], dim=1)
        text_with_bias  = torch.cat([bias_column, text_extracted_features],  dim=1)
        
        audio_rank_projections = torch.einsum("bf, rfo -> bro", audio_with_bias, self.audio_factor)
        video_rank_projections = torch.einsum("bf, rfo -> bro", video_with_bias, self.video_factor)
        text_rank_projections  = torch.einsum("bf, rfo -> bro", text_with_bias,  self.text_factor)
        
        fused_all_modalities_tensor = audio_rank_projections * video_rank_projections * text_rank_projections
        fused_all_modalities_tensor = self.post_fusion_dropout(fused_all_modalities_tensor)
        
        logits = torch.einsum("bro, r -> bo", all_modalities_fused_tensor, self.rank_weights)
        logits = logits + self.output_bias
        
        normalized_probability = F.softmax(logits, dim=1)
        return normalized_probability

In [None]:
batch = 32
t = torch.randn(batch, 300)   # text features
a = torch.randn(batch, 74)    # audio features
v = torch.randn(batch, 35)    # video features

model = LMF(300, 74, 35, output_dim=1, rank=8)
out = model(t, a, v)

print(out.shape)   # → torch.Size([32, 1])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_

class ModalityEncoder(nn.Module)

    def __init__(self, input_dimenstion, hidden_dimenstion, dropout_probability):
        super().__init__()

        self.batch_norm = nn.BatchNorm1d(input_dimenstion)
        self.dropout    = nn.Dropout(p=dropout_probability)

        # FIX: use correct variable names
        self.fc_layer1 = nn.Linear(input_dimenstion, hidden_dimenstion)
        self.fc_layer2 = nn.Linear(hidden_dimenstion, hidden_dimenstion)
        self.fc_layer3 = nn.Linear(hidden_dimenstion, hidden_dimenstion)

    def forward(self, inputs):
        normalized = self.batch_norm(inputs)
        dropped    = self.dropout(normalized)
        hidden1    = F.relu(self.fc_layer1(dropped))
        hidden2    = F.relu(self.fc_layer2(hidden1))
        hidden3    = F.relu(self.fc_layer3(hidden2))
        return hidden3


class TextEncoder(nn.Module):
    """
    LSTM-based Text processing subnet.
    """

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.2, bidirectional=False):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            batch_first=True
        )

        self.dropout = nn.Dropout(dropout)
        self.output_layer = nn.Linear(hidden_dim, output_dim)

    def forward(self, text_sequence):
        _, (hidden_states, _) = self.lstm(text_sequence)

        last_hidden = hidden_states[-1]              # last layer output
        last_hidden = self.dropout(last_hidden)

        output_features = self.output_layer(last_hidden)
        return output_features


class LMF(nn.Module):
    """
    Low-rank Multimodal Fusion (LMF)
    """

    def __init__(self, input_dimenstion, hidden_dimenstion, text_output_dimenstion, dropouts, output_dimenstion, rank):

        super().__init__()

        # Unpack inputs
        audio_raw_input, video_raw_input, text_raw_input = input_dimenstion
        audio_hidden, video_hidden, text_hidden = hidden_dimenstion

        self.rank = rank
        self.output_dimenstion = output_dimenstion

        # BUILD ENCODERS (Fixed names)
        self.audio_subnet = ModalityEncoder(audio_raw_input, audio_hidden, dropouts[0])
        self.video_subnet = ModalityEncoder(video_raw_input, video_hidden, dropouts[1])
        self.text_subnet  = TextEncoder(text_raw_input, text_hidden, text_output_dimenstion, dropout=dropouts[2])

        self.post_fusion_dropout = nn.Dropout(p=dropouts[3])

        bias_term = 1

        # Low-rank factors — correct shapes
        self.audio_factor = nn.Parameter(
            torch.Tensor(rank, audio_hidden + bias_term, output_dimenstion)
        )
        self.video_factor = nn.Parameter(
            torch.Tensor(rank, video_hidden + bias_term, output_dimenstion)
        )
        self.text_factor = nn.Parameter(
            torch.Tensor(rank, text_output_dimenstion + bias_term, output_dimenstion)
        )

        self.rank_weights = nn.Parameter(torch.Tensor(rank))
        self.output_bias  = nn.Parameter(torch.zeros(1, output_dimenstion))

        # Initialize
        xavier_normal_(self.audio_factor)
        xavier_normal_(self.video_factor)
        xavier_normal_(self.text_factor)
        xavier_normal_(self.rank_weights.view(1, -1))


    def forward(self, audio_raw_input, video_raw_input, text_raw_input):

        batch_size = audio_raw_input.size(0)
        device     = audio_raw_input.device

        # 1. Extract Features
        audio_extracted_features = self.audio_subnet(audio_raw_input)
        video_extracted_features = self.video_subnet(video_raw_input)
        text_extracted_features  = self.text_subnet(text_raw_input)

        # 2. Add Bias Term
        bias_column = torch.ones(batch_size, 1, device=device)

        audio_with_bias = torch.cat([bias_column, audio_extracted_features], dim=1)
        video_with_bias = torch.cat([bias_column, video_extracted_features], dim=1)
        text_with_bias  = torch.cat([bias_column, text_extracted_features],  dim=1)

        # 3. Apply rank projections (einsum fully correct)
        audio_rank_projections = torch.einsum("bf, rfo -> bro",
                                              audio_with_bias,
                                              self.audio_factor)
        video_rank_projections = torch.einsum("bf, rfo -> bro",
                                              video_with_bias,
                                              self.video_factor)
        text_rank_projections  = torch.einsum("bf, rfo -> bro",
                                              text_with_bias,
                                              self.text_factor)

        # 4. Element-wise product
        fused_all_modalities_tensor = (
            audio_rank_projections *
            video_rank_projections *
            text_rank_projections
        )

        fused_all_modalities_tensor = self.post_fusion_dropout(fused_all_modalities_tensor)

        # 5. Combine ranks → final fused vector
        logits = torch.einsum("bro, r -> bo",
                              fused_all_modalities_tensor,
                              self.rank_weights)

        logits = logits + self.output_bias

        return logits      # DO NOT softmax here (classifier will handle it)
