# Download Dataset

In [None]:
# ruff: noqa: E722, T201
import shutil
import tarfile
import urllib.request
from pathlib import Path

ONLY_LABELS = True

data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

datasets = [
    "Task01_BrainTumour",
    "Task02_Heart",
    "Task03_Liver",
    "Task04_Hippocampus",
    "Task05_Prostate",
    "Task06_Lung",
    "Task07_Pancreas",
    "Task08_HepaticVessel",
    "Task09_Spleen",
    "Task10_Colon",
]
for dataset in datasets:
    if not (data_dir / dataset).is_dir():
        url = f"https://msd-for-monai.s3-us-west-2.amazonaws.com/{dataset}.tar"
        tar_path = data_dir / f"{dataset}.tar"
        urllib.request.urlretrieve(url, tar_path)  # noqa: S310
        with tarfile.open(tar_path, "r:*") as tar:
            tar.extractall(data_dir, filter="data")
        if ONLY_LABELS:
            shutil.rmtree(data_dir / dataset / "imagesTr")
            shutil.rmtree(data_dir / dataset / "imagesTs")
        tar_path.unlink()
    print(f"Downloaded {dataset}.")

# Results

In [None]:
import json
import random
import re
from pathlib import Path
from time import perf_counter

import cc3d
import fill_voids
import itk
import nibabel as nib
import numpy as np
import trimesh
from numpy.typing import NDArray
from tqdm import tqdm

from edge_mender import EdgeMender, MeshGenerator

itk.Image  # pyright: ignore[reportAttributeAccessIssue]  # noqa: B018

import itk.CuberillePython  # noqa: E402

# Prep paths
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
mesh_dir = results_dir / "meshes"
mesh_dir.mkdir(exist_ok=True)


# Load previous results
results_path = results_dir / "results.json"
if results_path.exists():
    with results_path.open() as f:
        results = json.load(f)
else:
    results = {}

# NOTE: https://github.com/InsightSoftwareConsortium/ITKCuberille/issues/90
cuberille_kernel_crash = [
    "Task05_Prostate",
    "pancreas_08",
    "pancreas_12",
    "pancreas_40",
    "pancreas_41",
    "Task08_HepaticVessel",
    "spleen_2",
    "colon_069",
]


def mesh_is_malformed(mesh: trimesh.Trimesh) -> bool:
    """Check if a mesh is malformed."""
    try:
        EdgeMender(mesh).validate(spacing=(1.0, 1.0, 1.0))
    except ValueError:
        return True
    return False


