In [7]:
import os
import torch
import shutil
from PIL import Image
from tqdm import tqdm
from dataset import YoloTrainConvert

def get_image_shape(image_name, image_dir):
    image = Image.open(os.path.join(image_dir, image_name))
    width, height = image.size
    return width, height

def convert_box(box, image_width, image_height):
    assert len(box) == 4
    x1, y1, x2, y2 = box
    x_center, y_center = (x1+x2)//2, (y1+y2)//2
    width, height = (x2-x1), (y2-y1)
    return [x_center/image_width, y_center/image_height, width/image_width, height/image_height]

def get_label(target, image_width, image_height):
    boxes = target['boxes']
    labels = target['labels']
    rows = []
    for box, label in zip(boxes, labels):
        yolo_box = convert_box(box, image_width, image_height)
        # print(yolo_box)
        row = [label] + yolo_box
        rows.append(row)
    return rows

def write_rows(file_name, file_dir, rows):
    f = open(os.path.join(file_dir, file_name), 'w')
    for row in rows:
        row = list(map(str, row))
        # print(row)
        f.write(' '.join(row)+'\n')
    f.close()

def copy_image(source_path, target_dir):
    shutil.copy2(src=source_path, dst=target_dir)

mat_path = os.path.join('..', 'data', "train answer", 'digitStruct.mat')
print(f'mat_path: {mat_path}')

source_image_dir = os.path.join('..', 'data', 'train')
print(f'source_image_dir: {source_image_dir}')

target_image_dir = os.path.join('..', 'data', 'yolov5_train', 'images')
print(f'target_image_dir: {target_image_dir}')

target_label_dir = os.path.join('..', 'data', 'yolov5_train', 'labels')
print(f'target_label_dir: {target_label_dir}')

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(f'device: {device}')

dataset = YoloTrainConvert(mat_path)

pbar = tqdm(dataset)

for target, image_name in pbar:
    image_width, image_height = get_image_shape(image_name, source_image_dir)
    rows = get_label(target, image_width, image_height)
    label_name = image_name.split('.')[0] + '.txt'
    write_rows(label_name, target_label_dir, rows)
    source_image_path = os.path.join(source_image_dir, image_name)
    copy_image(source_image_path, target_image_dir)
    # break

mat_path: ..\data\train answer\digitStruct.mat
source_image_dir: ..\data\train
target_image_dir: ..\data\yolov5_train\images
target_label_dir: ..\data\yolov5_train\labels


100%|██████████| 33402/33402 [02:59<00:00, 185.92it/s]
