# Cross-Modal Lip-Sync Detection: Model Inference Guide

This notebook implements the complete inference pipeline for the cross-modal phoneme–viseme alignment model.

**Key Concept**: Detect lip-sync manipulation (deepfakes) by comparing audio and visual modalities over time windows.

In [None]:
# Example usage (uncomment and adapt to your environment)
"""
# Load model
model_path = Path("task3_alignment_model.pt")
model = load_model_from_checkpoint(model_path)

# Analyze a single window
video_path = Path("sample_video.mp4")
prob = predict_with_model(model, video_path, start_s=0.0, end_s=1.0)
print(f"Window [0.0–1.0s] probability: {prob:.4f}")

# Analyze entire video
scores, metadata = predict_video_windows(
    model,
    video_path,
    window_size=1.0,
    stride=0.5
)
print(f"Overall decision: {'FAKE' if metadata['mean_score'] >= 0.5 else 'REAL'}")
"""

print("""
✓ Cross-modal lip-sync inference pipeline ready!

To use in your MVP:
1. Load model: model = load_model_from_checkpoint(checkpoint_path)
2. Single window: prob = predict_with_model(model, video_path, start_s, end_s)
3. Full video: scores, meta = predict_video_windows(model, video_path)

Outputs:
- Probability ≥ 0.5 → FAKE (lip-sync manipulated)
- Probability < 0.5 → REAL (authentic)
""")

## Example Usage

Below is a complete example showing how to use the inference pipeline.

In [None]:
def load_model_from_checkpoint(checkpoint_path: Path) -> nn.Module:
    """
    Load model from saved checkpoint.
    
    Args:
        checkpoint_path: Path to model weights (.pt file)
    
    Returns:
        Model in eval mode on correct device
    """
    logger.info(f"Loading model from {checkpoint_path}...")
    
    try:
        checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
        
        # Instantiate model
        model = CrossModalPhonemeVisemeAlignmentModel()
        
        # Load weights
        if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
            model.load_state_dict(checkpoint["model_state_dict"])
        else:
            model.load_state_dict(checkpoint)
        
        model.eval()
        model.to(DEVICE)
        
        logger.info(f"✓ Model loaded successfully on {DEVICE}")
        return model
        
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise


def predict_video_windows(
    model: nn.Module,
    video_path: Path,
    window_size: float = 1.0,
    stride: float = 0.5,
) -> Tuple[np.ndarray, dict]:
    """
    Run sliding window inference on entire video.
    
    Args:
        model: Loaded model
        video_path: Path to video file
        window_size: Window duration in seconds (default 1.0s)
        stride: Stride between windows in seconds (default 0.5s)
    
    Returns:
        - scores: numpy array of probabilities per window
        - metadata: dict with window times and statistics
    """
    import subprocess
    
    # Get video duration
    probe_cmd = f"ffprobe -v error -show_entries format=duration -of csv=p=0 '{video_path}'"
    result = subprocess.run(probe_cmd, shell=True, capture_output=True, text=True)
    duration = float(result.stdout.strip())
    
    logger.info(f"Analyzing video: {video_path.name} (duration: {duration:.1f}s)")
    
    scores = []
    windows = []
    
    start_s = 0.0
    while start_s < duration:
        end_s = min(start_s + window_size, duration)
        
        try:
            prob = predict_with_model(model, video_path, start_s, end_s)
            scores.append(prob)
            windows.append((start_s, end_s))
        except Exception as e:
            logger.warning(f"Skipping window [{start_s:.2f}–{end_s:.2f}s]: {e}")
            scores.append(np.nan)
            windows.append((start_s, end_s))
        
        start_s += stride
    
    scores_array = np.array(scores)
    
    # Aggregate statistics
    valid_scores = scores_array[~np.isnan(scores_array)]
    metadata = {
        "duration": duration,
        "windows": windows,
        "mean_score": float(np.mean(valid_scores)),
        "std_score": float(np.std(valid_scores)),
        "max_score": float(np.max(valid_scores)),
        "min_score": float(np.min(valid_scores)),
    }
    
    # Final decision
    final_prob = metadata["mean_score"]
    if final_prob >= 0.5:
        decision = "FAKE (manipulated)"
    else:
        decision = "REAL (authentic)"
    
    logger.info(f"✓ Analysis complete.")
    logger.info(f"  Mean probability: {final_prob:.4f}")
    logger.info(f"  Decision: {decision}")
    
    return scores_array, metadata


