In [None]:
import tifffile as tif
from glob import glob
from collections import OrderedDict
import nrrd
import os
import matplotlib.pyplot as plt
import cv2
import re
import numpy as np

In [None]:
def get_filenames(path, ext):
    return sorted(glob(f"{path}/*.{ext}"))

def bgr2rgb(image):
    return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

def natural_sort_key(s, _nsre=re.compile('([0-9]+)')):
    return [int(text) if text.isdigit() else text.lower()
            for text in _nsre.split(s)]

In [None]:
dir_path = "/data/dkermany_data/3D-OCT/first-batch-labeled/GD_NORMAL-3"

In [None]:
vol_paths = [f for f in get_filenames(dir_path, "tif") if "slo" not in f]
seg_paths = [f for f in get_filenames(dir_path, "seg.nrrd")]
assert len(vol_paths) == len(seg_paths)

In [None]:
def overlay_segments(segments, colors):
    """
    Overlay binary masks onto a blank image with specified colors.

    :param masks: List of binary masks (numpy arrays).
    :param colors: List of colors corresponding to each mask.
    :return: Image with masks overlaid.
    """
    # Create a blank image

    final_images = np.zeros(segments.T.shape[:-1] + (3,), dtype=np.uint8)

    for segment, color in zip(segments, colors):
        for i, slice in enumerate(segment.T):
            bgr_image = cv2.cvtColor(slice, cv2.COLOR_GRAY2BGR)
            final_images[i] += bgr_image * color

    return final_images

In [None]:
pattern = re.compile("^Segment\d+")

for vol_path, seg_path in zip(vol_paths, seg_paths):
    print(f"Name: {os.path.splitext(os.path.basename(vol_path))[0]}")

    vol = tif.imread(vol_path)
    seg, header = nrrd.read(seg_path)
    print(f"TIFF type: {type(vol)}, TIFF shape: {vol.shape}, SEG.NRRD type: {type(seg)}, SEG.NRRD shape: {seg.shape}")

    segment_colors = {k.split("_")[0]: v for k, v in header.items() if k.endswith("Color")}
    print(f"segment_colors: {segment_colors}")

    sorted_color_map = sorted(segment_colors.items(), key=lambda x: natural_sort_key(x[0]))
    print(f"sorted_color_map: {sorted_color_map}")

    _, sorted_colors = list(zip(*sorted_color_map))
    print(f"sorted_colors: {sorted_colors}")

    rgb_colors = np.array([[round(255.*float(c)) for c in i.split(" ")] for i in sorted_colors], dtype=np.uint8)
    print(f"rgb_colors: {rgb_colors}")

    seg_vol = overlay_segments(seg, rgb_colors)

    for seg_slice, vol_slice in zip(seg_vol, vol):
        f, ax = plt.subplots(1,2, figsize=(25,14))
        ax[0].imshow(seg_slice)
        ax[1].imshow(vol_slice, cmap="gray")

    