# Loop through files
paths = [
    path
    for path in Path("data").glob("*/labelsTr/*.nii.gz")
    if not path.name.startswith("._")  # Skip macOS metadata files
    # and np.random.rand() < 0.1
]
random.shuffle(paths)
pbar = tqdm(paths, smoothing=0.1)
for path in pbar:
    pbar.set_description(str(path))
    name = re.sub(r".nii$", "", path.stem, count=1)
    # Skip if already processed
    if name in results:
        continue

    result = {
        "path": str(path),
        "cuberille": {},
        "surface_nets": {},
        "nmes": {},
        "volume": {},
        "repair": {},
    }

    # Load data
    img = nib.load(path)  # pyright: ignore[reportPrivateImportUsage]
    data: NDArray = img.get_fdata()  # pyright: ignore[reportAttributeAccessIssue]

    result["data_size"] = path.stat().st_size
    result["size"] = data.size
    result["shape"] = list(data.shape)
    result["voxel_count"] = np.count_nonzero(data).item()

    # Mesh with Surface Nets
    try:
        start = perf_counter()
        surface_nets_mesh = MeshGenerator.to_mesh_surface_nets(data)
        result["surface_nets"]["time"] = perf_counter() - start
        result["surface_nets"]["malformed"] = mesh_is_malformed(surface_nets_mesh)
    except KeyboardInterrupt:
        break
    except:
        result["surface_nets"]["failed"] = True

    # If the mesh crashes the program, use Surface Nets instead
    if any(x in str(path) for x in cuberille_kernel_crash):
        result["cuberille"]["failed"] = True

        surface_nets_mesh.export(mesh_dir / f"{name}.stl")
        data = cc3d.largest_k(data, k=1, connectivity=6).astype(np.uint8)
        data = fill_voids.fill(data, in_place=True)
        try:
            surface_nets_mesh = MeshGenerator.to_mesh_surface_nets(data)
        except KeyboardInterrupt:
            break
        except:
            result["surface_nets"]["repaired"] = False
        else:
            surface_nets_mesh.export(mesh_dir / f"{name}_repaired.stl")
            result["surface_nets"]["repaired"] = not mesh_is_malformed(
                surface_nets_mesh,
            )
        if not result["surface_nets"]["repaired"]:
            continue
        mesh = surface_nets_mesh
    else:
        # Mesh with Cuberille
        try:
            start = perf_counter()
            cuberille_mesh = MeshGenerator.to_mesh_cuberille(data)
            result["cuberille"]["time"] = perf_counter() - start
            cuberille_mesh.export(mesh_dir / f"{name}.stl")
            result["cuberille"]["malformed"] = mesh_is_malformed(cuberille_mesh)
        except KeyboardInterrupt:
            break
        except:
            result["cuberille"]["failed"] = True
            continue

        # Attempt to fix malformed meshes
        if result["cuberille"]["malformed"]:
            data = cc3d.largest_k(data, k=1, connectivity=6).astype(np.uint8)
            try:
                cuberille_mesh = MeshGenerator.to_mesh_cuberille(data)
            except KeyboardInterrupt:
                break
            except:
                result["cuberille"]["repaired"] = False
            else:
                cuberille_mesh.export(mesh_dir / f"{name}_repaired.stl")
                result["cuberille"]["repaired"] = not mesh_is_malformed(cuberille_mesh)
            if not result["cuberille"]["repaired"]:
                continue
        mesh = cuberille_mesh

    # Measure properties before
    result["nmes"]["before"] = len(EdgeMender(mesh).find_non_manifold_edges()[2])
    result["volume"]["before"] = mesh.volume

    # Repair mesh
    try:
        mesh._cache.clear()  # noqa: SLF001
        start = perf_counter()
        EdgeMender(mesh).repair()
        result["repair"]["time"] = perf_counter() - start
        result["repair"]["failed"] = False
    except KeyboardInterrupt:
        break
    except:
        result["repair"]["failed"] = True

    # Measure properties after
    result["nmes"]["after"] = len(EdgeMender(mesh).find_non_manifold_edges()[2])
    result["volume"]["after"] = mesh.volume
    result["is_watertight"] = mesh.is_watertight

    # Save results
    results[name] = result
    with results_path.open("w") as f:
        json.dump(results, f, indent=4)

# Statistics and Plots

In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Load previous results
results_dir = Path("results")
plots_dir = results_dir / "plots"
plots_dir.mkdir(exist_ok=True)
results_path = results_dir / "results.json"
if results_path.exists():
    with results_path.open() as f:
        results = json.load(f)
paths = [
    path
    for path in Path("data").glob("*/labelsTr/*.nii.gz")
    if not path.name.startswith("._")  # Skip macOS metadata files
]

data_sizes = np.array([v["data_size"] for v in results.values()])
sizes = np.array([v["size"] for v in results.values()])
voxel_counts = np.array([v["voxel_count"] for v in results.values()])
data_summary = pd.DataFrame(
    {
        "values": [
            len(paths),
            data_sizes.mean(),  # TODO: Convert to MB or KB
            data_sizes.min(),
            data_sizes.max(),
            data_sizes.std(),
            sizes.mean(),
            sizes.min(),
            sizes.max(),
            sizes.std(),
            voxel_counts.mean(),
            voxel_counts.min(),
            voxel_counts.max(),
            voxel_counts.std(),
        ],
    },
    index=[
        "Count",
        "Average Data Size",
        "Minimum Data Size",
        "Maximum Data Size",
        "Data Size Standard Deviation",
        "Average Size",
        "Minimum Size",
        "Maximum Size",
        "Size Standard Deviation",
        "Average Voxel Count",
        "Minimum Voxel Count",
        "Maximum Voxel Count",
        "Voxel Count Standard Deviation",
    ],
)

