In [1]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import ViTModel, ViTFeatureExtractor
from transformers import BertTokenizer, BertModel
import re
import glob

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load dataset
df = pd.read_csv('/kaggle/input/multilingual-meme-datasets/final_datasets.csv')
print(f"Dataset loaded with shape: {df.shape}")

# Image directory
image_dir = "/kaggle/input/multilingual-meme-datasets/datasets/datasets"

2025-04-19 19:08:05.981104: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745089686.265330      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745089686.343432      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda:0
Dataset loaded with shape: (25600, 11)


In [2]:
# Create a mapping between dataset 'name' and actual image filenames
def create_image_mapping(dataframe, image_directory):
    """
    Create a mapping between dataset 'name' column and actual image filenames in the directory.
    This handles the case where image filenames might have different formats.
    """
    # Get all image files in the directory
    image_files = glob.glob(os.path.join(image_directory, '*.*'))
    image_mapping = {}
    
    # Create a set of all available image filenames (without path)
    available_images = {os.path.basename(f) for f in image_files}
    print(f"Found {len(available_images)} images in directory")
    
    # Print some sample images to understand the naming pattern
    print("Sample image filenames:", list(available_images)[:5])
    
    # Print some sample names from the dataset
    print("Sample names from dataset:", dataframe['name'].iloc[:5].tolist())
    
    # Method 1: Exact match
    for name in dataframe['name'].unique():
        if name in available_images:
            image_mapping[name] = name
    
    # Method 2: Check if the id is part of the filename
    unmapped_names = set(dataframe['name']) - set(image_mapping.keys())
    for name in unmapped_names:
        # Extract ID from name (assuming name has some ID pattern)
        id_match = re.search(r'\d+', name)
        if id_match:
            id_value = id_match.group()
            # Look for files containing this ID
            matching_files = [f for f in available_images if id_value in f]
            if matching_files:
                image_mapping[name] = matching_files[0]
    
    # Method 3: Try matching using 'ids' or 'id' column if available
    if 'ids' in dataframe.columns or 'id' in dataframe.columns:
        id_col = 'ids' if 'ids' in dataframe.columns else 'id'
        id_to_name = dict(zip(dataframe[id_col], dataframe['name']))
        
        for id_value, name in id_to_name.items():
            if name not in image_mapping:
                # Look for files containing this ID
                matching_files = [f for f in available_images if str(id_value) in f]
                if matching_files:
                    image_mapping[name] = matching_files[0]
    
    print(f"Successfully mapped {len(image_mapping)} out of {len(dataframe['name'].unique())} unique names")
    return image_mapping

# Custom dataset class with image mapping
class HarmfulContentDataset(Dataset):
    def __init__(self, dataframe, image_dir, feature_extractor, tokenizer, image_mapping=None, max_len=128):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.image_mapping = image_mapping or {}
        
        # Create a default blank image to use when an image is not found
        self.blank_image = Image.new('RGB', (224, 224), color='white')
        
        # Image transformation
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Text features
        text_features = f"{row['gender']} {row['age']} {row['age_bucket']} {row['dominant_emotion']} {row['dominant_race']} {row['translated_text']}"
        
        # Tokenize text
        encoding = self.tokenizer.encode_plus(
            text_features,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        # Image processing
        try:
            # Get the correct image filename using the mapping
            image_filename = self.image_mapping.get(row['name'], row['name'])
            image_path = os.path.join(self.image_dir, image_filename)
            
            # Check if the file exists
            if os.path.exists(image_path):
                image = Image.open(image_path).convert('RGB')
            else:
                # Try alternate approach: check if the file exists with different extensions
                base_name = os.path.splitext(image_path)[0]
                for ext in ['.jpg', '.jpeg', '.png', '.gif']:
                    alt_path = base_name + ext
                    if os.path.exists(alt_path):
                        image = Image.open(alt_path).convert('RGB')
                        break
                else:
                    # If still not found, use blank image
                    image = self.blank_image
                    if idx % 100 == 0:  # Limit logging to avoid flooding
                        print(f"Image not found for {row['name']}, using blank image")
            
            # Process image for ViT
            image_features = self.feature_extractor(images=image, return_tensors="pt")
            pixel_values = image_features.pixel_values.squeeze()
            
        except Exception as e:
            if idx % 100 == 0:  # Limit logging
                print(f"Error processing image for {row['name']}: {e}")
            # Create blank image features
            pixel_values = torch.zeros((3, 224, 224))
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'pixel_values': pixel_values,
            'labels': torch.tensor(row['label'], dtype=torch.long)
        }

