# 07 - Contrastive Multimodal Deepfake Detection

## Objective
Use CLIP-style contrastive learning to align multimodal representations.

## Key Idea
Real videos have consistent representations across modalities.
Fake videos show inconsistencies between audio-visual features.

## Training Strategy
1. Contrastive loss to align real samples
2. Classification loss for fake detection
3. Joint optimization

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, features1, features2):
        # Normalize features
        features1 = nn.functional.normalize(features1, dim=1)
        features2 = nn.functional.normalize(features2, dim=1)
        
        # Similarity matrix
        similarity = torch.matmul(features1, features2.T) / self.temperature
        
        # Labels: diagonal elements are positive pairs
        labels = torch.arange(features1.size(0), device=features1.device)
        
        # Symmetric loss
        loss_1 = nn.functional.cross_entropy(similarity, labels)
        loss_2 = nn.functional.cross_entropy(similarity.T, labels)
        
        return (loss_1 + loss_2) / 2

class ContrastiveMultimodalDetector(nn.Module):
    def __init__(self, feature_dim=512):
        super().__init__()
        # Encoders
        self.image_encoder = nn.Sequential(
            nn.Linear(512, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim)
        )
        self.audio_encoder = nn.Sequential(
            nn.Linear(768, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim)
        )
        self.video_encoder = nn.Sequential(
            nn.Linear(512, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, feature_dim)
        )
        
        # Classifier on combined features
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim * 3, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2)
        )
        
        self.contrastive_loss = ContrastiveLoss()
    
    def forward(self, img_feat, aud_feat, vid_feat, compute_contrastive=False):
        img_enc = self.image_encoder(img_feat)
        aud_enc = self.audio_encoder(aud_feat)
        vid_enc = self.video_encoder(vid_feat)
        
        if compute_contrastive:
            # Compute contrastive losses
            loss_img_aud = self.contrastive_loss(img_enc, aud_enc)
            loss_img_vid = self.contrastive_loss(img_enc, vid_enc)
            loss_aud_vid = self.contrastive_loss(aud_enc, vid_enc)
            contrastive = (loss_img_aud + loss_img_vid + loss_aud_vid) / 3
        else:
            contrastive = None
        
        # Classification
        combined = torch.cat([img_enc, aud_enc, vid_enc], dim=1)
        logits = self.classifier(combined)
        
        return logits, contrastive

print('Contrastive Multimodal Model Defined!')