cuberille_repair_times = np.array(
    [v["cuberille"]["time"] for v in results.values() if "time" in v["cuberille"]],
)
surface_nets_repair_times = np.array(
    [
        v["surface_nets"]["time"]
        for v in results.values()
        if "time" in v["surface_nets"]
    ],
)
mesh_generation_summary = pd.DataFrame(
    {
        "Meshed": [
            len(cuberille_repair_times),
            len(surface_nets_repair_times),
        ],
        "Average Time (s)": [
            cuberille_repair_times.mean(),
            surface_nets_repair_times.mean(),
        ],
        "Min Time (s)": [
            cuberille_repair_times.min(),
            surface_nets_repair_times.min(),
        ],
        "Max Time (s)": [
            cuberille_repair_times.max(),
            surface_nets_repair_times.max(),
        ],
        "Standard Deviation Time (s)": [
            cuberille_repair_times.std(),
            surface_nets_repair_times.std(),
        ],
        "# Failed": [
            sum(1 for v in results.values() if v["cuberille"].get("failed")),
            sum(1 for v in results.values() if v["surface_nets"].get("failed")),
        ],
        "# Malformed": [
            sum(1 for v in results.values() if v["cuberille"].get("malformed")),
            sum(1 for v in results.values() if v["surface_nets"].get("malformed")),
        ],
        "# Fixed": [
            sum(1 for v in results.values() if v["cuberille"].get("repaired")),
            sum(1 for v in results.values() if v["surface_nets"].get("repaired")),
        ],
    },
    index=["Cuberille", "Surface Nets"],
)

nmes_before = np.array(
    [v["nmes"]["before"] for v in results.values() if "before" in v["nmes"]],
)
volumes_before = np.array(
    [
        v["volume"]["after"] - v["volume"]["before"]
        for v in results.values()
        if "before" in v["volume"] and "after" in v["volume"]
    ],
)
repair_times = np.array(
    [
        v["repair"]["time"]
        for v in results.values()
        if "time" in v["repair"] and v["nmes"]["before"] > 0
    ],
)
time_per_repair_times = np.array(
    [
        v["repair"]["time"] / v["nmes"]["before"]
        for v in results.values()
        if "time" in v["repair"] and v["nmes"]["before"] > 0
    ],
)
repair_summary = pd.DataFrame(
    {
        "values": [
            len(nmes_before),
            np.sum(nmes_before > 0),
            nmes_before.mean(),
            nmes_before.min(),
            nmes_before.max(),
            nmes_before.std(),
            np.mean(
                [v["nmes"]["after"] for v in results.values() if "after" in v["nmes"]],
            ),
            volumes_before.mean(),
            volumes_before.min(),
            volumes_before.max(),
            volumes_before.std(),
            np.mean(
                [
                    v["volume"]["after"] - v["volume"]["before"]
                    for v in results.values()
                    if "before" in v["volume"] and "after" in v["volume"]
                ],
            ),
            repair_times.mean(),
            repair_times.min(),
            repair_times.max(),
            repair_times.std(),
            time_per_repair_times.mean(),
            time_per_repair_times.min(),
            time_per_repair_times.max(),
            time_per_repair_times.std(),
            np.mean(
                [
                    v["repair"]["time"]
                    / (
                        v["cuberille"]["time"]
                        if "time" in v["cuberille"]
                        else v["surface_nets"]["time"]
                    )
                    for v in results.values()
                    if "time" in v["repair"] and v["nmes"]["before"] > 0
                ],
            ),
            sum(
                1
                for v in results.values()
                if not v["repair"]["failed"]
                and v["nmes"]["after"] == 0
                and v["is_watertight"]
            )
            / len(nmes_before),
        ],
    },
    index=[
        "Count",
        "Number with NMEs",
        "Average # of NMEs Before",
        "Minimum # of NMEs Before",
        "Maximum # of NMEs Before",
        "Standard Deviation # of NMEs Before",
        "Average # of NMEs After",
        "Average Volume",
        "Minimum Volume",
        "Maximum Volume",
        "Standard Deviation Volume",
        "Average Volume Change",
        "Average Repair Time (s)",
        "Minimum Repair Time (s)",
        "Maximum Repair Time (s)",
        "Standard Deviation Repair Time (s)",
        "Total Average Time per Repair (s)",
        "Minimum Average Time per Repair (s)",
        "Maximum Average Time per Repair (s)",
        "Standard Deviation Average Time per Repair (s)",
        "Average Repair Time Relative to Mesh Generation Time",
        "Success Rate",
    ],
)