# Custom model class (ViT + BERT + classification head)
class MultimodalClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(MultimodalClassifier, self).__init__()
        
        # Load pre-trained Vision Transformer
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        
        # Load pre-trained BERT
        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        
        # Freeze ViT weights
        for param in self.vit.parameters():
            param.requires_grad = False
            
        # Unfreeze the last few layers of ViT
        for param in self.vit.encoder.layer[-2:].parameters():
            param.requires_grad = True
        
        # Freeze BERT weights
        for param in self.bert.parameters():
            param.requires_grad = False
            
        # Unfreeze the last few layers of BERT
        for param in self.bert.encoder.layer[-2:].parameters():
            param.requires_grad = True
        
        # Dimensionality of ViT and BERT embeddings
        vit_hidden_size = self.vit.config.hidden_size
        bert_hidden_size = self.bert.config.hidden_size
        
        # Enhanced classification head
        self.classifier = nn.Sequential(
            nn.Linear(vit_hidden_size + bert_hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, input_ids, attention_mask, pixel_values):
        # Process image with ViT
        vit_outputs = self.vit(pixel_values=pixel_values)
        vit_embeddings = vit_outputs.last_hidden_state[:, 0, :]  # CLS token
        
        # Process text with BERT
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        bert_embeddings = bert_outputs.last_hidden_state[:, 0, :]  # CLS token
        
        # Concatenate image and text features
        combined_embeddings = torch.cat((vit_embeddings, bert_embeddings), dim=1)
        
        # Classification
        logits = self.classifier(combined_embeddings)
        
        return logits

# Training function
def train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs=10):
    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []
    
    best_valid_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            optimizer.zero_grad()
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
            
            # Print batch progress
            if (batch_idx + 1) % 10 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        train_loss = running_loss / len(train_loader)
        train_accuracy = correct_train / total_train
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        
        # Validation
        model.eval()
        running_valid_loss = 0.0
        correct_valid = 0
        total_valid = 0
        
        with torch.no_grad():
            for batch in valid_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                pixel_values = batch['pixel_values'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
                loss = criterion(outputs, labels)
                
                running_valid_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                total_valid += labels.size(0)
                correct_valid += (predicted == labels).sum().item()
        
        valid_loss = running_valid_loss / len(valid_loader)
        valid_accuracy = correct_valid / total_valid
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)
        
        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, '
              f'Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}')
        
        # Save best model
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best_multimodal_model.pth')
            print(f'Saved model with validation loss: {valid_loss:.4f}')
    
    return train_losses, valid_losses, train_accuracies, valid_accuracies

