In [None]:
import torch
import torch.nn as nn

class LMF(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim, output_dim=1, rank=8):
        super().__init__()
        self.rank = rank

        # modality projections (low-rank factors)
        self.U_t = nn.Parameter(torch.randn(rank, t_dim))
        self.U_a = nn.Parameter(torch.randn(rank, a_dim))
        self.U_v = nn.Parameter(torch.randn(rank, v_dim))

        # fusion weights
        self.W_fusion = nn.Parameter(torch.randn(output_dim, rank))

        # bias term
        self.bias = nn.Parameter(torch.randn(output_dim))

    def forward(self, t, a, v):
        """
        t : (B, t_dim)
        a : (B, a_dim)
        v : (B, v_dim)
        """

        # (B, r)
        t_proj = torch.matmul(t, self.U_t.T)
        a_proj = torch.matmul(a, self.U_a.T)
        v_proj = torch.matmul(v, self.U_v.T)

        # Element-wise fusion
        fused = t_proj * a_proj * v_proj  # (B, r)

        # Final output: (B, output_dim)
        out = torch.matmul(fused, self.W_fusion.T) + self.bias

        return out


In [None]:
batch = 32
t = torch.randn(batch, 300)   # text features
a = torch.randn(batch, 74)    # audio features
v = torch.randn(batch, 35)    # video features

model = LMF(300, 74, 35, output_dim=1, rank=8)
out = model(t, a, v)

print(out.shape)   # â†’ torch.Size([32, 1])