In [1]:
!pip install -U -q sentence-transformers

In [None]:
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2Model

model_name="facebook/wav2vec2-base"
model = Wav2Vec2Model.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)
def extract_audio_features(audio_path, processor, model):
    """SIZE [N,768]
    N (Sequence length dimension):input audio length 
    """
    # Load the audio file
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Resample if necessary
    if sample_rate != processor.feature_extractor.sampling_rate:
        resampler = torchaudio.transforms.Resample(
            sample_rate, 
            processor.feature_extractor.sampling_rate
        )
        waveform = resampler(waveform)

    # Define a silence threshold (e.g., 0.01 for low amplitude sounds)
    threshold = 0.01
    
    # Detect silent frames by checking if the absolute amplitude is below the threshold
    silent_frames = torch.abs(waveform) < threshold
    
    # Calculate the percentage of silent frames
    silent_percentage = silent_frames.float().mean().item() * 100
    
    # Set a percentage threshold (e.g., if 95% of the audio is silent, consider it silent)
    silence_threshold = 90
    if silent_percentage > silence_threshold:
        return torch.zeros((0, model.config.hidden_size))
        
    # Prepare input
    input_values = processor(
        waveform.squeeze().numpy(), 
        sampling_rate=processor.feature_extractor.sampling_rate, 
        return_tensors="pt"
    ).input_values
    
    # Extract features
    with torch.no_grad():
        outputs = model(input_values)
        # Get the last hidden states
        features = outputs.last_hidden_state
        
    return features.squeeze(0)

In [None]:
extract_audio_features("/kaggle/input/mmhate/HateMM/audio/non_hate_video_73.wav", processor, model)

In [None]:
import torch
import h5py
import numpy as np
from transformers import ViTModel, ViTImageProcessor

# Initialize model and processor
model_name = "google/vit-base-patch16-224"
model = ViTModel.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)

def extract_image_features(h5_path, processor, model, batch_size=32):
    """SIZE [N,768]
    N is the number of images"""
    # Load images from h5py file
    with h5py.File(h5_path, 'r') as f:
        frames = f['frames'][:]  # Load all frames
    # Convert frames to RGB if necessary
    if len(frames.shape) == 3:  # If single channel
        frames = np.stack([frames] * 3, axis=-1)
    
    # Process frames in batches
    all_features = []
    for i in range(0, len(frames), batch_size):
        batch_frames = frames[i:i + batch_size]
        
        # Prepare input
        inputs = processor(
            images=batch_frames,
            return_tensors="pt",
            padding=True
        )
        
        # Extract features
        with torch.no_grad():
            outputs = model(inputs.pixel_values)
            # Get the pooled features
            features = outputs.pooler_output
            all_features.append(features)

    # Concatenate all features
    features = torch.cat(all_features, dim=0)
    return features

In [None]:
extract_image_features("/kaggle/input/mmhate/HateMM/image_sequences/hate_video_122.h5",processor, model).shape

In [None]:
import torch

import numpy as np
from sentence_transformers import SentenceTransformer

model_name = "sentence-transformers/all-mpnet-base-v2"
model = SentenceTransformer(model_name)

def extract_text_features(txt_path, model):
    """size [N, 768] where N is the number of sentences in the transcript.
    """
    
    # Load text from file
    with open(txt_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    # Split text into sentences
    sentences = [s.strip() for s in text.split('. ') if s.strip()] 
    
    # empty transcript (silent/only-music audio)
    if not sentences:
        return torch.zeros((0, model.config.hidden_size))
    
    embeddings = model.encode(sentences)
    return embeddings

In [None]:
extract_text_features("/kaggle/input/mmhate/HateMM/transcripts/hate_video_95.txt", model).shape

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    Wav2Vec2Processor, 
    Wav2Vec2Model,
    ViTModel, 
    ViTImageProcessor
)
from sentence_transformers import SentenceTransformer
import torchaudio
import h5py
import numpy as np

