In [None]:
!pip install gdown lightning transformers

# Download the zip file


In [None]:
!gdown https://drive.google.com/uc?id=19o_R8S5f09XFXZIk_F1B7nui7TxNq9mH
!mkdir temp
!mkdir data
!unzip -q data.zip -d temp
!mv temp/data/* data/
!rm -rf temp
!ls -l data

In [None]:
import shutil
from pathlib import Path

import torch
from PIL import Image
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor

In [None]:
IMG_DIR = Path("data/train_data")
TRASH_DIR = Path("data/junk_images")
BATCH_SIZE = 32

In [None]:
def filter_junk_with_clip():
    TRASH_DIR.mkdir(exist_ok=True)
    print("load CLIP model...")
    model_id = "openai/clip-vit-base-patch32"
    model = CLIPModel.from_pretrained(model_id)
    processor = CLIPProcessor.from_pretrained(model_id)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    all_images = [
        f
        for f in IMG_DIR.iterdir()
        if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
        and not f.name.startswith("mask_")
    ]
    print(f"Total images: {len(all_images)}")

    text_queries = [
        "a microscope image of H&E stained tissue, four types: Luminal A, Luminal B, HER2-enriched, Triple-negative",
        "An image unrelated to H&E. For example, an Orcs and Slime",
    ]

    text_inputs = processor(text=text_queries, return_tensors="pt", padding=True).to(
        device
    )
    for i in tqdm(range(0, len(all_images), BATCH_SIZE)):
        batch_paths = all_images[i : i + BATCH_SIZE]
        batch_images = []
        valid_batch_paths = []

        for p in batch_paths:
            try:
                img = Image.open(p).convert("RGB")
                batch_images.append(img)
                valid_batch_paths.append(p)
            except:
                print(f"Skipping broken image: {p.name}")
                continue

        if not batch_images:
            continue

        image_inputs = processor(
            images=batch_images, return_tensors="pt", padding=True
        ).to(device)

        with torch.no_grad():
            outputs = model(
                input_ids=text_inputs.input_ids, pixel_values=image_inputs.pixel_values
            )
            probs = outputs.logits_per_image.softmax(dim=1)

        probs = probs.cpu().numpy()
        for idx, p in enumerate(valid_batch_paths):
            pathology_score = probs[idx][0]
            anime_score = probs[idx][1]

            if anime_score > pathology_score:
                print(
                    f"Moving to trash: {p.name} | pathology: {pathology_score:.4f}, anime: {anime_score:.4f}"
                )
                shutil.move(str(p), str(TRASH_DIR / p.name))

                mask_p = IMG_DIR / f"mask_{p.name}"
                if mask_p.exists():
                    shutil.move(str(mask_p), str(TRASH_DIR / mask_p.name))

In [None]:
# Show the trash image by matplotlib


filter_junk_with_clip()

In [None]:
num_images = len(
    [
        f
        for f in TRASH_DIR.iterdir()
        if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
        and not f.name.startswith("mask_")
    ]
)
print(f"Total junk images moved to trash: {num_images}")

In [None]:
import matplotlib.pyplot as plt


def show_trash_images(num_images):
    trash_images = [
        f
        for f in TRASH_DIR.iterdir()
        if f.suffix.lower() in [".png", ".jpg", ".jpeg"]
        and not f.name.startswith("mask_")
    ]
    print(f"Total trash images: {len(trash_images)}")

    # Limit to requested number of images
    num_to_show = min(num_images, len(trash_images))

    if num_to_show == 0:
        print("No trash images to display")
        return

    # Calculate grid dimensions (5 images per row)
    ncols = 5
    nrows = (num_to_show + ncols - 1) // ncols  # Ceiling division

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 3 * nrows))

    # Flatten axes for easier iteration
    if nrows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()

    for i in range(num_to_show):
        img = Image.open(trash_images[i])
        axes[i].imshow(img)
        axes[i].axis("off")
        axes[i].set_title(f"{trash_images[i].name}", fontsize=8)

    # Hide unused subplots
    for i in range(num_to_show, len(axes)):
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
show_trash_images(num_images=num_images)

In [None]:
# Get the trash image paths
trash_image_paths = [
    f
    for f in TRASH_DIR.iterdir()
    if f.suffix.lower() in [".png", ".jpg", ".jpeg"] and not f.name.startswith("mask_")
]
print(trash_image_paths)