In [None]:
# !pip install text_hammer

In [1]:
from torch.cuda.amp import GradScaler

In [2]:

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    CLIPModel,
    AutoModel,  # For MuRIL
    AutoTokenizer,  # For MuRIL
    get_linear_schedule_with_warmup
)
import albumentations as A
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler

In [3]:
train_df = pd.read_csv('/content/drive/MyDrive/memo_3/memotion3/memotion3/train.csv')

In [None]:
train_df['overall'] = train_df['overall'].replace({
                                                            'very_positive': 2,
                                                            'positive': 2,
                                                            'neutral': 1,
                                                            'very_negative': 0,
                                                            'negative': 0})

In [5]:
val_df = pd.read_csv('/content/drive/MyDrive/memo_3/val.csv')

In [None]:
val_df['overall'] = val_df['overall'].replace({
                                                            'very_positive': 2,
                                                            'positive': 2,
                                                            'neutral': 1,
                                                            'very_negative': 0,
                                                            'negative': 0})

In [7]:
import text_hammer as th

In [8]:
%%time

from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()

def text_preprocessing(df, col_name):
  column = col_name
  df[column] = df[column].progress_apply(lambda x:str(x).lower())
  df[column] = df[column].progress_apply(lambda x: th.cont_exp(x)) # you're -> you are; we'll be -> we will be
  df[column] = df[column].progress_apply(lambda x: th.remove_emails(x))
  df[column] = df[column].progress_apply(lambda x: th.remove_html_tags(x))

  df[column] = df[column].progress_apply(lambda x: th.remove_special_chars(x))
  df[column] = df[column].progress_apply(lambda x: th.remove_accented_chars(x))

  return df

CPU times: user 390 Âµs, sys: 0 ns, total: 390 Âµs
Wall time: 371 Âµs


In [None]:
train_dataset = text_preprocessing(train_df, 'ocr')
val_dataset = text_preprocessing(val_df, 'ocr')

In [None]:
# ===================== Dataset Class =====================
class MemeDataset(Dataset):
    def __init__(self, images, captions, sentiments, tokenizer, image_transforms, image_dir):
        self.images = images
        self.captions = captions
        self.sentiments = sentiments
        self.tokenizer = tokenizer
        self.image_transforms = image_transforms
        self.image_dir = image_dir

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name = self.images[idx]
        image_path = os.path.join(self.image_dir, image_name)
        caption = self.captions[idx]
        sentiment = self.sentiments[idx]

        # Load and preprocess image
        try:
            image = Image.open(image_path).convert('RGB')
            image = np.array(image)
        except Exception as e:
            image = np.full((224, 224, 3), 128, dtype=np.uint8)

        # Apply transforms
        image = self.image_transforms(image=image)['image']
        image = torch.tensor(image).permute(2, 0, 1).float()

        # Ensure caption is a valid string
        if not isinstance(caption, str):
            caption = str(caption) if caption else "empty caption"
        if isinstance(caption, list):
            caption = ' '.join(caption)
        if not caption or caption.strip() == '':
            caption = "empty caption"

        # Encode caption with MuRIL tokenizer
        # MuRIL uses max_length=512 by default
        encoded_caption = self.tokenizer(
            caption,
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=128  # Good for Hinglish captions
        )
        input_ids = encoded_caption['input_ids'].squeeze()
        attention_mask = encoded_caption['attention_mask'].squeeze()

        sentiment_class = torch.tensor(sentiment, dtype=torch.long)

        return {
            'image': image,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'sentiment': sentiment_class
        }


