# CutMix 알고리즘을 구현하는데 필요한 함수와 구현 함수가 있는 파일입니다.

In [1]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from copy import deepcopy
import random

In [2]:
# names= {0:'Door', 1:'House', 2:'Roof', 3:'Window'}
# object_name = 1.0
names = {0:'branch', 1:'crown', 2:'fruit', 3:'gnarl', 4:'root', 5:'tree'}
object_name = 5.0
# names = {0:'arm', 1:'eye', 2:'leg', 3:'mouth', 4:'person'}
# object_name = 4.0

In [3]:
def yolo_to_corners(boxes, img_width, img_height):
    box_list = []
    for box in boxes:
        x_center, y_center, width, height = box[0]
        x1 = (x_center - width / 2) * 512
        y1 = (y_center - height / 2) * 512
        x2 = (x_center + width / 2) * 512
        y2 = (y_center + height / 2) * 512
        box_list.append(np.stack([x1, y1, x2, y2], axis=0))
    box_list = np.array(box_list)
    return box_list

In [4]:
def corners_to_yolo(boxes, img_width, img_height):
    box_list = []
    for box in boxes:
        x1, y1, x2, y2 = box[0]
        x_center = (x1 + x2) / 2 / 512
        y_center = (y1 + y2) / 2 / 512
        width = (x2 - x1) / 512
        height = (y2 - y1) / 512
        box_list.append(np.stack([x_center, y_center, width, height], axis=0))
    box_list = np.array(box_list)
    return box_list

In [5]:
def plot_image_with_boxes(image, boxes, labels):
    """
    하나의 이미지에 대해 바운딩 박스를 그리고 레이블을 표시.
    Args:
        image: 원본 이미지 (PyTorch 텐서 (3, H, W) 형식)
        boxes: 이미지에 대한 YOLO 형식의 바운딩 박스
        labels: 이미지의 바운딩 박스 레이블
        names: 클래스 이름 리스트 (선택사항)
    """
    cnt = 0

    image = image.astype(np.uint8)
    
    # YOLO 형식을 corners로 변환
    corners = yolo_to_corners(boxes, 512, 512)

    # 바운딩 박스 그리기
    for box, label in zip(corners, labels):
        cnt = cnt % 2
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
        if names is not None:
            cv2.putText(image, names[label], (x1, y1 + 20* cnt), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
        else:
            cv2.putText(image, str(names), (x1, y1 + 20* cnt), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 1)
        cnt = cnt + 1

    # 이미지 표시
    plt.imshow(image)
    plt.axis('off')
    plt.show()

In [6]:
def plot_images_with_boxes(images, boxes_list, labels_list):
    """
    원본 이미지 2개와 CutMix된 이미지를 표시하고, 바운딩 박스를 그림.
    Args:
        images: 원본 이미지 2개와 CutMix된 이미지 (각각 PyTorch 텐서 (3, H, W) 형식)
        boxes_list: 각 이미지에 대한 YOLO 형식의 바운딩 박스 리스트
        labels_list: 각 이미지의 바운딩 박스 레이블 리스트
    """
    num_images = len(images)
    fig, axs = plt.subplots(1, num_images, figsize=(15, 5))
    cnt = 0
    for i, (image, boxes, labels) in enumerate(zip(images, boxes_list, labels_list)):
        img_height, img_width = image.shape[1], image.shape[2]
        image = (image - image.min()) / (image.max() - image.min())  # 정규화
        image = (image * 255).astype(np.uint8)  # (H, W, 3) 형식으로 변경

        # YOLO 형식을 corners로 변환
        corners = yolo_to_corners(boxes, img_width, img_height)

        # 바운딩 박스 그리기
        for box, label in zip(corners, labels):
            cnt = cnt % 2
            x1, y1, x2, y2 = map(int, box)
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            if names is not None:
                cv2.putText(image, names[label], (x1, y1 + 20* cnt), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 2)
            else:
                cv2.putText(image, str(names), (x1, y1 + 20* cnt), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36, 255, 12), 1)
            cnt = cnt + 1

        axs[i].imshow(image)
        axs[i].axis('off')
    
    plt.show()


In [7]:
def is_non_overlapping(box1, box2):
    """
    두 바운딩 박스가 겹치지 않는지 확인.
    Args:
        box1, box2: (x1, y1, x2, y2) 형식의 바운딩 박스
    Returns:
        겹치지 않으면 True, 겹치면 False
    """
    # Box1이 Box2의 왼쪽에 있는 경우
    if box1[2] <= box2[0]:
        return True
    # Box1이 Box2의 오른쪽에 있는 경우
    if box1[0] >= box2[2]:
        return True
    # Box1이 Box2의 위에 있는 경우
    if box1[3] <= box2[1]:
        return True
    # Box1이 Box2의 아래에 있는 경우
    if box1[1] >= box2[3]:
        return True
    # 겹치는 경우
    return False

