In [6]:
import argparse, json, math, random, shutil
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
from random import Random

In [2]:
NAME_TO_GROUP = {
    # Accessory
    "handbag":"Accessory", "umbrella":"Accessory", "backpack":"Accessory", "tie":"Accessory",
    # Animal1
    "bird":"Animal1", "sheep":"Animal1", "cow":"Animal1", "horse":"Animal1",
    # Animal2
    "elephant":"Animal2", "dog":"Animal2", "zebra":"Animal2", "giraffe":"Animal2",
    # Appliance
    "sink":"Appliance", "oven":"Appliance", "refrigerator":"Appliance", "microwave":"Appliance",
    # Electronic
    "cell phone":"Electronic", "tv":"Electronic", "remote":"Electronic", "laptop":"Electronic",
    # Food1
    "banana":"Food1", "carrot":"Food1", "broccoli":"Food1", "donut":"Food1",
    # Food2
    "orange":"Food2", "cake":"Food2", "apple":"Food2", "pizza":"Food2",
    # Furniture
    "chair":"Furniture", "dining table":"Furniture", "potted plant":"Furniture", "couch":"Furniture",
    # Indoor
    "book":"Indoor", "vase":"Indoor", "clock":"Indoor", "teddy bear":"Indoor",
    # Kitchen
    "bottle":"Kitchen", "cup":"Kitchen", "bowl":"Kitchen", "wine glass":"Kitchen",
    # Outdoor
    "traffic light":"Outdoor", "bench":"Outdoor", "stop sign":"Outdoor", "fire hydrant":"Outdoor",
    # Sports1
    "kite":"Sports1", "skis":"Sports1", "sports ball":"Sports1", "surfboard":"Sports1",
    # Sports2
    "skateboard":"Sports2", "tennis racket":"Sports2", "baseball glove":"Sports2", "baseball bat":"Sports2",
    # Vehicle1
    "car":"Vehicle1", "motorcycle":"Vehicle1", "boat":"Vehicle1", "truck":"Vehicle1",
    # Vehicle2
    "bicycle":"Vehicle2", "bus":"Vehicle2", "airplane":"Vehicle2", "train":"Vehicle2",
}

In [8]:
def ensure(path: Path) -> Path:
    path.mkdir(parents=True, exist_ok=True)
    return path

def load_instances(inst_json: Path):
    js = json.loads(inst_json.read_text())
    id2name = {c["id"]: c["name"] for c in js["categories"]}
    first_cat = {}
    for ann in js["annotations"]:          # 등장 순서 기준 “첫 번째 카테고리”
        first_cat.setdefault(ann["image_id"], ann["category_id"])
    return js["images"], first_cat, id2name

def load_captions(cap_json: Path):
    js = json.loads(cap_json.read_text())
    caps = defaultdict(list)
    for ann in js["annotations"]:
        caps[ann["image_id"]].append(ann["caption"].strip())
    return caps

def process_split(split: str, coco_root: Path, out_root: Path):
    img_src_dir = coco_root / ("train" if split == "train" else "val")
    ann_dir     = coco_root / "annotations"
    inst_json   = ann_dir / f"instances_{split}2017.json"
    cap_json    = ann_dir / f"captions_{split}2017.json"

    images, first_cat, id2name = load_instances(inst_json)
    caps = load_captions(cap_json)

    img_out_root = out_root / "image" / split
    txt_out_root = out_root / "text"  / split
    ensure(img_out_root);  ensure(txt_out_root)

    for im in tqdm(images, desc=f"{split} images"):
        img_id   = im["id"]
        src_path = img_src_dir / im["file_name"]

        cat_name = id2name.get(first_cat.get(img_id))
        group    = NAME_TO_GROUP.get(cat_name)     # 미사용 클래스/‘person’ 등은 None
        if group is None:
            continue

        # ----------- 이미지 링크 -----------
        dst_img_dir = ensure(img_out_root / f"{group}_image" / cat_name)
        dst_img     = dst_img_dir / src_path.name
        try:
            dst_img.symlink_to(src_path.resolve())
        except FileExistsError:
            pass

        # ----------- 캡션 TXT -------------
        dst_txt_dir = ensure(txt_out_root / f"{group}_text" / cat_name)
        txt_path = dst_txt_dir / (src_path.stem + ".txt")
        if not txt_path.exists():
            txt_path.write_text("\n".join(caps.get(img_id, [])), encoding="utf-8")

