In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

os.chdir('/content/drive/My Drive/Colab Notebooks/cs7643')
print("Current working directory:", os.getcwd())

In [None]:
import os
import urllib.request
import tarfile

# URL of the Hateful Memes features archive
URL = "https://dl.fbaipublicfiles.com/mmf/data/datasets/hateful_memes/defaults/features/features_2020_10_01.tar.gz"
TAR_PATH = "features_2020_10_01.tar.gz"
EXTRACT_DIR = "detectron.lmdb"

# 1. Download the tar.gz file (if not already downloaded)
if not os.path.exists(TAR_PATH):
    print(f"Downloading from {URL} ...")
    urllib.request.urlretrieve(URL, TAR_PATH)
    print(f"Downloaded to {TAR_PATH}")
else:
    print(f"{TAR_PATH} already exists, skipping download.")

# 2. Extract the tar.gz file (if not already extracted)
if not os.path.exists(EXTRACT_DIR):
    print(f"Extracting {TAR_PATH} ...")
    with tarfile.open(TAR_PATH, "r:gz") as tar:
        tar.extractall()
    print(f"Extraction complete. Files are in: {EXTRACT_DIR}/")
else:
    print(f"{EXTRACT_DIR}/ already exists, skipping extraction.")


In [None]:
!pip install lmdb

In [None]:
import lmdb
import torch
import pickle
from torch.utils.data import Dataset


class HatefulMemesDataset(Dataset):
    def __init__(self, hf_split, lmdb_path, tokenizer):
        """
        hf_split: one split from the HF DatasetDict (e.g. hf_ds['train'])
        lmdb_env: opened lmdb.Environment
        tokenizer: HuggingFace tokenizer (optional)
        """
        self.data = hf_split
        self.lmdb_path = lmdb_path
        self.env = None
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def _get_image_id(self, img_path):
        # "img/40259.png" -> "40259"
        return img_path.split("/")[-1].split(".")[0]

    def _load_visual_feats(self, img_id):
        if self.env is None: # opened separately in each worker
            self.env = lmdb.open(
                self.lmdb_path,
                readonly=True,
                lock=False,
                readahead=False,
                meminit=False
            )
        with self.env.begin(write=False) as txn:
            buf = txn.get(img_id.encode("utf-8"))
            sample = pickle.loads(buf)
        feats = torch.tensor(
            sample["features"], dtype=torch.float32
        )  # (num_boxes, 2048)
        bbox = torch.tensor(sample["bbox"], dtype=torch.float32)  # (num_boxes, 4)
        return feats, bbox

    def __getitem__(self, idx):
        row = self.data[idx]
        text = row["text"]
        img_path = row["img"]
        label = row["label"]

        img_id = self._get_image_id(img_path)
        visual_embeds, _ = self._load_visual_feats(img_id)

        encoded = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=48,
            return_tensors="pt",
        )

        item = {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "token_type_ids": encoded["token_type_ids"].squeeze(0),
            "visual_embeds": visual_embeds,
            "visual_attention_mask": torch.ones(
                visual_embeds.size(0), dtype=torch.long
            ),
            "visual_token_type_ids": torch.zeros(
                visual_embeds.size(0), dtype=torch.long
            ),
            "label": torch.tensor(label, dtype=torch.long),
        }

        return item

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import VisualBertModel, VisualBertConfig, BertTokenizer

lmdb_path = "detectron.lmdb"
dataset = load_dataset("neuralcatcher/hateful_memes")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

loader_train = DataLoader(
    HatefulMemesDataset(dataset["train"], lmdb_path, tokenizer),
    batch_size=32,
    shuffle=True,
    num_workers=16,
)

In [None]:
import torch.nn as nn
from transformers import VisualBertModel

class VisualBertForClassification(nn.Module):
    def __init__(self, num_labels=2):
        super().__init__()
        self.visualbert = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        hidden_size = self.visualbert.config.hidden_size  # usually 768
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(
        self,
        input_ids,
        attention_mask,
        token_type_ids,
        visual_embeds,
        visual_attention_mask,
        visual_token_type_ids,
    ):
        outputs = self.visualbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            visual_embeds=visual_embeds,
            visual_attention_mask=visual_attention_mask,
            visual_token_type_ids=visual_token_type_ids,
        )

        pooled = outputs.pooler_output  # (B, 768)
        logits = self.classifier(pooled)
        return logits

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

model = VisualBertForClassification().to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()


In [None]:
model.train()

for epoch in range(10):  # number of epochs
    total_loss = 0

    for batch in loader_train:
        # Move everything to the GPU
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass through YOUR classifier model
        logits = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            token_type_ids=batch["token_type_ids"],
            visual_embeds=batch["visual_embeds"],
            visual_attention_mask=batch["visual_attention_mask"],
            visual_token_type_ids=batch["visual_token_type_ids"],
        )

        # Compute loss
        loss = criterion(logits, batch["label"])

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")
