In [None]:
import os
import zipfile
import shutil
import random
import yaml
import cv2
from pathlib import Path
from PIL import Image, ImageEnhance

# ======== KONFIGURASI (Colab compatible) ========
import sys
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    # Install required packages
    !pip install pillow pyyaml opencv-python --quiet
    from google.colab import files
    images_zip = "/content/images.zip"
    labels_zip = "/content/labels.zip"
    data_yaml = "/content/data.yaml"
    output_dir = "/content/balanced_dataset"
    splits = ["train", "val", "test"]
    random.seed(42)
    print("Please upload images.zip, labels.zip, and data.yaml if not already present in /content.")
    for fname in [images_zip, labels_zip, data_yaml]:
        if not os.path.exists(fname):
            print(f"Upload {os.path.basename(fname)}:")
            files.upload()
else:
    images_zip = "D:\\TBE\\AI\\trainerV0.10\\enhanced data\\images.zip"
    labels_zip = "D:\\TBE\\AI\\trainerV0.10\\enhanced data\\labels.zip"
    data_yaml = "D:\\TBE\\AI\\trainerV0.10\\enhanced data\\data.yaml"
    output_dir = "D:\\TBE\\AI\\trainerV0.10\\enhanced data\\balanced_dataset"
    splits = ["train", "val", "test"]
    random.seed(42)
# =============================

# --- Utilitas path ---
def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def rel_from_labels_root(label_path: Path, labels_root: Path) -> Path:
    """Relatifkan path label terhadap root 'labels' pertama."""
    rel = label_path.relative_to(labels_root)
    # Jika di dalam zip ada nested 'labels/...', buang lapisan pertama.
    parts = rel.parts
    if parts and parts[0].lower() == "labels":
        rel = Path(*parts[1:])
    return rel  # contoh: train/sub/xxx.txt atau xxx.txt

def find_image_for_label(label_path: Path, labels_root: Path, images_root: Path):
    """Cari gambar pendamping untuk satu file label, robust terhadap nested 'images/' di zip."""
    rel = rel_from_labels_root(label_path, labels_root)      # train/xxx.txt
    stem = rel.with_suffix("")                               # train/xxx
    # Kandidat root image: images_root/..., images_root/images/...
    candidate_roots = [images_root, images_root / "images"]
    # Uji beberapa ekstensi
    for base in candidate_roots:
        for ext in [".jpg", ".jpeg", ".png", ".JPG", ".PNG"]:
            cand = base / stem.with_suffix(ext)
            if cand.exists():
                return cand
    return None

def split_of_label(label_rel: Path):
    """Ambil split (train/val/test) dari path relatif label, jika ada."""
    return label_rel.parts[0] if len(label_rel.parts) >= 2 else None

# ==== Fungsi augmentasi photometric (aman untuk label) ====
def augment_image_photometric(img_path: Path) -> Image.Image:
    img = Image.open(str(img_path)).convert("RGB")
    if random.random() < 0.9:
        img = ImageEnhance.Brightness(img).enhance(random.uniform(0.8, 1.25))
    if random.random() < 0.9:
        img = ImageEnhance.Contrast(img).enhance(random.uniform(0.85, 1.2))
    return img

# ==== IO Label ====
def read_yolo_labels(lbl_path: Path):
    items = []
    if not lbl_path.exists():
        return items
    with open(lbl_path, "r") as f:
        for ln in f:
            p = ln.strip().split()
            if len(p) >= 5:
                cid = int(p[0]); cx, cy, w, h = map(float, p[1:5])
                items.append((cid, cx, cy, w, h))
    return items

def write_yolo_labels(lbl_path: Path, labels):
    with open(lbl_path, "w") as f:
        for (cid, cx, cy, w, h) in labels:
            f.write(f"{cid} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")