with pd.ExcelWriter(results_dir / "results.xlsx", engine="openpyxl") as writer:
    data_summary.to_excel(writer, sheet_name="Data")
    mesh_generation_summary.to_excel(writer, sheet_name="Mesh Generation")
    repair_summary.to_excel(writer, sheet_name="Repair", header=False)

In [None]:
# TODO: Add repair statistics excluding meshes with no NMEs
# TODO: Add relative repair time vs. mesh generation time statistics

In [None]:
# Sizes
sizes = [v["size"] for v in results.values()]

plt.hist(sizes, bins=20, edgecolor="black")
plt.xlabel("Size (voxels)")
plt.ylabel("Frequency")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "size_histogram.png")
plt.show()

In [None]:
# Size vs. Mesh Time per Method
cuberille_sizes = [v["size"] for v in results.values() if "time" in v["cuberille"]]
cuberille_times = [
    v["cuberille"]["time"] for v in results.values() if "time" in v["cuberille"]
]
surface_nets_sizes = [
    v["size"] for v in results.values() if "time" in v["surface_nets"]
]
surface_nets_times = [
    v["surface_nets"]["time"] for v in results.values() if "time" in v["surface_nets"]
]

plt.scatter(cuberille_sizes, cuberille_times, label="Cuberille")
plt.scatter(surface_nets_sizes, surface_nets_times, label="Surface Nets")
plt.xlabel("Size (voxels)")
plt.ylabel("Time (seconds)")
plt.legend()
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "size_vs_mesh_time.png")
plt.show()

In [None]:
# Voxel Count vs. Mesh Time per Method
cuberille_voxel_counts = [
    v["voxel_count"] for v in results.values() if "time" in v["cuberille"]
]
cuberille_times = [
    v["cuberille"]["time"] for v in results.values() if "time" in v["cuberille"]
]
surface_nets_voxel_counts = [
    v["voxel_count"] for v in results.values() if "time" in v["surface_nets"]
]
surface_nets_times = [
    v["surface_nets"]["time"] for v in results.values() if "time" in v["surface_nets"]
]

plt.scatter(cuberille_voxel_counts, cuberille_times, label="Cuberille")
plt.scatter(surface_nets_voxel_counts, surface_nets_times, label="Surface Nets")
plt.xlabel("Foreground Voxels")
plt.ylabel("Time (seconds)")
plt.legend()
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "voxel_count_vs_mesh_time.png")
plt.show()

In [None]:
# Size vs. Non-manifold Edges
sizes = [v["size"] for v in results.values() if "before" in v["nmes"]]
nmes = [v["nmes"]["before"] for v in results.values() if "before" in v["nmes"]]

plt.scatter(sizes, nmes)
plt.xlabel("Size (voxels)")
plt.ylabel("Non-Manifold Edges")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "size_vs_nmes.png")
plt.show()

In [None]:
# Voxel Count vs. Non-Manifold Edges
voxel_counts = [v["voxel_count"] for v in results.values() if "before" in v["nmes"]]
nmes = [v["nmes"]["before"] for v in results.values() if "before" in v["nmes"]]

plt.scatter(voxel_counts, nmes)
plt.xlabel("Foreground Voxels")
plt.ylabel("Non-Manifold Edges")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "voxel_counts_vs_nmes.png")
plt.show()

