## Install Required Libraries

In [1]:
!pip install transformers datasets torchvision wordcloud matplotlib tensorboard gdown

Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch==2.6.0->torchvision)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch==2.6.0->torchvision)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch==2.6.0->torchvision)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collect

In [2]:
!ls /kaggle/input/facebook-hateful-meme-dataset/data

dev.jsonl  img	LICENSE.txt  README.md	test.jsonl  train.jsonl


# Implement a Custom Dataset Class


In [29]:
# @title Dataset Implementation

import os
import torch
import json
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from transformers import BertTokenizer
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import random
from collections import Counter
import re
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from sklearn.feature_extraction.text import TfidfVectorizer
from torch.utils.data import WeightedRandomSampler

# Download NLTK resources
nltk.download('stopwords')
nltk.download('punkt')

class HatefulMemesDataset(Dataset):
    def __init__(self, data_dir, split='train', transform=None, text_processor=None, 
                 max_length=128, augment=False):
        """
        Custom PyTorch Dataset for the Hateful Memes dataset
        """
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        self.text_processor = text_processor
        self.max_length = max_length
        self.augment = augment
        
        # Load annotations
        json_file = os.path.join(data_dir, f"{split}.jsonl")
        self.data = []
        with open(json_file, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
                
        # Calculate balanced class weights
        if split == 'train':
            labels = [item['label'] for item in self.data]
            class_counts = Counter(labels)
            total = sum(class_counts.values())
            # Inverse frequency weighting
            self.class_weights = {
                0: total / (2 * class_counts[0]),  # non-hateful
                1: total / (2 * class_counts[1])   # hateful
            }
            self.sample_weights = [self.class_weights[label] for label in labels]
        else:
            # Initialize empty sample_weights for non-train splits
            self.sample_weights = []
        
        # Enhanced transforms with stronger augmentation
        if self.transform is None:
            if split == 'train':
                self.transform = transforms.Compose([
                    transforms.Resize((288, 288)),  # Even larger initial size
                    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(
                        brightness=0.3,
                        contrast=0.3,
                        saturation=0.3,
                        hue=0.1
                    ),
                    transforms.RandomAffine(
                        degrees=10, 
                        translate=(0.1, 0.1),
                        scale=(0.9, 1.1)
                    ),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]
                    ),
                    transforms.RandomErasing(p=0.3)
                ])
            else:
                self.transform = transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]
                    )
                ])
                
        # Set up text processor with better augmentation
        if self.text_processor is None:
            self.text_processor = BertTokenizer.from_pretrained('bert-base-uncased')
            
        # Setup for text augmentation
        self.augment = augment
        if self.augment:
            try:
                import nltk
                from nltk.corpus import stopwords
                nltk.download('stopwords', quiet=True)
                self.stop_words = set(stopwords.words('english'))
            except:
                print("NLTK stopwords not available, using minimal stopwords")
                self.stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'is', 'are', 'was', 'were'}
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Load and transform image
        img_path = os.path.join(self.data_dir, item['img'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        # Get text and apply augmentation if needed
        text = item['text']
        if self.augment and random.random() < 0.5:  # 50% chance to augment text
            text = self._augment_text(text)
            
        # Tokenize text
        encoding = self.text_processor(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'image': image,
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'text': text,  # Return original/augmented text for debugging
            'label': torch.tensor(item['label'], dtype=torch.float)
        }
        
    def _augment_text(self, text):
        """Apply text augmentation techniques"""
        augmentation_type = random.choice(['synonym', 'deletion', 'swap', 'none'])
        
        if augmentation_type == 'none':
            return text
        
        words = text.split()
        
        if augmentation_type == 'deletion' and len(words) > 4:
            # Randomly delete non-stopwords (up to 15% of words)
            num_to_delete = max(1, int(len(words) * 0.15))
            delete_indices = random.sample(range(len(words)), num_to_delete)
            words = [w for i, w in enumerate(words) if i not in delete_indices or w.lower() in self.stop_words]
        
        elif augmentation_type == 'swap' and len(words) > 2:
            # Randomly swap adjacent words (up to 2 pairs)
            num_swaps = random.randint(1, min(2, len(words)//2))
            for _ in range(num_swaps):
                i = random.randint(0, len(words)-2)
                words[i], words[i+1] = words[i+1], words[i]
        
        elif augmentation_type == 'synonym':
            # Replace some words with synonyms (if available)
            try:
                from nltk.corpus import wordnet
                nltk.download('wordnet', quiet=True)
                
                for i, word in enumerate(words):
                    if random.random() < 0.2 and word.lower() not in self.stop_words:
                        synsets = wordnet.synsets(word)
                        if synsets:
                            synonyms = [lemma.name() for synset in synsets for lemma in synset.lemmas()]
                            if synonyms:
                                words[i] = random.choice(synonyms).replace('_', ' ')
            except:
                # If wordnet is not available, just do a simple character swap
                for i, word in enumerate(words):
                    if random.random() < 0.1 and len(word) > 3:
                        chars = list(word)
                        j = random.randint(0, len(chars)-2)
                        chars[j], chars[j+1] = chars[j+1], chars[j]
                        words[i] = ''.join(chars)
        
        return ' '.join(words)
    
    def get_sampler(self):
        """Get weighted sampler for balanced training"""
        if self.split != 'train':
            return None
            
        weights = torch.DoubleTensor(self.sample_weights)
        sampler = WeightedRandomSampler(
            weights=weights,
            num_samples=len(weights),
            replacement=True
        )
        return sampler

# Functions for text preprocessing

def preprocess_text_for_lstm(text):
    """Preprocess text for LSTM model"""
    # Convert to lowercase
    text = text.lower()
    # Remove special characters
    text = re.sub(r'[^\w\s]', '', text)
    # Tokenize
    tokens = word_tokenize(text)
    # Remove stopwords
    tokens = [word for word in tokens if word not in stopwords.words('english')]
    return tokens

# Data preprocessing and visualization functions

def analyze_class_distribution(dataset):
    """Analyze and visualize class distribution"""
    labels = [item['label'] for item in dataset.data]
    label_counts = Counter(labels)
    
    plt.figure(figsize=(8, 6))
    plt.bar(['Non-Hateful (0)', 'Hateful (1)'], [label_counts[0], label_counts[1]])
    plt.title('Class Distribution in Dataset')
    plt.ylabel('Count')
    plt.savefig('class_distribution.png')
    plt.close()
    
    print(f"Class distribution: Non-Hateful={label_counts[0]}, Hateful={label_counts[1]}")
    print(f"Class imbalance ratio: {label_counts[0]/label_counts[1]:.2f}:1")
    
    return label_counts

def generate_word_cloud(dataset):
    """Generate a word cloud of important words based on TF-IDF"""
    # Extract all text from dataset
    all_texts = [item['text'] for item in dataset.data]
    
    # Separate hateful and non-hateful texts
    hateful_texts = [item['text'] for item in dataset.data if item['label'] == 1]
    non_hateful_texts = [item['text'] for item in dataset.data if item['label'] == 0]
    
    # Calculate TF-IDF
    tfidf = TfidfVectorizer(stop_words='english', max_features=100)
    tfidf_matrix = tfidf.fit_transform(all_texts)
    
    # Get feature names and TF-IDF scores
    feature_names = tfidf.get_feature_names_out()
    tfidf_scores = tfidf_matrix.sum(axis=0).A1
    
    # Create a word cloud of high TF-IDF terms
    word_scores = {feature_names[i]: tfidf_scores[i] for i in range(len(feature_names))}
    
    # Generate word cloud for all texts
    wordcloud = WordCloud(width=800, height=400, background_color='white').generate_from_frequencies(word_scores)
    plt.figure(figsize=(10, 5))
    plt.imshow(wordcloud, interpolation='bilinear')
    plt.axis('off')
    plt.title('Word Cloud of Important Terms (High TF-IDF)')
    plt.savefig('wordcloud_all.png')
    plt.close()
    
    # Generate word clouds for hateful and non-hateful separately
    if hateful_texts:
        tfidf_hateful = TfidfVectorizer(stop_words='english', max_features=100)
        tfidf_matrix_hateful = tfidf_hateful.fit_transform(hateful_texts)
        feature_names_hateful = tfidf_hateful.get_feature_names_out()
        tfidf_scores_hateful = tfidf_matrix_hateful.sum(axis=0).A1
        word_scores_hateful = {feature_names_hateful[i]: tfidf_scores_hateful[i] for i in range(len(feature_names_hateful))}
        
        wordcloud_hateful = WordCloud(width=800, height=400, background_color='white').generate_from_frequencies(word_scores_hateful)
        plt.figure(figsize=(10, 5))
        plt.imshow(wordcloud_hateful, interpolation='bilinear')
        plt.axis('off')
        plt.title('Word Cloud of Important Terms in Hateful Memes')
        plt.savefig('wordcloud_hateful.png')
        plt.close()
    
    if non_hateful_texts:
        tfidf_non_hateful = TfidfVectorizer(stop_words='english', max_features=100)
        tfidf_matrix_non_hateful = tfidf_non_hateful.fit_transform(non_hateful_texts)
        feature_names_non_hateful = tfidf_non_hateful.get_feature_names_out()
        tfidf_scores_non_hateful = tfidf_matrix_non_hateful.sum(axis=0).A1
        word_scores_non_hateful = {feature_names_non_hateful[i]: tfidf_scores_non_hateful[i] for i in range(len(feature_names_non_hateful))}
        
        wordcloud_non_hateful = WordCloud(width=800, height=400, background_color='white').generate_from_frequencies(word_scores_non_hateful)
        plt.figure(figsize=(10, 5))
        plt.imshow(wordcloud_non_hateful, interpolation='bilinear')
        plt.axis('off')
        plt.title('Word Cloud of Important Terms in Non-Hateful Memes')
        plt.savefig('wordcloud_non_hateful.png')
        plt.close()

def visualize_sample_memes(dataset, num_samples=5):
    """Visualize sample memes from the dataset"""
    # Get indices of hateful and non-hateful memes
    hateful_indices = [i for i, item in enumerate(dataset.data) if item['label'] == 1]
    non_hateful_indices = [i for i, item in enumerate(dataset.data) if item['label'] == 0]
    
    # Sample from each class
    sampled_hateful = random.sample(hateful_indices, min(num_samples, len(hateful_indices)))
    sampled_non_hateful = random.sample(non_hateful_indices, min(num_samples, len(non_hateful_indices)))
    
    # Create a figure with subplots
    fig, axes = plt.subplots(2, num_samples, figsize=(15, 8))
    
    # Plot hateful memes
    for i, idx in enumerate(sampled_hateful):
        item = dataset.data[idx]
        img_path = os.path.join(dataset.data_dir, item['img'])
        image = Image.open(img_path).convert('RGB')
        axes[0, i].imshow(image)
        axes[0, i].set_title(f"Hateful: {item['text']}", fontsize=8)
        axes[0, i].axis('off')
    
    # Plot non-hateful memes
    for i, idx in enumerate(sampled_non_hateful):
        item = dataset.data[idx]
        img_path = os.path.join(dataset.data_dir, item['img'])
        image = Image.open(img_path).convert('RGB')
        axes[1, i].imshow(image)
        axes[1, i].set_title(f"Non-Hateful: {item['text']}", fontsize=8)
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.savefig('sample_memes.png')
    plt.close()

# Example usage:
if __name__ == "__main__":
    data_dir = "/kaggle/input/facebook-hateful-meme-dataset/data"
    
    # Create dataset
    train_dataset = HatefulMemesDataset(
        data_dir=data_dir,
        split='train',
        augment=True
    )
    
    # Create data loader with weighted sampling
    train_loader = DataLoader(
        train_dataset,
        batch_size=64,
        sampler=train_dataset.get_sampler(),  # Use weighted sampler
        num_workers=4
    )
    
    # Add more detailed analysis
    analyze_class_distribution(train_dataset)
    print("\nData Loading Statistics:")
    print(f"Total samples: {len(train_dataset)}")
    print(f"Number of batches: {len(train_loader)}")
    print(f"Effective samples per epoch: {len(train_loader) * train_loader.batch_size}")
    
    # Verify augmentation
    if train_dataset.augment:
        print("\nAugmentation Verification:")
        sample_idx = next(iter(train_dataset.get_sampler()))
        original_item = train_dataset.data[sample_idx]
        augmented_item = train_dataset[sample_idx]
        print(f"Original text: {original_item['text']}")
        print(f"Augmented text: {augmented_item['text']}")
    
    # Analyze and visualize the dataset
    generate_word_cloud(train_dataset)
    visualize_sample_memes(train_dataset)
    
    # Print a sample batch
    batch = next(iter(train_loader))
    print(f"Batch size: {len(batch['image'])}")
    print(f"Image shape: {batch['image'].shape}")
    print(f"Input IDs shape: {batch['input_ids'].shape}")
    print(f"Attention mask shape: {batch['attention_mask'].shape}")
    print(f"Labels: {batch['label']}")






[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Class distribution: Non-Hateful=5450, Hateful=3050
Class imbalance ratio: 1.79:1

Data Loading Statistics:
Total samples: 8500
Number of batches: 133
Effective samples per epoch: 8512

Augmentation Verification:
Original text: i'm sure glad i dodged that bullet fuck white women asians are cute and tight asf i should know, i watch anime and hentai
Augmented text: i'm sure glad i dodged that bullet fuck white women asians are cute and tight asf i should know, i watch anime and hentai
Batch size: 64
Image shape: torch.Size([64, 3, 224, 224])
Input IDs shape: torch.Size([64, 128])
Attention mask shape: torch.Size([64, 128])
Labels: tensor([1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 1.,
        0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
        0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
        1., 1., 0., 1., 1., 0., 1., 1., 1., 0.])


# Text Image Models

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class LSTMTextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1, dropout=0.2, bidirectional=True):
        """
        LSTM-based text encoder
        
        Args:
            vocab_size: Size of the vocabulary
            embed_dim: Word embedding dimension
            hidden_dim: LSTM hidden dimension
            num_layers: Number of LSTM layers
            dropout: Dropout probability
            bidirectional: Whether to use bidirectional LSTM
        """
        super(LSTMTextEncoder, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
        self.dropout = nn.Dropout(dropout)
        self.num_directions = 2 if bidirectional else 1
        self.output_dim = hidden_dim * self.num_directions
        
    def forward(self, x, lengths=None):
        """
        Forward pass
        
        Args:
            x: Input tensor of token indices [batch_size, seq_len]
            lengths: Sequence lengths for packing (optional)
            
        Returns:
            output: Last hidden state [batch_size, hidden_dim*num_directions]
        """
        # Embed tokens
        embedded = self.embedding(x)
        embedded = self.dropout(embedded)
        
        # Pack sequences if lengths are provided
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            self.lstm.flatten_parameters()
            _, (hidden, _) = self.lstm(packed)
        else:
            # Otherwise, use standard LSTM
            self.lstm.flatten_parameters()
            output, (hidden, _) = self.lstm(embedded)
        
        # Concatenate bidirectional outputs
        if self.num_directions == 2:
            hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        else:
            hidden = hidden[-1]
            
        return hidden

class BERTTextEncoder(nn.Module):
    def __init__(self, pretrained_model='bert-base-uncased', freeze_bert=False):
        """
        BERT-based text encoder
        
        Args:
            pretrained_model: Name of the pretrained BERT model
            freeze_bert: Whether to freeze BERT parameters
        """
        super(BERTTextEncoder, self).__init__()
        
        # Load pretrained BERT model
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.output_dim = self.bert.config.hidden_size
        
        # Freeze BERT parameters if specified
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
    
    def forward(self, input_ids, attention_mask=None):
        """
        Forward pass
        
        Args:
            input_ids: Input token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            
        Returns:
            cls_output: CLS token embedding [batch_size, hidden_size]
        """
        # BERT forward pass
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # Use CLS token as the sentence representation
        cls_output = outputs.last_hidden_state[:, 0, :]
        
        return cls_output

class CNNImageEncoder(nn.Module):
    def __init__(self, out_dim=512):
        """
        CNN-based image encoder
        
        Args:
            out_dim: Output dimension
        """
        super(CNNImageEncoder, self).__init__()
        
        # CNN layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 28 * 28, 1024)
        self.fc2 = nn.Linear(1024, out_dim)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)
        
        self.output_dim = out_dim
        
    def forward(self, x):
        """
        Forward pass
        
        Args:
            x: Input image tensor [batch_size, 3, 224, 224]
            
        Returns:
            features: Image features [batch_size, out_dim]
        """
        # CNN feature extraction
        x = self.pool(F.relu(self.conv1(x)))  # -> [batch_size, 32, 112, 112]
        x = self.pool(F.relu(self.conv2(x)))  # -> [batch_size, 64, 56, 56]
        x = self.pool(F.relu(self.conv3(x)))  # -> [batch_size, 128, 28, 28]
        
        # Flatten
        x = x.view(x.size(0), -1)  # -> [batch_size, 128*28*28]
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        features = self.fc2(x)
        
        return features

class ResNetImageEncoder(nn.Module):
    def __init__(self, pretrained=True, out_dim=512):
        """
        ResNet-based image encoder
        
        Args:
            pretrained: Whether to use pretrained weights
            out_dim: Output dimension
        """
        super(ResNetImageEncoder, self).__init__()
        
        # Use newer ResNet initialization
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        
        # Remove final layers and add custom head
        modules = list(resnet.children())[:-2]  # Remove avg pool and fc
        self.features = nn.Sequential(*modules)
        
        # Add custom pooling and FC layers
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
        
        # Doubled feature size due to concat of avg and max pool
        self.fc = nn.Sequential(
            nn.Linear(resnet.fc.in_features * 2, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, out_dim)
        )
        
        self.output_dim = out_dim
        
    def forward(self, x):
        x = self.features(x)
        
        # Combine average and max pooling
        avg_pooled = self.avg_pool(x).flatten(1)
        max_pooled = self.max_pool(x).flatten(1)
        pooled = torch.cat([avg_pooled, max_pooled], dim=1)
        
        features = self.fc(pooled)
        return F.normalize(features, p=2, dim=1)  # L2 normalize features

class EarlyFusionLSTMCNNModel(nn.Module):
    def __init__(self, vocab_size=30522, embed_dim=300, hidden_dim=256, num_layers=2, 
                 dropout=0.3, bidirectional=True, fusion_dim=512):
        super().__init__()
        
        # Text Encoder: LSTM
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
        self.text_dropout = nn.Dropout(dropout)
        self.num_directions = 2 if bidirectional else 1
        self.text_output_dim = hidden_dim * self.num_directions
        
        # Image Encoder: CNN
        self.image_encoder = nn.Sequential(
            # First conv block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),  # 112x112
            
            # Second conv block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),  # 56x56
            
            # Third conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),  # 28x28
        )
        
        # Early fusion: combine features before final processing
        # Calculate the output size of the CNN
        self.cnn_output_size = 128 * 28 * 28
        
        # Cross-modal attention mechanism
        self.use_attention = True
        if self.use_attention:
            self.text_attention = nn.Linear(self.text_output_dim, fusion_dim)
            self.image_attention = nn.Linear(self.cnn_output_size, fusion_dim)
            self.attention_weights = nn.Linear(fusion_dim, 2)
        
        # Fusion layers
        combined_dim = self.text_output_dim + self.cnn_output_size
        self.fusion = nn.Sequential(
            nn.Linear(combined_dim, fusion_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(fusion_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                if m.padding_idx is not None:
                    nn.init.zeros_(m.weight[m.padding_idx])
                    
    def forward(self, images, input_ids, attention_mask):
        # Process text with LSTM
        # Calculate sequence lengths from attention mask
        seq_lengths = attention_mask.sum(dim=1).cpu()
        
        # Embed tokens
        embedded = self.embedding(input_ids)
        embedded = self.text_dropout(embedded)
        
        # Pack sequences for LSTM
        packed = pack_padded_sequence(
            embedded, seq_lengths, batch_first=True, enforce_sorted=False
        )
        
        # Process with LSTM
        self.lstm.flatten_parameters()
        _, (hidden, _) = self.lstm(packed)
        
        # Concatenate bidirectional outputs
        if self.num_directions == 2:
            text_features = torch.cat([hidden[-2], hidden[-1]], dim=1)
        else:
            text_features = hidden[-1]
        
        # Process images with CNN
        image_features = self.image_encoder(images)
        image_features = image_features.view(image_features.size(0), -1)  # Flatten
        
        # Apply cross-modal attention if enabled
        if self.use_attention:
            text_proj = self.text_attention(text_features)
            image_proj = self.image_attention(image_features)
            
            # Calculate attention scores
            fusion_repr = text_proj + image_proj
            attention_scores = F.softmax(self.attention_weights(fusion_repr), dim=1)
            
            # Apply attention weights
            text_features = text_features * attention_scores[:, 0].unsqueeze(1)
            image_features = image_features * attention_scores[:, 1].unsqueeze(1)
        
        # Early fusion: concatenate features
        combined = torch.cat([text_features, image_features], dim=1)
        
        # Classification
        return self.fusion(combined)


## Evaluation

In [28]:
import torch
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter
import torchvision
import os
from datetime import datetime
import torch.nn as nn
from transformers import BertModel, AutoModel
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader
import torch.nn.functional as F
from sklearn.metrics import f1_score
from torch.amp import autocast, GradScaler
from PIL import Image


class HatefulMemesModel(nn.Module):
    def __init__(self, bert_model='bert-base-uncased'):
        super().__init__()
        
        # 1. Simpler image encoder with more freezing
        self.image_encoder = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        # Freeze everything except final classifier
        for param in self.image_encoder.parameters():
            param.requires_grad = False
        self.image_encoder.fc = nn.Sequential(
            nn.Linear(2048, 768),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # 2. Back to BERT with more freezing
        self.text_encoder = BertModel.from_pretrained(bert_model)
        for param in self.text_encoder.parameters():
            param.requires_grad = False
        # Only unfreeze final layer
        for param in self.text_encoder.encoder.layer[-1:].parameters():
            param.requires_grad = True
            
        # 3. Much simpler fusion
        self.fusion = nn.Sequential(
            nn.Linear(768 * 2, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)
        )
        
        # 4. Better initialization
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, images, input_ids, attention_mask):
        # 4. Better feature extraction
        image_features = self.image_encoder(images)
        image_features = F.normalize(image_features, p=2, dim=1)
        
        text_outputs = self.text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        # Use both last hidden state and pooler output
        last_hidden = text_outputs.last_hidden_state[:, 0, :]
        pooler = text_outputs.pooler_output
        text_features = (last_hidden + pooler) / 2
        text_features = F.normalize(text_features, p=2, dim=1)
        
        # Combine features
        combined = torch.cat([image_features, text_features], dim=1)
        return self.fusion(combined)

class LSTMCNNHatefulMemesModel(nn.Module):
    def __init__(self, vocab_size=30522, embed_dim=300, hidden_dim=256, num_layers=2, 
                 dropout=0.3, bidirectional=True):
        super().__init__()
        
        # 1. CNN Image Encoder (simpler than ResNet)
        self.image_encoder = nn.Sequential(
            # First conv block
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, 2),  # 112x112
            
            # Second conv block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2, 2),  # 56x56
            
            # Third conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2, 2),  # 28x28
            
            # Fourth conv block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2, 2),  # 14x14
            
            # Fifth conv block
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(2, 2),  # 7x7
            
            # Global pooling
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(0.5)
        )
        
        # 2. LSTM Text Encoder
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )
        self.text_dropout = nn.Dropout(dropout)
        
        # Calculate output dimensions
        self.num_directions = 2 if bidirectional else 1
        self.text_output_dim = hidden_dim * self.num_directions
        self.image_output_dim = 512
        
        # 3. Fusion and classification
        self.fusion = nn.Sequential(
            nn.Linear(self.text_output_dim + self.image_output_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )
        
        # 4. Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                if m.padding_idx is not None:
                    m.weight.data[m.padding_idx].zero_()

    def forward(self, images, input_ids, attention_mask):
        # Process images with CNN
        image_features = self.image_encoder(images)
        
        # Process text with LSTM
        # Calculate sequence lengths from attention mask
        seq_lengths = attention_mask.sum(dim=1).cpu()
        
        # Embed tokens
        embedded = self.embedding(input_ids)
        embedded = self.text_dropout(embedded)
        
        # Pack sequences for LSTM
        from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
        packed = pack_padded_sequence(
            embedded, seq_lengths, batch_first=True, enforce_sorted=False
        )
        
        # Process with LSTM
        self.lstm.flatten_parameters()
        _, (hidden, _) = self.lstm(packed)
        
        # Concatenate bidirectional outputs
        if self.num_directions == 2:
            text_features = torch.cat([hidden[-2], hidden[-1]], dim=1)
        else:
            text_features = hidden[-1]
        
        # Normalize features
        text_features = F.normalize(text_features, p=2, dim=1)
        image_features = F.normalize(image_features, p=2, dim=1)
        
        # Combine features
        combined = torch.cat([image_features, text_features], dim=1)
        
        # Classification
        return self.fusion(combined)

class KaggleHatefulMemesEvaluator:
    def __init__(self, model, device, log_dir='/kaggle/working/runs/hateful_memes'):
        """
        Initialize the evaluator for Kaggle environment
        
        Args:
            model: The model to evaluate
            device: torch.device for computation
            log_dir: Directory for TensorBoard logs (default: Kaggle working directory)
        """
        self.model = model
        self.device = device
        
        # Create timestamp-based directory to avoid conflicts
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.log_dir = f"{log_dir}_{timestamp}"
        os.makedirs(self.log_dir, exist_ok=True)
        self.writer = SummaryWriter(self.log_dir)
        
        # Initialize best metrics
        self.best_auroc = 0
        self.best_f1 = 0
        
    def evaluate(self, dataloader, epoch=0, mode='val'):
        """Evaluate the model on the given dataloader"""
        self.model.eval()
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            for batch in dataloader:
                # Move data to device
                images = batch['image'].to(self.device)
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device)
                
                # Forward pass
                outputs = self.model(images, input_ids, attention_mask)
                probs = torch.sigmoid(outputs)
                preds = (probs > 0.5).float()
                
                # Store predictions and labels
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        # Convert to numpy arrays
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        all_probs = np.array(all_probs)
        
        # Calculate metrics
        metrics = self.calculate_metrics(all_labels, all_preds, all_probs)
        
        # Log to TensorBoard and save visualizations
        self.log_metrics(metrics, epoch, mode)
        self.generate_visualizations(metrics, epoch, mode)
        
        # Save best model if applicable
        if mode == 'val':
            self.save_best_model(metrics, epoch)
        
        return metrics
    
    def calculate_metrics(self, labels, preds, probs):
        """Calculate evaluation metrics"""
        auroc = roc_auc_score(labels, probs)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, 
                                                                 average='binary')
        cm = confusion_matrix(labels, preds)
        
        return {
            'auroc': auroc,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'confusion_matrix': cm,
            'true_labels': labels,  # Store for ROC curve
            'probabilities': probs  # Store for ROC curve
        }
    
    def log_metrics(self, metrics, epoch, mode):
        """Log metrics to TensorBoard"""
        for metric_name, value in metrics.items():
            # Skip arrays and only log scalar values
            if metric_name not in ['confusion_matrix', 'true_labels', 'probabilities'] and np.isscalar(value):
                self.writer.add_scalar(f'{mode}/{metric_name}', value, epoch)
        
        # Save metrics to CSV for Kaggle
        metrics_file = os.path.join(self.log_dir, f'{mode}_metrics.csv')
        
        # Create file with headers if it doesn't exist
        if not os.path.exists(metrics_file):
            with open(metrics_file, 'w') as f:
                f.write('epoch,auroc,precision,recall,f1\n')
        
        # Append metrics
        with open(metrics_file, 'a') as f:
            f.write(f"{epoch},{metrics['auroc']:.4f},{metrics['precision']:.4f},"
                    f"{metrics['recall']:.4f},{metrics['f1']:.4f}\n")
    
    def generate_visualizations(self, metrics, epoch, mode):
        """Generate and save visualizations"""
        # Confusion Matrix
        plt.figure(figsize=(8, 6))
        sns.heatmap(metrics['confusion_matrix'], annot=True, fmt='d',
                   xticklabels=['Non-Hateful', 'Hateful'],
                   yticklabels=['Non-Hateful', 'Hateful'])
        plt.title(f'Confusion Matrix - {mode.capitalize()} (Epoch {epoch})')
        
        # Save plot to Kaggle working directory
        plt.savefig(os.path.join(self.log_dir, f'{mode}_confusion_matrix_epoch_{epoch}.png'))
        
        # Add to TensorBoard
        confusion_fig = plt.gcf()
        self.writer.add_figure(f'{mode}/confusion_matrix', confusion_fig, epoch)
        plt.close()
        
        # ROC Curve
        fpr, tpr, thresholds = roc_curve(metrics['true_labels'], metrics['probabilities'])
        roc_auc = metrics['auroc']
        
        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.3f})')
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curve - {mode.capitalize()} (Epoch {epoch})')
        plt.legend(loc="lower right")
        
        # Save ROC curve
        plt.savefig(os.path.join(self.log_dir, f'{mode}_roc_curve_epoch_{epoch}.png'))
        
        # Add to TensorBoard
        roc_fig = plt.gcf()
        self.writer.add_figure(f'{mode}/roc_curve', roc_fig, epoch)
        plt.close()
    
    def save_best_model(self, metrics, epoch):
        """Save best model based on AUROC and F1 score"""
        if metrics['auroc'] > self.best_auroc:
            self.best_auroc = metrics['auroc']
            torch.save(self.model.state_dict(), 
                      os.path.join(self.log_dir, 'best_model_auroc.pth'))
            
        if metrics['f1'] > self.best_f1:
            self.best_f1 = metrics['f1']
            torch.save(self.model.state_dict(), 
                      os.path.join(self.log_dir, 'best_model_f1.pth'))
    
    def log_sample_predictions(self, batch, outputs, epoch, mode):
        """Log sample predictions with images and text"""
        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float()
        
        # Create visualization grid
        num_samples = min(8, len(batch['image']))
        fig, axes = plt.subplots(2, 4, figsize=(15, 8))
        axes = axes.ravel()
        
        for idx in range(num_samples):
            img = batch['image'][idx].cpu().permute(1, 2, 0)
            img = torch.clamp(img * torch.tensor([0.229, 0.224, 0.225]) + 
                            torch.tensor([0.485, 0.456, 0.406]), 0, 1)
            
            axes[idx].imshow(img)
            axes[idx].axis('off')
            
            title = f"Text: {batch['text'][idx]}\n"
            title += f"True: {'Hateful' if batch['label'][idx] else 'Non-Hateful'}\n"
            title += f"Pred: {probs[idx].item():.2f}"
            
            axes[idx].set_title(title, fontsize=8)
        
        plt.tight_layout()
        
        # Save to Kaggle working directory
        plt.savefig(os.path.join(self.log_dir, f'{mode}_samples_epoch_{epoch}.png'))
        self.writer.add_figure(f'{mode}/Sample_Predictions', fig, epoch)
        plt.close()
    
    def close(self):
        """Close TensorBoard writer"""
        self.writer.close()

    def analyze_model_performance(self, dataloader, epoch, mode='val'):
        """Analyze model performance by examining correct and incorrect predictions"""
        self.model.eval()
        
        correct_examples = {'images': [], 'texts': [], 'probs': [], 'labels': []}
        incorrect_examples = {'images': [], 'texts': [], 'probs': [], 'labels': []}
        
        with torch.no_grad():
            for batch in dataloader:
                # Move data to device
                images = batch['image'].to(self.device)
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device)
                texts = batch.get('text', None)
                
                # Forward pass
                outputs = self.model(images, input_ids, attention_mask)
                probs = torch.sigmoid(outputs).cpu().numpy().flatten()
                preds = (probs > 0.5).astype(int)
                labels_np = labels.cpu().numpy()
                
                # Collect correct and incorrect examples
                for i in range(len(labels)):
                    example = {
                        'images': images[i].cpu(),
                        'probs': probs[i],
                        'labels': labels_np[i]
                    }
                    if texts is not None:
                        example['texts'] = texts[i]
                    
                    if preds[i] == labels_np[i]:
                        # Correct prediction
                        if len(correct_examples['images']) < 10:  # Limit to 10 examples
                            for k, v in example.items():
                                correct_examples[k].append(v)
                    else:
                        # Incorrect prediction
                        if len(incorrect_examples['images']) < 10:  # Limit to 10 examples
                            for k, v in example.items():
                                incorrect_examples[k].append(v)
                    
                    # Break if we have enough examples
                    if (len(correct_examples['images']) >= 10 and 
                        len(incorrect_examples['images']) >= 10):
                        break
        
        # Visualize correct examples
        if correct_examples['images']:
            fig_correct = self._create_examples_figure(correct_examples, "Correctly Classified Examples")
            self.writer.add_figure(f'{mode}/correct_examples', fig_correct, epoch)
            plt.savefig(os.path.join(self.log_dir, f'{mode}_correct_examples_epoch_{epoch}.png'))
            plt.close()
        
        # Visualize incorrect examples
        if incorrect_examples['images']:
            fig_incorrect = self._create_examples_figure(incorrect_examples, "Incorrectly Classified Examples")
            self.writer.add_figure(f'{mode}/incorrect_examples', fig_incorrect, epoch)
            plt.savefig(os.path.join(self.log_dir, f'{mode}_incorrect_examples_epoch_{epoch}.png'))
            plt.close()
        
        return correct_examples, incorrect_examples

    def _create_examples_figure(self, examples, title):
        """Create figure for analyzing examples"""
        n_samples = min(5, len(examples['images']))
        fig, axes = plt.subplots(n_samples, 1, figsize=(12, 4*n_samples))
        if n_samples == 1:
            axes = [axes]
        
        for i in range(n_samples):
            # Convert tensor to numpy for visualization
            img = examples['images'][i].permute(1, 2, 0).numpy()
            # Denormalize
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)
            
            axes[i].imshow(img)
            
            label_text = "Hateful" if examples['labels'][i] == 1 else "Non-hateful"
            prob_text = f"Model confidence: {examples['probs'][i]:.3f}"
            
            if 'texts' in examples and examples['texts']:
                text = examples['texts'][i]
                axes[i].set_title(f"Label: {label_text} | {prob_text}\nText: {text[:100]}...")
            else:
                axes[i].set_title(f"Label: {label_text} | {prob_text}")
            
            axes[i].axis('off')
        
        plt.suptitle(title, fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.97])
        return fig

    def compare_models(self, models_dict, dataloader, mode='val'):
        """
        Compare performance of different models on the same validation set.
        
        Args:
            models_dict: Dictionary of {model_name: model}
            dataloader: Validation data loader
            mode: 'val' or 'test'
        """
        results = {name: {'preds': [], 'targets': [], 'probs': []} for name in models_dict}
        
        # Get predictions from each model
        for name, model in models_dict.items():
            model.eval()
            with torch.no_grad():
                for batch in dataloader:
                    images = batch['image'].to(self.device)
                    input_ids = batch['input_ids'].to(self.device)
                    attention_mask = batch['attention_mask'].to(self.device)
                    labels = batch['label'].to(self.device)
                    
                    outputs = model(images, input_ids, attention_mask)
                    probs = torch.sigmoid(outputs).cpu().numpy().flatten()
                    preds = (probs > 0.5).astype(int)
                    
                    results[name]['probs'].extend(probs)
                    results[name]['preds'].extend(preds)
                    results[name]['targets'].extend(labels.cpu().numpy())
        
        # Create ROC curve comparison
        plt.figure(figsize=(10, 8))
        
        for name, data in results.items():
            fpr, tpr, _ = roc_curve(data['targets'], data['probs'])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.3f})')
        
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curves Comparison')
        plt.legend(loc="lower right")
        
        # Save and add to TensorBoard
        plt.savefig(os.path.join(self.log_dir, f'{mode}_model_comparison_roc.png'))
        self.writer.add_figure(f'{mode}/model_comparison_roc', plt.gcf())
        plt.close()
        
        # Create confusion matrices
        n_models = len(models_dict)
        fig, axes = plt.subplots(1, n_models, figsize=(6*n_models, 5))
        if n_models == 1:
            axes = [axes]
        
        for i, (name, data) in enumerate(results.items()):
            cm = confusion_matrix(data['targets'], data['preds'])
            
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                       xticklabels=['Non-hateful', 'Hateful'],
                       yticklabels=['Non-hateful', 'Hateful'], ax=axes[i])
            axes[i].set_xlabel('Predicted labels')
            axes[i].set_ylabel('True labels')
            axes[i].set_title(f'Confusion Matrix - {name}')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, f'{mode}_model_comparison_cm.png'))
        self.writer.add_figure(f'{mode}/model_comparison_cm', plt.gcf())
        plt.close()
        
        # Calculate and print metrics
        metrics_table = []
        for name, data in results.items():
            auroc = roc_auc_score(data['targets'], data['probs'])
            precision, recall, f1, _ = precision_recall_fscore_support(
                data['targets'], data['preds'], average='binary')
            
            metrics_table.append([
                name, f"{auroc:.4f}", f"{precision:.4f}", 
                f"{recall:.4f}", f"{f1:.4f}"
            ])
        
        # Create metrics comparison table
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.axis('tight')
        ax.axis('off')
        
        table = ax.table(
            cellText=metrics_table,
            colLabels=['Model', 'AUROC', 'Precision', 'Recall', 'F1'],
            loc='center',
            cellLoc='center'
        )
        table.auto_set_font_size(False)
        table.set_fontsize(12)
        table.scale(1.2, 1.5)
        ax.set_title('Model Performance Metrics Comparison', fontsize=14)
        
        plt.savefig(os.path.join(self.log_dir, f'{mode}_model_comparison_metrics.png'))
        self.writer.add_figure(f'{mode}/model_comparison_metrics', plt.gcf())
        plt.close()
        
        # Write metrics to CSV
        metrics_file = os.path.join(self.log_dir, f'{mode}_model_comparison.csv')
        with open(metrics_file, 'w') as f:
            f.write('model,auroc,precision,recall,f1\n')
            for row in metrics_table:
                f.write(','.join(row) + '\n')
        
        return results

    def generate_analysis_report(self, model_results=None, dataset_info=None):
        """Generate a written analysis report"""
        report_path = os.path.join(self.log_dir, 'model_analysis.md')
        
        with open(report_path, 'w') as f:
            f.write("# Hateful Memes Detection Model Analysis\n\n")
            
            # Dataset information
            if dataset_info:
                f.write("## Dataset Information\n\n")
                f.write(f"- Training samples: {dataset_info.get('train_samples', 'N/A')}\n")
                f.write(f"- Validation samples: {dataset_info.get('val_samples', 'N/A')}\n")
                f.write(f"- Test samples: {dataset_info.get('test_samples', 'N/A')}\n")
                f.write(f"- Class distribution: {dataset_info.get('class_distribution', 'N/A')}\n\n")
            
            # Model architecture
            f.write("## Model Architecture\n\n")
            f.write("Our model uses a late fusion approach, combining:\n")
            f.write("- **Image Encoder**: ResNet50 pretrained on ImageNet\n")
            f.write("- **Text Encoder**: BERT pretrained on large text corpus\n")
            f.write("- **Fusion Strategy**: Concatenation of image and text features\n\n")
            
            # Performance metrics
            f.write("## Performance Metrics\n\n")
            if model_results:
                f.write("| Model | AUROC | Precision | Recall | F1 |\n")
                f.write("|-------|-------|-----------|--------|----|\n")
                for name, metrics in model_results.items():
                    auroc = metrics.get('auroc', 'N/A')
                    precision = metrics.get('precision', 'N/A')
                    recall = metrics.get('recall', 'N/A')
                    f1 = metrics.get('f1', 'N/A')
                    f.write(f"| {name} | {auroc} | {precision} | {recall} | {f1} |\n")
                f.write("\n")
            
            # Analysis of results
            f.write("## Analysis of Results\n\n")
            
            # BERT vs LSTM analysis
            f.write("### BERT vs LSTM Performance\n\n")
            f.write("BERT outperforms LSTM for text encoding in this task for several reasons:\n\n")
            f.write("1. **Contextual Understanding**: BERT's bidirectional attention mechanism captures context in both directions, essential for understanding nuanced hate speech.\n")
            f.write("2. **Pre-training Advantage**: BERT is pre-trained on a massive corpus, giving it strong language understanding capabilities.\n")
            f.write("3. **Handling of Out-of-Vocabulary Words**: BERT's WordPiece tokenization handles rare words better than LSTM's fixed vocabulary.\n")
            f.write("4. **Attention to Important Words**: BERT's self-attention mechanism focuses on relevant words for classification.\n\n")
            
            # Fusion strategy analysis
            f.write("### Fusion Strategy Impact\n\n")
            f.write("Late fusion (concatenation of features) works well for this task because:\n\n")
            f.write("1. **Modality Independence**: It allows each modality to be processed by specialized architectures.\n")
            f.write("2. **Feature Normalization**: L2 normalization before fusion prevents one modality from dominating.\n")
            f.write("3. **Complementary Information**: Text and image features provide complementary signals for classification.\n\n")
            
            # Error analysis
            f.write("### Error Analysis\n\n")
            f.write("Common patterns in misclassified examples:\n\n")
            f.write("1. **Subtle Hate Speech**: The model struggles with examples where hate is implied rather than explicit.\n")
            f.write("2. **Multimodal Understanding**: Some examples require complex reasoning about the relationship between text and image.\n")
            f.write("3. **Cultural Context**: Memes that require specific cultural knowledge are challenging.\n\n")
            
            # Limitations
            f.write("## Limitations\n\n")
            f.write("1. **Class Imbalance**: The dataset contains more non-hateful than hateful memes, potentially biasing the model.\n")
            f.write("2. **Dataset Size**: The limited size of the dataset may not capture the full diversity of hateful content.\n")
            f.write("3. **Modality Bias**: The model may rely too heavily on either text or image features in certain cases.\n")
            f.write("4. **Generalization**: The model may not generalize well to new types of hateful content or different visual styles.\n\n")
            
            # Future improvements
            f.write("## Future Improvements\n\n")
            f.write("1. **Cross-modal Attention**: Implementing attention mechanisms between modalities could improve fusion.\n")
            f.write("2. **Data Augmentation**: More sophisticated augmentation techniques could help address class imbalance.\n")
            f.write("3. **Ensemble Methods**: Combining multiple models could improve robustness.\n")
            f.write("4. **Explainability**: Adding visualization of attention weights could help interpret model decisions.\n")
        
        print(f"Analysis report generated at {report_path}")
        return report_path

