In [None]:
!pip install -U -q sentence-transformers

In [None]:
import numpy as np
import os
import random
import torch
import torch.nn as nn
from transformers import Wav2Vec2Processor, Wav2Vec2Model, ViTModel, ViTImageProcessor
from sentence_transformers import SentenceTransformer
import torchaudio
import h5py
import torch.nn.functional as F
from torch.nn.functional import pad
import math
# Set environment variable
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

global_seed = 0
random.seed(global_seed)
np.random.seed(global_seed)

In [None]:
class HateVideoClassifier(nn.Module):
    def __init__(self, hidden_size=768, num_heads=2, num_layers=2, dropout=0.1, train_option="finetune", seed=0):
        
        assert train_option in ["finetune","transfer"], ValueError("Not a correct training option")
        # Ensure reproducible

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        super().__init__()

        self.hidden_size = hidden_size
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Initialize feature extractors
        self._init_extractors(train_option=="transfer")
        
        # Modality-specific projections to ensure consistent dimensionality
        self.audio_proj = nn.Linear(hidden_size, hidden_size,bias=False)
        self.image_proj = nn.Linear(hidden_size, hidden_size,bias=False)
        self.text_proj = nn.Linear(hidden_size, hidden_size,bias=False)
        
        # Modality-specific positional embeddings
        self.audio_pos_embed = nn.Parameter(torch.randn(1, 1, hidden_size))
        self.image_pos_embed = nn.Parameter(torch.randn(1, 1, hidden_size))
        self.text_pos_embed = nn.Parameter(torch.randn(1, 1, hidden_size))
              
        # Missing modality tokens (learnable)
        self.missing_audio_token = nn.Parameter(torch.randn(1, 1, hidden_size))
        self.missing_image_token = nn.Parameter(torch.randn(1, 1, hidden_size))
        self.missing_text_token = nn.Parameter(torch.randn(1, 1, hidden_size))
        
        # Cross-modal attention layers
        self.cross_attention = nn.ModuleList([
            nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout, batch_first=True)
            for _ in range(num_layers)
        ])
        
        # Layer normalization and dropout
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(hidden_size)
            for _ in range(num_layers)
        ])
        
        # Modality-specific layer norms
        self.audio_norm = nn.LayerNorm(hidden_size)
        self.image_norm = nn.LayerNorm(hidden_size)
        self.text_norm = nn.LayerNorm(hidden_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # Final classification layers
        self.output_head = nn.Sequential(
            nn.Linear(hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def _init_extractors(self,transfer=False): # Initialize the feature extractors for each modality
        # Audio
        self.wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        
        # Image
        self.vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
        
        # Text
        self.text_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
        if transfer:
            for param in self.wav2vec_model.parameters():
                param.requires_grad = False
            for param in self.vit_model.parameters():
                param.requires_grad = False
            for param in self.text_model.parameters():
                param.requires_grad = False

    def _extract_audio_features(self, audio_path): # Extract audio features using Wav2Vec2
        
        # Load and preprocess audio
        waveform, sample_rate = torchaudio.load(audio_path)
        
        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Resample if necessary
        if sample_rate != self.wav2vec_processor.feature_extractor.sampling_rate:
            resampler = torchaudio.transforms.Resample(
                sample_rate, 
                self.wav2vec_processor.feature_extractor.sampling_rate
            )
            waveform = resampler(waveform)

        # Check for silence
        threshold = 0.01
        silent_frames = torch.abs(waveform) < threshold
        silent_percentage = silent_frames.float().mean().item() * 100
        
        if silent_percentage > 90:
            return None
            
        # Process audio
        input_values = self.wav2vec_processor(
            waveform.squeeze().numpy(), 
            sampling_rate=self.wav2vec_processor.feature_extractor.sampling_rate, 
            return_tensors="pt"
        ).input_values.to(self.device)
        
        # Extract features
        with torch.no_grad():
            outputs = self.wav2vec_model(input_values)
            features = outputs.last_hidden_state.squeeze(0)
            
        return features

    def _extract_image_features(self, h5_path, batch_size=32): # Extract image features using ViT
        with h5py.File(h5_path, 'r') as f:
            frames = f['frames'][:]

        # Convert frames to RGB if necessary
        if len(frames.shape) == 3:
            frames = np.stack([frames] * 3, axis=-1)
        
        # Process frames in batches
        all_features = []
        for i in range(0, len(frames), batch_size):
            batch_frames = frames[i:i + batch_size]
            
            inputs = self.vit_processor(
                images=batch_frames,
                return_tensors="pt"
            ).pixel_values.to(self.device)
            
            with torch.no_grad():
                outputs = self.vit_model(inputs)
                features = outputs.pooler_output
                all_features.append(features)
        
        return torch.cat(all_features, dim=0) if all_features else None

    def _extract_text_features(self, txt_path): # Extract text features using SentenceTransformer
        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()
        
        sentences = [s.strip() for s in text.split('. ') if s.strip()]
        
        if not sentences:
            return None
            
        with torch.no_grad():
            embeddings = self.text_model.encode(sentences, convert_to_tensor=True).to(self.device)
        
        return embeddings

    def _apply_modality_projection(self, features, proj_layer, pos_embed):
        if features is None:
            return None
            
        features = proj_layer(features)
        return features + pos_embed
        
    def _handle_missing_modality(self, features, missing_token, norm_layer):
        if features is None:
            return missing_token
        return norm_layer(features)

    def _fuse_modalities(self, audio_features, image_features, text_features):
        # Get batch size from non-None features
        batch_size = None
        for features in [audio_features, image_features, text_features]:
            if features is not None:
                batch_size = features.size(0)
                break
        
        if batch_size is None:
            batch_size = 1
            
        # Get sequence lengths
        audio_len = 1 if audio_features is None else audio_features.size(1)
        image_len = 1 if image_features is None else image_features.size(1)
        text_len = 1 if text_features is None else text_features.size(1)

        # Create attention mask based on presence/absence of modalities
        attention_mask = torch.zeros(
            batch_size, audio_len + image_len + text_len,
            device=self.device,
            dtype=torch.bool
        )
        
        # Set mask values based on modality presence
        current_pos = 0
        
        # Audio mask
        if audio_features is not None:
            attention_mask[:, current_pos:current_pos + audio_len] = True
        current_pos += audio_len
        
        # Image mask
        if image_features is not None:
            attention_mask[:, current_pos:current_pos + image_len] = True
        current_pos += image_len
        
        # Text mask
        if text_features is not None:
            attention_mask[:, current_pos:current_pos + text_len] = True

        # Handle missing modalities with correct batch dimension
        audio_features = self._handle_missing_modality(audio_features, 
                                                     self.missing_audio_token.repeat(batch_size, 1, 1), 
                                                     self.audio_norm)
        image_features = self._handle_missing_modality(image_features, 
                                                     self.missing_image_token.repeat(batch_size, 1, 1), 
                                                     self.image_norm)
        text_features = self._handle_missing_modality(text_features, 
                                                    self.missing_text_token.repeat(batch_size, 1, 1), 
                                                    self.text_norm)
        
        # Concatenate all modalities
        combined_features = torch.cat([
            audio_features,
            image_features,
            text_features
        ], dim=1)
        
        # Apply cross-modal attention layers
        features = combined_features
        for attention, norm in zip(self.cross_attention, self.layer_norms):
            torch.cuda.empty_cache()
            gc.collect()
            # Apply attention with correct mask
            attended_features, _ = attention(
                query=features,
                key=features,
                value=features,
                key_padding_mask=~attention_mask,  # PyTorch attention expects False for valid positions
                need_weights=False
            )
            
            # Apply residual connection and dropout
            features = features + self.dropout(attended_features)
            features = norm(features)
        
        # Pool features using the mask
        mask_expanded = attention_mask.unsqueeze(-1).float()
        masked_sum = (features * mask_expanded).sum(dim=1)
        mask_sum = mask_expanded.sum(dim=1).clamp(min=1.0)  # Prevent division by zero
        pooled_features = masked_sum / mask_sum
        
        return pooled_features

    def forward(self, audio_path, image_path, text_path):
        # Extract features for each modality
        audio_features = self._extract_audio_features(audio_path)
        image_features = self._extract_image_features(image_path)
        text_features = self._extract_text_features(text_path)
        
        # Apply projections and embeddings
        audio_features = self._apply_modality_projection(
            audio_features, self.audio_proj, self.audio_pos_embed)
        image_features = self._apply_modality_projection(
            image_features, self.image_proj, self.image_pos_embed)
        text_features = self._apply_modality_projection(
            text_features, self.text_proj, self.text_pos_embed)
        
        # Fuse modalities using cross-attention
        fused_features = self._fuse_modalities(audio_features, image_features, text_features)
        # Final classification
        output = self.output_head(fused_features).squeeze(0)
        del audio_features, image_features, text_features, fused_features
        torch.cuda.empty_cache()
        return output

    def predict(self, audio_path, image_path, text_path):
        self.eval()
        with torch.no_grad():
            output = self.forward(audio_path, image_path, text_path)
            return (output > 0.5).item(), output.item()

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from tqdm import tqdm
import os
import gc
import time

class HateVideoDataset(Dataset):
    def __init__(self, dataframe, base_path):
        self.data = dataframe
        self.base_path = base_path
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Construct full paths
        audio_path = os.path.join(self.base_path, 'audio', row['audio'])
        image_path = os.path.join(self.base_path, 'image_sequences', row['imageseq'])
        text_path = os.path.join(self.base_path, 'transcripts', row['transcript'])
        label = torch.tensor([row['label']], dtype=torch.float)
        return {
            'audio_path': audio_path,
            'image_path': image_path,
            'text_path': text_path,
            'label': label
        }

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def validate(model, val_loader, criterion):
  torch.cuda.empty_cache()
  gc.collect()
  model.eval()
  running_loss = 0.0

  with torch.no_grad():
    for data in tqdm(val_loader):
        torch.cuda.empty_cache()
        gc.collect()
        audio_path = data['audio_path']
        image_path = data['image_path']
        text_path = data['text_path']
        label = data['label'].to("cuda")
        outputs = model(audio_path, image_path, text_path)
        loss = criterion(outputs, label)
        
        running_loss += loss.item()

  val_loss = running_loss / len(val_loader)
  return val_loss

def train_model(model, train_loader, val_loader, criterion, optimizer, 
                scheduler=None, num_epochs=50, 
              patience=3, checkpoint_path="best_model.pth"):

    model.train()

    early_stopping = EarlyStopping(patience=patience)

    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print('-' * 10)
        running_loss = 0.0
        for data in tqdm(train_loader):
            torch.cuda.empty_cache()
            gc.collect()
            audio_path = data['audio_path']
            image_path = data['image_path']
            text_path = data['text_path']
            label = data['label'].to("cuda")
            optimizer.zero_grad()
            outputs = model(audio_path, image_path, text_path)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        print(f"Training Loss: {epoch_loss}")

        val_loss = validate(model, val_loader, criterion)
        print(f"Validation Loss: {val_loss}")

        if scheduler is not None:
            scheduler.step(val_loss)

        early_stopping(val_loss)

        if early_stopping.early_stop:
            print("Early stopping")
            break
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), checkpoint_path)
            print("Model saved")

    return model

def evaluate_model(model, test_loader):
    
    model.eval()
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for data in tqdm(test_loader):
            torch.cuda.empty_cache()
            gc.collect()
            audio_path = data['audio_path']
            image_path = data['image_path']
            text_path = data['text_path']
            label = data['label'].to("cuda")
            outputs = model(audio_path, image_path, text_path)
            y_true.extend(label.cpu().numpy())
            y_pred.extend(outputs.cpu().numpy())
            
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    y_pred = (y_pred > 0.5).astype(int)
    
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_pred)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc
    }

