In [None]:
#!/usr/bin/env python3
# Для обучения
from pathlib import Path
from ultralytics import YOLO


DATA_DIR = Path(__file__).parent / "insulators"
ANNOTATION_FILE = DATA_DIR / "annotation_data.json"
COCO_CATEGORIES = DATA_DIR / "coco_categories.json"
OUTPUT_DIR = Path(__file__).parent / "results"
OUTPUT_DIR.mkdir(exist_ok=True)


def train_model(data_yaml_path):
    model = YOLO("yolov8m.pt")
    results = model.train(
        data="yolo_dataset/data.yaml",
        save=True,
        epochs=150,
        imgsz=768,
        batch=4,
        workers=2,

        hsv_h=0.015,
        hsv_s=0.7,
        hsv_v=0.4,

        degrees=5.0,
        translate=0.10,
        scale=0.50,
        shear=3.0,

        flipud=0.0,
        fliplr=0.5,

        mosaic=1.0,
        mixup=0.15,
        copy_paste=0.15,

        close_mosaic=10,
    )

    return results


def main():
    if not ANNOTATION_FILE.exists():
        return

    train_model("yolo_dataset/data.yaml")


if __name__ == "__main__":
    main()


In [12]:
# Для создания файлов с аугментацией
import albumentations as A
import cv2
from pathlib import Path
import numpy as np

IM_DIR = Path("yolo_dataset/images/train")
LB_DIR = Path("yolo_dataset/labels/train")

RARE_IDS = [6, 7]
DUPS = 7

transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.3),
        A.RandomBrightnessContrast(
            brightness_limit=0.15,
            contrast_limit=0.15,
            p=0.5
        ),
        A.ColorJitter(
            brightness=0.1,
            contrast=0.1,
            saturation=0.1,
            hue=0.05,
            p=0.3
        ),
        A.RandomScale(scale_limit=0.15, p=0.3),
        A.GaussNoise(p=0.2),
    ],
    bbox_params=A.BboxParams(
        format="yolo",
        label_fields=["cls_ids"],
        clip=True,
        min_visibility=0.4,
        min_area=0.0,
    ),
)

def read_yolo_labels(path: Path):
    bboxes = []
    cls_ids = []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            cls = int(float(parts[0]))
            x, y, w, h = map(float, parts[1:5])
            cls_ids.append(cls)
            bboxes.append([x, y, w, h])
    return bboxes, cls_ids

def write_yolo_labels(path: Path, bboxes, cls_ids):
    with open(path, "w") as f:
        for c, b in zip(cls_ids, bboxes):
            x, y, w, h = np.clip(b, 0.0, 1.0)
            f.write(f"{c} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")

def augment_rare():
    cnt = 0
    for lbl_path in LB_DIR.glob("*.txt"):
        bboxes, cls_ids = read_yolo_labels(lbl_path)

        if not any(c in RARE_IDS for c in cls_ids):
            continue

        img_path_jpg = IM_DIR / f"{lbl_path.stem}.jpg"
        img_path_png = IM_DIR / f"{lbl_path.stem}.png"
        if img_path_jpg.exists():
            img_path = img_path_jpg
        elif img_path_png.exists():
            img_path = img_path_png
        else:
            continue

        img = cv2.imread(str(img_path))
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        bboxes_np = np.array(bboxes, dtype=np.float32)
        bboxes_np = np.clip(bboxes_np, 0.0, 1.0)
        bboxes = bboxes_np.tolist()

        for i in range(DUPS):
            aug = transform(image=img, bboxes=bboxes, cls_ids=cls_ids)
            aimg = aug["image"]
            abox = aug["bboxes"]
            acls = aug["cls_ids"]

            if len(abox) == 0:
                continue

            out_img = IM_DIR / f"{lbl_path.stem}_aug{i}.jpg"
            out_lbl = LB_DIR / f"{lbl_path.stem}_aug{i}.txt"

            cv2.imwrite(
                str(out_img),
                cv2.cvtColor(aimg, cv2.COLOR_RGB2BGR)
            )
            write_yolo_labels(out_lbl, abox, acls)
            cnt += 1

    print(f"created {cnt} augmented images for rare classes")

if __name__ == "__main__":
    augment_rare()


created 3584 augmented images for rare classes


In [11]:
# Для удаления файлов с аугментацией
from pathlib import Path

IM_DIR = Path("yolo_dataset/images/train")
LB_DIR = Path("yolo_dataset/labels/train")

def delete_augmented():
    img_deleted = 0
    lbl_deleted = 0

    for ext in ("*.jpg", "*.png"):
        for img_path in IM_DIR.glob(f"*aug*{ext[1:]}"):
            stem = img_path.stem
            lbl_path = LB_DIR / f"{stem}.txt"

            img_path.unlink(missing_ok=True)
            img_deleted += 1

            if lbl_path.exists():
                lbl_path.unlink()
                lbl_deleted += 1

    print(f"deleted {img_deleted} aug images and {lbl_deleted} aug labels")

if __name__ == "__main__":
    delete_augmented()


deleted 533 aug images and 532 aug labels