# Just start using the classes and functions defined above
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


from torch.utils.data import DataLoader

# Create datasets
data_dir = "/kaggle/input/facebook-hateful-meme-dataset/data"
train_dataset = HatefulMemesDataset(
    data_dir=data_dir,
    split='train',
    augment=True
)

val_dataset = HatefulMemesDataset(
    data_dir=data_dir,
    split='dev',  # Using dev set for validation
    augment=False
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=32,  # Smaller batch size for LSTM
    sampler=train_dataset.get_sampler(),  # Use weighted sampler
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=2
)

# Initialize early fusion model
model = EarlyFusionLSTMCNNModel(
    vocab_size=30522,
    embed_dim=300,
    hidden_dim=256,
    num_layers=2,
    dropout=0.5,  # Increase dropout
    bidirectional=True,
    fusion_dim=512
)
model = model.to(device)

# Initialize optimizer and loss function
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=0.01
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',
    factor=0.5,
    patience=2,
    verbose=True
)

# Loss function with class weights
pos_weight = torch.tensor([2.0]).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# Initialize evaluator
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = f"/kaggle/working/runs/early_fusion_{timestamp}"
os.makedirs(log_dir, exist_ok=True)  # Make sure the directory exists
evaluator = KaggleHatefulMemesEvaluator(model, device, log_dir=log_dir)