print("✓ Full pipeline functions defined")

## Section 7: End-to-End Prediction Pipeline

Complete workflow for a full video analysis:
1. Load model from checkpoint
2. Split video into sliding windows
3. Run inference on each window
4. Aggregate scores and make final decision

In [None]:
def predict_with_model(
    model: nn.Module,
    video_path: Path,
    start_s: float,
    end_s: float,
) -> float:
    """
    Complete inference pipeline for lip-sync manipulation detection.
    
    Args:
        model: Loaded CrossModalPhonemeVisemeAlignmentModel
        video_path: Path to video file
        start_s: Window start time in seconds
        end_s: Window end time in seconds
    
    Returns:
        Probability of lip-sync manipulation [0.0, 1.0]
        - p ≥ 0.5: likely FAKE (manipulated)
        - p < 0.5: likely REAL (authentic)
    """
    try:
        logger.info(f"Predicting window [{start_s:.2f}–{end_s:.2f}s] from {video_path.name}")
        
        # Step 1: Extract audio segment
        logger.debug("  → Extracting audio...")
        audio_waveform = extract_audio_segment(video_path, start_s, end_s)
        
        # Step 2: Extract video frames (mouth ROI)
        logger.debug("  → Extracting mouth frames...")
        mouth_frames = extract_mouth_frames(video_path, start_s, end_s)
        
        # Step 3: Verify temporal alignment
        verify_temporal_alignment(audio_waveform, mouth_frames, start_s, end_s)
        
        # Step 4: Preprocess
        logger.debug("  → Preprocessing audio...")
        audio_waveform = preprocess_audio(audio_waveform)
        
        logger.debug("  → Preprocessing video...")
        mouth_frames = preprocess_video(mouth_frames)
        
        # Step 5: Model inference
        logger.debug("  → Running model forward pass...")
        with torch.no_grad():
            logits = model(audio_waveform, mouth_frames)
        
        # Step 6: Convert logits to probability via sigmoid
        probability = torch.sigmoid(logits).squeeze().item()
        probability = float(np.clip(probability, 0.0, 1.0))
        
        logger.info(f"✓ Inference complete. Probability: {probability:.4f}")
        return probability
        
    except Exception as e:
        logger.error(f"Model inference failed: {e}")
        raise


print("✓ Inference function defined")

## Section 6: Model Inference Function

The core inference pipeline: extract → preprocess → forward → sigmoid

In [None]:
def verify_temporal_alignment(
    audio_waveform: Tensor,
    video_frames: Tensor,
    start_s: float,
    end_s: float,
    audio_sr: int = 16000,
    video_fps: int = 25,
) -> None:
    """
    Verify that audio and video segments are properly aligned.
    
    Args:
        audio_waveform: Audio tensor [1, num_samples]
        video_frames: Video tensor [num_frames, 3, H, W]
        start_s: Window start time
        end_s: Window end time
        audio_sr: Audio sample rate (16 kHz)
        video_fps: Video frame rate (25 fps)
    
    Raises:
        AssertionError if temporal alignment is invalid
    """
    window_duration = end_s - start_s
    
    # Expected audio samples
    expected_audio_samples = int(window_duration * audio_sr)
    actual_audio_samples = audio_waveform.shape[-1]
    
    # Expected video frames
    expected_frames = int(window_duration * video_fps)
    actual_frames = video_frames.shape[0]
    
    logger.info(
        f"Temporal alignment check for window [{start_s:.2f}–{end_s:.2f}s]:"
    )
    logger.info(f"  Audio: {actual_audio_samples} samples (expected ~{expected_audio_samples})")
    logger.info(f"  Video: {actual_frames} frames (expected ~{expected_frames})")
    
    # Allow 10% tolerance
    assert abs(actual_audio_samples - expected_audio_samples) < 0.1 * expected_audio_samples, \
        "Audio samples out of expected range"
    assert abs(actual_frames - expected_frames) < 0.1 * expected_frames, \
        "Video frames out of expected range"
    
    logger.info("✓ Temporal alignment verified")


