In [None]:
import os, json, shutil
from pathlib import Path

import torch
from torchvision import models
from PIL import Image
import pandas as pd
from tqdm.auto import tqdm

INPUT_DIR  = r"../data/raw_images"
OUTPUT_DIR = r"../data/sorted"
RESULTS_CSV = r"../classifier_results.csv"

EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
TOP1_THRESHOLD = 0.25  

KEYWORDS = {
    "people": ["person","man","woman","boy","girl","bride","groom","diver","skier","player","guitarist","singer"],
    "bird":   ["bird","eagle","hawk","falcon","parrot","duck","goose","swan","owl","hen","rooster"],
    "animal": ["dog","cat","lion","tiger","leopard","bear","horse","cow","sheep","goat","pig","deer","monkey","rabbit","fox","wolf"],
    "vehicle":["car","taxi","bus","truck","motorcycle","bike","bicycle","train","airplane","helicopter","boat","ship","sail"],
    "indoor": ["sofa","couch","chair","table","desk","lamp","tv","monitor","keyboard","bed"],
    "outdoor":["mountain","beach","forest","valley","cliff","lakeside","seashore","desert"],
    "building":["castle","palace","mosque","church","pagoda","barn","bridge","lighthouse"],
}
DEFAULT_CLASS = "other"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

weights = models.ResNet50_Weights.IMAGENET1K_V2
model = models.resnet50(weights=weights).to(device).eval()
preprocess = weights.transforms()
classes = weights.meta["categories"]

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/user/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:08<00:00, 12.8MB/s]


In [2]:
def list_images(root: str):
    root = Path(root)
    files = []
    for p in root.rglob("*"):
        if p.is_file() and p.suffix.lower() in EXTS:
            files.append(p)
    return sorted(files)

def predict_topk(img_path: Path, k=5):
    try:
        img = Image.open(img_path).convert("RGB")
    except Exception as e:
        return None 
    x = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1)[0]
        top_prob, top_idx = probs.topk(k)
    names = [classes[i] for i in top_idx.tolist()]
    return top_idx.tolist(), names, top_prob.tolist()

def map_to_coarse(names):
    low = [n.lower() for n in names]
    for coarse, kws in KEYWORDS.items():
        for n in low:
            if any(kw in n for kw in kws):
                return coarse
    return DEFAULT_CLASS


In [3]:
files = list_images(INPUT_DIR)
print("Found images:", len(files))

rows = []
for p in tqdm(files, desc="Classifying"):
    pred = predict_topk(p, k=5)
    if pred is None:
        continue
    idxs, names, probs = pred
    top1_prob = probs[0]
    coarse = map_to_coarse(names) if top1_prob >= TOP1_THRESHOLD else DEFAULT_CLASS
    rows.append({
        "path": str(p),
        "top1_name": names[0],
        "top1_prob": top1_prob,
        "top5_name": json.dumps(names),
        "top5_prob": json.dumps(probs),
        "coarse": coarse
    })

df = pd.DataFrame(rows)
df.to_csv(RESULTS_CSV, index=False)
df.head()

Found images: 0


Classifying: 0it [00:00, ?it/s]


In [None]:
out = Path(OUTPUT_DIR)
out.mkdir(parents=True, exist_ok=True)

for _, r in tqdm(df.iterrows(), total=len(df), desc="Copying"):
    src = Path(r["path"])
    dst_dir = out / r["coarse"]
    dst_dir.mkdir(parents=True, exist_ok=True)
    dst = dst_dir / src.name
    if not dst.exists():
        try:
            shutil.copy2(src, dst)
        except Exception as e:
            print("COPY ERROR:", src, "->", dst, e)

print("Done. Classes distribution:")
print(df["coarse"].value_counts())