### processes to semantic segmentation binary masks, removes tooth quadrant and type, removes overap with alveolar bone and teeth, sorts into hierarchy where each parent class does not have its own class, but is the sum of its child classes.

In [6]:
import os
import json
import numpy as np
from PIL import Image
from shapely.geometry import Polygon, Point
from shapely.geometry import MultiPolygon
from PIL import ImageDraw
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt


# traverses through all poolygons and class_id to create a binary mask for each class 
def create_binary_mask(polygons, image_size, class_id, clss_map):
    # Create an empty mask with the size of the image
    mask = np.zeros((image_size[1], image_size[0]), dtype=np.uint8)

    # adds 1 to class_id to avoid 0 class
    class_id = [i+1 if i != None else i for i in class_id]
    # devides class_id by 10 to convert them to floats
    class_id = [i/len(clss_map) if i != None else i for i in class_id]
    for poly, cls_id in zip(polygons, class_id):
        if poly:  # Ensure the polygon is not empty
            # continues if class_id is none
            if cls_id is None:
                continue
            # converts poly to a list of tuples
            coords = [(x, y) for x, y in poly]
            # creates Image with inverted image_size
            img = Image.new('L', image_size, 0)
            ImageDraw.Draw(img).polygon(coords, outline=1, fill=1)
            # show mask
            # Image.fromarray(np.array(img)*255).show()
            # combines img and mask but if a pixel value is less than the new value, exept for 0, keep the old value
            mask = np.where((np.array(img)* int(cls_id * 255) > 0) & (mask == 0), np.array(img) * int(cls_id * 255), mask)
            mask = np.where((mask > np.array(img)* int(cls_id * 255)) & (np.array(img)* int(cls_id * 255) != 0), np.array(img) * int(cls_id * 255), mask)
            
            # mask = np.minimum(mask, np.array(img) * int(cls_id * 255))

            # combines img and mask
            # mask += (np.array(img)*int(cls_id*255))
            # mask = np.maximum(mask, np.array(img))  # Combine masks
        else:
            raise ValueError(f"Error: Empty polygon in class {cls_id}")
    # Image.fromarray(mask).show()
    return mask

def process_json_files(input_dir, img_dir, output_dir, class_map, class_map_final):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # creates and saves a class to pixel map to a csv file
    class_map = {v: k for k, v in class_map.items()}
    with open(os.path.join(output_dir, 'class_map.csv'), 'w') as f:
        # adds a header to the csv file
        f.write("class_id,class_name,pixel_val\n")
        # f.write("0,background,0\n")
        for count, name in class_map_final.items():
            # gets positional value for class_map for the current class_map_final class
            map_nme = class_map.get(name)
            if map_nme == None:
                if count == 0:
                    f.write("0,background,0\n")
                else:
                    f.write(f"{count},{name},{None}\n")
            else:
                f.write(f"{count},{name},{str(int(((int(map_nme)+1)/len(class_map))*255))}\n")
        # for k, v in class_map.items():
            # f.write(f"{v+1},{k},{str(int(((int(v)+1)/len(class_map))*255))}\n")
        
    for file in tqdm(os.listdir(input_dir), desc="Processing files"):
        if file.endswith('.json'):
            json_path = os.path.join(input_dir, file)
            with open(json_path, 'r') as f:
                data = json.load(f)
            
            polygons, class_id = [], []
            for annotation in data.get('annotations', []):
                class_id.append(annotation.get('class_id'))
                polygons.append(annotation.get('coordinates'))
            
            # gets image size from image file
            img_path = os.path.join(img_dir, file.split('.')[0]+'.jpg')
            img = Image.open(img_path)
            # show image
            image_size = img.size
    
            # Determine image size from metadata or annotations
            # image_size = tuple(data.get('image_size', ()))  # Expecting (height, width)
            # if not image_size:
            #     print(f"Error: Missing image size in {json_path}")
            #     continue

            # replaces class 4 with null
            class_id = [None if i == 4 else i for i in class_id]
            # moves classes above 4 down by 1
            class_id = [i-1 if i != None and i > 4 else i for i in class_id]
            # class_id = [i - 1 if isinstance(i, int) and i > 4 else i for i in class_id]
            combined_mask = create_binary_mask(polygons, image_size, class_id, class_map)
            # combined_mask = (binary_mask * 255).astype(np.uint8)  # Ensure binary mask is 0 or 255
            # show all unique values in mask
            # print(np.unique(combined_mask))

            # increases the 
            # show mask
            # Image.fromarray(combined_mask).show()
            
            output_path = os.path.join(output_dir, f"{os.path.splitext(file)[0]}.png")
            Image.fromarray(combined_mask).save(output_path)

