In [23]:
from scalabel.label.transforms import poly_to_patch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import json
from pathlib import Path
import numpy as np
from tqdm import tqdm
import os

# Define paths
mode = 'train' # 'val' or 'train'
img_width = 1280
img_height = 720
root_folder = Path('BDDDataset')
images_folder = root_folder / 'images' / mode
labels_folder = root_folder / 'labels'
det_folder = root_folder / 'labels' / 'det' / mode
seg_folder = root_folder / 'labels' / 'seg' / mode

# Create folders if they don't exist
det_folder.mkdir(parents=True, exist_ok=True)
seg_folder.mkdir(parents=True, exist_ok=True)

# Define detection classes
detection_classes = {
    'pedestrian':   0,
    'rider':        1,
    'car':          2,
    'truck':        3,
    'bus':          4,
    'train':        5,
    'motorcycle':   6,
    'bicycle':      7,
    'traffic light': 8,
    'traffic sign': 9,
}

# Load labels (both detectiona and segmentation labels) for all the images
with open(labels_folder / f'{mode}.json', 'r') as f:
    data = json.load(f)

# Remove images that are not in the dataset
img_names = [img['name'] for img in data]
actual_imgs = list(img.name for img in images_folder.iterdir())
remove_imgs = [img for img in actual_imgs if img not in img_names]
for img in remove_imgs:
    os.remove(images_folder / img)

for img in tqdm(data):
    img_name = img['name']
    img_path = images_folder / img_name
    
    # Skip images that are not in the subset of the dataset we selected
    if not img_path.exists():
        continue

    # Save detection labels in yolo format
    with open(det_folder / (img_name.replace('.jpg', '.txt')), 'w') as f:
        for label in img['labels']:
            if label['category'] in detection_classes:
                # Convert pixel coordinates to YOLO format
                x_center = (label['box2d']['x1'] + label['box2d']['x2']) / (2 * img_width)
                y_center = (label['box2d']['y1'] + label['box2d']['y2']) / (2 * img_height)
                width = (label['box2d']['x2'] - label['box2d']['x1']) / img_width
                height = (label['box2d']['y2'] - label['box2d']['y1']) / img_height
                f.write(f"{detection_classes[label['category']]} {x_center:.4f} {y_center:.4f} {width:.4f} {height:.4f}\n")
    
    # Save segmentation labels as images
    mask = np.zeros((img_height, img_width, 1), dtype=np.uint8)
    all_patches = []

    for label in img['labels']:
        if label['category'] == 'drivable area':
            for poly in label['poly2d']:
                polygon = [(vertex[0], vertex[1]) for vertex in poly['vertices']]
                # poly_to_patch function reads the types and connects the points using straight line segments or Bezier curves
                patch = poly_to_patch(polygon, types = poly['types'], color = (1.0, 1.0, 1.0), closed=True)
                all_patches.append(patch)

    # save the image in the correct format 
    fig = plt.figure(facecolor="0")
    fig.set_size_inches((img_width / fig.get_dpi()), img_height / fig.get_dpi())
    ax = fig.add_axes([0, 0, 1, 1])
    ax.axis("off")
    ax.set_xlim(0, img_width)
    ax.set_ylim(0, img_height)
    ax.set_facecolor((0, 0, 0, 0))
    ax.invert_yaxis()
    for patch in all_patches:
        ax.add_patch(patch)
    plt.savefig(seg_folder / img_name)

100%|██████████| 69863/69863 [03:09<00:00, 368.88it/s]
