# üîç Hateful Meme Detection - Live Demo
### Graph-based Multimodal Fusion with CLIP + Graphormer

**Instructions:**
1. Click "Runtime" ‚Üí "Run all" (or Ctrl+F9)
2. Wait for model to load (~2 minutes)
3. Upload your meme image when prompted
4. Enter the meme text
5. Get instant prediction!

---

## Step 1: Install Dependencies

In [None]:
!pip install -q transformers torch torchvision pillow numpy scikit-learn gradio

## Step 2: Load Model from Google Drive

**Automatically loads `best_graphormer.pt` from your Google Drive:**

In [None]:
from google.colab import drive
import os

# Mount Google Drive
print("üìÇ Mounting Google Drive...")
drive.mount('/content/drive')

# Path to your model in Google Drive (root folder)
model_path = '/content/drive/MyDrive/best_graphormer.pt'

# Check if model exists
if os.path.exists(model_path):
    print(f"‚úÖ Model found: {model_path}")
    print(f"   Size: {os.path.getsize(model_path) / (1024**3):.2f} GB")
else:
    print("‚ùå Model not found!")
    print("\nPlease make sure 'best_graphormer.pt' is in your Google Drive root folder.")
    print("Or update the path above if it's in a subfolder.")
    raise FileNotFoundError("Model file not found in Google Drive")

## Step 3: Model Architecture Code

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

# Graph builder constants
EDGE_TEXT_TEXT = 0
EDGE_IMG_IMG = 1
EDGE_TEXT_IMG = 2
EDGE_GLOBAL = 3

def build_graph(text_feats, image_feats, top_k=8):
    """Build heterogeneous graph from text/image features."""
    device = text_feats.device
    text_len = text_feats.shape[0]
    img_len = image_feats.shape[0]
    
    # Compute grid size
    grid_size = int(img_len ** 0.5)
    assert grid_size * grid_size == img_len, f"Image patches {img_len} must be square"
    
    node_feats = torch.cat([text_feats, image_feats], dim=0)
    global_feat = node_feats.mean(dim=0, keepdim=True)
    x = torch.cat([node_feats, global_feat], dim=0)
    
    edge_index = []
    edge_type = []
    edge_weight = []
    
    # Text-text chain
    for i in range(max(text_len - 1, 0)):
        edge_index.extend([(i, i + 1), (i + 1, i)])
        edge_type.extend([EDGE_TEXT_TEXT, EDGE_TEXT_TEXT])
        edge_weight.extend([0.0, 0.0])
    
    # Image-image grid
    for p in range(img_len):
        row, col = divmod(p, grid_size)
        neighbors = []
        if row > 0: neighbors.append((row - 1) * grid_size + col)
        if row < grid_size - 1: neighbors.append((row + 1) * grid_size + col)
        if col > 0: neighbors.append(row * grid_size + (col - 1))
        if col < grid_size - 1: neighbors.append(row * grid_size + (col + 1))
        
        for n in neighbors:
            src = text_len + p
            dst = text_len + n
            edge_index.append((src, dst))
            edge_type.append(EDGE_IMG_IMG)
            edge_weight.append(0.0)
    
    # Text-image edges (L2-normalized cosine similarity)
    if text_len > 0 and img_len > 0:
        text_norm = F.normalize(text_feats, p=2, dim=-1)
        image_norm = F.normalize(image_feats, p=2, dim=-1)
        sim = torch.matmul(text_norm, image_norm.t())
        
        k = min(top_k, img_len)
        topk_vals, topk_idx = torch.topk(sim, k=k, dim=1)
        for t in range(text_len):
            for j in range(k):
                patch_idx = topk_idx[t, j].item()
                weight = topk_vals[t, j].item()
                t_node = t
                p_node = text_len + patch_idx
                edge_index.extend([(t_node, p_node), (p_node, t_node)])
                edge_type.extend([EDGE_TEXT_IMG, EDGE_TEXT_IMG])
                edge_weight.extend([weight, weight])
    
    # Global connections
    global_idx = text_len + img_len
    for node in range(global_idx):
        edge_index.extend([(global_idx, node), (node, global_idx)])
        edge_type.extend([EDGE_GLOBAL, EDGE_GLOBAL])
        edge_weight.extend([0.0, 0.0])
    
    edge_index_tensor = torch.tensor(edge_index, device=device, dtype=torch.long).t()
    edge_type_tensor = torch.tensor(edge_type, device=device, dtype=torch.long)
    edge_weight_tensor = torch.tensor(edge_weight, device=device, dtype=torch.float)
    
    text_indices = torch.arange(text_len, device=device, dtype=torch.long)
    image_indices = torch.arange(img_len, device=device, dtype=torch.long) + text_len
    
    return {
        "x": x.float(),
        "edge_index": edge_index_tensor,
        "edge_type": edge_type_tensor,
        "edge_weight": edge_weight_tensor,
        "text_indices": text_indices,
        "image_indices": image_indices,
        "global_index": torch.tensor(global_idx, device=device, dtype=torch.long),
    }


class GraphormerLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout),
        )
    
    def forward(self, x, attn_bias, node_mask):
        B, N, _ = x.shape
        qkv = self.qkv(x).view(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
        scores = scores + attn_bias
        
        key_mask = node_mask.unsqueeze(1).unsqueeze(2)
        scores = scores.masked_fill(~key_mask, float("-inf"))
        
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        
        out = out.transpose(1, 2).contiguous().view(B, N, self.d_model)
        out = self.out_proj(out)
        out = self.dropout(out)
        x = self.norm1(x + out)
        
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        return x


class GraphormerModel(nn.Module):
    def __init__(self, input_dim, d_model=256, num_heads=4, num_layers=3, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        
        self.node_proj = nn.Linear(input_dim, d_model)
        self.layers = nn.ModuleList(
            [GraphormerLayer(d_model, num_heads, dropout) for _ in range(num_layers)]
        )
        self.edge_type_emb = nn.Embedding(4, num_heads)
        
        self.conflict_eps = 1e-6
        self.mlp = nn.Sequential(
            nn.Linear(d_model * 3 + 1, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1),
        )
    
    def _build_attention_bias(self, edge_index_list, edge_type_list, edge_weight_list, max_nodes, device):
        bias = torch.zeros(len(edge_index_list), self.num_heads, max_nodes, max_nodes, device=device)
        et_emb = self.edge_type_emb.weight
        
        for b, (edge_index, edge_type, edge_weight) in enumerate(zip(edge_index_list, edge_type_list, edge_weight_list)):
            if edge_index.numel() == 0:
                continue
            src = edge_index[0]
            dst = edge_index[1]
            et = edge_type
            ew = edge_weight
            bias[b, :, src, dst] += et_emb[et].transpose(0, 1)
            
            text_img_mask = et == EDGE_TEXT_IMG
            if text_img_mask.any():
                sim_vals = ew[text_img_mask].unsqueeze(0)
                bias[b, :, src[text_img_mask], dst[text_img_mask]] += sim_vals
        return bias
    
    def forward(self, node_feats, node_mask, text_mask, image_mask, global_indices,
                edge_index_list, edge_type_list, edge_weight_list):
        device = node_feats.device
        max_nodes = node_feats.size(1)
        
        attn_bias = self._build_attention_bias(
            edge_index_list, edge_type_list, edge_weight_list, max_nodes, device
        )
        
        x = self.node_proj(node_feats)
        for layer in self.layers:
            x = layer(x, attn_bias, node_mask)
        
        text_mask_f = text_mask.float()
        image_mask_f = image_mask.float()
        text_count = text_mask_f.sum(dim=1, keepdim=True).clamp_min(1.0)
        image_count = image_mask_f.sum(dim=1, keepdim=True).clamp_min(1.0)
        
        text_pool = (x * text_mask_f.unsqueeze(-1)).sum(dim=1) / text_count
        image_pool = (x * image_mask_f.unsqueeze(-1)).sum(dim=1) / image_count
        
        batch_indices = torch.arange(x.size(0), device=device)
        global_embeds = x[batch_indices, global_indices]
        
        conflict = 1.0 - F.cosine_similarity(text_pool, image_pool, dim=-1, eps=self.conflict_eps)
        conflict = conflict.unsqueeze(-1)
        
        combined = torch.cat([global_embeds, text_pool, image_pool, conflict], dim=-1)
        logits = self.mlp(combined).squeeze(-1)
        return logits

print("‚úÖ Model architecture loaded")

## Step 4: Load Model

In [None]:
from transformers import CLIPModel, CLIPProcessor
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load checkpoint
print("\nüì• Loading trained model...")
checkpoint = torch.load(model_path, map_location=device)
cfg = checkpoint["config"]

# Load CLIP
clip_backbone = cfg.get("clip_backbone", "openai/clip-vit-base-patch32")
clip_layer_idx = cfg.get("clip_layer_idx", -1)

print(f"Loading CLIP: {clip_backbone}")
processor = CLIPProcessor.from_pretrained(clip_backbone)
clip_model = CLIPModel.from_pretrained(clip_backbone).to(device)

if "clip_state" in checkpoint:
    clip_model.load_state_dict(checkpoint["clip_state"])
    print("‚úì Loaded fine-tuned CLIP")

clip_model.eval()

# Projection layers
shared_dim = cfg["input_dim"]
text_dim = cfg.get("text_dim", shared_dim)
vision_dim = cfg.get("vision_dim", 768)

text_proj = nn.Identity() if text_dim == shared_dim else nn.Linear(text_dim, shared_dim)
image_proj = nn.Linear(vision_dim, shared_dim)
text_proj = text_proj.to(device)
image_proj = image_proj.to(device)

if "text_proj_state" in checkpoint:
    text_proj.load_state_dict(checkpoint["text_proj_state"])
if "image_proj_state" in checkpoint:
    image_proj.load_state_dict(checkpoint["image_proj_state"])

text_proj.eval()
image_proj.eval()

# Graphormer
model = GraphormerModel(
    input_dim=shared_dim,
    d_model=cfg["d_model"],
    num_heads=cfg["num_heads"],
    num_layers=cfg["num_layers"],
    dropout=cfg.get("dropout", 0.1),
).to(device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

top_k = cfg.get("top_k", 4)

print(f"\n‚úÖ Model loaded successfully!")
print(f"   Layers: {cfg['num_layers']}, Dim: {cfg['d_model']}, Top-K: {top_k}")

## Step 5: Prediction Function

In [None]:
def predict_meme(image, text):
    """Predict if a meme is hateful."""
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    # Process
    inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = clip_model(**inputs, output_hidden_states=True, return_dict=True)
        
        text_hidden = outputs.text_model_output.hidden_states[clip_layer_idx]
        vision_hidden = outputs.vision_model_output.hidden_states[clip_layer_idx]
        attention_mask = inputs["attention_mask"]
        
        # Build graph
        txt_len = int(attention_mask[0].sum().item())
        txt_start = 1
        txt_end = max(txt_len - 1, txt_start)
        text_feats = text_hidden[0, txt_start:txt_end]
        text_feats = text_proj(text_feats)
        
        image_feats = vision_hidden[0, 1:]
        image_feats = image_proj(image_feats)
        
        graph = build_graph(text_feats, image_feats, top_k)
        
        # Prepare batch
        node_feats = graph["x"].unsqueeze(0)
        N = node_feats.size(1)
        node_mask = torch.ones(1, N, dtype=torch.bool, device=device)
        
        text_mask = torch.zeros(1, N, dtype=torch.bool, device=device)
        text_mask[0, graph["text_indices"]] = True
        
        image_mask = torch.zeros(1, N, dtype=torch.bool, device=device)
        image_mask[0, graph["image_indices"]] = True
        
        global_indices = graph["global_index"].unsqueeze(0)
        
        # Predict
        logits = model(
            node_feats, node_mask, text_mask, image_mask, global_indices,
            [graph["edge_index"]], [graph["edge_type"]], [graph["edge_weight"]]
        )
        
        prob = torch.sigmoid(logits).item()
    
    return prob

print("‚úÖ Prediction function ready")

## Step 6: Interactive Demo - Upload & Test!

**Upload a meme image and enter its text below:**

In [None]:
from IPython.display import display, HTML
import ipywidgets as widgets

# Upload widget
uploader = widgets.FileUpload(accept='image/*', multiple=False)
text_input = widgets.Textarea(placeholder='Enter meme text here...', description='Meme Text:')
button = widgets.Button(description='üîé Analyze Meme', button_style='primary')
output_area = widgets.Output()

def on_button_click(b):
    output_area.clear_output()
    
    with output_area:
        if len(uploader.value) == 0:
            print("‚ùå Please upload an image first!")
            return
        
        # Get image
        uploaded_file = list(uploader.value.values())[0]
        image = Image.open(uploaded_file['content'])
        
        # Get text
        text = text_input.value.strip()
        if not text:
            text = "[No text provided]"
        
        print("üîÑ Processing...\n")
        
        # Predict
        prob = predict_meme(image, text)
        
        # Display results
        is_hateful = prob > 0.5
        confidence = prob if is_hateful else (1 - prob)
        
        print("="*60)
        if is_hateful:
            print("üö´ HATEFUL CONTENT DETECTED")
            print("  ‚ö†Ô∏è  This meme contains offensive content")
        else:
            print("‚úÖ SAFE CONTENT")
            print("  ‚úì This meme is appropriate")
        
        print(f"\nüìä Statistics:")
        print(f"   Hateful Probability: {prob:.1%}")
        print(f"   Confidence: {confidence:.1%}")
        print("="*60)
        
        # Display image
        display(image.resize((300, 300)))

button.on_click(on_button_click)

# Display interface
display(HTML("<h2>üì§ Upload Meme & Get Prediction</h2>"))
display(uploader)
display(text_input)
display(button)
display(output_area)

print("\n‚ú® Demo ready! Upload an image and click 'Analyze Meme'")