# class map of classes with associated pixel values. this version is for only positive classes with associated pixel values (no background or parent classes)
class_map = {0: 'composite', 1: 'enamel', 2: 'pulp', 3: 'dentin', 4: 'upper', 5: 'lower'}
# class map that orders the final class layout in relation to their class_map values
class_map_final = {0:'background', 1: 'upper', 2: 'lower', 3: 'tooth', 4: 'pulp', 5: 'dentin', 6: 'enamel', 7: 'composite'}
# Example usage
input_directory = r"I:\Datasets\Hierarchical Datasets\TL_pano\annotated\labels"
img_directory = r"I:\Datasets\Hierarchical Datasets\TL_pano\annotated\images"
output_directory = r"I:\Datasets\Hierarchical Datasets\TL_pano\processed"

process_json_files(input_directory, img_directory, output_directory, class_map, class_map_final)

Processing files:   7%|â–‹         | 14/197 [02:26<31:52, 10.45s/it]


KeyboardInterrupt: 

In [None]:
# 5 fold split


import os
import shutil
from pathlib import Path
import numpy as np

IMG_DIR   = r"D:\\Datasets Main\\5. tooth layer segmentation\\TL_pano\\semantic\\images"
LBL_DIR   = r"D:\\Datasets Main\\5. tooth layer segmentation\\TL_pano\\semantic\\labels"

OUT_DIR   = r"D:\\Datasets Main\\5. tooth layer segmentation\\TL_pano\\semantic\\5_fold"

# Splitting behaviour
INCLUDE_TEST   = True
TEST_FRACTION  = 0.10
FOLDS          = 5

# Single-split train/val sizes (used only when not doing K-fold)
SINGLE_TRAIN_FRACTION = 0.80

IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
LABEL_EXT  = ".png"

# rand seed
SEED = 42

# True, remove existing OUT_DIR.
# False, overwrites dir files.
CLEAR_OUT_DIR = True



def ensure_empty_dir(p: Path):
    if p.exists() and CLEAR_OUT_DIR:
        shutil.rmtree(p)
    p.mkdir(parents=True, exist_ok=True)

# matches images and labels
def pair_images_and_labels(img_dir: Path, lbl_dir: Path):
    # Index images by stem (prefer one file per stem if multiple exts exist)
    stem_to_img = {}
    for p in img_dir.iterdir():
        if p.is_file() and p.suffix.lower() in IMAGE_EXTS:
            stem = p.stem
            # If multiple images share the same stem with different extensions, keep the first encountered.
            stem_to_img.setdefault(stem, p)

    stem_to_lbl = {}
    for p in lbl_dir.iterdir():
        if p.is_file() and p.suffix.lower() == LABEL_EXT:
            stem_to_lbl[p.stem] = p

    # Keep only stems that have both
    stems = sorted(list(set(stem_to_img.keys()) & set(stem_to_lbl.keys())))
    return stems, stem_to_img, stem_to_lbl

# randomises and splits dataset into per fold sets
def split_indices(n_total, test_frac, folds, rng):
    indices = np.arange(n_total)
    rng.shuffle(indices)

    n_test = int(np.floor(test_frac * n_total)) if INCLUDE_TEST else 0
    test_idx = indices[:n_test]
    remain_idx = indices[n_test:]

    if folds is not None and folds >= 2:
        # K-fold split over remaining indices
        chunks = np.array_split(remain_idx, folds)
        fold_val_lists = [chunks[k] for k in range(folds)]
        return test_idx, fold_val_lists, None, None
    else:
        # Single train/val over remaining indices
        n_remain = len(remain_idx)
        n_train = int(np.floor(SINGLE_TRAIN_FRACTION * n_remain))
        single_train_idx = remain_idx[:n_train]
        single_val_idx = remain_idx[n_train:]
        return test_idx, None, single_train_idx, single_val_idx


