# Loading Libraries

In [14]:
import json
import os
from itertools import combinations

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

from shapely.geometry import Polygon as ShapelyPolygon, mapping
from shapely.validation import make_valid

%matplotlib inline

# Functions

## Combine

In [15]:
def change(txt: list[str] | str) -> str:
    if isinstance(txt, str):
        return txt.replace(" ", "_")
    return combine(txt[0], txt[1])

def combine(t1, t2):
    return change(t1) + "-" + change(t2)

## Make Valid

In [16]:
def ensure_validity(poly: ShapelyPolygon) -> ShapelyPolygon:
    return poly if poly.is_valid else make_valid(poly)

# Main Class

In [17]:
class DataIterator:
    def __init__(self, data: list, host: list[str] | str, target: str, ratio: float = 0):
        self.data = data
        self.host = host
        self.target = target
        self.ratio = ratio
        
        self.counter = {}
        self.perm = ""

In [18]:
class CompositeIterator(DataIterator):
    def __init__(self, data: list, host: list[str] | str, target: str, ratio: float = 0):
        super().__init__(data, host, target, ratio)
        self.class_coordinates = {}
        self.host_counter = 0
        self.run()
    
    def __count_combination(self):
        for c1 in self.class_coordinates[self.host[0]]:
            if len(c1) <= 3:
                continue
            
            h1_poly = ensure_validity(ShapelyPolygon(c1))
            
            if mapping(h1_poly)["type"] != "Polygon":
                continue
            
            for c2 in self.class_coordinates[self.host[1]]:
                if len(c2) <= 3:
                    continue
                
                h2_poly = ensure_validity(ShapelyPolygon(c2))
                
                if mapping(h2_poly)["type"] != "Polygon":
                    continue
                
                if h1_poly.intersects(h2_poly) or h1_poly.touches(h2_poly):
                    combined_poly = h1_poly.union(h2_poly)
                    
                    try:
                        coordinates = list(mapping(combined_poly)["coordinates"][0])
                    except KeyError:
                        continue
                    
                    if len(coordinates) < 2:
                        continue
                    
                    self.host_counter += 1
    
    def __separate_labels(self):
        for image in self.data:
            self.class_coordinates = {k: [] for k in self.host}
            
            for label in image["labels"]:
                if label["category"] in self.host:
                    self.class_coordinates[label["category"]].append(label["coordinates"])
            
            self.__count_combination()
    
    def run(self):
        self.__separate_labels()
    
    def get_counter(self):
        return self.host_counter
    
    def __generate_perm(self):
        np.random.seed(1337)
        self.perm = np.zeros(self.host_counter, dtype=np.uint8)
        ones = round(self.host_counter * self.ratio)
        self.perm[:ones] = 1
        
        np.random.shuffle(self.perm)
        
    
    def get_perm(self):
        self.__generate_perm()
        return self.perm

In [19]:
class Selector:
    def __init__(self, root: str):
        self.root = root
        self.method = "train"
        self.data = []
        self.categories = []

In [23]:
class BDDSelector(Selector):
    def __init__(self, root: str):
        super().__init__(root)        
        self.run()
    
    def __load_data(self):
        print("Loading data", end="...")
        with open(os.path.join(self.root, "labels", "sem_seg", "polygons", "sem_seg_train.json"), "r") as f:
            self.data = json.load(f)
        print("✅")
    
    def __load_categories(self):
        for row in self.data:
            for label in row["labels"]:
                if label["category"] not in self.categories:
                    self.categories.append(label["category"])
        self.categories.sort()
        self.categories = {v: i for i, v in enumerate(self.categories)}
    
    def __fix_data(self):
        print("Fixing data", end="...")
        output = []
        for row in self.data:
            img_info = {
                "name": row["name"],
                "width": 1280,
                "height": 720,
                "labels": []
            }
            
            for label in row["labels"]:
                img_info["labels"].append({
                    "category": label["category"],
                    "coordinates": label["poly2d"][0]["vertices"]
                })
            
            output.append(img_info)
        
        self.data = output
        print("✅")
    
    def run(self):
        self.__load_data()
        self.__load_categories()
        self.__fix_data()
        
        for pair in combinations(self.categories, 2):
            print(combine(*pair), end="\t")
            data_iterator = CompositeIterator(self.data, list(pair), "")
            print(data_iterator.get_counter())

In [None]:
BDDSelector(r"D:\datasets\bdd100k")

Loading data...✅
Fixing data...✅
banner-bicycle	3
banner-billboard	30
banner-bridge	69
banner-building	1678
banner-bus	15
banner-bus_stop	0
banner-car	119
banner-caravan	4
banner-dynamic	6
banner-ego_vehicle	17
banner-fence	21
banner-fire_hydrant	2
banner-garage	1
banner-ground	11
banner-guard_rail	14
banner-lane_divider	6
banner-mail_box	0
banner-motorcycle	0
banner-parking	2
banner-parking_sign	8
banner-person	14
banner-pole	1613
banner-polegroup	4
banner-rail_track	0
banner-rider	0
banner-road	41
banner-sidewalk	43
banner-sky	1080
banner-static	247
banner-street_light	53
banner-terrain	32
banner-traffic_cone	4
banner-traffic_device	16
banner-traffic_light	62
banner-traffic_sign	111
banner-traffic_sign_frame	27
banner-trailer	0
banner-train	0
banner-trash_can	0
banner-truck	24
banner-tunnel	2
banner-unlabeled	0
banner-vegetation	922
banner-wall	11
bicycle-billboard	4
bicycle-bridge	12
bicycle-building	358
bicycle-bus	14
bicycle-bus_stop	0
bicycle-car	238
bicycle-caravan	0
bicycle-dyn