# Training loop
num_epochs = 10
early_stopping_patience = 5
no_improve_epochs = 0
best_val_auroc = 0

# Initialize gradient scaler for mixed precision training
scaler = GradScaler()

for epoch in range(num_epochs):
    # Training phase
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_loader):
        # Move data to device
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].float().to(device)
        
        # Mixed precision training - specify device type
        with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
            outputs = model(images, input_ids, attention_mask)
            loss = criterion(outputs.squeeze(), labels)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
        # Log batch loss to TensorBoard
        evaluator.writer.add_scalar('Batch/train_loss', loss.item(), 
                                   epoch * len(train_loader) + batch_idx)
        
        # Print progress
        if (batch_idx + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], '
                  f'Loss: {loss.item():.4f}, LR: {optimizer.param_groups[0]["lr"]:.2e}')
    
    avg_train_loss = total_loss / len(train_loader)
    print(f'\nEpoch [{epoch+1}/{num_epochs}], Average Training Loss: {avg_train_loss:.4f}')
    
    # Log average training loss
    evaluator.writer.add_scalar('Epoch/train_loss', avg_train_loss, epoch)
    
    # Validation phase
    val_metrics = evaluator.evaluate(val_loader, epoch, mode='val')
    print(f"Validation Metrics:")
    print(f"AUROC: {val_metrics['auroc']:.3f}")
    print(f"F1: {val_metrics['f1']:.3f}")
    
    # Update learning rate based on validation performance
    scheduler.step(val_metrics['auroc'])
    
    # Log sample predictions
    batch = next(iter(val_loader))
    with torch.no_grad():
        outputs = model(batch['image'].to(device),
                      batch['input_ids'].to(device),
                      batch['attention_mask'].to(device))
    evaluator.log_sample_predictions(batch, outputs, epoch, mode='val')
    
    # Analyze model performance (every 2 epochs to save time)
    if epoch % 2 == 0:
        evaluator.analyze_model_performance(val_loader, epoch, mode='val')
    
    # Early stopping check
    if val_metrics['auroc'] > best_val_auroc:
        best_val_auroc = val_metrics['auroc']
        torch.save(model.state_dict(), os.path.join(log_dir, 'best_model.pth'))
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
        
    if no_improve_epochs >= early_stopping_patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

