In [None]:
import os
import json
from PIL import Image, ImageEnhance
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

# -----------------------------
# 1) DETECT DEVICE
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# -----------------------------
# 2) PREPROCESSING FUNCTION
# -----------------------------
def preprocess_image(
    img_path: str,
    target_size=(384, 384),
    enhance_contrast=1.5,
    enhance_sharpness=2.0,
    apply_noise_removal=True
) -> torch.Tensor:
    """
    Steps 1–3 (PIL/OpenCV) run on CPU.
    Steps 4–6 (convert to tensor, resize, normalize) run on GPU if available.
    Returns a CPU tensor in range [-1, +1], shape [3, target_h, target_w].
    """

    # ------------ CPU STEPS: Load & Enhance ------------
    # 1. Load image (CPU)
    image = Image.open(img_path).convert("RGB")
    orig_w, orig_h = image.size
    tgt_w, tgt_h = target_size

    # 2. Compute scale & pad offsets on CPU
    scale = min(tgt_w / orig_w, tgt_h / orig_h)
    new_w = int(orig_w * scale)
    new_h = int(orig_h * scale)
    image_resized = image.resize((new_w, new_h), Image.BICUBIC)

    canvas = Image.new("RGB", target_size, (0, 0, 0))
    paste_x = (tgt_w - new_w) // 2
    paste_y = (tgt_h - new_h) // 2
    canvas.paste(image_resized, (paste_x, paste_y))

    # 3. Enhance contrast & sharpness (CPU)
    contrast_enhancer = ImageEnhance.Contrast(canvas)
    im_contrasted = contrast_enhancer.enhance(enhance_contrast)
    sharpness_enhancer = ImageEnhance.Sharpness(im_contrasted)
    im_sharp = sharpness_enhancer.enhance(enhance_sharpness)

    # 4. Optional median blur (CPU via OpenCV)
    if apply_noise_removal:
        np_im = cv2.cvtColor(np.array(im_sharp), cv2.COLOR_RGB2BGR)
        np_im = cv2.medianBlur(np_im, 3)
        im_processed = Image.fromarray(cv2.cvtColor(np_im, cv2.COLOR_BGR2RGB))
    else:
        im_processed = im_sharp

    # ------------ GPU STEPS: Convert & Normalize ------------
    # 5. Convert CPU PIL → float tensor in [0,1], then send to GPU
    tensor_cpu = torch.from_numpy(
        np.array(im_processed).transpose(2, 0, 1)  # [H,W,3] → [3,H,W], dtype=uint8
    ).float() / 255.0                               # now in [0,1] on CPU
    tensor = tensor_cpu.to(device)                  # move to GPU if available

    # 6. Resize on GPU to exact target_size (this duplicates PIL resizing but ensures GPU usage)
    tensor = tensor.unsqueeze(0)  # [1,3,H,W]
    tensor = F.interpolate(
        tensor,
        size=target_size,
        mode="bicubic",
        align_corners=False
    )                    # [1,3,target_h,target_w] on GPU
    tensor = tensor.squeeze(0)  # [3,target_h,target_w]

    # 7. Normalize on GPU → range [-1, +1]
    mean = torch.tensor([0.5, 0.5, 0.5], device=device).view(3, 1, 1)
    std  = torch.tensor([0.5, 0.5, 0.5], device=device).view(3, 1, 1)
    tensor = (tensor - mean) / std

    # 8. Move back to CPU before returning/saving
    return tensor.cpu()


# -----------------------------
# 3) SPLIT PROCESSING FUNCTION
# -----------------------------
def process_split(
    split: str,
    data_root: str,
    output_root: str,
    target_size=(384, 384),
    enhance_contrast=1.5,
    enhance_sharpness=2.0,
    apply_noise_removal=True
):
    """
    Processes all images in one split: 'train', 'dev', or 'test'.
    For each JSONL line, reads 'img' (filename), 'label', and 'text'.
    Skips any image if missing. Saves preprocessed tensor + collects metadata.
    """
    assert split in ["train", "dev", "test"], "split must be 'train', 'dev', or 'test'"

    # 3.1. Load JSONL
    jsonl_path = os.path.join(data_root, f"{split}.jsonl")
    if not os.path.isfile(jsonl_path):
        raise FileNotFoundError(f"Cannot find {split}.jsonl in {data_root}")

    # 3.2. Paths
    img_dir = data_root
    split_out = os.path.join(output_root, split)
    images_out = os.path.join(split_out, "images")
    os.makedirs(images_out, exist_ok=True)

    filenames = []
    labels = []
    texts = []

    # 3.3. Iterate and preprocess
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in tqdm(f, desc=f"Preprocessing {split}", unit="img"):
            record = json.loads(line)
            img_filename = record.get("img") or record.get("img_fn")
            if img_filename is None:
                continue

            label = int(record.get("label", 0))
            text = record.get("text", "") or record.get("text_only", "")

            src_path = os.path.join(img_dir, img_filename)
            if not os.path.isfile(src_path):
                # Skip missing files silently
                continue

            # Preprocess and save tensor
            tensor = preprocess_image(
                src_path,
                target_size=target_size,
                enhance_contrast=enhance_contrast,
                enhance_sharpness=enhance_sharpness,
                apply_noise_removal=apply_noise_removal
            )

            base_name = os.path.splitext(img_filename)[0]
            out_path = os.path.join(images_out, base_name + ".pt")

            if not os.path.exists(out_path):
              os.makedirs(os.path.dirname(out_path), exist_ok=True)

            torch.save(tensor, out_path)

            filenames.append(base_name)
            labels.append(label)
            texts.append(text)

    # 3.4. Save metadata
    meta = {
        "filenames": filenames,             # e.g. ["000001", "000002", …]
        "labels": torch.tensor(labels),     # shape [N], dtype=torch.long
        "texts": texts                      # list of captions
    }
    torch.save(meta, os.path.join(split_out, "metadata.pt"))


# -----------------------------
# 4) MAIN: PROCESS ALL SPLITS
# -----------------------------
if __name__ == "__main__":
    # 4.1. Quick check: does img/ exist and contain files?
    DATA_ROOT = "./HatefulMemes/HatefulMemes"
    OUTPUT_ROOT = "./HatefulMemes/HatefulMemes/HatefulMemes_processed/"

    img_check = os.path.join(DATA_ROOT, "img")
    if not os.path.isdir(img_check) or len(os.listdir(img_check)) == 0:
        raise FileNotFoundError(
            f"Folder '{img_check}' either does not exist or is empty. "
            "Place your .jpg files in that directory and re-run."
        )

    # 4.2. Ensure OUTPUT_ROOT exists
    os.makedirs(OUTPUT_ROOT, exist_ok=True)

    # 4.3. Preprocess each split
    for split_name in ["train", "dev", "test"]:
        process_split(
            split=split_name,
            data_root=DATA_ROOT,
            output_root=OUTPUT_ROOT,
            target_size=(384, 384),
            enhance_contrast=1.5,
            enhance_sharpness=2.0,
            apply_noise_removal=True
        )

    print("✅ All splits processed. Saved under:", OUTPUT_ROOT)


Preprocessing train: 8500img [07:56, 17.83img/s]
Preprocessing dev: 500img [00:28, 17.61img/s]
Preprocessing test: 1000img [00:57, 17.45img/s]

✅ All splits processed. Saved under: ./HatefulMemes/HatefulMemes/HatefulMemes_processed/



