# 05 - Late Fusion Multimodal Deepfake Detection

## Objective
Combine predictions from individual modality classifiers (late fusion).

## Approach
1. Train separate classifiers for each modality
2. Combine predictions using weighted voting or learned fusion
3. Compare with early fusion

## Fusion Strategies
- Average voting
- Weighted voting (learned weights)
- Meta-classifier on predictions

In [None]:
class LateFusionDetector(nn.Module):
    def __init__(self, image_model, audio_model, video_model, fusion_type='weighted'):
        super().__init__()
        self.image_model = image_model
        self.audio_model = audio_model
        self.video_model = video_model
        self.fusion_type = fusion_type
        
        if fusion_type == 'weighted':
            self.weights = nn.Parameter(torch.ones(3) / 3)
        elif fusion_type == 'meta':
            self.meta_classifier = nn.Sequential(
                nn.Linear(6, 32),  # 3 models * 2 classes
                nn.ReLU(),
                nn.Linear(32, 2)
            )
    
    def forward(self, image, audio, video):
        img_pred = torch.softmax(self.image_model(image), dim=1)
        aud_pred = torch.softmax(self.audio_model(audio), dim=1)
        vid_pred = torch.softmax(self.video_model(video), dim=1)
        
        if self.fusion_type == 'average':
            output = (img_pred + aud_pred + vid_pred) / 3
        elif self.fusion_type == 'weighted':
            weights = torch.softmax(self.weights, dim=0)
            output = weights[0] * img_pred + weights[1] * aud_pred + weights[2] * vid_pred
        elif self.fusion_type == 'meta':
            combined = torch.cat([img_pred, aud_pred, vid_pred], dim=1)
            output = self.meta_classifier(combined)
        
        return output

print('Late Fusion Model Defined!')