# Final evaluation
print("\nTraining completed! Loading best model for final evaluation...")
try:
    checkpoint = torch.load(os.path.join(log_dir, 'best_model.pth'))
    model.load_state_dict(checkpoint)
    print("Successfully loaded best model.")
    
    # Final validation evaluation
    final_metrics = evaluator.evaluate(val_loader, epoch='final', mode='val')
    print(f"Final Validation AUROC: {final_metrics['auroc']:.4f}")
    print(f"Final Validation F1: {final_metrics['f1']:.4f}")
    
except Exception as e:
    print(f"Error loading best model: {e}")

# Generate analysis report
dataset_info = {
    'train_samples': len(train_dataset),
    'val_samples': len(val_dataset),
    'class_distribution': "See training data distribution"
}

model_results = {
    'Early Fusion (LSTM+CNN)': {
        'auroc': final_metrics['auroc'] if 'final_metrics' in locals() else best_val_auroc,
        'precision': final_metrics['precision'] if 'final_metrics' in locals() else 0,
        'recall': final_metrics['recall'] if 'final_metrics' in locals() else 0,
        'f1': final_metrics['f1'] if 'final_metrics' in locals() else 0
    }
}

# Generate the analysis report
evaluator.generate_analysis_report(model_results, dataset_info)

# Close TensorBoard writer
evaluator.close()