# Function for evaluation and metrics
def evaluate_model(model, test_loader):
    model.eval()
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
            _, predicted = torch.max(outputs.data, 1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    conf_matrix = confusion_matrix(y_true, y_pred)
    class_report = classification_report(y_true, y_pred, target_names=['Non-Harmful', 'Harmful'])
    
    print(f'Test Accuracy: {accuracy:.4f}')
    print('\nConfusion Matrix:')
    print(conf_matrix)
    print('\nClassification Report:')
    print(class_report)
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Non-Harmful', 'Harmful'], 
                yticklabels=['Non-Harmful', 'Harmful'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    return accuracy, conf_matrix, class_report

# Function to plot training curves
def plot_training_curves(train_losses, valid_losses, train_accuracies, valid_accuracies):
    # Plot loss
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training vs Validation Loss')
    plt.legend()
    
    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(valid_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Training vs Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()

# Function to inspect dataset
def inspect_dataset(df, image_dir):
    """
    Inspects the dataset and the image directory to understand the structure.
    """
    print("\n=== Dataset Inspection ===")
    print(f"Dataset columns: {df.columns.tolist()}")
    print(f"Number of rows: {len(df)}")
    print(f"Label distribution: {df['label'].value_counts().to_dict()}")
    
    # Check if image directory exists
    if not os.path.exists(image_dir):
        print(f"WARNING: Image directory does not exist: {image_dir}")
        return
    
    # Count image files
    image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f)) and 
                   f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
    
    print(f"Number of image files in directory: {len(image_files)}")
    if len(image_files) > 0:
        print(f"Sample image filenames: {image_files[:5]}")
    
    # Check name column format
    print("\nSample names from dataset:")
    print(df['name'].head(5).tolist())
    
    # Check if any dataset names exactly match image filenames
    matching_names = [name for name in df['name'].unique() if name in image_files]
    print(f"Number of exact matches between 'name' column and image filenames: {len(matching_names)}")
    
    # Check if IDs are in the name column
    if 'id' in df.columns or 'ids' in df.columns:
        id_col = 'ids' if 'ids' in df.columns else 'id'
        print(f"\nSample values from '{id_col}' column:")
        print(df[id_col].head(5).tolist())
        
        # Check if IDs are in image filenames
        sample_ids = df[id_col].astype(str).head(5).tolist()
        for id_val in sample_ids:
            matches = [f for f in image_files if id_val in f]
            if matches:
                print(f"Found matches for ID {id_val}: {matches[:2]}")
                
    return image_files

# Main execution
def main():
    # Initialize feature extractor and tokenizer
    feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
    tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
    
    # Inspect dataset and image directory
    image_files = inspect_dataset(df, image_dir)
    
    # Create mapping between dataset names and image filenames
    image_mapping = create_image_mapping(df, image_dir)
    
    # Split dataset
    train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42, stratify=df['label'])
    valid_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['label'])
    
    print(f"Train set: {len(train_df)}, Validation set: {len(valid_df)}, Test set: {len(test_df)}")
    
    # Create datasets with image mapping
    train_dataset = HarmfulContentDataset(train_df, image_dir, feature_extractor, tokenizer, image_mapping)
    valid_dataset = HarmfulContentDataset(valid_df, image_dir, feature_extractor, tokenizer, image_mapping)
    test_dataset = HarmfulContentDataset(test_df, image_dir, feature_extractor, tokenizer, image_mapping)
    
    # Create dataloaders with appropriate batch size based on available GPU memory
    batch_size = 8  # Reduced batch size to handle larger model
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=2)
    
    # Initialize model
    model = MultimodalClassifier(num_classes=2).to(device)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    
    # Train model
    print("Starting training...")
    train_losses, valid_losses, train_accuracies, valid_accuracies = train_model(
        model, train_loader, valid_loader, criterion, optimizer, num_epochs=2
    )
    
    # Plot training curves
    plot_training_curves(train_losses, valid_losses, train_accuracies, valid_accuracies)
    
    # Load best model for evaluation
    model.load_state_dict(torch.load('best_multimodal_model.pth'))
    
    # Evaluate model
    print("\nEvaluating model on test set...")
    accuracy, conf_matrix, class_report = evaluate_model(model, test_loader)
    
    # Save final model with metadata
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'accuracy': accuracy,
        'conf_matrix': conf_matrix,
        'class_report': class_report,
    }, 'multimodal_harmful_content_classifier.pth')
    
    print("Model training and evaluation complete!")

if __name__ == "__main__":
    main()

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]


=== Dataset Inspection ===
Dataset columns: ['ids', 'name', 'text', 'label', 'id', 'gender', 'age', 'age_bucket', 'dominant_emotion', 'dominant_race', 'translated_text']
Number of rows: 25600
Label distribution: {1: 17388, 0: 8212}
Number of image files in directory: 25716
Sample image filenames: ['eng476.png', 'meme_184.png', 'tangaila (166).jpg', 'Image- (178).jpg', 'Image- (2026).jpg']

Sample names from dataset:
['tangaila (1).jpg', 'tangaila (2).jpg', 'tangaila (3).jpg', 'tangaila (4).jpg', 'tangaila (5).jpg']
Number of exact matches between 'name' column and image filenames: 18670