print("✓ Temporal alignment verification defined")

## Section 5: Temporal Alignment and Synchronization

Critical: Audio and video segments must be temporally synchronized.
Both modalities use the same time window [start_s, end_s] for alignment.

In [None]:
# ImageNet normalization constants
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
MOUTH_ROI_SIZE = (112, 112)


def extract_mouth_roi_simple(frame: np.ndarray) -> Tensor:
    """
    Simple mouth ROI extraction using center crop.
    
    NOTE: In production, use MediaPipe or dlib for robust face/mouth detection.
    
    Args:
        frame: Input frame in BGR format [H, W, 3]
    
    Returns:
        Mouth ROI tensor [3, 112, 112]
    """
    h, w = frame.shape[:2]
    
    # Center crop approximation (30% of frame height/width)
    center_y, center_x = h // 2, w // 2
    roi_h, roi_w = int(h * 0.3), int(w * 0.3)
    
    y_start = max(0, center_y - roi_h // 2)
    y_end = min(h, center_y + roi_h // 2)
    x_start = max(0, center_x - roi_w // 2)
    x_end = min(w, center_x + roi_w // 2)
    
    mouth_roi = frame[y_start:y_end, x_start:x_end, :]
    
    # Convert BGR → RGB
    mouth_roi = cv2.cvtColor(mouth_roi, cv2.COLOR_BGR2RGB)
    
    # Resize to model input size
    mouth_roi = cv2.resize(mouth_roi, MOUTH_ROI_SIZE, interpolation=cv2.INTER_LINEAR)
    
    # Convert to tensor [3, H, W]
    mouth_roi_tensor = torch.from_numpy(mouth_roi).float()
    mouth_roi_tensor = mouth_roi_tensor.permute(2, 0, 1)  # HWC → CHW
    
    return mouth_roi_tensor


def extract_mouth_frames(
    video_path: Path,
    start_s: float,
    end_s: float,
    target_fps: int = 25,
) -> Tensor:
    """
    Extract mouth ROI frames from video segment at fixed FPS.
    
    Args:
        video_path: Path to video file
        start_s: Start time in seconds
        end_s: End time in seconds
        target_fps: Frame sampling rate (25 fps standard)
    
    Returns:
        Frame tensor [num_frames, 3, 112, 112]
    """
    try:
        cap = cv2.VideoCapture(str(video_path))
        if not cap.isOpened():
            raise RuntimeError(f"Failed to open video: {video_path}")
        
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Calculate frame indices
        start_frame = int(start_s * fps)
        end_frame = int(end_s * fps)
        
        # Clamp to valid range
        start_frame = max(0, min(start_frame, total_frames - 1))
        end_frame = max(start_frame + 1, min(end_frame, total_frames))
        
        frames = []
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
        
        for _ in range(start_frame, end_frame):
            ret, frame = cap.read()
            if not ret:
                break
            
            mouth_roi = extract_mouth_roi_simple(frame)
            frames.append(mouth_roi)
        
        cap.release()
        
        if not frames:
            raise RuntimeError(f"No frames extracted from window {start_s}-{end_s}s")
        
        # Stack into tensor [num_frames, 3, H, W]
        frames_tensor = torch.stack(frames)
        return frames_tensor
        
    except Exception as e:
        logger.error(f"Video frame extraction failed: {e}")
        raise


def preprocess_video(video_frames: Tensor) -> Tensor:
    """
    Normalize video frames using ImageNet statistics.
    
    Args:
        video_frames: Raw frame tensor [num_frames, 3, H, W] in [0, 255]
    
    Returns:
        Normalized frames on correct device
    """
    # Normalize to [0, 1]
    if video_frames.max() > 1.0:
        video_frames = video_frames / 255.0
    
    # Apply ImageNet normalization
    normalize = transforms.Normalize(
        mean=IMAGENET_MEAN,
        std=IMAGENET_STD
    )
    
    video_frames = normalize(video_frames)
    return video_frames.to(DEVICE)


print("✓ Video preprocessing functions defined")

## Section 4: Video Preprocessing Pipeline

Extract and preprocess mouth ROI frames:
- Load video frames
- Extract mouth region-of-interest (ROI)
- Resize to 112×112 (model input size)
- Normalize using ImageNet statistics

In [None]:
def extract_audio_segment(
    video_path: Path,
    start_s: float,
    end_s: float,
    target_sr: int = 16000,
) -> Tensor:
    """
    Extract audio segment from video and return as normalized torch.Tensor.
    
    Args:
        video_path: Path to video file
        start_s: Start time in seconds
        end_s: End time in seconds
        target_sr: Target sample rate (16 kHz for Wav2Vec2)
    
    Returns:
        torch.Tensor of shape [1, num_samples], normalized to [-1, 1]
    """
    try:
        # Load full audio from video
        audio, sr = librosa.load(str(video_path), sr=None, mono=True)
        
        # Convert time to samples
        start_sample = int(start_s * sr)
        end_sample = int(end_s * sr)
        
        # Extract segment
        audio_segment = audio[start_sample:end_sample]
        
        # Resample if needed
        if sr != target_sr:
            audio_segment = librosa.resample(
                audio_segment, orig_sr=sr, target_sr=target_sr
            )
        
        # Normalize to [-1, 1]
        max_val = np.abs(audio_segment).max()
        if max_val > 0:
            audio_segment = audio_segment / (max_val + 1e-8)
        
        # Convert to tensor [1, T]
        audio_tensor = torch.from_numpy(audio_segment).float()
        audio_tensor = audio_tensor.unsqueeze(0)
        
        return audio_tensor
        
    except Exception as e:
        logger.error(f"Audio extraction failed: {e}")
        raise


def preprocess_audio(audio_waveform: Tensor) -> Tensor:
    """
    Final audio preprocessing before model input.
    
    Args:
        audio_waveform: Raw audio tensor [1, T] or [T]
    
    Returns:
        Preprocessed audio on correct device
    """
    if audio_waveform.dim() == 1:
        audio_waveform = audio_waveform.unsqueeze(0)
    
    # Ensure normalized
    max_val = audio_waveform.abs().max()
    if max_val > 0:
        audio_waveform = audio_waveform / (max_val + 1e-8)
    
    return audio_waveform.to(DEVICE)


print("✓ Audio preprocessing functions defined")

## Section 3: Audio Preprocessing Pipeline

Extract and preprocess audio segments from video files:
- Load from video file
- Resample to 16 kHz (Wav2Vec2 standard)
- Convert to mono
- Normalize to [-1, 1] range

In [None]:
class CrossModalPhonemeVisemeAlignmentModel(nn.Module):
    """
    Cross-modal model for detecting lip-sync manipulation (deepfakes).
    
    Architecture:
    - Audio encoder: Wav2Vec2 (precomputed or fine-tuned)
    - Visual encoder: ResNet-18 CNN on mouth ROI frames
    - Fusion: Cross-modal attention mechanism
    - Classifier: FC head → logits
    """
    
    def __init__(
        self,
        audio_dim: int = 768,  # Wav2Vec2 hidden size
        visual_dim: int = 512,  # ResNet-18 feature dim
        shared_dim: int = 256,  # Fusion embedding size
        num_classes: int = 1,  # Binary: real vs fake
    ):
        super().__init__()
        
        # Audio encoder (using pretrained Wav2Vec2)
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.audio_encoder.freeze_feature_extractor()  # Freeze CNN feature extractor
        self.audio_proj = nn.Linear(audio_dim, shared_dim)
        
        # Visual encoder (ResNet-18 backbone)
        from torchvision.models import resnet18
        resnet = resnet18(pretrained=True)
        self.visual_encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.visual_proj = nn.Linear(512, shared_dim)
        
        # Cross-modal attention
        self.attention_audio = nn.MultiheadAttention(
            shared_dim, num_heads=4, batch_first=True
        )
        self.attention_visual = nn.MultiheadAttention(
            shared_dim, num_heads=4, batch_first=True
        )
        
        # Bottleneck (temporal pooling + fusion)
        self.bottleneck = nn.Sequential(
            nn.Linear(shared_dim * 2, shared_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(shared_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )
    
    def forward(
        self,
        audio_waveform: Tensor,  # [1, num_samples]
        video_frames: Tensor,  # [num_frames, 3, H, W]
    ) -> Tensor:
        """
        Forward pass: extract features, fuse, classify.
        
        Returns:
            logits: [1] or [batch_size, 1] raw scores (before sigmoid)
        """
        # Audio processing
        audio_features = self.audio_encoder(audio_waveform).last_hidden_state  # [1, T_audio, 768]
        audio_features = self.audio_proj(audio_features)  # [1, T_audio, 256]
        
        # Visual processing
        # Batch process frames through CNN
        visual_features_list = []
        for frame in video_frames:
            feat = self.visual_encoder(frame.unsqueeze(0))  # [1, 512, 1, 1]
            feat = feat.view(1, 512)  # [1, 512]
            visual_features_list.append(feat)
        visual_features = torch.cat(visual_features_list, dim=0)  # [num_frames, 512]
        visual_features = self.visual_proj(visual_features)  # [num_frames, 256]
        visual_features = visual_features.unsqueeze(0)  # [1, num_frames, 256]
        
        # Cross-modal attention (query audio, key/value visual)
        audio_attended, _ = self.attention_audio(
            audio_features, visual_features, visual_features
        )  # [1, T_audio, 256]
        
        # Cross-modal attention (query visual, key/value audio)
        visual_attended, _ = self.attention_visual(
            visual_features, audio_features, audio_features
        )  # [1, num_frames, 256]
        
        # Temporal pooling (mean over time)
        audio_pooled = audio_attended.mean(dim=1)  # [1, 256]
        visual_pooled = visual_attended.mean(dim=1)  # [1, 256]
        
        # Fusion bottleneck
        fused = torch.cat([audio_pooled, visual_pooled], dim=-1)  # [1, 512]
        bottleneck_out = self.bottleneck(fused)  # [1, 256]
        
        # Classification
        logits = self.classifier(bottleneck_out)  # [1, 1]
        
        return logits


print("✓ Model architecture defined: CrossModalPhonemeVisemeAlignmentModel")

## Section 2: Define Model Architecture

The **CrossModalPhonemeVisemeAlignmentModel** fuses audio and visual representations:

1. **Audio Branch**: Wav2Vec2 encoder → audio embeddings (T_audio, D_audio)
2. **Visual Branch**: ResNet-18 CNN → visual embeddings (T_video, D_visual)  
3. **Fusion**: Linear projections to shared embedding space
4. **Cross-Modal Attention**: Align audio-visual features
5. **Bottleneck**: Shared representation (temporal pooling)
6. **Classifier**: FC layer → raw logits (0/1 for real/fake)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
import cv2
import librosa
import soundfile as sf
from pathlib import Path
from typing import Tuple, Optional
from transformers import Wav2Vec2Model, Wav2Vec2Processor
from torchvision import transforms
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✓ Using device: {DEVICE}")

## Section 1: Import Required Libraries