In [1]:
import random
import matplotlib.pyplot as plt
import cv2

from detectron2.data import DatasetCatalog, MetadataCatalog, get_detection_dataset_dicts, print_instances_class_histogram
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.utils.logger import setup_logger

from data_utils import read_split_file, register_dataset

setup_logger()
%matplotlib inline

### Check Mixed

In [2]:
mixed_sets = read_split_file("data/panels/mixed/split.txt")

# Register mixed datasets
for spl, im_paths in zip(["train", "val", "test"], mixed_sets):
    DatasetCatalog.register(
        f"mixed_{spl}", lambda im_paths=im_paths: register_dataset(im_paths)
    )
    MetadataCatalog.get(f"mixed_{spl}").set(
        thing_classes=["label", "button"],
        thing_colors=[(0, 255, 0), (0, 0, 255)],
    )

In [3]:
trainset = get_detection_dataset_dicts("mixed_train", filter_empty=False)
valset = get_detection_dataset_dicts("mixed_val", filter_empty=False)
testset = get_detection_dataset_dicts("mixed_test", filter_empty=False)
dataset = trainset + valset + testset
print(len(dataset))

registering mixed dataset: 100%|██████████| 76/76 [00:03<00:00, 21.53it/s]

[32m[02/28 23:25:38 d2.data.build]: [0mDistribution of instances among all 2 categories:
[36m|  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|
|   label    | 948          |   button   | 948          |
|            |              |            |              |
|   total    | 1896         |            |              |[0m



registering mixed dataset: 100%|██████████| 10/10 [00:00<00:00, 18.70it/s]

[32m[02/28 23:25:38 d2.data.build]: [0mDistribution of instances among all 2 categories:
[36m|  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|
|   label    | 169          |   button   | 169          |
|            |              |            |              |
|   total    | 338          |            |              |[0m



registering mixed dataset: 100%|██████████| 22/22 [00:00<00:00, 26.68it/s]

[32m[02/28 23:25:39 d2.data.build]: [0mDistribution of instances among all 2 categories:
[36m|  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|
|   label    | 423          |   button   | 423          |
|            |              |            |              |
|   total    | 846          |            |              |[0m
108





### Check the UT-West-Campus Dataset

In [4]:
ut_west_campus_sets = read_split_file("data/panels/ut_west_campus/split.txt")

# Register ut_west_campus datasets
for spl, im_paths in zip(["train", "val", "test"], ut_west_campus_sets):
    DatasetCatalog.register(
        f"ut_west_campus_{spl}", lambda im_paths=im_paths: register_dataset(im_paths)
    )
    MetadataCatalog.get(f"ut_west_campus_{spl}").set(
        thing_classes=["label", "button"],
        thing_colors=[(0, 255, 0), (0, 0, 255)],
    )

In [5]:
trainset = get_detection_dataset_dicts("ut_west_campus_train", filter_empty=False)
valset = get_detection_dataset_dicts("ut_west_campus_val", filter_empty=False)
testset = get_detection_dataset_dicts("ut_west_campus_test", filter_empty=False)
dataset = trainset + valset + testset
print(len(dataset))

registering ut_west_campus dataset: 100%|██████████| 192/192 [00:41<00:00,  4.64it/s]

[32m[02/28 23:26:36 d2.data.build]: [0mDistribution of instances among all 2 categories:
[36m|  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|
|   label    | 2310         |   button   | 2310         |
|            |              |            |              |
|   total    | 4620         |            |              |[0m



registering ut_west_campus dataset: 100%|██████████| 12/12 [00:03<00:00,  3.86it/s]

[32m[02/28 23:26:40 d2.data.build]: [0mDistribution of instances among all 2 categories:
[36m|  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|
|   label    | 174          |   button   | 174          |
|            |              |            |              |
|   total    | 348          |            |              |[0m



registering ut_west_campus dataset: 100%|██████████| 88/88 [00:18<00:00,  4.72it/s]

[32m[02/28 23:26:58 d2.data.build]: [0mDistribution of instances among all 2 categories:
[36m|  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|
|   label    | 982          |   button   | 982          |
|            |              |            |              |
|   total    | 1964         |            |              |[0m
292





Use the function below to generate the tables if its not popping up for you

In [None]:
print_instances_class_histogram(dataset, ["label", "button"])

In [None]:
print_instances_class_histogram(testset, ["label", "button"])

In [None]:
sampled_dicts = random.sample(trainset, 20)

d = sampled_dicts[0]
og_img = cv2.imread(d["file_name"])
plt.imshow(og_img[:, :, ::-1])

In [None]:
def get_optimal_font_scale(text, width):

    for scale in reversed(range(0, 60, 1)):
        textSize = cv2.getTextSize(text, fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=scale/10, thickness=1)
        new_width = textSize[0][0]
        if (new_width <= width):
            return scale/10
    return 1

In [None]:
import json
import os

from data_utils import generate_bbox, generate_gt_mask_coords

symbol_map = {"<|>": "open", ">|<": "close", "^": "alarm", "&": "call", "#": "stop"}

annos = json.load(open("data/panels/mixed/annotations.json"))
for d in sampled_dicts:
    print(d["file_name"])
    img = cv2.imread(d["file_name"])
    height, width = img.shape[:2]
    visualizer = Visualizer(
        img[:, :, ::-1],
        metadata=MetadataCatalog.get("mixed_train"),
        scale=1,
        instance_mode=ColorMode.SEGMENTATION,
    )
    out = visualizer.draw_dataset_dict(d)
    out = out.get_image()
    img_dict = annos[os.path.basename(d["file_name"])]
    for r in img_dict["regions"]:
        pair = r["region_attributes"].get("pair")
        if pair is None or pair == "":
            continue
        
        pair = pair.rstrip().lower()
        if r["region_attributes"]["category_id"] == "button" and (r["region_attributes"]["pair"] is not None or r["region_attributes"]["pair"] != ""):
            bbox = generate_bbox(*generate_gt_mask_coords(r, height, width), height, width)
            cv2.putText(
                out,
                text=pair if pair not in symbol_map else symbol_map[pair],
                org=(int(bbox[0]) + 5, int((bbox[1] + bbox[3]) / 2)),
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=1,
                color=(255, 0, 0),
                thickness=2,
            )

    cv2.imwrite(f"big_graphic/{os.path.basename(d['file_name'])}", out[:, :, ::-1])
    # plt.imshow(out.get_image())
    # plt.axis("off")
    # plt.show()