Sample values from 'ids' column:
[1, 2, 3, 4, 5]
Found matches for ID 1: ['meme_184.png', 'tangaila (166).jpg']
Found matches for ID 2: ['Image- (2026).jpg', 'image_ (2682).jpg']
Found matches for ID 3: ['Bangla Thug Life (374).jpg', '37825.png']
Found matches for ID 4: ['eng476.png', 'meme_184.png']
Found matches for ID 5: ['52691.png', '37825.png']
Found 25716 images in directory
Sample image filena

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Starting training...
Epoch 1, Batch 10/2240, Loss: 0.6528
Epoch 1, Batch 20/2240, Loss: 0.7533
Epoch 1, Batch 30/2240, Loss: 0.5498
Epoch 1, Batch 40/2240, Loss: 0.6444
Epoch 1, Batch 50/2240, Loss: 0.7191
Epoch 1, Batch 60/2240, Loss: 0.6911
Epoch 1, Batch 70/2240, Loss: 0.4955
Epoch 1, Batch 80/2240, Loss: 0.6798
Epoch 1, Batch 90/2240, Loss: 0.5652
Epoch 1, Batch 100/2240, Loss: 0.6873
Epoch 1, Batch 110/2240, Loss: 0.5965
Epoch 1, Batch 120/2240, Loss: 0.5416
Epoch 1, Batch 130/2240, Loss: 0.3786
Epoch 1, Batch 140/2240, Loss: 0.5852
Epoch 1, Batch 150/2240, Loss: 0.4975
Epoch 1, Batch 160/2240, Loss: 0.6288
Epoch 1, Batch 170/2240, Loss: 0.5387
Epoch 1, Batch 180/2240, Loss: 0.6311
Epoch 1, Batch 190/2240, Loss: 0.6473
Epoch 1, Batch 200/2240, Loss: 0.4852
Epoch 1, Batch 210/2240, Loss: 0.5446
Epoch 1, Batch 220/2240, Loss: 0.4739
Epoch 1, Batch 230/2240, Loss: 0.9125
Epoch 1, Batch 240/2240, Loss: 0.6559
Epoch 1, Batch 250/2240, Loss: 0.4290
Epoch 1, Batch 260/2240, Loss: 0.6857




Epoch 1, Batch 1430/2240, Loss: 0.4627
Epoch 1, Batch 1440/2240, Loss: 0.2380
Epoch 1, Batch 1450/2240, Loss: 0.7317
Epoch 1, Batch 1460/2240, Loss: 0.5255
Epoch 1, Batch 1470/2240, Loss: 0.4367
Epoch 1, Batch 1480/2240, Loss: 0.6106
Epoch 1, Batch 1490/2240, Loss: 0.2603
Epoch 1, Batch 1500/2240, Loss: 0.3898
Epoch 1, Batch 1510/2240, Loss: 0.1618
Epoch 1, Batch 1520/2240, Loss: 0.7836
Epoch 1, Batch 1530/2240, Loss: 0.4737
Epoch 1, Batch 1540/2240, Loss: 0.6425
Epoch 1, Batch 1550/2240, Loss: 0.3202
Epoch 1, Batch 1560/2240, Loss: 0.3124
Epoch 1, Batch 1570/2240, Loss: 0.2516
Epoch 1, Batch 1580/2240, Loss: 0.5325
Epoch 1, Batch 1590/2240, Loss: 0.3164
Epoch 1, Batch 1600/2240, Loss: 0.3354
Epoch 1, Batch 1610/2240, Loss: 0.3916
Epoch 1, Batch 1620/2240, Loss: 0.4494
Epoch 1, Batch 1630/2240, Loss: 0.3839
Epoch 1, Batch 1640/2240, Loss: 0.4280




Epoch 1, Batch 1650/2240, Loss: 0.2828
Epoch 1, Batch 1660/2240, Loss: 0.4417
Epoch 1, Batch 1670/2240, Loss: 0.6498
Epoch 1, Batch 1680/2240, Loss: 0.6733
Epoch 1, Batch 1690/2240, Loss: 0.4434
Epoch 1, Batch 1700/2240, Loss: 0.3513
Epoch 1, Batch 1710/2240, Loss: 0.4260
Epoch 1, Batch 1720/2240, Loss: 0.8439
Epoch 1, Batch 1730/2240, Loss: 0.2810
Epoch 1, Batch 1740/2240, Loss: 0.3927
Epoch 1, Batch 1750/2240, Loss: 0.4831
Epoch 1, Batch 1760/2240, Loss: 0.6130
Epoch 1, Batch 1770/2240, Loss: 0.2510
Epoch 1, Batch 1780/2240, Loss: 0.2721
Epoch 1, Batch 1790/2240, Loss: 0.5156
Epoch 1, Batch 1800/2240, Loss: 0.2417
Epoch 1, Batch 1810/2240, Loss: 0.4087
Epoch 1, Batch 1820/2240, Loss: 0.1495
Epoch 1, Batch 1830/2240, Loss: 0.2831
Epoch 1, Batch 1840/2240, Loss: 0.2266
Epoch 1, Batch 1850/2240, Loss: 0.3234
Epoch 1, Batch 1860/2240, Loss: 0.6026
Epoch 1, Batch 1870/2240, Loss: 0.5248
Epoch 1, Batch 1880/2240, Loss: 0.5558
Epoch 1, Batch 1890/2240, Loss: 0.4657
Epoch 1, Batch 1900/2240,