# ===================== Cross-Attention Fusion Module =====================
class CrossAttentionFusion(nn.Module):
    def __init__(self, embed_dim, num_heads=8, dropout=0.1):
        super(CrossAttentionFusion, self).__init__()

        # Bidirectional cross-attention
        self.text_to_image = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )
        self.image_to_text = nn.MultiheadAttention(
            embed_dim, num_heads, dropout=dropout, batch_first=True
        )

        # Feed-forward networks
        self.ffn_text = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout)
        )

        self.ffn_image = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout)
        )

        # Layer normalization
        self.norm_text_1 = nn.LayerNorm(embed_dim)
        self.norm_text_2 = nn.LayerNorm(embed_dim)
        self.norm_image_1 = nn.LayerNorm(embed_dim)
        self.norm_image_2 = nn.LayerNorm(embed_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, image_feats, text_feats):
        # Text attends to image
        text_attended, _ = self.text_to_image(
            query=text_feats, key=image_feats, value=image_feats
        )
        text_feats = self.norm_text_1(text_feats + self.dropout(text_attended))
        text_feats = self.norm_text_2(text_feats + self.ffn_text(text_feats))

        # Image attends to text
        image_attended, _ = self.image_to_text(
            query=image_feats, key=text_feats, value=text_feats
        )
        image_feats = self.norm_image_1(image_feats + self.dropout(image_attended))
        image_feats = self.norm_image_2(image_feats + self.ffn_image(image_feats))

        # Pool and concatenate
        image_pooled = image_feats.mean(dim=1)
        text_pooled = text_feats.mean(dim=1)
        fused_features = torch.cat([image_pooled, text_pooled], dim=1)

        return fused_features, image_pooled, text_pooled


# ===================== Enhanced Loss Function with Class Weights =====================
class EnhancedLoss(nn.Module):
    def __init__(self, contrastive_weight=0.04, temperature=0.07, class_weights=None):
        super(EnhancedLoss, self).__init__()
        # Use weighted cross entropy if class imbalance exists
        if class_weights is not None:
            self.ce_loss = nn.CrossEntropyLoss(weight=class_weights)
        else:
            self.ce_loss = nn.CrossEntropyLoss()
        self.contrastive_weight = contrastive_weight  # Reduced from 0.1 to 0.05
        self.temperature = temperature

    def forward(self, logits, labels, image_feats, text_feats):
        ce_loss = self.ce_loss(logits, labels)

        # Contrastive alignment
        image_feats_norm = F.normalize(image_feats, dim=-1)
        text_feats_norm = F.normalize(text_feats, dim=-1)
        similarity = torch.matmul(image_feats_norm, text_feats_norm.T) / self.temperature

        batch_size = image_feats.size(0)
        labels_contrastive = torch.arange(batch_size).to(image_feats.device)

        contrastive_loss = (
            F.cross_entropy(similarity, labels_contrastive) +
            F.cross_entropy(similarity.T, labels_contrastive)
        ) / 2

        total_loss = ce_loss + self.contrastive_weight * contrastive_loss
        return total_loss, ce_loss, contrastive_loss


# ===================== Main Model with MuRIL =====================
class CustomCLIPMuRILModel(nn.Module):
    """
    CLIP + MuRIL model for Hinglish meme sentiment classification.
    MuRIL is specifically designed for Indian languages and code-mixing.
    """
    def __init__(self, clip_model, muril_model):
        super(CustomCLIPMuRILModel, self).__init__()
        self.clip_model = clip_model
        self.muril_model = muril_model

        # Make models trainable
        for param in self.clip_model.parameters():
            param.requires_grad = True
        for param in self.muril_model.parameters():
            param.requires_grad = True

        # Project CLIP vision features to MuRIL dimension (768)
        self.image_proj = nn.Sequential(
            nn.Linear(768, 768),
            nn.LayerNorm(768),
            nn.GELU()
        )

        # Cross-attention fusion
        self.cross_fusion = CrossAttentionFusion(embed_dim=768, num_heads=8, dropout=0.1)

        # 4-layer MLP classifier with high dropout
        self.classifier = nn.Sequential(
            # Layer 1
            nn.Linear(768 * 2, 768),
            nn.LayerNorm(768),
            nn.GELU(),
            nn.Dropout(0.1),

            # Layer 2
            nn.Linear(768, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.2),

            # Layer 3
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Dropout(0.2),

            # Layer 4
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(0.1),

            # Output layer
            nn.Linear(128, 3)
        )

    def forward(self, image, input_ids, attention_mask):
        # Extract CLIP vision features
        vision_outputs = self.clip_model.vision_model(pixel_values=image)
        image_features = vision_outputs.last_hidden_state  # [B, 50, 768]
        image_features = self.image_proj(image_features)

        # Extract MuRIL text features
        muril_output = self.muril_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        text_features = muril_output.last_hidden_state  # [B, seq_len, 768]

        # Cross-attention fusion
        fused_features, image_pooled, text_pooled = self.cross_fusion(
            image_features, text_features
        )

        # Classification
        logits = self.classifier(fused_features)
        return logits, image_pooled, text_pooled


