In [1]:
import torch
import torch.nn as nn

class AverageFusion(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim):
        super().__init__()
        self.text_model = TextModel(t_dim)
        self.audio_model = AudioModel(a_dim)
        self.video_model = VideoModel(v_dim)

    def forward(self, t, a, v):
        pred_t = self.text_model(t)  # (B,1)
        pred_a = self.audio_model(a) # (B,1)
        pred_v = self.video_model(v) # (B,1)

        fused_pred = (pred_t + pred_a + pred_v) / 3
        return fused_pred


class MaxFusion(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim):
        super().__init__()
        self.text_model = TextModel(t_dim)
        self.audio_model = AudioModel(a_dim)
        self.video_model = VideoModel(v_dim)

    def forward(self, t, a, v):
        pred_t = self.text_model(t)
        pred_a = self.audio_model(a)
        pred_v = self.video_model(v)

        return torch.max(torch.stack([pred_t, pred_a, pred_v], dim=0), dim=0).values


class MajorityVoteFusion(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim, num_classes):
        super().__init__()
        self.text_model = nn.Linear(t_dim, num_classes)
        self.audio_model = nn.Linear(a_dim, num_classes)
        self.video_model = nn.Linear(v_dim, num_classes)

    def forward(self, t, a, v):
        # logits → predicted class index
        pred_t = self.text_model(t).argmax(dim=1)
        pred_a = self.audio_model(a).argmax(dim=1)
        pred_v = self.video_model(v).argmax(dim=1)

        # stack predictions
        preds = torch.stack([pred_t, pred_a, pred_v], dim=1)

        # majority vote
        fused_pred = torch.mode(preds, dim=1).values
        return fused_pred

class WeightedFusion(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim, weights=(0.5, 0.3, 0.2)):
        super().__init__()
        self.text_model = TextModel(t_dim)
        self.audio_model = AudioModel(a_dim)
        self.video_model = VideoModel(v_dim)

        w = torch.tensor(weights, dtype=torch.float32)
        self.register_buffer("w", w / w.sum())  # normalize

    def forward(self, t, a, v):
        pred_t = self.text_model(t)
        pred_a = self.audio_model(a)
        pred_v = self.video_model(v)

        fused_pred = (
            self.w[0] * pred_t +
            self.w[1] * pred_a +
            self.w[2] * pred_v
        )
        return fused_pred


class ConcatFusion(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim):
        super().__init__()
        self.text_model = TextModel(t_dim)
        self.audio_model = AudioModel(a_dim)
        self.video_model = VideoModel(v_dim)

        # output dims = 1 each → concat → 3-dim vector
        self.classifier = nn.Sequential(
            nn.Linear(3, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )

    def forward(self, t, a, v):
        pred_t = self.text_model(t)
        pred_a = self.audio_model(a)
        pred_v = self.video_model(v)

        fused = torch.cat([pred_t, pred_a, pred_v], dim=1)  # (B,3)
        return self.classifier(fused)


t = torch.randn(32, 300)
a = torch.randn(32, 74)
v = torch.randn(32, 35)

model = DecisionFusion(300, 74, 35)
out = model(t, a, v)
print(out.shape)  # torch.Size([32, 1])

torch.Size([32, 1])
