In [1]:
import os
import sys
import cv2
import time
import json
import random
import pickle
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm
from pprint import pprint
import matplotlib.pyplot as plt
from collections import Counter

from PIL import Image
import tensorflow as tf

---

In [2]:
## 공통(모든) 함수

def get_file_names(name):
    return f"{name}.png", f"{name}.txt"

def read_image(name):
    img = cv2.imread(name)
    
    return img if len(img.shape) == 3 else np.expand_dims(img, -1)

def get_size(li):
    w = li[0] - li[1]
    h = li[2] - li[3]
    return abs(w*h)

def read_labels(name, bound=32*32):
    labels = list()
    
    with open(name, 'r') as f:
        for line in f.readlines():
            label, x1,y1, x2,y2 = line.split()
            numbers = list(map(int, map(float, [x1,y1,x2,y2])))
            
            if get_size(numbers) <= bound:
                labels.append((label, *numbers))
    
    return labels

def get_image(image, h, w, reshape):
    return image[h-reshape:h, w-reshape:w, :]

def get_minmax(n1, n2):
    return min(n1, n2), max(n1, n2)

def get_iou(y, p):
    ix1, ax1 = get_minmax(y[0], p[0])
    ix2, ax2 = get_minmax(y[1], p[1])
    iy1, ay1 = get_minmax(y[2], p[2])
    iy2, ay2 = get_minmax(y[3], p[3])
    
    inner = abs(ix2 - ax1) * abs(iy2 - ay1)
    outer = abs(ax2 - ix1) * abs(ay2 - iy1)
    
    try:
        return inner / outer
    except ZeroDivisionError:
        return 1

def max0(x):
    return max(x, 0)

def get_labels(labels, h, w, reshape): #TODO: sdfajdf;ajksdf;ajdfkjasdkfj
    start_y, end_y = h - reshape, h
    start_x, end_x = w - reshape, w
    selects = list()
    
    for label in labels:
        gt, x1,x2, y1,y2 = label
        
        if x1 > end_x or y1 > end_y or x2 < start_x or y2 < start_y:
            continue
        
        iou = get_iou(
            [x1,x2, y1,y2],
            [max(x1, start_x), min(x2, end_x), max(y1, start_y), min(y2, end_y)]
        )
        
        if iou >= IOU_THRESHOLD:
            selects.append([
                gt, 
                max0(x1 - start_x), x2 - start_x,
                max0(y1 - start_y), y2 - start_y
            ])
    
    return selects

def iter_resizing(image, labels, reshape=224, stride=112):
    height, width, channel = image.shape
    
    for h in range(reshape, height, stride):
        for w in range(reshape, width, stride):
            yield get_image(image, h, w, reshape), get_labels(labels, h, w, reshape)

def save_data(image, labels, index):
    cv2.imwrite(f"{FOLDER}/{index:06}.png", image)
    with open(f"{FOLDER}/{index:06}.txt", 'w') as f:
        f.write("\n".join([
            " ".join(map(str, label)) for label in labels
        ]))
        
def visual(file, color=(255,0,0), thickness=2):
    print(file)
    gts = list()
    image_name, label_name = get_file_names(f"{FOLDER}/{file}")
    image, labels = read_image(image_name), read_labels(label_name)
    
    for label in labels:
        gt, x1,x2,y1,y2 = [label[0], *[int(d) for d in label[1:]]]
        gts.append(gt)
        image = cv2.rectangle(image, (x1,y1), (x2,y2), color, thickness)
    
    pprint(dict(Counter(gts)))
    cv2.imshow(file, image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

---

## Data Save
- labels 구성 : '정답, x1, x2, y1, y2'로 이루어져 있다.

In [3]:
FOLDER = f"data-{datetime.date.today()}" # 데이터 저장할 폴더
load_path = os.path.join(os.getcwd(), 'save_dota_aug_train_change')
IOU_THRESHOLD = 0.5

os.makedirs(FOLDER, exist_ok=True)

In [None]:
file_list = sorted(set([
    os.path.join(load_path, file[:-4]) 
    for file in os.listdir(load_path)
]))

index = 0
label_sizes = list()
for i, name in tqdm(enumerate(file_list)):
    image_name, label_name = get_file_names(name)
    image, labels = read_image(image_name), read_labels(label_name)
    
    for re_img, re_label in iter_resizing(image, labels, 224, 112):
        if re_label:
            save_data(re_img, re_label, index)
            index += 1

# Train data 총 3528개
# 초당 2~10개 평균 5개로 환산시 700초 정도 -> 12분정도 소요

1127it [05:13,  5.36it/s]

----

## Data Visualize

In [None]:
FOLDER = "data-2021-09-11"
start, cnt = 0, 10
for i in range(start, start + cnt):
    visual(f"{i:06}")

In [None]:
# 수정 전 기존 파일 확인
## FOLDER = "save_dota_aug_train_change"
## FILE = "P0002_augment"
## visual(FILE)

---

In [None]:
# 모델 메모리 예측?

# def s(x):
#     return x ** 2

# def cal_memory(x, byte=8):
#     return (
#         s(x) * 3 + 
#         s(x) * 64 *2 +
#         s(x/2) * 128*2 +
#         s(x/4) * 256*3 +
#         s(x/8) * 512*3 +
#         s(x/16) * 512*3 +
        
#         s(x*(9/8))*(256+512+512) + 
#         s(x/8)*2 +
        
#         s(x/8) * (9*1+9*4) +
#         s(7)*256 +
#         4096 + 4096 + 4096
#         ) * byte

# def b2gb(x):
#     return x / 1024 / 1024 / 1024