In [1]:
import torch
import torch.nn as nn

# ---- Individual unimodal models ----
class TextModel(nn.Module):
    def __init__(self, t_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(t_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.fc(x)


class AudioModel(nn.Module):
    def __init__(self, a_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(a_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.fc(x)


class VideoModel(nn.Module):
    def __init__(self, v_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(v_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.fc(x)


# ---- Decision-Level Fusion ----
class DecisionFusion(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim):
        super().__init__()
        self.text_model = TextModel(t_dim)
        self.audio_model = AudioModel(a_dim)
        self.video_model = VideoModel(v_dim)

    def forward(self, t, a, v):
        pred_t = self.text_model(t)  # (B,1)
        pred_a = self.audio_model(a) # (B,1)
        pred_v = self.video_model(v) # (B,1)

        # Simple average fusion
        fused_pred = (pred_t + pred_a + pred_v) / 3
        return fused_pred


# Example input
t = torch.randn(32, 300)
a = torch.randn(32, 74)
v = torch.randn(32, 35)

model = DecisionFusion(300, 74, 35)
out = model(t, a, v)
print(out.shape)  # torch.Size([32, 1])

torch.Size([32, 1])
