In [None]:
import warnings

# Suppress the specific Hugging Face token warning
warnings.filterwarnings(
    "ignore", message="Error while fetching `HF_TOKEN` secret value from your vault"
)

from datasets import load_dataset

# Load the public dataset
dataset = load_dataset("neuralcatcher/hateful_memes")
print(dataset)

In [None]:
import pandas as pd
from datasets import Dataset

# Remove duplicates
for i_split, i_data in dataset.items():
    dataset[i_split] = Dataset.from_pandas(
        pd.DataFrame(i_data).drop_duplicates(), preserve_index=False
    )

print(dataset)

total_set = set()
for i_split, i_data in dataset.items():
    print(i_split)
    split_set = set()
    for i_sample in i_data:
        if i_sample["id"] in split_set:
            print(f"duplicate id: {i_sample['id']} in split set")
        else:
            split_set.add(i_sample["id"])
            if i_sample["id"] in total_set:
                print(f"duplicate id: {i_sample['id']} in total set")
            else:
                total_set.add(i_sample["id"])

In [None]:
import sys

for split, data in dataset.items():
    print(f"{split}: {len(data)} examples")
    print(dataset[split].features)
    total_bytes = (
        sum(sys.getsizeof(dataset[split][i]) for i in range(100))
        * len(dataset[split])
        // 100
    )
    print(f"Approximate size of {split} split: {total_bytes / 1e6:.2f} MB")
    print()

In [None]:
# Download and extract img archive

import os
import gdown
import tarfile

if not os.path.exists("img/"):
    gdown.download(
        "https://drive.google.com/uc?id=1VZ2WQrh4MRStFfWRSx0ezYJ_DlcaCGwI",
        "img.tar.gz",
        quiet=False,
    )
    print("Download complete.")

    print("Extracting...")
    with tarfile.open("img.tar.gz", "r:gz") as tar:
        tar.extractall()  # extracts into ./img/ if archive contains a img/ folder
    print("Extraction complete.")

In [None]:
# Fetch missing images if any

import os
import requests

base_url = (
    "https://huggingface.co/datasets/limjiayi/hateful_memes_expanded/resolve/main"
)

for i_split, i_data in dataset.items():
    for i_sample in i_data:
        if not os.path.exists(i_sample["img"]):
            response = requests.get(f"{base_url}/{i_sample['img']}")
            response.raise_for_status()
            with open(i_sample["img"], "wb") as f:
                f.write(response.content)

In [None]:
# Turn img dirs into PIL objects and load them to confirm their existence

from datasets import Image as HFImage

dataset_dir = os.path.abspath(".")


# Turn relative dirs into absolute
def fix_paths(example):
    example["img"] = os.path.join(dataset_dir, example["img"])
    return example


dataset = dataset.map(fix_paths)
dataset = dataset.cast_column("img", HFImage())

for i_split, i_data in dataset.items():
    for i_sample in i_data:
        i_sample["img"].load()

In [None]:
!pip install pytesseract

In [None]:
# Helper func: extract text from meme (optional)

import cv2
import numpy as np
import pytesseract


def extract_text(image):
    image = np.array(image)
    if image.ndim == 3 and image.shape[2] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
    image = cv2.bilateralFilter(image, 5, 55, 60)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    _, image = cv2.threshold(image, 240, 255, 1)

    custom_config = (
        r"--oem 3 --psm 11 -c tessedit_char_whitelist= 'ABCDEFGHIJKLMNOPQRSTUVWXYZ '"
    )
    text = pytesseract.image_to_string(image, lang="eng", config=custom_config)
    return text.replace("\n", " ").replace("  ", " ").rstrip()

In [None]:
# Validate text in memes (optional)

from difflib import SequenceMatcher

index = 70
text_ref = dataset["train"][index]["text"]
text_ext = extract_text(dataset["train"][index]["img"])
ratio = SequenceMatcher(None, text_ref, text_ext).ratio()
while ratio > 0.60:
    index += 1
    text_ref = dataset["train"][index]["text"]
    text_ext = extract_text(dataset["train"][index]["img"])
    ratio = SequenceMatcher(None, text_ref, text_ext).ratio()

# Print the extracted text
print(index)
print(ratio)
print(text_ref)
print(text_ext)
dataset["train"][index]["img"]