In [3]:
import os
import shutil
from pathlib import Path
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel

# Define folder paths
INPUT_FOLDER = "input_folder"
OUTPUT_FOLDER = "output_folder"
OUTPUT_SUBFOLDERS = {
    "pie_chart": "pie_charts",
    "bar_chart": "bar_charts",
    "line_chart": "line_charts",
    "table": "tables"
}

# Create output folders
for folder in OUTPUT_SUBFOLDERS.values():
    Path(os.path.join(OUTPUT_FOLDER, folder)).mkdir(parents=True, exist_ok=True)

# Initialize CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Define chart/table class prompts
class_prompts = [
    "a pie chart",
    "a bar chart",
    "a line chart",
    "a table"
]

# Encode class prompts
with torch.no_grad():
    text_inputs = processor(text=class_prompts, return_tensors="pt", padding=True).to(device)
    text_features = model.get_text_features(**text_inputs)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# Supported image extensions
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}

# Process each image in input folder
for img_path in Path(INPUT_FOLDER).iterdir():
    if img_path.suffix.lower() not in image_extensions:
        continue

    try:
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        inputs = processor(images=image, return_tensors="pt").to(device)

        # Get image embedding
        with torch.no_grad():
            image_features = model.get_image_features(**inputs)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # Compute similarity scores
        similarity = (image_features @ text_features.T).softmax(dim=-1)
        best_class_idx = similarity.argmax().item()
        best_class = class_prompts[best_class_idx].split()[1]  # pie, bar, line, table

        # Move image to corresponding output folder
        dest_folder = OUTPUT_SUBFOLDERS[f"{best_class}_chart"] if best_class != "table" else OUTPUT_SUBFOLDERS["table"]
        dest_path = os.path.join(OUTPUT_FOLDER, dest_folder, img_path.name)
        shutil.copy(img_path, dest_path)
        print(f"Moved {img_path.name} to {dest_folder}")

    except Exception as e:
        print(f"Failed to process {img_path.name}: {e}")


Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Moved IMG-1.jpg to tables
Moved IMG-2.png to bar_charts
Moved IMG-3.png to line_charts
Moved IMG-4.jpg to tables
Moved IMG-5.jpg to pie_charts
Moved IMG-6.png to bar_charts
Moved IMG-7.jpg to tables
Moved IMG-8.png to pie_charts