In [7]:
def main():

    coco_root = Path('/data_library/mscoco').expanduser().resolve()
    out_root  = Path('/data_library/mscoco').expanduser().resolve()

    for split in ("train", "val"):
        print(f"➡  Processing {split} split …")
        process_split(split, coco_root, out_root)

    print("\n✅  Done.")

if __name__ == "__main__":
    main()

➡  Processing train split …


train images: 100%|██████████| 118287/118287 [00:10<00:00, 10906.15it/s]


➡  Processing val split …


val images: 100%|██████████| 5000/5000 [00:00<00:00, 6396.00it/s]


✅  Done.





## 각 클래스 별 validset수 

In [None]:
MSC_ROOT   = Path("/data_library/mscoco")   # 루트 경로
THRESHOLD  = 50                             # 최소 허용 개수
valid_dir  = MSC_ROOT / "image" / "test"    # ← 'val' → 'test' 로 바꾼 디렉터리

def iter_subclass_dirs(root: Path):
    """
    image/test/<Group>_image/<Subclass>/ 까지 두 단계 하위 폴더를 순회.
    """
    for group_dir in root.iterdir():
        if not group_dir.is_dir():
            continue
        for subclass_dir in group_dir.iterdir():
            if subclass_dir.is_dir():
                yield group_dir.name, subclass_dir

def main():
    short_list = []
    for gname, sdir in iter_subclass_dirs(valid_dir):
        n_files = sum(1 for f in sdir.iterdir() if f.is_file())
        if n_files < THRESHOLD:
            short_list.append((gname, sdir.name, n_files))

    if not short_list:
        print("✅ 모든 서브클래스가 최소 50개 이상 보유하고 있습니다.")
    else:
        print(f"⚠  valid set < {THRESHOLD} 인 서브클래스 ({len(short_list)}개):")
        for gname, subc, n in sorted(short_list):
            print(f"  • {gname}/{subc}  →  {n} files")

if __name__ == "__main__":
    main()


⚠  valid set < 50 인 서브클래스 (34개):
  • Accessory_image/backpack  →  6 files
  • Accessory_image/handbag  →  5 files
  • Appliance_image/microwave  →  9 files
  • Appliance_image/oven  →  10 files
  • Appliance_image/refrigerator  →  43 files
  • Appliance_image/sink  →  23 files
  • Electronic_image/laptop  →  23 files
  • Electronic_image/remote  →  9 files
  • Food1_image/banana  →  31 files
  • Food1_image/broccoli  →  29 files
  • Food1_image/carrot  →  6 files
  • Food1_image/donut  →  20 files
  • Food2_image/apple  →  12 files
  • Food2_image/cake  →  12 files
  • Food2_image/orange  →  14 files
  • Food2_image/pizza  →  39 files
  • Indoor_image/book  →  7 files
  • Indoor_image/teddy bear  →  45 files
  • Indoor_image/vase  →  30 files
  • Kitchen_image/bowl  →  38 files
  • Kitchen_image/cup  →  33 files
  • Kitchen_image/wine glass  →  5 files
  • Outdoor_image/bench  →  32 files
  • Outdoor_image/fire hydrant  →  35 files
  • Outdoor_image/stop sign  →  43 files
  • Outdoor_i

In [7]:
ROOT_DEFAULT = Path("/data_library/mscoco")
MIN_RATIO    = 0.20          # valid ≥ 20 % *train*

