In [None]:
import torch
import torch.nn as nn

class EarlyFusionModel(nn.Module):
    def __init__(self, t_dim, a_dim, v_dim, hidden_dim=128):
        super().__init__()
        input_dim = t_dim + a_dim + v_dim

        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)   # regression or 1-class classification
        )

    def forward(self, t, a, v):
        # Concatenate features
        fused = torch.cat([t, a, v], dim=-1)
        return self.fc(fused)

# Example input
t = torch.randn(32, 300)  # batch=32, text=300-d
a = torch.randn(32, 74)   # audio
v = torch.randn(32, 35)   # video

model = EarlyFusionModel(300, 74, 35)
output = model(t, a, v)
print(output.shape)  # torch.Size([32, 1])