class MultimodalHateClassifier(nn.Module):
    def __init__(
        self,
        hidden_size=768,  # All extractors output 768-dim features
        fusion_hidden_size=512,
        num_heads=8,
        dropout=0.1
    ):
        super().__init__()
        
        # Initialize feature extractors
        self._init_extractors()
        
        # Modality-specific encoders to process variable-length sequences
        self.audio_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dim_feedforward=fusion_hidden_size,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=2
        )
        
        self.vision_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dim_feedforward=fusion_hidden_size,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=2
        )
        
        self.text_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=num_heads,
                dim_feedforward=fusion_hidden_size,
                dropout=dropout,
                batch_first=True
            ),
            num_layers=2
        )
        
        # Cross-modal attention fusion
        self.cross_modal_attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Final classification layers
        self.fusion_layer = nn.Sequential(
            nn.Linear(hidden_size * 3, fusion_hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_size, fusion_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_hidden_size // 2, 1)
        )

    def _init_extractors(self):
        """Initialize all feature extractors as private members"""
        # Audio extractor
        self._wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self._wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        
        # Vision extractor
        self._vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self._vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
        
        # Text extractor
        self._text_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
        
        # Freeze extractors
        for param in self._wav2vec_model.parameters():
            param.requires_grad = False
        for param in self._vit_model.parameters():
            param.requires_grad = False
        for param in self._text_model.parameters():
            param.requires_grad = False

    def _extract_audio_features(self, audio_path):
        """Extract audio features using Wav2Vec2"""
        # Load and preprocess audio
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample if necessary
        if sample_rate != self._wav2vec_processor.feature_extractor.sampling_rate:
            resampler = torchaudio.transforms.Resample(
                sample_rate, 
                self._wav2vec_processor.feature_extractor.sampling_rate
            )
            waveform = resampler(waveform)

        # Check for silence
        threshold = 0.01
        silent_frames = torch.abs(waveform) < threshold
        silent_percentage = silent_frames.float().mean().item() * 100
        
        if silent_percentage > 90:
            return torch.zeros((0, self._wav2vec_model.config.hidden_size))
            
        # Process audio
        input_values = self._wav2vec_processor(
            waveform.squeeze().numpy(), 
            sampling_rate=self._wav2vec_processor.feature_extractor.sampling_rate, 
            return_tensors="pt"
        ).input_values
        
        # Extract features
        with torch.no_grad():
            outputs = self._wav2vec_model(input_values)
            features = outputs.last_hidden_state
            
        return features.squeeze(0)

    def _extract_image_features(self, h5_path, batch_size=32):
        """Extract image features using ViT"""
        # Load images from h5py file
        with h5py.File(h5_path, 'r') as f:
            frames = f['frames'][:]
            
        # Convert frames to RGB if necessary
        if len(frames.shape) == 3:
            frames = np.stack([frames] * 3, axis=-1)
        
        # Process frames in batches
        all_features = []
        for i in range(0, len(frames), batch_size):
            batch_frames = frames[i:i + batch_size]
            
            # Prepare input
            inputs = self._vit_processor(
                images=batch_frames,
                return_tensors="pt",
                padding=True
            )
            
            # Extract features
            with torch.no_grad():
                outputs = self._vit_model(inputs.pixel_values)
                features = outputs.last_hidden_state[:, 0, :]  # Use CLS token
                all_features.append(features)
        
        # Concatenate all features
        features = torch.cat(all_features, dim=0)
        return features

    def _extract_text_features(self, txt_path):
        """Extract text features using SentenceTransformer"""
        # Load and preprocess text
        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        # Split into sentences
        sentences = [s.strip() for s in text.split('. ') if s.strip()]
        
        if not sentences:
            return torch.zeros((0, self._text_model.config.hidden_size)) 
        
        # Extract features
        with torch.no_grad():
            embeddings = self._text_model.encode(sentences)
            features = torch.tensor(embeddings)
            
        return features

    def _apply_attention_pooling(self, features, encoder):
        """Apply transformer encoding and attention pooling to get sequence representation"""
        if features.size(0) == 0:  # Handle empty sequences
            return torch.zeros(1, features.size(-1))
            
        # Add positional encoding
        pos = torch.arange(0, features.size(1)).unsqueeze(0)
        pos_encoding = torch.zeros_like(features[0])
        pos_encoding[:, 0::2] = torch.sin(pos.float() / 10000 ** (torch.arange(0, features.size(-1), 2).float() / features.size(-1)))
        pos_encoding[:, 1::2] = torch.cos(pos.float() / 10000 ** (torch.arange(1, features.size(-1), 2).float() / features.size(-1)))
        
        features = features + pos_encoding
        
        # Apply transformer encoding
        features = encoder(features.unsqueeze(0))
        
        # Apply attention pooling
        attention_weights = torch.softmax(
            torch.matmul(features, features.transpose(-2, -1)) / np.sqrt(features.size(-1)),
            dim=-1
        )
        pooled = torch.matmul(attention_weights, features).mean(dim=1)
        
        return pooled

    def forward(self, audio_path, image_path, text_path):
        """Forward pass through the multimodal architecture"""
        # Extract features from each modality
        audio_features = self._extract_audio_features(audio_path)
        image_features = self._extract_image_features(image_path)
        text_features = self._extract_text_features(text_path)
        
        # Process each modality sequence
        audio_encoded = self._apply_attention_pooling(audio_features, self.audio_encoder)
        vision_encoded = self._apply_attention_pooling(image_features, self.vision_encoder)
        text_encoded = self._apply_attention_pooling(text_features, self.text_encoder)
        
        # Concatenate encoded features
        combined_features = torch.cat([audio_encoded, vision_encoded, text_encoded], dim=-1)
        
        # Final classification
        logits = self.fusion_layer(combined_features)
        return torch.sigmoid(logits)

    def predict(self, audio_path, image_path, text_path):
        """Convenience method for making predictions"""
        self.eval()
        with torch.no_grad():
            output = self.forward(audio_path, image_path, text_path)
            prediction = (output >= 0.5).int().item()
            confidence = output.item()
        return {
            'prediction': 'hate' if prediction == 1 else 'non-hate',
            'confidence': confidence
        }