In [None]:
import torch
import torch.nn as nn


class TFN(nn.Module):
    
    def __init__(self, text_dim, audio_dim, visual_dim, output_dim=1):
        super().__init__()
        fused_feature_dim = (text_dim + 1) * (audio_dim + 1) * (visual_dim + 1)
        self.fc = nn.Linear(fused_feature_dim, output_dim)

    
    def forward(self, text_features, audio_features, visual_features):
        batch_size = text_features.size(0)
        bias_column = torch.ones(batch_size, 1, device=text_features.device)

        text_with_bias = torch.cat([text_features, bias_column], dim=1)
        audio_with_bias = torch.cat([audio_features, bias_column], dim=1)
        visual_with_bias = torch.cat([visual_features, bias_column], dim=1)

        text_expanded = text_with_bias.unsqueeze(2).unsqueeze(3)   # (B, text_dim+1, 1, 1)
        audio_expanded = audio_with_bias.unsqueeze(1).unsqueeze(3) # (B, 1, audio_dim+1, 1)
        visual_expanded = visual_with_bias.unsqueeze(1).unsqueeze(2) # (B, 1, 1, visual_dim+1)

        fused_tensor = text_expanded * audio_expanded * visual_expanded
        fused_vector = fused_tensor.view(batch_size, -1)

        output = self.fc(fused_vector)
        return output