Epoch 2, Batch 590/2240, Loss: 0.4992
Epoch 2, Batch 600/2240, Loss: 0.2717
Epoch 2, Batch 610/2240, Loss: 0.7523
Epoch 2, Batch 620/2240, Loss: 0.6227
Epoch 2, Batch 630/2240, Loss: 0.5180
Epoch 2, Batch 640/2240, Loss: 0.2567
Epoch 2, Batch 650/2240, Loss: 0.3355
Epoch 2, Batch 660/2240, Loss: 0.4679
Epoch 2, Batch 670/2240, Loss: 0.5647
Epoch 2, Batch 680/2240, Loss: 0.1875
Epoch 2, Batch 690/2240, Loss: 0.8502
Epoch 2, Batch 700/2240, Loss: 0.6898
Epoch 2, Batch 710/2240, Loss: 0.4071
Epoch 2, Batch 720/2240, Loss: 0.3502
Epoch 2, Batch 730/2240, Loss: 0.2869
Epoch 2, Batch 740/2240, Loss: 0.3653
Epoch 2, Batch 750/2240, Loss: 0.2491
Epoch 2, Batch 760/2240, Loss: 0.4620
Epoch 2, Batch 770/2240, Loss: 0.2803
Epoch 2, Batch 780/2240, Loss: 0.4899
Epoch 2, Batch 790/2240, Loss: 0.3095
Epoch 2, Batch 800/2240, Loss: 0.0846
Epoch 2, Batch 810/2240, Loss: 0.2549
Epoch 2, Batch 820/2240, Loss: 0.2430
Epoch 2, Batch 830/2240, Loss: 0.2761
Epoch 2, Batch 840/2240, Loss: 0.2130
Epoch 2, Bat



Epoch 2, Batch 880/2240, Loss: 0.3016
Epoch 2, Batch 890/2240, Loss: 0.5420
Epoch 2, Batch 900/2240, Loss: 0.4342
Epoch 2, Batch 910/2240, Loss: 0.4521
Epoch 2, Batch 920/2240, Loss: 0.5895
Epoch 2, Batch 930/2240, Loss: 0.2728
Epoch 2, Batch 940/2240, Loss: 0.4232
Epoch 2, Batch 950/2240, Loss: 0.2512
Epoch 2, Batch 960/2240, Loss: 0.3862
Epoch 2, Batch 970/2240, Loss: 0.5139
Epoch 2, Batch 980/2240, Loss: 0.3752
Epoch 2, Batch 990/2240, Loss: 0.3571
Epoch 2, Batch 1000/2240, Loss: 0.1973
Epoch 2, Batch 1010/2240, Loss: 0.2843
Epoch 2, Batch 1020/2240, Loss: 0.5188
Epoch 2, Batch 1030/2240, Loss: 0.3690
Epoch 2, Batch 1040/2240, Loss: 0.3686
Epoch 2, Batch 1050/2240, Loss: 0.2986
Epoch 2, Batch 1060/2240, Loss: 0.4678
Epoch 2, Batch 1070/2240, Loss: 0.7500
Epoch 2, Batch 1080/2240, Loss: 0.1715
Epoch 2, Batch 1090/2240, Loss: 0.2815
Epoch 2, Batch 1100/2240, Loss: 0.1799
Epoch 2, Batch 1110/2240, Loss: 0.5078
Epoch 2, Batch 1120/2240, Loss: 0.3061
Epoch 2, Batch 1130/2240, Loss: 0.223

  model.load_state_dict(torch.load('best_multimodal_model.pth'))



Evaluating model on test set...




Test Accuracy: 0.7990

Confusion Matrix:
[[ 972  260]
 [ 512 2096]]