# ===================== Utility Classes =====================
class AvgMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


# ===================== Training Function =====================
def train_epoch(model, train_loader, optimizer, scheduler, device, criterion, scaler=None):
    model.train()
    loss_meter = AvgMeter()
    ce_loss_meter = AvgMeter()
    contrastive_loss_meter = AvgMeter()
    correct_predictions = 0
    total_predictions = 0

    tqdm_object = tqdm(train_loader, total=len(train_loader))

    for batch in tqdm_object:
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        sentiments = batch['sentiment'].to(device)

        optimizer.zero_grad()

        if scaler is not None:
            with autocast():
                logits, image_feats, text_feats = model(images, input_ids, attention_mask)
                loss, ce_loss, contrastive_loss = criterion(
                    logits, sentiments, image_feats, text_feats
                )

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits, image_feats, text_feats = model(images, input_ids, attention_mask)
            loss, ce_loss, contrastive_loss = criterion(
                logits, sentiments, image_feats, text_feats
            )
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        scheduler.step()

        count = images.size(0)
        loss_meter.update(loss.item(), count)
        ce_loss_meter.update(ce_loss.item(), count)
        contrastive_loss_meter.update(contrastive_loss.item(), count)

        preds = logits.argmax(dim=1)
        correct_predictions += (preds == sentiments).sum().item()
        total_predictions += sentiments.size(0)

        tqdm_object.set_postfix(
            train_loss=loss_meter.avg,
            ce_loss=ce_loss_meter.avg,
            contrast_loss=contrastive_loss_meter.avg,
            lr=get_lr(optimizer)
        )

    accuracy = correct_predictions / total_predictions
    return loss_meter, ce_loss_meter, contrastive_loss_meter, accuracy


# ===================== Evaluation Function =====================
def evaluate(model, data_loader, device, criterion):
    model.eval()
    predictions, true_labels = [], []
    loss_meter = AvgMeter()
    ce_loss_meter = AvgMeter()
    contrastive_loss_meter = AvgMeter()
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in data_loader:
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            sentiments = batch['sentiment'].to(device)

            logits, image_feats, text_feats = model(images, input_ids, attention_mask)
            loss, ce_loss, contrastive_loss = criterion(
                logits, sentiments, image_feats, text_feats
            )

            loss_meter.update(loss.item(), len(images))
            ce_loss_meter.update(ce_loss.item(), len(images))
            contrastive_loss_meter.update(contrastive_loss.item(), len(images))

            preds = logits.argmax(dim=1)
            correct_predictions += (preds == sentiments).sum().item()
            total_predictions += sentiments.size(0)

            predictions.extend(preds.cpu().numpy())
            true_labels.extend(sentiments.cpu().numpy())

    accuracy = correct_predictions / total_predictions
    return predictions, true_labels, loss_meter.avg, ce_loss_meter.avg, contrastive_loss_meter.avg, accuracy