# ---------- 유틸 ----------
def sorted_dirs(path: Path):
    """path 하위의 디렉터리들을 이름순으로 반환"""
    return sorted([p for p in path.iterdir() if p.is_dir()], key=lambda x: x.name)

def iter_subclasses(img_train_root: Path):
    """
    image/train/<Group>_image/<Subclass>/ 두 단계 하위 폴더를
    이름순(그룹→서브클래스)으로 yield.
    """
    for gdir in sorted_dirs(img_train_root):
        for sdir in sorted_dirs(gdir):
            yield gdir.name, sdir.name, sdir

def count_files(p: Path) -> int:
    return sum(1 for f in p.iterdir() if f.is_file())

def ensure(p: Path) -> Path:
    p.mkdir(parents=True, exist_ok=True)
    return p

# ---------- 메인 ----------
def rebalance(root: Path, seed: int):
    rng = Random(seed)       # 고정 시드를 가진 독립 RNG

    img_tr = root / "image" / "train"
    img_te = root / "image" / "test"
    txt_tr = root / "text"  / "train"
    txt_te = root / "text"  / "test"

    moved_summary = []

    for group, subclass, sdir_tr in iter_subclasses(img_tr):
        sdir_te = img_te / group / subclass
        ensure(sdir_te)

        n_tr = count_files(sdir_tr)
        n_te = count_files(sdir_te)

        if n_tr == 0:
            continue

        min_need = math.ceil(n_tr * MIN_RATIO)
        if n_te >= min_need:
            continue

        deficit = min_need - n_te

        # 후보 리스트를 이름순으로 고정한 뒤 RNG 샘플링
        candidates = sorted([f for f in sdir_tr.iterdir() if f.is_file()],
                            key=lambda x: x.name)
        move_files = rng.sample(candidates, deficit)

        # ---------- 실제 이동 ----------
        for src_img in move_files:
            dst_img_dir = ensure(sdir_te)
            shutil.move(src_img, dst_img_dir / src_img.name)

            base = src_img.stem
            txt_src = txt_tr / group.replace("_image", "_text") / subclass / (base + ".txt")
            if txt_src.exists():
                dst_txt_dir = ensure(txt_te / group.replace("_image", "_text") / subclass)
                shutil.move(txt_src, dst_txt_dir / txt_src.name)

        moved_summary.append((group, subclass, n_tr, n_te, deficit))

    # ---------- 요약 ----------
    if moved_summary:
        print(f"\n[Rebalanced] valid < {int(MIN_RATIO*100)} % 이었던 {len(moved_summary)}개 서브클래스")
        for g, s, tr, te, add in moved_summary:
            print(f"  • {g}/{s:<18}  train={tr:<5} → test {te}+{add} = {te+add}")
    else:
        print("✅ 모든 서브클래스가 이미 20 % 이상입니다.")

In [8]:
if __name__ == "__main__":
    root='/data_library/mscoco'
    seed=42

    rebalance(Path(root).expanduser().resolve(), seed)


[Rebalanced] valid < 20 % 이었던 60개 서브클래스
  • Accessory_image/backpack            train=143   → test 6+23 = 29
  • Accessory_image/handbag             train=112   → test 5+18 = 23
  • Accessory_image/tie                 train=2238  → test 91+357 = 448
  • Accessory_image/umbrella            train=2137  → test 93+335 = 428
  • Animal1_image/bird                train=2591  → test 106+413 = 519
  • Animal1_image/cow                 train=1563  → test 62+251 = 313
  • Animal1_image/horse               train=2664  → test 116+417 = 533
  • Animal1_image/sheep               train=1283  → test 53+204 = 257
  • Animal2_image/dog                 train=3999  → test 169+631 = 800
  • Animal2_image/elephant            train=1308  → test 53+209 = 262
  • Animal2_image/giraffe             train=1906  → test 80+302 = 382
  • Animal2_image/zebra               train=1651  → test 81+250 = 331
  • Appliance_image/microwave           train=197   → test 9+31 = 40
  • Appliance_image/oven                train