In [None]:
base_path = "/kaggle/input/mmhate/HateMM"

def get_duration(file):
    import wave
    with wave.open(file, 'r') as f:
        frames = f.getnframes()
        rate = f.getframerate()
        duration = frames / float(rate)
        return int(duration)+ (duration>int(duration))

files = []  

for file in os.listdir('/kaggle/input/mmhate/HateMM/audio/'):
    if file.endswith('.wav'):
        if get_duration('/kaggle/input/mmhate/HateMM/audio/' + file) <= 180: # 3 minutes
            files.append(file)

df = pd.DataFrame(columns=['audio', 'imageseq', 'transcript', 'label'])

for file in files:
  filename = file.split('.')[0]
  df = pd.concat([df, pd.DataFrame({'audio': [filename + '.wav'], 'imageseq': [filename + '.h5'], 'transcript': [filename + '.txt'], 'label': [0 if "non" in filename else 1]})], ignore_index=True)
                                                                                                                                           

train, test = train_test_split(df, test_size=0.15, random_state=0, stratify=df['label'])
train, val = train_test_split(train, test_size=0.176, random_state=0, stratify=train['label'])

num_epochs = 20
learning_rate = 1e-4


# Create datasets
train_dataset = HateVideoDataset(train, base_path)
val_dataset = HateVideoDataset(val, base_path)
test_dataset = HateVideoDataset(test, base_path)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=None, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=None)
test_loader = DataLoader(test_dataset, batch_size=None)

# Initialize model and training components
model = HateVideoClassifier().to("cuda")

criterion = torch.nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3, verbose=True
)

# Train model
model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=num_epochs)

# Evaluate model
results = evaluate_model(model, test_loader)
print(results)