Classification Report:
              precision    recall  f1-score   support

 Non-Harmful       0.65      0.79      0.72      1232
     Harmful       0.89      0.80      0.84      2608

    accuracy                           0.80      3840
   macro avg       0.77      0.80      0.78      3840
weighted avg       0.81      0.80      0.80      3840

Model training and evaluation complete!


In [14]:
import torch
from transformers import BertTokenizer, ViTFeatureExtractor
from PIL import Image
import torch.nn as nn
import re

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the MultimodalClassifier (update as per your training code)
class MultimodalClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(MultimodalClassifier, self).__init__()
        
        # Load pre-trained Vision Transformer
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        
        # Load pre-trained BERT
        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        
        # Freeze ViT weights
        for param in self.vit.parameters():
            param.requires_grad = False
            
        # Unfreeze the last few layers of ViT
        for param in self.vit.encoder.layer[-2:].parameters():
            param.requires_grad = True
        
        # Freeze BERT weights
        for param in self.bert.parameters():
            param.requires_grad = False
            
        # Unfreeze the last few layers of BERT
        for param in self.bert.encoder.layer[-2:].parameters():
            param.requires_grad = True
        
        # Dimensionality of ViT and BERT embeddings
        vit_hidden_size = self.vit.config.hidden_size
        bert_hidden_size = self.bert.config.hidden_size
        
        # Enhanced classification head
        self.classifier = nn.Sequential(
            nn.Linear(vit_hidden_size + bert_hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, input_ids, attention_mask, pixel_values):
        # Process image with ViT
        vit_outputs = self.vit(pixel_values=pixel_values)
        vit_embeddings = vit_outputs.last_hidden_state[:, 0, :]  # CLS token
        
        # Process text with BERT
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        bert_embeddings = bert_outputs.last_hidden_state[:, 0, :]  # CLS token
        
        # Concatenate image and text features
        combined_embeddings = torch.cat((vit_embeddings, bert_embeddings), dim=1)
        
        # Classification
        logits = self.classifier(combined_embeddings)
        
        return logits

# Load tokenizer and feature extractor
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

# Load model
model = MultimodalClassifier(num_classes=2)
model.load_state_dict(torch.load("/kaggle/working/best_multimodal_model.pth", map_location=device))
model.to(device)
model.eval()

# Function to parse input text and metadata
def parse_input(full_text):
    match = re.search(r"\[(.*?)\]$", full_text)
    meta_str = match.group(1) if match else ""
    main_text = full_text.split("[")[0].strip()

    # Extract metadata from bracketed string
    metadata = dict(re.findall(r"(\w+):\s*([^|]+)", meta_str))

    gender = metadata.get("Gender", "unknown")
    caste = metadata.get("Caste", "unknown")
    religion = metadata.get("Religion", metadata.get("Dominant Race", "unknown"))

    return main_text, gender.strip(), caste.strip(), religion.strip()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load("/kaggle/working/best_multimodal_model.pth", map_location=device))


In [15]:
import torch.nn.functional as F

def predict_single(full_input_text, image_path):
    # Parse text and metadata
    main_text, gender, caste, religion = parse_input(full_input_text)

    # Preprocess image
    image = Image.open(image_path).convert("RGB")
    pixel_values = feature_extractor(images=image, return_tensors="pt")['pixel_values'].to(device)

    # Combine all text
    combined_text = f"{main_text} Gender: {gender} Caste: {caste} Religion: {religion}"
    encoding = tokenizer(combined_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # Predict
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
        probs = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probs, dim=1)

    return predicted.item(), confidence.item()

In [17]:
full_input_text = "when you're feeling horny asf but your habibi is on periods let's try a goat. [Label: 1 | ID: 1235 | Gender: man | Age: 43 | Age Bucket: 40-50 | Dominant Emotion: happy | Dominant Race: middle eastern]"
image_path = "/kaggle/input/multilingual-meme-datasets/datasets/datasets/37945.png"

prediction, confidence = predict_single(full_input_text, image_path)

label_map = {0: "Non-Hateful", 1: "Hateful"}

print("Prediction:", prediction)
print("Label:", label_map[prediction])
print(f"Confidence Score: {confidence:.4f}")

Prediction: 0
Label: Non-Hateful
Confidence Score: 0.7124