# ==== Copy-paste objek ke file baru ====
def copy_paste_object_to_new_file(src_img_path: Path, src_lbl_path: Path,
                                  tgt_img_path: Path, tgt_lbl_path: Path,
                                  out_img_path: Path, out_lbl_path: Path,
                                  target_cid: int):
    # Baca sumber & target
    img_src = cv2.imread(str(src_img_path))
    img_tgt = cv2.imread(str(tgt_img_path))
    if img_src is None or img_tgt is None:
        return False

    h_src, w_src = img_src.shape[:2]
    h_tgt, w_tgt = img_tgt.shape[:2]

    # Salin label target ke file output lebih dulu
    tgt_labels = read_yolo_labels(tgt_lbl_path)
    write_yolo_labels(out_lbl_path, tgt_labels)

    # Ambil semua bbox class target dari sumber
    src_labels = read_yolo_labels(src_lbl_path)
    src_boxes = [(cid, cx, cy, w, h) for (cid, cx, cy, w, h) in src_labels if cid == target_cid]
    if not src_boxes:
        # Tidak ada object class ini di sumber
        cv2.imwrite(str(out_img_path), img_tgt)  # tetap salin gambar target
        return True

    # Pilih salah satu bbox untuk ditempel
    cid, cx, cy, w, h = random.choice(src_boxes)
    # Konversi ke piksel
    x1 = int((cx - w/2) * w_src); y1 = int((cy - h/2) * h_src)
    x2 = int((cx + w/2) * w_src); y2 = int((cy + h/2) * h_src)
    x1 = max(0, min(x1, w_src - 1)); y1 = max(0, min(y1, h_src - 1))
    x2 = max(1, min(x2, w_src));     y2 = max(1, min(y2, h_src))
    if x2 <= x1 or y2 <= y1:
        cv2.imwrite(str(out_img_path), img_tgt)
        return True

    crop = img_src[y1:y2, x1:x2]
    if crop.size == 0:
        cv2.imwrite(str(out_img_path), img_tgt)
        return True

    # Skala & posisi acak
    scale = random.uniform(0.6, 1.2)
    new_w = max(1, int(crop.shape[1] * scale))
    new_h = max(1, int(crop.shape[0] * scale))
    crop = cv2.resize(crop, (new_w, new_h))
    if new_w >= w_tgt or new_h >= h_tgt:
        new_w = min(new_w, w_tgt - 1); new_h = min(new_h, h_tgt - 1)
        crop = cv2.resize(crop, (new_w, new_h))
    tx = random.randint(0, w_tgt - new_w)
    ty = random.randint(0, h_tgt - new_h)

    # Tempel ke salinan target
    img_out = img_tgt.copy()
    img_out[ty:ty+new_h, tx:tx+new_w] = crop

    # Tulis bbox baru
    new_cx = (tx + new_w/2) / w_tgt
    new_cy = (ty + new_h/2) / h_tgt
    new_bw = new_w / w_tgt
    new_bh = new_h / h_tgt

    with open(out_lbl_path, "a") as f:
        f.write(f"{cid} {new_cx:.6f} {new_cy:.6f} {new_bw:.6f} {new_bh:.6f}\n")

    # Simpan gambar output
    cv2.imwrite(str(out_img_path), img_out)
    return True

# ====== 1) Ekstrak zip ======
images_root = Path("images")
labels_root = Path("labels")
shutil.rmtree(images_root, ignore_errors=True)
shutil.rmtree(labels_root, ignore_errors=True)
ensure_dir(images_root); ensure_dir(labels_root)

with zipfile.ZipFile(images_zip, "r") as z: z.extractall(images_root)
with zipfile.ZipFile(labels_zip, "r") as z: z.extractall(labels_root)

# ====== 2) Baca data.yaml ======
with open(data_yaml, "r") as f:
    cfg = yaml.safe_load(f)
names = cfg.get("names", [])
num_classes = len(names)

# ====== 3) Hitung jumlah anotasi per kelas ======
label_counts = {i: 0 for i in range(num_classes)}
all_label_files = [p for p in labels_root.rglob("*.txt")]
for lp in all_label_files:
    for (cid, *_rest) in read_yolo_labels(lp):
        label_counts[cid] += 1

print("Jumlah anotasi per kelas:")
for cid in range(num_classes):
    print(f"{cid} ({names[cid] if cid < len(names) else cid}): {label_counts.get(cid,0)}")

max_count = max(label_counts.values()) if label_counts else 0

# ====== 4) Siapkan output (salin dasar) ======
out_img_root = Path(output_dir) / "images"
out_lbl_root = Path(output_dir) / "labels"
if Path(output_dir).exists():
    shutil.rmtree(output_dir)
for sub in ["images", "labels"]:
    for sp in splits:
        ensure_dir(Path(output_dir) / sub / sp)

