In [11]:
import os
import numpy as np
import json
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.structures import BoxMode
import random
import cv2
import itertools
from pathlib import Path
from torchvision import transforms
from PIL import Image
import torch

In [2]:
DATASET = Path("./Datasets/training")
IMAGES = DATASET / "images"
GROUNDTRUTH = DATASET / "groundtruth"
IMG_FORMAT = ".png"

In [13]:
def get_img_size(img_path):
    """
    returns height and width of img_path
    """
    image = Image.open(img_path)
    width, height = image.size
    return width, height

def groundtruth_to_sem_seg(img_path):
    """
    Transform a binary .jpg image to torch.tensor for sem_seg
    """
    image = Image.open(img_path)
    tensor_mask = transforms.ToTensor()(image).int().squeeze_()
    return tensor_mask

In [9]:
def get_data_dicts(img_dir):
    """
    Function to return the json dicts to detectron2
    """
    data_dicts = []
    for img_path in IMAGES.glob("**/*"+IMG_FORMAT):
        filename = img_path
        sem_seg_file_name = GROUNDTRUTH / img_path.name
        width, height = get_img_size(img_path)
        img_id = img_path.name.split("_")[1].split(".")[0]
        sem_seg = groundtruth_to_sem_seg(sem_seg_file_name)
        img_dict = {
            "filename": str(filename),
            "sem_seg_file_name": str(sem_seg_file_name),
            "sem_seg": sem_seg,
            "height": height,
            "width": width,
            "img_id": img_id,
        }
        data_dicts.append(img_dict)
    return data_dicts      

from detectron2.data import DatasetCatalog, MetadataCatalog

In [14]:
training_dir = Path("./Datasets/training/")
dataset_dicts = get_data_dicts(training_dir)

In [6]:
DatasetCatalog.register("road_training", get_data_dicts)

In [15]:
MetadataCatalog.get("road_training").set(thing_classes=["road"],
                                        stuff_classes=["roads"])
road_metadata = MetadataCatalog.get("road_training")

In [None]:
for d in random.sample(dataset_dicts, 3):
    print(d)
    img = cv2.imread(d["filename"])
    visualizer = Visualizer(img[:, :, ::-1], metadata=road_metadata, 
                            scale=0.5)
    #vis = visualizer.draw_dataset_dict(d)
    vis = visualizer.draw_sem_seg(sem_seg=d["sem_seg"])
    cv2.imshow(d["img_id"] ,vis.get_image()[:, :, ::-1])
    cv2.waitKey(0)
    cv2.destroyAllWindows()

{'filename': 'Datasets/training/images/satImage_083.png', 'sem_seg_file_name': 'Datasets/training/groundtruth/satImage_083.png', 'sem_seg': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int32), 'height': 400, 'width': 400, 'img_id': '083'}


In [13]:
road_metadata.stuff_classes[0]

'roads'

In [8]:
from PIL import Image
img_path = Path("./Datasets/training/groundtruth/satImage_001.png")
img = Image.open(img_path)

In [13]:
from torchvision import transforms
transforms.ToTensor()(img).squeeze_().shape

torch.Size([400, 400])