# copy of image and annotation files into corresponding subfolders
def copy_pairs(stems, stem_to_img, stem_to_lbl, dst_images: Path, dst_labels: Path):
    dst_images.mkdir(parents=True, exist_ok=True)
    dst_labels.mkdir(parents=True, exist_ok=True)
    for s in stems:
        src_img = stem_to_img[s]
        src_lbl = stem_to_lbl[s]
        shutil.copy2(src_img, dst_images / src_img.name)
        shutil.copy2(src_lbl, dst_labels / src_lbl.name)





img_dir = Path(IMG_DIR)
lbl_dir = Path(LBL_DIR)
out_dir = Path(OUT_DIR)

# Basic sanity checks
if not img_dir.exists():
    raise FileNotFoundError(f"Image directory not found: {img_dir}")
if not lbl_dir.exists():
    raise FileNotFoundError(f"Label directory not found: {lbl_dir}")

# Prepare output root
ensure_empty_dir(out_dir)

# Build paired dataset
stems, stem_to_img, stem_to_lbl = pair_images_and_labels(img_dir, lbl_dir)
if len(stems) == 0:
    raise RuntimeError("No paired image/label files found. Check paths and extensions.")

rng = np.random.default_rng(SEED)

test_idx, fold_val_lists, single_train_idx, single_val_idx = split_indices(
    n_total=len(stems),
    test_frac=TEST_FRACTION,
    folds=FOLDS if FOLDS is not None and FOLDS >= 2 else None,
    rng=rng,
)

# Always write test set if requested
if INCLUDE_TEST and len(test_idx) > 0:
    test_dir = out_dir / "test"
    (test_dir / "images").mkdir(parents=True, exist_ok=True)
    (test_dir / "labels").mkdir(parents=True, exist_ok=True)
    test_stems = [stems[i] for i in test_idx]
    copy_pairs(test_stems, stem_to_img, stem_to_lbl, test_dir / "images", test_dir / "labels")

# K-folds or single split
if fold_val_lists is not None:
    # K-fold CV on the remaining data
    remain_set = set(range(len(stems))) - set(test_idx.tolist())
    remain_idx = np.array(sorted(list(remain_set), key=int))
    # Build mapping
    remain_set_global = set(remain_idx.tolist())

    for k, val_idx_chunk in enumerate(fold_val_lists, start=1):
        fold_dir = out_dir / f"fold_{k}"
        train_dir = fold_dir / "train"
        val_dir = fold_dir / "val"
        (train_dir / "images").mkdir(parents=True, exist_ok=True)
        (train_dir / "labels").mkdir(parents=True, exist_ok=True)
        (val_dir / "images").mkdir(parents=True, exist_ok=True)
        (val_dir / "labels").mkdir(parents=True, exist_ok=True)

        # Validation stems for this fold
        val_stems = [stems[i] for i in val_idx_chunk]

        # Training stems are the remaining (non-test, non-val for this fold)
        train_global_idx = list(remain_set_global - set(val_idx_chunk.tolist()))
        train_stems = [stems[i] for i in sorted(train_global_idx, key=int)]

        copy_pairs(train_stems, stem_to_img, stem_to_lbl, train_dir / "images", train_dir / "labels")
        copy_pairs(val_stems,   stem_to_img, stem_to_lbl, val_dir / "images",   val_dir / "labels")

else:
    # Single train/val split
    single_dir = out_dir 
    train_dir = single_dir / "train"
    val_dir   = single_dir / "val"
    (train_dir / "images").mkdir(parents=True, exist_ok=True)
    (train_dir / "labels").mkdir(parents=True, exist_ok=True)
    (val_dir / "images").mkdir(parents=True, exist_ok=True)
    (val_dir / "labels").mkdir(parents=True, exist_ok=True)

    train_stems = [stems[i] for i in single_train_idx]
    val_stems   = [stems[i] for i in single_val_idx]
    copy_pairs(train_stems, stem_to_img, stem_to_lbl, train_dir / "images", train_dir / "labels")
    copy_pairs(val_stems,   stem_to_img, stem_to_lbl, val_dir / "images",   val_dir / "labels")