def not_overlapping_boxes(corners):
    """
    다른 바운딩 박스와 겹치지 않는 바운딩 박스를 반환.
    Args:
        corners: (x1, y1, x2, y2) 형식의 바운딩 박스 배열
    Returns:
        다른 바운딩 박스와 겹치지 않는 바운딩 박스 리스트
    """
    non_overlapping = []

    for i, box in enumerate(corners):
        is_non_overlapping_with_all = True
        for j, box_other in enumerate(corners):
            if i != j:
                if not is_non_overlapping(box, box_other):
                    is_non_overlapping_with_all = False
                    break
        if is_non_overlapping_with_all:
            non_overlapping.append(box)

    return np.array(non_overlapping)


In [8]:
def change_cutmix(image1, boxes1, labels1, image2, boxes2, labels2):
    h, w = image1.shape[1], image1.shape[2]

    # YOLO형식을 to corners형태로 변환
    corners1 = yolo_to_corners(boxes1, w, h)
    corners2 = yolo_to_corners(boxes2, w, h)
    
    # 2개의 이미지에서 object_name에 해당하는 인덱스를 찾음
    object_indices1 = [i for i, label in enumerate(labels1) if label == object_name]
    object_indices2 = [i for i, label in enumerate(labels2) if label == object_name]

    # 2개의 이미지에서 object_name에 해당하지 않는 인덱스를 찾음
    non_object_indices1 = [i for i, label in enumerate(labels1) if label != object_name]
    non_object_indices2 = [i for i, label in enumerate(labels2) if label != object_name]

    if len(non_object_indices1) == 0 or len(non_object_indices2) == 0:
        raise ValueError(f"Non-object regions not found in one of the images")

    # 2개의 이미지에서 object_name에 해당하는 바운딩 박스 중에서 겹치지 않는 바운딩 박스를 찾음
    not_overlapping_boxes1 = not_overlapping_boxes(corners1[non_object_indices1])
    not_overlapping_boxes2 = not_overlapping_boxes(corners2[non_object_indices2])

    if len(not_overlapping_boxes1) == 0 or len(not_overlapping_boxes2) == 0:
        raise ValueError(f"Not overlapping boxes not found in one of the images")

    # 적당한 값으로 swap 횟수를 설정
    swap_count = min(len(not_overlapping_boxes1), len(not_overlapping_boxes2))
    if swap_count >= 3:
        swap_count = 3

    past_index1 = []
    past_index2 = []

    for _ in range(swap_count):
        index1 = random.choice(range(len(not_overlapping_boxes1)))
        index2 = random.choice(range(len(not_overlapping_boxes2)))
        if index1 in past_index1 and index2 in past_index2:
            continue
        box1 = not_overlapping_boxes1[index1]
        box2 = not_overlapping_boxes2[index2]

        for i, box in enumerate(corners1):
            if np.array_equal(box, box1):
                label_index1 = i
                break

        for i, box in enumerate(corners2):
            if np.array_equal(box, box2):
                label_index2 = i
                break

        box1 = box1.astype(int)
        box2 = box2.astype(int)
        

        
        # Resize and swap 
        resize_box2 = deepcopy(image2[box2[1]:box2[3], box2[0]:box2[2], :])
        resize_box2 = cv2.resize(resize_box2, (box1[2] - box1[0], box1[3] - box1[1]))

        resize_box1 = deepcopy(image1[box1[1]:box1[3], box1[0]:box1[2], :])
        resize_box1 = cv2.resize(resize_box1, (box2[2] - box2[0], box2[3] - box2[1]))

        image2[box2[1]:box2[3], box2[0]:box2[2], :] = resize_box1
        image1[box1[1]:box1[3], box1[0]:box1[2], :] = resize_box2
        
        # 바뀐 object의 레이블 교환
        changed_label1 = labels1[label_index1]
        changed_label2 = labels2[label_index2]

        labels1[label_index1] = changed_label2
        labels2[label_index2] = changed_label1

        past_index1.append(index1)
        past_index2.append(index2)

    return image1, boxes1, labels1, image2, boxes2, labels2