# ===================== Visualization Functions =====================
def plot_training_history(train_losses, val_losses, train_accuracies, val_accuracies):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    ax1.plot(train_losses, label='Train Loss', marker='o')
    ax1.plot(val_losses, label='Val Loss', marker='s')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)

    ax2.plot(train_accuracies, label='Train Accuracy', marker='o')
    ax2.plot(val_accuracies, label='Val Accuracy', marker='s')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig('training_history_xlm.png', dpi=300, bbox_inches='tight')
    plt.show()


def plot_confusion_matrix(true_labels, predictions, class_names=['Negative', 'Neutral', 'Positive']):
    cm = confusion_matrix(true_labels, predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix - XLM-RoBERTa')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig('confusion_matrix_xlm.png', dpi=300, bbox_inches='tight')
    plt.show()


# ===================== Main Training Script =====================
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Data augmentation
train_image_transforms = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.CoarseDropout(max_holes=8, max_height=16, max_width=16, p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

val_image_transforms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])


# Load models
print("\n" + "="*70)
print("LOADING MODELS: CLIP + MuRIL")
print("="*70)
print("Loading CLIP (Vision Encoder)...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

print("Loading MuRIL (Hinglish Text Encoder)...")
muril_tokenizer = AutoTokenizer.from_pretrained('google/muril-base-cased')
muril_model = AutoModel.from_pretrained('google/muril-base-cased')

print("âœ“ Models loaded successfully!")
print(f"  CLIP: {sum(p.numel() for p in clip_model.parameters()):,} parameters")
print(f"  MuRIL (uncased): {sum(p.numel() for p in muril_model.parameters()):,} parameters")
print("\nðŸ’¡ MuRIL-uncased Features:")
print("  - Trained on 17 Indian languages")
print("  - Excellent Hinglish support")
print("  - Handles case-insensitive text (perfect for memes)")
print("  - Understands code-mixing naturally")
print("  - Indian cultural context awareness")
print("="*70)

# Create model
model = CustomCLIPMuRILModel(clip_model, muril_model)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters:     {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

# Dataset paths
train_image_dir = '/content/drive/MyDrive/memo_3/trainImages/trainImages'
val_image_dir = '/content/drive/MyDrive/memo_3/valImages/valImages'

# Create datasets
print("\nCreating datasets with MuRIL tokenizer...")
train_dataset = MemeDataset(
    images=train_df['image_url'].tolist(),
    captions=train_df['ocr'].tolist(),
    sentiments=train_df['overall'].tolist(),
    tokenizer=muril_tokenizer,  # Using MuRIL tokenizer
    image_transforms=train_image_transforms,
    image_dir=train_image_dir
)

val_dataset = MemeDataset(
    images=val_df['image_url'].tolist(),
    captions=val_df['ocr'].tolist(),
    sentiments=val_df['overall'].tolist(),
    tokenizer=muril_tokenizer,  # Using MuRIL tokenizer
    image_transforms=val_image_transforms,
    image_dir=val_image_dir
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# CRITICAL: Check class distribution and compute class weights
print("\n" + "="*70)
print("CLASS DISTRIBUTION ANALYSIS")
print("="*70)

train_class_counts = train_df['overall'].value_counts().sort_index()
val_class_counts = val_df['overall'].value_counts().sort_index()

print("\nTraining Set:")
for class_idx in range(3):
    count = train_class_counts.get(class_idx, 0)
    percentage = (count / len(train_df)) * 100
    print(f"  Class {class_idx}: {count:,} samples ({percentage:.2f}%)")

print("\nValidation Set:")
for class_idx in range(3):
    count = val_class_counts.get(class_idx, 0)
    percentage = (count / len(val_df)) * 100
    print(f"  Class {class_idx}: {count:,} samples ({percentage:.2f}%)")

# Compute class weights
class_counts = torch.tensor([train_class_counts.get(i, 1) for i in range(3)], dtype=torch.float32)
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * 3
class_weights = class_weights.to(device)

print(f"\nComputed Class Weights: {class_weights.cpu().numpy()}")
print("="*70)

# Optimizer with HIGHER learning rates for better convergence
optimizer = torch.optim.AdamW([
    {'params': model.clip_model.parameters(), 'lr': 2e-6},
    {'params': model.muril_model.parameters(), 'lr': 2e-6},
    {'params': model.image_proj.parameters(), 'lr': 5e-4},
    {'params': model.cross_fusion.parameters(), 'lr': 5e-4},
    {'params': model.classifier.parameters(), 'lr': 5e-4}
], weight_decay=1e-4)

num_epochs = 10
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=len(train_loader) * 3,
    num_training_steps=len(train_loader) * num_epochs
)

# Simple classification loss
criterion = EnhancedLoss(
    contrastive_weight=0.04,  # Reduced from 0.1
    temperature=0.07,
    class_weights=class_weights  # Add class weights
)
scaler = GradScaler() if device == 'cuda' else None

print(f"\n{'='*70}")
print("LOSS CONFIGURATION")
print(f"{'='*70}")
print("Using: Pure Cross-Entropy Loss (NO contrastive loss)")
print("Focus: Sentiment classification only")
print(f"{'='*70}\n")

# Training tracking
BEST_MODEL_PATH = 'best_model_muril.pth'
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_val_accuracy = 0

print("\n" + "="*70)
print("STARTING TRAINING WITH CLIP + MuRIL")
print("="*70)

for epoch in range(num_epochs):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"{'='*50}")

    # Train - FIXED: Unpack all 4 return values
    train_loss_meter, train_ce_loss, train_contrastive_loss, train_accuracy = train_epoch(
        model, train_loader, optimizer, scheduler, device, criterion, scaler
    )

    # Validate - FIXED: Unpack all 6 return values
    val_predictions, val_true_labels, val_loss, val_ce_loss, val_contrastive_loss, val_accuracy = evaluate(
        model, val_loader, device, criterion
    )

    # Save best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_accuracy': val_accuracy,
        }, BEST_MODEL_PATH)
        print(f"âœ“ Best model saved with validation accuracy: {val_accuracy:.4f}")

    # Store metrics
    train_losses.append(train_loss_meter.avg)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)

    # Print summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Train Loss: {train_loss_meter.avg:.4f}")
    print(f"  Train CE Loss: {train_ce_loss.avg:.4f}")
    print(f"  Train Contrastive Loss: {train_contrastive_loss.avg:.4f}")
    print(f"  Train Accuracy: {train_accuracy:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  Val CE Loss: {val_ce_loss:.4f}")
    print(f"  Val Contrastive Loss: {val_contrastive_loss:.4f}")
    print(f"  Val Accuracy: {val_accuracy:.4f}")
    print(f"  Best Val Accuracy: {best_val_accuracy:.4f}")

# Load best model
print(f"\nLoading best model with validation accuracy: {best_val_accuracy:.4f}")
checkpoint = torch.load(BEST_MODEL_PATH)
model.load_state_dict(checkpoint['model_state_dict'])


In [None]:
# Final evaluation on validation set
print("\n" + "="*70)
print("FINAL EVALUATION ON VALIDATION SET")
print("="*70)

val_predictions, val_true_labels, val_loss, val_ce_loss, val_contrastive_loss, val_accuracy = evaluate(
    model, val_loader, device, criterion
)

print(f"\nFinal Validation Accuracy: {val_accuracy:.4f}")
print(f"Final Validation Loss: {val_loss:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(
    val_true_labels,
    val_predictions,
    target_names=['Negative', 'Neutral', 'Positive']
))

# F1 scores
f1_macro = f1_score(val_true_labels, val_predictions, average='macro')
f1_weighted = f1_score(val_true_labels, val_predictions, average='weighted')
print(f"F1 Score (Macro): {f1_macro:.4f}")
print(f"F1 Score (Weighted): {f1_weighted:.4f}")

print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)