# salin seluruh isi asal → output, tetap menjaga struktur (kalau ada nested 'images/' atau 'labels/' akan ikut)
# tapi yang dibutuhkan YOLO adalah .../images/train,.../labels/train
# Jadi kita coba salin konten split jika ada; jika tidak, salin semuanya ke root split 'train'
for sp in splits:
    # cari direktori split di labels (wajib untuk mapping)
    cand_lbl_dirs = [labels_root / sp, labels_root / "labels" / sp]
    lbl_dir = next((d for d in cand_lbl_dirs if d.exists()), None)
    cand_img_dirs = [images_root / sp, images_root / "images" / sp]
    img_dir = next((d for d in cand_img_dirs if d.exists()), None)
    if lbl_dir and img_dir:
        shutil.copytree(img_dir, out_img_root / sp, dirs_exist_ok=True)
        shutil.copytree(lbl_dir, out_lbl_root / sp, dirs_exist_ok=True)

# Jika tidak ada struktur split, salin semua file ke split 'train'
if not any((out_lbl_root / sp).glob("**/*.txt") for sp in splits):
    ensure_dir(out_img_root / "train"); ensure_dir(out_lbl_root / "train")
    # salin semua gambar
    for p in images_root.rglob("*.*"):
        if p.suffix.lower() in [".jpg", ".jpeg", ".png"]:
            shutil.copy2(p, out_img_root / "train" / p.name)
    # salin semua label
    for p in labels_root.rglob("*.txt"):
        shutil.copy2(p, out_lbl_root / "train" / p.name)

# Re-scan label files dari OUTPUT (agar penamaan split konsisten)
out_all_lbl_files = [p for p in out_lbl_root.rglob("*.txt")]

# ====== 5) Oversample + augmentasi ======
for cid, count in label_counts.items():
    if count == 0:
        print(f"⚠ Lewati kelas {cid} ({names[cid] if cid < len(names) else cid}), tidak ada data sumber.")
        continue
    if count >= max_count:
        continue

    # file label yang memuat class ini (berbasis OUTPUT agar path split terjaga)
    files_for_cid = []
    for lp in out_all_lbl_files:
        if any(int(line.split()[0]) == cid for line in open(lp, "r")):
            files_for_cid.append(lp)
    if not files_for_cid:
        print(f"⚠ Tidak menemukan file untuk kelas {cid}, lewati.")
        continue

    needed = max_count - count
    print(f"Menambah kelas {cid} ({names[cid] if cid < len(names) else cid}): {needed} sampel tambahan")

    for i in range(needed):
        src_lbl = random.choice(files_for_cid)
        # tentukan split & path image sumber
        src_rel = rel_from_labels_root(src_lbl, out_lbl_root)
        sp = split_of_label(src_rel) or "train"
        src_img = find_image_for_label(src_lbl, out_lbl_root, out_img_root)
        if src_img is None or not src_img.exists():
            continue

        # buat nama output baru
        base = src_img.stem
        out_img_dir = out_img_root / sp
        out_lbl_dir = out_lbl_root / sp
        ensure_dir(out_img_dir); ensure_dir(out_lbl_dir)
        out_img = out_img_dir / f"{base}_aug_{cid}_{i}.jpg"
        out_lbl = out_lbl_dir / f"{base}_aug_{cid}_{i}.txt"

        # 50% photometric augment
        if random.random() < 0.5:
            img_aug = augment_image_photometric(src_img)
            img_aug.save(out_img, quality=90)
            # salin label sumber ke label baru (aman krn photometric)
            shutil.copy2(src_lbl, out_lbl)
        else:
            # copy-paste ke gambar target acak → simpan sebagai file baru
            tgt_lbl = random.choice(out_all_lbl_files)
            tgt_img = find_image_for_label(tgt_lbl, out_lbl_root, out_img_root)
            if tgt_img is None or not tgt_img.exists():
                # fallback: photometric saja
                img_aug = augment_image_photometric(src_img)
                img_aug.save(out_img, quality=90)
                shutil.copy2(src_lbl, out_lbl)
            else:
                ok = copy_paste_object_to_new_file(src_img, src_lbl, tgt_img, tgt_lbl, out_img, out_lbl, cid)
                if not ok:
                    # fallback kalau gagal
                    img_aug = augment_image_photometric(src_img)
                    img_aug.save(out_img, quality=90)
                    shutil.copy2(src_lbl, out_lbl)

print(f"\n✅ Selesai. Dataset seimbang + augmentasi tersimpan di: {output_dir}")