In [9]:
def save_cutmix_results(image1, boxes1, labels1, image2, boxes2, labels2, img1_filename, img2_filename, txt1_filename, txt2_filename):
    """
    이미지와 바운딩 박스 정보를 파일로 저장하는 함수
    Args:
        image1, image2: 수정된 이미지 배열
        boxes1, boxes2: 수정된 바운딩 박스 (YOLO 형식)
        labels1, labels2: 수정된 레이블
        img1_filename, img2_filename: 저장할 이미지 파일 이름
        txt1_filename, txt2_filename: 저장할 텍스트 파일 이름
    """
    # 이미지 저장
    cv2.imwrite(img1_filename, cv2.cvtColor(image1, cv2.COLOR_RGB2BGR))
    cv2.imwrite(img2_filename, cv2.cvtColor(image2, cv2.COLOR_RGB2BGR))

    # 바운딩 박스 정보 저장 (YOLO 형식)
    with open(txt1_filename, 'w') as f:
        for box, label in zip(boxes1, labels1):
            box = box[0]
            x_center, y_center, width, height = box
            f.write(f"{int(label)} {x_center} {y_center} {width} {height}\n")

    with open(txt2_filename, 'w') as f:
        for box, label in zip(boxes2, labels2):
            box = box[0]
            x_center, y_center, width, height = box
            f.write(f"{int(label)} {x_center} {y_center} {width} {height}\n")

In [10]:
# cutmix data 생성
cat = "tree"
path = "/home/connet/jay_deb/DMC-Connet-Team2-Project-1/cutmix/cutmix_images"
path = os.path.join(path, cat)
save_cnt = 0
overlapping_cnt = 0
for phase in ["train", "valid"]:
    join_path = os.path.join(path, phase)
    file_names = os.listdir(join_path+"/images/")
    for file_index in range(0, len(file_names)):
        try:
            image1 = cv2.imread(join_path + "/images/" + file_names[file_index])
            image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
            boxes1 = []  # 첫 번째 이미지의 바운딩 박스 (YOLO 형식)
            labels1 = []  # 첫 번째 이미지의 클래스 레이블
            label_name = file_names[file_index].rsplit(".", 1)[0]
            with open(join_path + "/labels/" + label_name+".txt") as f:
                for line in f.readlines():
                    cn, x_center, y_center, width, height = map(float, line.strip().split())
                    box = np.array([[x_center, y_center, width, height]])
                    boxes1.append(box)
                    labels1.append(cn)
            boxes1 = np.array(boxes1)
            labels1 = np.array(labels1)

            image2 = cv2.imread(join_path + "/images/" + file_names[file_index+3])
            image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
            boxes2 = []  # 첫 번째 이미지의 바운딩 박스 (YOLO 형식)
            labels2 =[]  # 첫 번째 이미지의 클래스 레이블
            label_name = file_names[file_index+3].rsplit(".", 1)[0]
            with open(join_path + "/labels/" + label_name+".txt") as f:
                for line in f.readlines():
                    cn, x_center, y_center, width, height = map(float, line.strip().split())
                    box = np.array([[x_center, y_center, width, height]])
                    boxes2.append(box)
                    labels2.append(cn)
            boxes2 = np.array(boxes2)
            labels2 = np.array(labels2)

            mixed_image1, mixed_boxes1, mixed_labels1, mixed_image2, mixed_boxes2, mixed_labels2 = change_cutmix(image1, boxes1, labels1, image2, boxes2, labels2)
            save_cutmix_results(mixed_image1, mixed_boxes1, mixed_labels1, mixed_image2, mixed_boxes2, mixed_labels2, join_path + f"/images/cutmix_{save_cnt}.jpg", join_path+f"/images/cutmix_{save_cnt+1}.jpg", join_path+f"/labels/cutmix_{save_cnt}.txt", join_path+f"/labels/cutmix_{save_cnt+1}.txt")
            print("saved")
            save_cnt += 2
        except:
            print("overlapping")
            overlapping_cnt += 1
            continue
        

saved
overlapping
saved
saved
overlapping
overlapping
saved
saved
overlapping
saved
overlapping
saved
overlapping
overlapping
saved
overlapping
saved
saved
overlapping
saved
saved
overlapping
overlapping
saved
overlapping
overlapping
saved
saved
saved
saved
overlapping
saved
overlapping
overlapping
saved
overlapping
saved
saved
overlapping
saved
overlapping
saved
saved
overlapping
overlapping
saved
overlapping
overlapping
saved
overlapping
saved
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
saved
overlapping
overlapping
overlapping
saved
overlapping
overlapping
saved
overlapping
saved
overlapping
overlapping
saved
overlapping
overlapping
saved
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
saved
overlapping
overlapping
overlapping
saved
overlapping
overlapping
saved
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
overlapping
saved
overlappin