In [None]:
# Size vs. Repair Time
sizes = [
    v["size"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]
times = [
    v["repair"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]

plt.scatter(sizes, times)
plt.xlabel("Size (voxels)")
plt.ylabel("Time (seconds)")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "size_vs_repair_time.png")
plt.show()

In [None]:
# Voxel Count vs. Repair Time
voxel_counts = [
    v["voxel_count"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]
times = [
    v["repair"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]

plt.scatter(voxel_counts, times)
plt.xlabel("Foreground Voxels")
plt.ylabel("Time (seconds)")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "voxel_count_vs_repair_time.png")
plt.show()

In [None]:
# Non-Manifold Edges vs. Repair Time
nmes = [
    v["nmes"]["before"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]
times = [
    v["repair"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]

plt.scatter(nmes, times)
plt.xlabel("Non-Manifold Edges")
plt.ylabel("Time (seconds)")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "nmes_vs_repair_time.png")
plt.show()

In [None]:
# Size vs. Time per Repair
sizes = [
    v["size"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]
times = [
    v["repair"]["time"] / v["nmes"]["before"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]

plt.scatter(sizes, times)
plt.xlabel("Size (voxels)")
plt.ylabel("Time (seconds)")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "size_vs_time_per_repair.png")
plt.show()

In [None]:
# Voxel Count vs. Time per Repair
voxel_counts = [
    v["voxel_count"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]
times = [
    v["repair"]["time"] / v["nmes"]["before"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]

plt.scatter(voxel_counts, times)
plt.xlabel("Foreground Voxels")
plt.ylabel("Time (seconds)")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "voxel_count_vs_time_per_repair.png")
plt.show()

In [None]:
# Non-Manifold Edges vs. Time per Repair
nmes = [
    v["nmes"]["before"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]
times = [
    v["repair"]["time"] / v["nmes"]["before"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"]
]

plt.scatter(nmes, times)
plt.xlabel("Non-Manifold Edges")
plt.ylabel("Time (seconds)")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.savefig(plots_dir / "nmes_vs_time_per_repair.png")
plt.show()

In [None]:
# Mesh Time per Method vs. Repair Time
cuberille_mesh_times = [
    v["cuberille"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"] and "time" in v["cuberille"]
]
cuberille_repair_times = [
    v["repair"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"] and "time" in v["cuberille"]
]
surface_nets_mesh_times = [
    v["surface_nets"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0
    and "time" in v["repair"]
    and "time" in v["surface_nets"]
]
surface_nets_repair_times = [
    v["repair"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0
    and "time" in v["repair"]
    and "time" in v["surface_nets"]
]

plt.scatter(cuberille_mesh_times, cuberille_repair_times, label="Cuberille")
plt.scatter(surface_nets_mesh_times, surface_nets_repair_times, label="Surface Nets")
plt.xlabel("Mesh Time (seconds)")
plt.ylabel("Repair Time (seconds)")
# z = np.polyfit(cuberille_mesh_times, cuberille_repair_times, 1)
# p = np.poly1d(z)
# plt.plot(cuberille_mesh_times, p(cuberille_mesh_times), color="darkblue")
# z = np.polyfit(surface_nets_mesh_times, surface_nets_repair_times, 1)
# p = np.poly1d(z)
# plt.plot(surface_nets_mesh_times, p(surface_nets_mesh_times), color="red")
plt.legend()
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_ylim(-5, 90)
ax.set_aspect("equal", adjustable="box")
plt.savefig(plots_dir / "mesh_time_vs_repair_time.png")
plt.show()

In [None]:
# Mesh Time per Method vs. Time per Repair
cuberille_mesh_times = [
    v["cuberille"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"] and "time" in v["cuberille"]
]
cuberille_repair_times = [
    v["repair"]["time"]
    for v in results.values()
    if v["nmes"]["before"] != 0 and "time" in v["repair"] and "time" in v["cuberille"]
]
surface_nets_mesh_times = [
    v["surface_nets"]["time"] / v["nmes"]["before"]
    for v in results.values()
    if v["nmes"]["before"] != 0
    and "time" in v["repair"]
    and "time" in v["surface_nets"]
]
surface_nets_repair_times = [
    v["repair"]["time"] / v["nmes"]["before"]
    for v in results.values()
    if v["nmes"]["before"] != 0
    and "time" in v["repair"]
    and "time" in v["surface_nets"]
]

plt.scatter(cuberille_mesh_times, cuberille_repair_times, label="Cuberille")
plt.scatter(surface_nets_mesh_times, surface_nets_repair_times, label="Surface Nets")
plt.xlabel("Mesh Time (seconds)")
plt.ylabel("Repair Time (seconds)")
plt.legend()
# z = np.polyfit(cuberille_mesh_times, cuberille_repair_times, 1)
# p = np.poly1d(z)
# plt.plot(cuberille_mesh_times, p(cuberille_mesh_times), color="darkblue")
# z = np.polyfit(surface_nets_mesh_times, surface_nets_repair_times, 1)
# p = np.poly1d(z)
# plt.plot(surface_nets_mesh_times, p(surface_nets_mesh_times), color="red")
plt.grid(alpha=0.2)
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_ylim(-2, 25)
ax.set_aspect("equal", adjustable="box")
plt.savefig(plots_dir / "mesh_time_vs_time_per_repair.png")
plt.show()