print("\nEvaluation complete! Check the log directory for visualizations and analysis.")

Using device: cuda


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b3b78429940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b3b78429940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

Epoch [1/10], Batch [10/266], Loss: 15.6878, LR: 1.00e-03
Epoch [1/10], Batch [20/266], Loss: 9.9994, LR: 1.00e-03
Epoch [1/10], Batch [30/266], Loss: 9.1138, LR: 1.00e-03
Epoch [1/10], Batch [40/266], Loss: 13.1245, LR: 1.00e-03
Epoch [1/10], Batch [50/266], Loss: 5.1850, LR: 1.00e-03
Epoch [1/10], Batch [60/266], Loss: 4.8027, LR: 1.00e-03
Epoch [1/10], Batch [70/266], Loss: 7.9640, LR: 1.00e-03
Epoch [1/10], Batch [80/266], Loss: 5.2496, LR: 1.00e-03
Epoch [1/10], Batch [90/266], Loss: 4.8649, LR: 1.00e-03
Epoch [1/10], Batch [100/266], Loss: 6.0364, LR: 1.00e-03
Epoch [1/10], Batch [110/266], Loss: 3.0715, LR: 1.00e-03
Epoch [1/10], Batch [120/266], Loss: 4.6340, LR: 1.00e-03
Epoch [1/10], Batch [130/266], Loss: 2.4763, LR: 1.00e-03
Epoch [1/10], Batch [140/266], Loss: 3.4526, LR: 1.00e-03
Epoch [1/10], Batch [150/266], Loss: 2.2232, LR: 1.00e-03
Epoch [1/10], Batch [160/266], Loss: 1.9314, LR: 1.00e-03
Epoch [1/10], Batch [170/266], Loss: 1.1724, LR: 1.00e-03
Epoch [1/10], Batch [

In [30]:
!zip -r output.zip /kaggle/working/runs/early_fusion_20250519_091941_20250519_091941

  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/ (stored 0%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_roc_curve_epoch_2.png (deflated 11%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_samples_epoch_7.png (deflated 0%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_confusion_matrix_epoch_8.png (deflated 19%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_samples_epoch_5.png (deflated 0%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_incorrect_examples_epoch_8.png (deflated 90%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_samples_epoch_9.png (deflated 0%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_confusion_matrix_epoch_5.png (deflated 20%)
  adding: kaggle/working/runs/early_fusion_20250519_091941_20250519_091941/val_correct

In [27]:
!rm -rf /kaggle/working/*