In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_


class SubNet(nn.Module):
    """
    Subnetwork used for Audio and Video before fusion.
    """

    def __init__(self, in_size, hidden_size, dropout):
        super().__init__()
        self.norm     = nn.BatchNorm1d(in_size)
        self.drop     = nn.Dropout(p=dropout)
        self.linear_1 = nn.Linear(in_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, hidden_size)
        self.linear_3 = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        x = self.norm(x)
        x = self.drop(x)
        x = F.relu(self.linear_1(x))
        x = F.relu(self.linear_2(x))
        x = F.relu(self.linear_3(x))
        return x


class TextSubNet(nn.Module):
    """
    LSTM-based text subnet.
    """

    def __init__(self, in_size, hidden_size, out_size, num_layers=1, dropout=0.2, bidirectional=False):
        super().__init__()
        self.rnn = nn.LSTM(
            input_size=in_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        self.linear_1 = nn.Linear(hidden_size, out_size)

    def forward(self, x):
        _, (h_n, _) = self.rnn(x)
        h = self.dropout(h_n[-1])   # last layer hidden state
        return self.linear_1(h)

class LMF(nn.Module):
    def __init__(self, input_dims, hidden_dims, text_out, dropouts, output_dim, rank, use_softmax=False):
        super().__init__()

        audio_in, video_in, text_in = input_dims
        audio_h, video_h, text_h    = hidden_dims

        self.use_softmax = use_softmax
        self.rank = rank
        self.output_dim = output_dim

        # Subnetworks
        self.audio_subnet = SubNet(audio_in, audio_h, dropouts[0])
        self.video_subnet = SubNet(video_in, video_h, dropouts[1])
        self.text_subnet  = TextSubNet(text_in, text_h, text_out, dropout=dropouts[2])

        self.post_fusion_dropout = nn.Dropout(p=dropouts[3])

        # Low-rank factors
        self.audio_factor = nn.Parameter(torch.Tensor(rank, audio_h + 1, output_dim))
        self.video_factor = nn.Parameter(torch.Tensor(rank, video_h + 1, output_dim))
        self.text_factor  = nn.Parameter(torch.Tensor(rank, text_out + 1, output_dim))

        self.fusion_weights = nn.Parameter(torch.Tensor(rank))
        self.fusion_bias    = nn.Parameter(torch.zeros(1, output_dim))

        # Xavier initialization
        xavier_normal_(self.audio_factor)
        xavier_normal_(self.video_factor)
        xavier_normal_(self.text_factor)
        xavier_normal_(self.fusion_weights.unsqueeze(0))

    def forward(self, audio_x, video_x, text_x):
        batch_size = audio_x.size(0)

        # Feature extraction
        audio_h = self.audio_subnet(audio_x)
        video_h = self.video_subnet(video_x)
        text_h  = self.text_subnet(text_x)

        device = audio_x.device
        ones   = torch.ones(batch_size, 1, device=device)

        audio_h = torch.cat([ones, audio_h], dim=1)   # (B, A+1)
        video_h = torch.cat([ones, video_h], dim=1)   # (B, V+1)
        text_h  = torch.cat([ones, text_h], dim=1)    # (B, T+1)

        # Low-rank fusion using einsum
        fusion_audio = torch.einsum("ba, rao -> bro", audio_h, self.audio_factor)
        fusion_video = torch.einsum("bv, rvo -> bro", video_h, self.video_factor)
        fusion_text  = torch.einsum("bt, rto -> bro", text_h, self.text_factor)

        fusion = fusion_audio * fusion_video * fusion_text  # (B, R, O)

        # Weighted sum across rank
        output = torch.einsum("bro, r -> bo", fusion, self.fusion_weights)
        output = output + self.fusion_bias

        if self.use_softmax:
            output = F.softmax(output, dim=1)

        return output


In [None]:
batch = 32
t = torch.randn(batch, 300)   # text features
a = torch.randn(batch, 74)    # audio features
v = torch.randn(batch, 35)    # video features

model = LMF(300, 74, 35, output_dim=1, rank=8)
out = model(t, a, v)

print(out.shape)   # â†’ torch.Size([32, 1])