# Simple summary
n_test = len(test_idx) if INCLUDE_TEST else 0
if fold_val_lists is not None:
    print(f"Done. Total paired samples: {len(stems)}")
    print(f"Test set: {n_test}")
    for k, val_idx_chunk in enumerate(fold_val_lists, start=1):
        remain_count = len(stems) - n_test
        val_count = len(val_idx_chunk)
        train_count = remain_count - val_count
        print(f"Fold {k}: train={train_count}, val={val_count}")
else:
    print(f"Done. Total paired samples: {len(stems)}")
    print(f"Test set: {n_test}")
    print(f"Train: {len(single_train_idx)}, Val: {len(single_val_idx)}")





Done. Total paired samples: 197
Test set: 19
Fold 1: train=142, val=36
Fold 2: train=142, val=36
Fold 3: train=142, val=36
Fold 4: train=143, val=35
Fold 5: train=143, val=35


In [None]:
# checks each dataset, so all fold train and test sets are unique to one another


from pathlib import Path
import re, shutil
from collections import defaultdict


OUT_DIR = Path(r"D:\\Datasets Main\\5. tooth layer segmentation\\TL_pano\\semantic\\5_fold")

IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
LABEL_EXT  = ".png"
RAISE_ON_OVERLAP = True



# gets sorted list of base names
def stems_in_images_dir(images_dir: Path):
    if not images_dir.exists():
        return []
    return sorted([p.stem for p in images_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS])

# gets missing labels and orphan labels as a sorted list
def check_labels_match(images_dir: Path, labels_dir: Path, label_ext: str = LABEL_EXT):
    img_stems = set(stems_in_images_dir(images_dir))
    lbl_stems = set(p.stem for p in labels_dir.iterdir() 
                    if p.is_file() and p.suffix.lower() == label_ext.lower()) if labels_dir.exists() else set()
    missing_labels = sorted(img_stems - lbl_stems)
    orphan_labels  = sorted(lbl_stems - img_stems)
    return missing_labels, orphan_labels

def print_overlap(name_a, set_a, name_b, set_b):
    inter = set_a & set_b
    if inter:
        print(f"OVERLAP between {name_a} and {name_b}: {len(inter)} items")
        for s in list(sorted(inter))[:10]:
            print("   -", s)
    else:
        print(f"No overlap between {name_a} and {name_b}")
    return inter


fold_dirs = sorted([p for p in OUT_DIR.iterdir() if p.is_dir() and re.fullmatch(r"fold_\d+", p.name)],
                   key=lambda p: int(p.name.split("_")[1]))

if not fold_dirs:
    raise RuntimeError(f"No fold_* directories found under {OUT_DIR}. "
                       "If you did a single split, adapt the checker manually.")

test_images_dir = OUT_DIR / "test" / "images"
test_labels_dir = OUT_DIR / "test" / "labels"
test_stems = set(stems_in_images_dir(test_images_dir))
print(f"Detected {len(fold_dirs)} folds. Test set images: {len(test_stems)}")

fold_data = {}
for fd in fold_dirs:
    train_images = fd / "train" / "images"
    train_labels = fd / "train" / "labels"
    val_images   = fd / "val"   / "images"
    val_labels   = fd / "val"   / "labels"

    train_stems = set(stems_in_images_dir(train_images))
    val_stems   = set(stems_in_images_dir(val_images))

    fold_data[fd.name] = {
        "train_images": train_images, "train_labels": train_labels, "train_stems": train_stems,
        "val_images":   val_images,   "val_labels":   val_labels,   "val_stems":   val_stems,
    }

for k, d in fold_data.items():
    print(f"{k}: train={len(d['train_stems'])}, val={len(d['val_stems'])}")


violations = []
fold_names = list(fold_data.keys())
for i in range(len(fold_names)):
    for j in range(i+1, len(fold_names)):
        fi, fj = fold_names[i], fold_names[j]
        inter = print_overlap(f"{fi} val", fold_data[fi]["val_stems"], f"{fj} val", fold_data[fj]["val_stems"])
        if inter:
            violations.append(("val_vs_val", fi, fj, inter))

for fname, d in fold_data.items():
    inter_tv  = print_overlap(f"{fname} train", d["train_stems"], f"{fname} val", d["val_stems"])
    inter_tt  = print_overlap(f"{fname} train", d["train_stems"], "test", test_stems)
    inter_vt  = print_overlap(f"{fname} val", d["val_stems"], "test", test_stems)
    if inter_tv: violations.append(("train_vs_val", fname, fname, inter_tv))
    if inter_tt: violations.append(("train_vs_test", fname, "test", inter_tt))
    if inter_vt: violations.append(("val_vs_test",   fname, "test", inter_vt))

all_train_union = set().union(*[d["train_stems"] for d in fold_data.values()])
all_val_union   = set().union(*[d["val_stems"]   for d in fold_data.values()])
_ = print_overlap("ALL-TRAIN-UNION", all_train_union, "TEST", test_stems)
_ = print_overlap("ALL-VAL-UNION",   all_val_union,   "TEST", test_stems)

if (len(all_val_union) > 0) and (sum(len(d["val_stems"]) for d in fold_data.values()) != len(all_val_union)):
    print("Warning: Sum of per-fold val sizes != size of union(all val). There may be overlap.")

def report_pairing(name, img_dir, lbl_dir):
    missing, orphan = check_labels_match(img_dir, lbl_dir, label_ext=LABEL_EXT)
    if missing:
        print(f"Missing labels for images in {name}: {len(missing)} (showing up to 10)")
        for s in missing[:10]:
            print("   -", s + " (no label)")
    else:
        print(f"All images in {name} have labels")

    if orphan:
        print(f"Labels without matching images in {name}: {len(orphan)} (showing up to 10)")
        for s in orphan[:10]:
            print("   -", s + " (orphan label)")
    else:
        print(f"No orphan labels in {name}")

for fname, d in fold_data.items():
    report_pairing(f"{fname}/train", d["train_images"], d["train_labels"])
    report_pairing(f"{fname}/val",   d["val_images"],   d["val_labels"])

if test_images_dir.exists():
    report_pairing("test", test_images_dir, test_labels_dir)


if violations and RAISE_ON_OVERLAP:
    msg_lines = ["Overlap violations detected:"]
    for kind, a, b, inter in violations:
        msg_lines.append(f" - {kind}: {a} vs {b} -> {len(inter)} overlapping items (e.g. {', '.join(sorted(list(inter))[:5])})")
    raise AssertionError("\n".join(msg_lines))
else:
    if violations:
        print("Completed with overlaps detected (RAISE_ON_OVERLAP=False). See messages above.")
    else:
        print("All checks passed: no overlaps between splits and label pairing is consistent.")


Detected 5 folds. Test set images: 19
fold_1: train=142, val=36
fold_2: train=142, val=36
fold_3: train=142, val=36
fold_4: train=143, val=35
fold_5: train=143, val=35
No overlap between fold_1 val and fold_2 val
No overlap between fold_1 val and fold_3 val
No overlap between fold_1 val and fold_4 val
No overlap between fold_1 val and fold_5 val
No overlap between fold_2 val and fold_3 val
No overlap between fold_2 val and fold_4 val
No overlap between fold_2 val and fold_5 val
No overlap between fold_3 val and fold_4 val
No overlap between fold_3 val and fold_5 val
No overlap between fold_4 val and fold_5 val
No overlap between fold_1 train and fold_1 val
No overlap between fold_1 train and test
No overlap between fold_1 val and test
No overlap between fold_2 train and fold_2 val
No overlap between fold_2 train and test
No overlap between fold_2 val and test
No overlap between fold_3 train and fold_3 val
No overlap between fold_3 train and test
No overlap between fold_3 val and test
N