# Diptera wing classification using Topological Data Analysis

Guilherme Vituri F. Pinto [](https://orcid.org/0000-0002-7813-8777) (Universidade Estadual Paulista)  
Sergio UraNorthonFebruary 24, 2026

We apply tools from Topological Data Analysis (TDA) to classify Diptera families based on wing venation patterns. Using multiple filtration strategies — Vietoris-Rips on point clouds, directional height filtrations (8 directions), radial filtrations, Euclidean Distance Transform filtrations, and grayscale sublevel-set (cubical) filtrations on wing images — we extract both H0 and H1 topological features (persistence images, Betti curves, persistence landscapes and summary statistics) and compare distance-based and feature-based classifiers via leave-one-out cross-validation. Feature selection via Random Forest importance and nested LOOCV provide honest, unbiased accuracy estimates.

In [2]:
using TDAfly, TDAfly.Preprocessing, TDAfly.TDA, TDAfly.Analysis
using Images: mosaicview, Gray
using Plots: plot, display, heatmap, scatter, bar
using StatsPlots: boxplot
using PersistenceDiagrams
using PersistenceDiagrams: BettiCurve, Landscape, PersistenceImage
using DataFrames
using Distances: euclidean
using LIBSVM
using StatsBase: mean

## Introduction

The order Diptera (true flies) comprises over 150,000 described species across more than 150 families. Wing venation patterns are a classical diagnostic character in Diptera systematics: the arrangement, branching and connectivity of veins varies markedly across families and provides a natural morphological signature.

In this work, we apply **Topological Data Analysis (TDA)** to the problem of classifying Diptera families from wing images. TDA provides a framework for extracting shape descriptors that are robust to continuous deformations — exactly the kind of invariance desirable when comparing biological structures that vary in scale, orientation and minor deformations across individuals.

We employ five complementary filtration strategies:

1.  **Vietoris-Rips filtration** on point-cloud samples of wing silhouettes
2.  **Directional height filtrations** (8 directions) that sweep across the wing along different axes
3.  **Radial filtration** from the wing centroid to the periphery
4.  **Euclidean Distance Transform (EDT) filtration** capturing vein thickness hierarchy
5.  **Grayscale sublevel-set (cubical) filtration** on the raw wing image

For each filtration, we compute both **H0** (connected components / vein branching) and **H1** (loops / enclosed cells) persistence, then vectorize into feature representations (persistence images, Betti curves, persistence landscapes) and feed into classifiers. Feature selection and nested cross-validation provide honest accuracy estimates.

## Methods

### Data loading and preprocessing

All images are in the `images/processed` directory. For each image, we load it, apply a Gaussian blur (to close small gaps in the wing membrane and keep it connected), crop to the bounding box, and resize to 150 pixels of height.

In [3]:
all_paths = readdir("images/processed", join = true)
all_filenames = basename.(all_paths) .|> (x -> replace(x, ".png" => ""))

function extract_family(name)
    family_raw = lowercase(split(name, r"[\s\-]")[1])
    if family_raw in ("bibionidae", "biobionidae")
        return "Bibionidae"
    elseif family_raw in ("sciaridae", "scaridae")
        return "Sciaridae"
    elseif family_raw == "simulidae"
        return "Simuliidae"
    else
        return titlecase(family_raw)
    end
end

function canonical_id(name)
    family = extract_family(name)
    parts = split(name, r"[\s\-]")
    number = parts[end]
    "$(family)-$(number)"
end

# Deduplicate (space vs hyphen variants of the same file)
seen = Set{String}()
keep_idx = Int[]
for (i, fname) in enumerate(all_filenames)
    cid = canonical_id(fname)
    if !(cid in seen)
        push!(seen, cid)
        push!(keep_idx, i)
    end
end

paths = all_paths[keep_idx]
species = all_filenames[keep_idx]
families = extract_family.(species)

individuals = map(species) do specie
    parts = split(specie, r"[\s\-]")
    string(extract_family(specie)[1]) * "-" * parts[end]
end

println("Total images after deduplication: $(length(paths))")
println("Families: ", sort(unique(families)))
println("\nSamples per family:")
for f in sort(unique(families))
    println("  $(f): $(count(==(f), families))")
end

Total images after deduplication: 72
Families: ["Asilidae", "Bibionidae", "Ceratopogonidae", "Chironomidae", "Pelecorhynchidae", "Rhagionidae", "Sciaridae", "Simuliidae", "Tabanidae", "Tipulidae"]

Samples per family:
  Asilidae: 8
  Bibionidae: 6
  Ceratopogonidae: 8
  Chironomidae: 8
  Pelecorhynchidae: 2
  Rhagionidae: 4
  Sciaridae: 6
  Simuliidae: 7
  Tabanidae: 11
  Tipulidae: 12

#### Excluding small families

Families with fewer than 3 samples (e.g. Pelecorhynchidae with $n=2$) can distort cross-validation results—a single misclassification changes accuracy by 50%. We provide a filtered version and run the analysis both ways.

In [4]:
MIN_FAMILY_SIZE = 3
family_counts = Dict(f => count(==(f), families) for f in unique(families))
small_families = [f for (f, c) in family_counts if c < MIN_FAMILY_SIZE]

if !isempty(small_families)
    println("Families with < $MIN_FAMILY_SIZE samples (excluded from filtered analysis):")
    for f in sort(small_families)
        println("  $(f): $(family_counts[f]) samples")
    end
end

# Build filtered indices
keep_filtered = [i for i in eachindex(families) if family_counts[families[i]] >= MIN_FAMILY_SIZE]
paths_filtered = paths[keep_filtered]
species_filtered = species[keep_filtered]
families_filtered = families[keep_filtered]
individuals_filtered = individuals[keep_filtered]

println("\nFiltered dataset: $(length(keep_filtered)) samples, $(length(unique(families_filtered))) families")

Families with < 3 samples (excluded from filtered analysis):
  Pelecorhynchidae: 2 samples

Filtered dataset: 70 samples, 9 families

In [5]:
wings = load_wing.(paths, blur = 1.3)
Xs = map(wings) do w
    image_to_r2(w; ensure_connected = true, connectivity = 8)
end;

In [6]:
mosaicview(wings, ncol = 6, fillvalue = 1)

### Example: forcing connectivity on 5 wings

The chunk below selects 5 wings (prioritizing those with the largest number of disconnected components before correction), then compares the binary pixel set before and after `connect_pixel_components`.

In [7]:
threshold_conn = 0.2
conn = 8

component_count_before = map(wings) do w
    ids0 = findall_ids(>(threshold_conn), image_to_array(w))
    length(pixel_components(ids0; connectivity = conn))
end

demo_idx = sortperm(component_count_before, rev = true)[1:min(5, length(wings))]

function ids_to_mask(ids)
    isempty(ids) && return zeros(Float32, 1, 1)
    xs = first.(ids)
    ys = last.(ids)
    M = zeros(Float32, maximum(xs), maximum(ys))
    for p in ids
        M[p[1], p[2]] = 1f0
    end
    M
end

demo_connectivity_df = DataFrame(
    sample = String[],
    n_components_before = Int[],
    n_components_after = Int[],
    n_pixels_before = Int[],
    n_pixels_after = Int[],
)

panel_plots = Any[]
for idx in demo_idx
    ids_before = findall_ids(>(threshold_conn), image_to_array(wings[idx]))
    ids_after = connect_pixel_components(ids_before; connectivity = conn)

    n_before = length(pixel_components(ids_before; connectivity = conn))
    n_after = length(pixel_components(ids_after; connectivity = conn))

    push!(demo_connectivity_df, (
        species[idx],
        n_before,
        n_after,
        length(ids_before),
        length(ids_after),
    ))

    M_before = ids_to_mask(ids_before)
    M_after = ids_to_mask(ids_after)

    p_before = heatmap(
        M_before[end:-1:1, :],
        color = :grays,
        colorbar = false,
        legend = false,
        aspect_ratio = :equal,
        xticks = false,
        yticks = false,
        title = "Before: $(species[idx])\ncomponents = $(n_before)",
    )

    p_after = heatmap(
        M_after[end:-1:1, :],
        color = :grays,
        colorbar = false,
        legend = false,
        aspect_ratio = :equal,
        xticks = false,
        yticks = false,
        title = "After: $(species[idx])\ncomponents = $(n_after)",
    )

    push!(panel_plots, p_before)
    push!(panel_plots, p_after)
end

plot(panel_plots..., layout = (length(demo_idx), 2), size = (900, 260 * length(demo_idx)))

In [8]:
demo_connectivity_df

## Topological feature extraction

We now compute persistent homology using five filtration strategies. For the Vietoris-Rips filtration on connected point clouds, H0 is uninformative (single infinite bar), so we use only H1. However, for cubical filtrations (directional, radial, EDT, grayscale), **H0 is highly informative** — it captures when disconnected vein segments merge as the filtration parameter grows, directly encoding vein count and branching patterns. We therefore compute both H0 and H1 for all cubical-based filtrations.

> **What is persistent homology?**
>
> Persistent homology is the main tool of TDA. Given a shape or dataset, it tracks how topological features — connected components (dimension 0), loops (dimension 1), voids (dimension 2), etc. — appear and disappear as we “grow” the shape through a filtration parameter. Each feature has a **birth** time (when it appears) and a **death** time (when it gets filled in). The collection of all (birth, death) pairs is called a **persistence diagram**. Features with long lifetimes (high persistence = death $-$ birth) represent genuine topological structure, while short-lived features are typically noise.

### Strategy 1: Vietoris-Rips filtration on point clouds

> **Vietoris-Rips filtration**
>
> Given a set of points in $\mathbb{R}^n$, the Vietoris-Rips complex at scale $\varepsilon$ connects any subset of points that are pairwise within distance $\varepsilon$. As $\varepsilon$ increases from 0, we obtain a nested sequence of simplicial complexes — the Rips filtration. This is the most common filtration in TDA for point-cloud data. It is computationally expensive (since it must consider all pairwise distances), which is why we subsample the point clouds.

We sample 750 points from each wing silhouette using farthest-point sampling (which ensures good coverage of the shape), then compute 1-dimensional Rips persistence:

In [9]:
samples = Vector{Any}(undef, length(Xs))
Threads.@threads for i in eachindex(Xs)
    samples[i] = farthest_points_sample(Xs[i], 750)
end

In [10]:
pds_rips = @showprogress map(samples) do s
    rips_pd_1d(s, cutoff = 5, threshold = 200)
end;

In [11]:
wing_arrays = [convert(Array{Float64}, w) for w in wings]

72-element Vector{Matrix{Float64}}:
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804]
 [1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; … ; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804; 1.0000000074505804 1.0000000074505804 … 1.0000000074505804 1.0000000074505804

### Strategy 2: Directional height filtrations

> **Directional (height) filtrations**
>
> A **height filtration** sweeps a hyperplane across the shape in a chosen direction and tracks topology as the “visible” region grows. For a direction vector $v$, we assign each foreground pixel the value $\langle (i,j), v \rangle$ (its projection onto $v$), then compute sublevel-set persistence. Different directions capture different geometric aspects: a horizontal sweep detects how vein loops are arranged from base to tip, a vertical sweep captures dorsal-ventral structure, and diagonal sweeps capture oblique patterns. Using multiple directions enriches the topological signature.

We compute persistence along **eight** directions (every 22.5°) to capture finer angular structure of vein branching, including oblique vein angles missed by 4 directions. For each direction, we extract both **H0** (connected component merging = vein branching) and **H1** (loop formation):

In [12]:
angles = range(0, π, length=9)[1:8]
directions = [[sin(θ), cos(θ)] for θ in angles]
direction_names = ["Dir_$(round(Int, rad2deg(θ)))°" for θ in angles]

println("Using $(length(directions)) directions:")
for (name, dir) in zip(direction_names, directions)
    println("  $name: $dir")
end

# H1 persistence (loops) — as before, but expanded to 8 directions
pds_directional = Dict{String, Vector}()
for (dir, name) in zip(directions, direction_names)
    pds_directional[name] = @showprogress "$name H1" map(wing_arrays) do A
        directional_pd_1d(A, dir)
    end
end

# H0 persistence (connected component merging = vein branching patterns)
pds_directional_h0 = Dict{String, Vector}()
for (dir, name) in zip(directions, direction_names)
    pds_directional_h0[name] = @showprogress "$name H0" map(wing_arrays) do A
        directional_pd_0d(A, dir)
    end
end;

Using 8 directions:
  Dir_0°: [0.0, 1.0]
  Dir_22°: [0.3826834323650898, 0.9238795325112867]
  Dir_45°: [0.7071067811865475, 0.7071067811865476]
  Dir_68°: [0.9238795325112867, 0.38268343236508984]
  Dir_90°: [1.0, 6.123233995736766e-17]
  Dir_112°: [0.9238795325112867, -0.3826834323650897]
  Dir_135°: [0.7071067811865476, -0.7071067811865475]
  Dir_158°: [0.3826834323650899, -0.9238795325112867]

### Strategy 3: Radial filtration

> **Radial filtration**
>
> The **radial filtration** assigns each foreground pixel a value equal to its distance from the centroid of the wing. Sublevel-set persistence on this function captures how topological features (loops in the venation) are distributed from the center of the wing outward. This is complementary to the directional filtrations.

In [13]:
pds_radial = @showprogress "radial_pd_1d" map(wing_arrays) do A
    radial_pd_1d(A)
end;

We also compute H0 persistence for the radial filtration, capturing how disconnected vein segments merge as the radial sweep grows outward:

In [14]:
pds_radial_h0 = @showprogress "radial_pd_0d" map(wing_arrays) do A
    radial_pd_0d(A)
end;

### Strategy 4: Euclidean Distance Transform (EDT) filtration

> **EDT filtration**
>
> The **Euclidean Distance Transform** assigns each foreground pixel the distance to the nearest background pixel. Thick veins get high EDT values. By negating the EDT as a filtration value, thick veins appear first in the sublevel-set filtration. This captures the **vein thickness hierarchy** — a diagnostic taxonomic character (e.g., Tabanidae have thickened costal and subcostal veins).

In [15]:
pds_edt_h1 = @showprogress "EDT H1" map(wing_arrays) do A
    edt_pd_1d(A)
end

pds_edt_h0 = @showprogress "EDT H0" map(wing_arrays) do A
    edt_pd_0d(A)
end;

### Strategy 5: Cubical (grayscale sublevel-set) persistence

> **Grayscale sublevel-set persistence**
>
> The function `cubical_pd` computes sublevel-set persistence directly on the grayscale wing image (inverted so that dark veins have low filtration values). This captures the intensity landscape of the wing image without any thresholding, preserving information about semi-transparent wing membrane regions and vein intensity gradients.

In [16]:
pds_cubical = @showprogress "Cubical" map(wing_arrays) do A
    cubical_pd(A; dim_max=1)
end

pds_cubical_h0 = [pd[1] for pd in pds_cubical]
pds_cubical_h1 = [pd[2] for pd in pds_cubical];

### Persistence vectorization

Raw persistence diagrams live in a space that is not directly amenable to standard machine learning. We vectorize them using three approaches:

> **Persistence images**
>
> A **persistence image** is a stable, finite-dimensional representation of a persistence diagram. Each point $(b, d)$ is mapped to $(b, d - b)$ coordinates (birth vs persistence), weighted by a function that emphasizes long-lived features, then smoothed with a Gaussian kernel and discretized onto a grid. The result is a matrix (image) that can be treated as a feature vector. Persistence images are stable with respect to the Wasserstein distance and have proven effective in machine learning pipelines.

> **Betti curves**
>
> The **Betti curve** $\beta_k(t)$ counts the number of $k$-dimensional features alive at filtration value $t$. For dimension 1, it counts the number of loops present at each scale. Discretized over a grid, it produces a feature vector. Betti curves are simple, interpretable, and capture the “topological complexity” of the shape at each scale.

> **Persistence landscapes**
>
> A **persistence landscape** is a sequence of piecewise-linear functions derived from a persistence diagram. The $k$-th landscape $\lambda_k$ is the $k$-th largest value of a collection of tent functions, one per interval. Landscapes live in a Banach space, which means we can compute means, perform hypothesis tests, and use them directly in statistical and machine learning methods. They provide a richer representation than Betti curves.

In [17]:
# Vectorize Rips persistence
PI_rips = PersistenceImage(pds_rips, size = (15, 15))
pi_rips = PI_rips.(pds_rips)

bc_rips = BettiCurve(pds_rips; length = 50)
betti_rips = bc_rips.(pds_rips)

land1_rips = Landscape(1, pds_rips; length = 50)
land2_rips = Landscape(2, pds_rips; length = 50)
land1_rips_vecs = land1_rips.(pds_rips)
land2_rips_vecs = land2_rips.(pds_rips);

In [18]:
# Vectorize directional persistence (H1)
pi_directional = Dict{String, Vector}()
betti_directional = Dict{String, Vector}()

for name in direction_names
    pds = pds_directional[name]
    PI_d = PersistenceImage(pds, size = (10, 10))
    pi_directional[name] = PI_d.(pds)

    bc_d = BettiCurve(pds; length = 30)
    betti_directional[name] = bc_d.(pds)
end

# Vectorize directional persistence (H0 — NEW)
pi_directional_h0 = Dict{String, Vector}()
betti_directional_h0 = Dict{String, Vector}()

for name in direction_names
    pds = pds_directional_h0[name]
    # Filter out infinite intervals for vectorization
    pds_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds]
    if any(!isempty, pds_finite)
        PI_d0 = PersistenceImage(pds_finite, size = (10, 10))
        pi_directional_h0[name] = PI_d0.(pds_finite)
        bc_d0 = BettiCurve(pds_finite; length = 30)
        betti_directional_h0[name] = bc_d0.(pds_finite)
    else
        pi_directional_h0[name] = [zeros(10 * 10) for _ in pds]
        betti_directional_h0[name] = [zeros(30) for _ in pds]
    end
end

# Radial H1
PI_rad = PersistenceImage(pds_radial, size = (10, 10))
pi_radial = PI_rad.(pds_radial)

bc_rad = BettiCurve(pds_radial; length = 30)
betti_radial = bc_rad.(pds_radial)

# Radial H0 (NEW)
pds_radial_h0_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds_radial_h0]
if any(!isempty, pds_radial_h0_finite)
    PI_rad_h0 = PersistenceImage(pds_radial_h0_finite, size = (10, 10))
    pi_radial_h0 = PI_rad_h0.(pds_radial_h0_finite)
    bc_rad_h0 = BettiCurve(pds_radial_h0_finite; length = 30)
    betti_radial_h0 = bc_rad_h0.(pds_radial_h0_finite)
else
    pi_radial_h0 = [zeros(10 * 10) for _ in pds_radial_h0]
    betti_radial_h0 = [zeros(30) for _ in pds_radial_h0]
end

# EDT H1 (NEW)
if any(!isempty, pds_edt_h1)
    PI_edt_h1 = PersistenceImage(pds_edt_h1, size = (10, 10))
    pi_edt_h1 = PI_edt_h1.(pds_edt_h1)
    bc_edt_h1 = BettiCurve(pds_edt_h1; length = 30)
    betti_edt_h1 = bc_edt_h1.(pds_edt_h1)
else
    pi_edt_h1 = [zeros(10 * 10) for _ in pds_edt_h1]
    betti_edt_h1 = [zeros(30) for _ in pds_edt_h1]
end

# EDT H0 (NEW)
pds_edt_h0_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds_edt_h0]
if any(!isempty, pds_edt_h0_finite)
    PI_edt_h0 = PersistenceImage(pds_edt_h0_finite, size = (10, 10))
    pi_edt_h0 = PI_edt_h0.(pds_edt_h0_finite)
    bc_edt_h0 = BettiCurve(pds_edt_h0_finite; length = 30)
    betti_edt_h0 = bc_edt_h0.(pds_edt_h0_finite)
else
    pi_edt_h0 = [zeros(10 * 10) for _ in pds_edt_h0]
    betti_edt_h0 = [zeros(30) for _ in pds_edt_h0]
end

# Cubical H0 and H1 (NEW)
if any(!isempty, pds_cubical_h1)
    PI_cub_h1 = PersistenceImage(pds_cubical_h1, size = (10, 10))
    pi_cubical_h1 = PI_cub_h1.(pds_cubical_h1)
    bc_cub_h1 = BettiCurve(pds_cubical_h1; length = 30)
    betti_cubical_h1 = bc_cub_h1.(pds_cubical_h1)
else
    pi_cubical_h1 = [zeros(10 * 10) for _ in pds_cubical_h1]
    betti_cubical_h1 = [zeros(30) for _ in pds_cubical_h1]
end

pds_cubical_h0_finite = [filter(x -> isfinite(persistence(x)), pd) for pd in pds_cubical_h0]
if any(!isempty, pds_cubical_h0_finite)
    PI_cub_h0 = PersistenceImage(pds_cubical_h0_finite, size = (10, 10))
    pi_cubical_h0 = PI_cub_h0.(pds_cubical_h0_finite)
    bc_cub_h0 = BettiCurve(pds_cubical_h0_finite; length = 30)
    betti_cubical_h0 = bc_cub_h0.(pds_cubical_h0_finite)
else
    pi_cubical_h0 = [zeros(10 * 10) for _ in pds_cubical_h0]
    betti_cubical_h0 = [zeros(30) for _ in pds_cubical_h0]
end

println("Vectorization complete.")
println("  Directional: $(length(direction_names)) directions × H0 + H1")
println("  Radial: H0 + H1")
println("  EDT: H0 + H1")
println("  Cubical: H0 + H1");

Vectorization complete.
  Directional: 8 directions × H0 + H1
  Radial: H0 + H1
  EDT: H0 + H1
  Cubical: H0 + H1

### Examples

Below are examples of 1-dimensional persistence diagrams from each filtration strategy for one specimen per family:

In [19]:
example_indices = [findfirst(==(f), families) for f in sort(unique(families))]

for i in example_indices
    pers_rips = persistence.(pds_rips[i])
    pers_dir = persistence.(pds_directional[direction_names[1]][i])
    pers_rad = persistence.(pds_radial[i])
    pers_edt = persistence.(pds_edt_h1[i])

    p1 = isempty(pers_rips) ? plot(title = "Rips H₁ (empty)") :
         bar(sort(pers_rips, rev = true), title = "Rips H₁", legend = false, ylabel = "persistence")
    p2 = isempty(pers_dir) ? plot(title = "Dir 0° H₁ (empty)") :
         bar(sort(pers_dir, rev = true), title = "Dir 0° H₁", legend = false, ylabel = "persistence")
    p3 = isempty(pers_rad) ? plot(title = "Radial H₁ (empty)") :
         bar(sort(pers_rad, rev = true), title = "Radial H₁", legend = false, ylabel = "persistence")
    p4 = isempty(pers_edt) ? plot(title = "EDT H₁ (empty)") :
         bar(sort(pers_edt, rev = true), title = "EDT H₁", legend = false, ylabel = "persistence")
    p5 = scatter(last.(samples[i]), first.(samples[i]),
                 aspect_ratio = :equal, markersize = 1, legend = false, title = "Point cloud")

    # H0 examples
    pers_dir_h0 = [persistence(x) for x in pds_directional_h0[direction_names[1]][i] if isfinite(persistence(x))]
    p6 = isempty(pers_dir_h0) ? plot(title = "Dir 0° H₀ (empty)") :
         bar(sort(pers_dir_h0, rev = true), title = "Dir 0° H₀", legend = false, ylabel = "persistence")

    p = plot(p1, p2, p3, p4, p5, p6, layout = (3, 2), size = (900, 900),
             plot_title = "$(families[i]) ($(individuals[i]))")
    display(p)
end;

### Summary statistics

We extract summary statistics from each persistence diagram:

In [20]:
stat_names = ["count", "max_pers", "total_pers", "total_pers2",
              "q10", "q25", "median", "q75", "q90", "entropy", "std_pers"]

stats_rips = collect(hcat([pd_statistics(pd) for pd in pds_rips]...)')

# Directional H1 stats
stats_directional = Dict{String, Matrix}()
for name in direction_names
    stats_directional[name] = collect(hcat([pd_statistics(pd) for pd in pds_directional[name]]...)')
end

# Directional H0 stats (NEW)
stats_directional_h0 = Dict{String, Matrix}()
for name in direction_names
    stats_directional_h0[name] = collect(hcat([pd_statistics(pd) for pd in pds_directional_h0[name]]...)')
end

stats_radial = collect(hcat([pd_statistics(pd) for pd in pds_radial]...)')
stats_radial_h0 = collect(hcat([pd_statistics(pd) for pd in pds_radial_h0]...)')   # NEW

# EDT stats (NEW)
stats_edt_h1 = collect(hcat([pd_statistics(pd) for pd in pds_edt_h1]...)')
stats_edt_h0 = collect(hcat([pd_statistics(pd) for pd in pds_edt_h0]...)')

# Cubical stats (NEW)
stats_cubical_h1 = collect(hcat([pd_statistics(pd) for pd in pds_cubical_h1]...)')
stats_cubical_h0 = collect(hcat([pd_statistics(pd) for pd in pds_cubical_h0]...)')

# Combined stats from ALL filtrations (H1 only, for backward compatibility)
stats_all_h1 = hcat(stats_rips, stats_radial,
                    [stats_directional[name] for name in direction_names]...,
                    stats_edt_h1, stats_cubical_h1)

# Combined stats from ALL filtrations (H0 + H1 = comprehensive)
stats_all = hcat(
    stats_rips,
    stats_radial, stats_radial_h0,
    [stats_directional[name] for name in direction_names]...,
    [stats_directional_h0[name] for name in direction_names]...,
    stats_edt_h1, stats_edt_h0,
    stats_cubical_h1, stats_cubical_h0,
)

println("Statistics dimensions:")
println("  Rips H1: $(size(stats_rips))")
println("  All H1 only: $(size(stats_all_h1))")
println("  All H0+H1 (comprehensive): $(size(stats_all))")

Statistics dimensions:
  Rips H1: (72, 11)
  All H1 only: (72, 132)
  All H0+H1 (comprehensive): (72, 253)

#### Statistics comparison by family

In [21]:
stats_df = DataFrame(
    sample = individuals,
    family = families,
    n_intervals_rips = stats_rips[:, 1],
    max_pers_rips = stats_rips[:, 2],
    entropy_rips = stats_rips[:, 10],
    n_intervals_rad = stats_radial[:, 1],
    max_pers_rad = stats_radial[:, 2],
    entropy_rad = stats_radial[:, 10],
    n_intervals_edt_h1 = stats_edt_h1[:, 1],
    max_pers_edt_h1 = stats_edt_h1[:, 2],
    n_intervals_cub_h1 = stats_cubical_h1[:, 1],
    max_pers_cub_h1 = stats_cubical_h1[:, 2],
)

p1 = boxplot(stats_df.family, stats_df.n_intervals_rips,
             title = "Rips: # intervals", legend = false, ylabel = "count", xrotation = 45)
p2 = boxplot(stats_df.family, stats_df.max_pers_rips,
             title = "Rips: max persistence", legend = false, ylabel = "persistence", xrotation = 45)
p3 = boxplot(stats_df.family, stats_df.n_intervals_edt_h1,
             title = "EDT H₁: # intervals", legend = false, ylabel = "count", xrotation = 45)
p4 = boxplot(stats_df.family, stats_df.max_pers_edt_h1,
             title = "EDT H₁: max persistence", legend = false, ylabel = "persistence", xrotation = 45)
plot(p1, p2, p3, p4, layout = (2, 2), size = (1000, 700))

### Decision tree on rich 1D persistence statistics

We now use the full 1-dimensional persistence statistics from all filtrations (`stats_all`) as features for a single decision tree classifier.

In [22]:
using DecisionTree
using Random: MersenneTwister

labels_tree = families
X_tree = sanitize_feature_matrix(stats_all)

tree_blocks_h1 = ["Rips_H1", "Radial_H1", [name * "_H1" for name in direction_names]..., "EDT_H1", "Cubical_H1"]
tree_blocks_h0 = ["Radial_H0", [name * "_H0" for name in direction_names]..., "EDT_H0", "Cubical_H0"]
tree_blocks = [tree_blocks_h1; tree_blocks_h0]
tree_feature_names = ["$(block)__$(stat)" for block in tree_blocks for stat in stat_names]

function loocv_decision_tree(X::Matrix, y::Vector{String};
                             max_depth::Int = 6,
                             min_samples_leaf::Int = 2,
                             min_samples_split::Int = 2,
                             rng_seed::Int = 20260223)
    Xclean = sanitize_feature_matrix(X)
    n = size(Xclean, 1)
    predictions = Vector{String}(undef, n)

    for i in 1:n
        train_idx = setdiff(1:n, i)
        X_train = Xclean[train_idx, :]
        y_train = y[train_idx]

        tree = DecisionTree.build_tree(
            y_train,
            X_train,
            size(X_train, 2),
            max_depth,
            min_samples_leaf,
            min_samples_split,
            0.0;
            loss = DecisionTree.util.gini,
            rng = MersenneTwister(rng_seed + i),
            impurity_importance = true
        )

        predictions[i] = DecisionTree.apply_tree(tree, Xclean[i, :])
    end

    (accuracy = mean(predictions .== y), predictions = predictions)
end

tree_results = DataFrame(
    max_depth = Int[],
    min_samples_leaf = Int[],
    min_samples_split = Int[],
    n_correct = Int[],
    accuracy = Float64[],
    balanced_accuracy = Float64[],
    macro_f1 = Float64[],
)

for max_depth in [3, 4, 5, 6, 8]
    for min_leaf in [1, 2, 3]
        for min_split in [2, 4]
            r = loocv_decision_tree(X_tree, labels_tree;
                                    max_depth = max_depth,
                                    min_samples_leaf = min_leaf,
                                    min_samples_split = min_split)
            m = classification_metrics(labels_tree, r.predictions)
            push!(tree_results, (
                max_depth,
                min_leaf,
                min_split,
                sum(r.predictions .== labels_tree),
                r.accuracy,
                m.balanced_accuracy,
                m.macro_f1
            ))
        end
    end
end

sort!(tree_results, :accuracy, rev = true)
first(tree_results, 10)

In [23]:
best_tree = tree_results[1, :]

tree_model = DecisionTree.build_tree(
    labels_tree,
    X_tree,
    size(X_tree, 2),
    best_tree.max_depth,
    best_tree.min_samples_leaf,
    best_tree.min_samples_split,
    0.0;
    loss = DecisionTree.util.gini,
    rng = MersenneTwister(20260223),
    impurity_importance = true
)

tree_importance = DecisionTree.impurity_importance(tree_model; normalize = true)

tree_importance_df = DataFrame(
    feature = tree_feature_names,
    importance = tree_importance
)
sort!(tree_importance_df, :importance, rev = true)

println("Best Decision Tree LOOCV: $(best_tree.n_correct)/$(length(labels_tree)) ($(round(best_tree.accuracy * 100, digits = 1))%)")
println("Balanced accuracy: $(round(best_tree.balanced_accuracy * 100, digits = 1))%")
println("Macro-F1: $(round(best_tree.macro_f1 * 100, digits = 1))%")

first(filter(:importance => >(0.0), tree_importance_df), 15)

Best Decision Tree LOOCV: 38/72 (52.8%)
Balanced accuracy: 46.0%
Macro-F1: 45.6%

In [24]:
topk = min(12, nrow(tree_importance_df))
top_tree_importance = first(tree_importance_df, topk)

bar(
    top_tree_importance.feature,
    top_tree_importance.importance,
    xlabel = "1D persistence statistics",
    ylabel = "Normalized impurity importance",
    title = "Decision tree feature importance (top $(topk))",
    legend = false,
    xrotation = 45,
    size = (1100, 550),
)

## Distance matrices

> **Distances between persistence diagrams**
>
> The **Wasserstein distance** $W_q$ between two persistence diagrams is the cost of the optimal matching between their points (including matching points to the diagonal, which represents trivial features). With $q=1$ it equals the Earth Mover’s Distance; with $q=2$ it penalizes large mismatches more. The **Bottleneck distance** $d_B$ is the $\ell^\infty$ version: it measures the worst single mismatch in the optimal pairing. These distances are metrics on the space of persistence diagrams and are stable with respect to perturbations of the input data.

We compute multiple distance metrics between the persistence diagrams from each filtration:

In [25]:
labels = families

# Rips-based distances (Rips PDs have ~20 intervals, so Wasserstein/Bottleneck are feasible)
D_pi_rips = pairwise_distance([vec(v) for v in pi_rips])
D_betti_rips = pairwise_distance(betti_rips, euclidean)
D_land_rips = pairwise_distance(land1_rips_vecs, euclidean)
D_wass1_rips = wasserstein_distance_matrix(pds_rips, q = 1)
D_wass2_rips = wasserstein_distance_matrix(pds_rips, q = 2)   # Phase 5: Wasserstein-2
D_bott_rips = bottleneck_distance_matrix(pds_rips)

# Directional/radial PDs have hundreds of intervals, so
# Wasserstein/Bottleneck would be prohibitively slow on 72×72 pairs.
# We use only vectorized distances (persistence images, Betti curves, landscapes),
# which are Euclidean distances on fixed-size feature vectors and compute instantly.

# Directional H1 distances (combine all directions via sum of per-direction distances)
D_pi_dir = sum(pairwise_distance([vec(v) for v in pi_directional[name]]) for name in direction_names)
D_betti_dir = sum(pairwise_distance(betti_directional[name], euclidean) for name in direction_names)

# Directional H0 distances (NEW)
D_pi_dir_h0 = sum(pairwise_distance([vec(v) for v in pi_directional_h0[name]]) for name in direction_names)
D_betti_dir_h0 = sum(pairwise_distance(betti_directional_h0[name], euclidean) for name in direction_names)

# Radial distances
D_pi_rad = pairwise_distance([vec(v) for v in pi_radial])
D_betti_rad = pairwise_distance(betti_radial, euclidean)

# Radial H0 distances (NEW)
D_pi_rad_h0 = pairwise_distance([vec(v) for v in pi_radial_h0])
D_betti_rad_h0 = pairwise_distance(betti_radial_h0, euclidean)

# EDT distances (NEW)
D_pi_edt_h1 = pairwise_distance([vec(v) for v in pi_edt_h1])
D_betti_edt_h1 = pairwise_distance(betti_edt_h1, euclidean)
D_pi_edt_h0 = pairwise_distance([vec(v) for v in pi_edt_h0])
D_betti_edt_h0 = pairwise_distance(betti_edt_h0, euclidean)

# Cubical distances (NEW)
D_pi_cub_h1 = pairwise_distance([vec(v) for v in pi_cubical_h1])
D_betti_cub_h1 = pairwise_distance(betti_cubical_h1, euclidean)
D_pi_cub_h0 = pairwise_distance([vec(v) for v in pi_cubical_h0])
D_betti_cub_h0 = pairwise_distance(betti_cubical_h0, euclidean)

distances = Dict(
    # Rips
    "Rips PI" => D_pi_rips,
    "Rips Bottleneck" => D_bott_rips,
    "Rips Wass-1" => D_wass1_rips,
    "Rips Wass-2" => D_wass2_rips,
    "Rips Betti" => D_betti_rips,
    "Rips Landscape" => D_land_rips,
    # Directional H1
    "Directional H1 PI" => D_pi_dir,
    "Directional H1 Betti" => D_betti_dir,
    # Directional H0 (NEW)
    "Directional H0 PI" => D_pi_dir_h0,
    "Directional H0 Betti" => D_betti_dir_h0,
    # Radial
    "Radial H1 PI" => D_pi_rad,
    "Radial H1 Betti" => D_betti_rad,
    "Radial H0 PI" => D_pi_rad_h0,
    "Radial H0 Betti" => D_betti_rad_h0,
    # EDT (NEW)
    "EDT H1 PI" => D_pi_edt_h1,
    "EDT H1 Betti" => D_betti_edt_h1,
    "EDT H0 PI" => D_pi_edt_h0,
    "EDT H0 Betti" => D_betti_edt_h0,
    # Cubical (NEW)
    "Cubical H1 PI" => D_pi_cub_h1,
    "Cubical H1 Betti" => D_betti_cub_h1,
    "Cubical H0 PI" => D_pi_cub_h0,
    "Cubical H0 Betti" => D_betti_cub_h0,
);

In [26]:
p1 = plot_heatmap(D_wass1_rips, individuals, "Rips Wasserstein-1")
p2 = plot_heatmap(D_pi_dir_h0, individuals, "Directional H0 PI")
p3 = plot_heatmap(D_pi_edt_h1, individuals, "EDT H1 PI")
p4 = plot_heatmap(D_pi_cub_h1, individuals, "Cubical H1 PI")
plot(p1, p2, p3, p4, layout = (2, 2), size = (1000, 900))

## Classification

> **Leave-one-out cross-validation (LOOCV)**
>
> With only 72 samples, we use **leave-one-out cross-validation**: for each sample, the classifier is trained on all other samples and tested on the held-out one. The accuracy is the fraction of correctly predicted labels across all 72 folds. LOOCV has low bias (nearly the entire dataset is used for training) and is the standard validation strategy for small datasets.

### Distance-based classifiers: k-NN

> **k-Nearest Neighbors (k-NN)**
>
> Given a precomputed distance matrix, **k-NN** classifies a query point by majority vote among its $k$ nearest neighbors. **Weighted k-NN** weights each neighbor’s vote by $1/d$ (inverse distance), giving closer neighbors more influence. The **nearest centroid** classifier assigns the query to the class whose average distance to the query is smallest. These are nonparametric methods that work directly with any distance or dissimilarity measure — making them natural for TDA, where we have principled distances between persistence diagrams.

In [27]:
knn_results = []
for (dist_name, D) in distances
    for k in [1, 3, 5]
        r = loocv_knn(D, labels; k = k)
        push!(knn_results, (
            method = "k-NN (k=$k)",
            distance = dist_name,
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy
        ))

        r2 = loocv_knn_weighted(D, labels; k = k)
        push!(knn_results, (
            method = "W-kNN (k=$k)",
            distance = dist_name,
            n_correct = sum(r2.predictions .== labels),
            n_total = length(labels),
            accuracy = r2.accuracy
        ))
    end

    r3 = loocv_nearest_centroid(D, labels)
    push!(knn_results, (
        method = "Nearest centroid",
        distance = dist_name,
        n_correct = sum(r3.predictions .== labels),
        n_total = length(labels),
        accuracy = r3.accuracy
    ))
end

knn_df = DataFrame(knn_results)
sort!(knn_df, :accuracy, rev = true)
first(knn_df, 20)

### Feature-based classifiers

We construct feature matrices by concatenating the vectorized TDA representations from all filtrations:

In [28]:
# Feature matrices at different levels of richness
X_stats_rips = sanitize_feature_matrix(stats_rips)
X_stats_all_h1 = sanitize_feature_matrix(stats_all_h1)
X_stats_all = sanitize_feature_matrix(stats_all)

X_rips_full = build_feature_matrix(
    stats = stats_rips,
    pi = pi_rips,
    betti = betti_rips,
    landscape = land1_rips_vecs,
) |> sanitize_feature_matrix

# Multi-filtration features (H1 only): combine everything
all_pi_h1 = [vcat(vec(pi_rips[i]),
                  vec(pi_radial[i]),
                  [vec(pi_directional[name][i]) for name in direction_names]...,
                  vec(pi_edt_h1[i]),
                  vec(pi_cubical_h1[i]))
             for i in 1:length(families)]

all_betti_h1 = [vcat(betti_rips[i],
                     betti_radial[i],
                     [betti_directional[name][i] for name in direction_names]...,
                     betti_edt_h1[i],
                     betti_cubical_h1[i])
                for i in 1:length(families)]

X_multi_h1 = build_feature_matrix(
    stats = stats_all_h1,
    pi = all_pi_h1,
    betti = all_betti_h1,
) |> sanitize_feature_matrix

# Multi-filtration features (H0 + H1 = comprehensive)
all_pi = [vcat(vec(pi_rips[i]),
               vec(pi_radial[i]), vec(pi_radial_h0[i]),
               [vec(pi_directional[name][i]) for name in direction_names]...,
               [vec(pi_directional_h0[name][i]) for name in direction_names]...,
               vec(pi_edt_h1[i]), vec(pi_edt_h0[i]),
               vec(pi_cubical_h1[i]), vec(pi_cubical_h0[i]))
          for i in 1:length(families)]

all_betti = [vcat(betti_rips[i],
                  betti_radial[i], betti_radial_h0[i],
                  [betti_directional[name][i] for name in direction_names]...,
                  [betti_directional_h0[name][i] for name in direction_names]...,
                  betti_edt_h1[i], betti_edt_h0[i],
                  betti_cubical_h1[i], betti_cubical_h0[i])
             for i in 1:length(families)]

X_multi = build_feature_matrix(
    stats = stats_all,
    pi = all_pi,
    betti = all_betti,
) |> sanitize_feature_matrix

println("Feature dimensions:")
println("  Rips stats only: $(size(X_stats_rips))")
println("  All H1 stats: $(size(X_stats_all_h1))")
println("  All H0+H1 stats: $(size(X_stats_all))")
println("  Rips full (stats+PI+Betti+Land): $(size(X_rips_full))")
println("  Multi-filtration H1 only: $(size(X_multi_h1))")
println("  Multi-filtration H0+H1 (comprehensive): $(size(X_multi))")

Feature dimensions:
  Rips stats only: (72, 11)
  All H1 stats: (72, 132)
  All H0+H1 stats: (72, 253)
  Rips full (stats+PI+Betti+Land): (72, 336)
  Multi-filtration H1 only: (72, 1837)
  Multi-filtration H0+H1 (comprehensive): (72, 3388)

#### SVM (Support Vector Machine)

> **Support Vector Machine (SVM)**
>
> An **SVM** finds the hyperplane that maximizes the margin between classes. The **RBF (Radial Basis Function) kernel** maps data into a high-dimensional space where linear separation becomes possible, controlled by a cost parameter $C$ (penalty for misclassification). For distance matrices, we convert distances to an RBF-like kernel $K(i,j) = \exp(-D_{ij}^2 / 2\sigma^2)$ and train a linear SVM on the resulting kernel matrix. This is sometimes called an “empirical kernel map.”

In [29]:
feature_sets = [
    ("Rips stats", X_stats_rips),
    ("All H1 stats", X_stats_all_h1),
    ("All H0+H1 stats", X_stats_all),
    ("Rips full", X_rips_full),
    ("Multi-filtration H1", X_multi_h1),
    ("Multi-filtration H0+H1", X_multi),
]

svm_results = []
for (feat_name, X) in feature_sets
    for kernel in [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear]
        for cost in [0.1, 1.0, 10.0, 100.0]
            kernel_name = kernel == LIBSVM.Kernel.RadialBasis ? "RBF" : "Linear"
            r = loocv_svm(X, labels; kernel = kernel, cost = cost)
            push!(svm_results, (
                method = "SVM ($kernel_name, C=$cost)",
                features = feat_name,
                n_correct = sum(r.predictions .== labels),
                n_total = length(labels),
                accuracy = r.accuracy
            ))
        end
    end
end

svm_df = DataFrame(svm_results)
sort!(svm_df, :accuracy, rev = true)
first(svm_df, 15)

#### SVM on distance matrices

In [30]:
svm_dist_results = []
for (dist_name, D) in distances
    for cost in [0.1, 1.0, 10.0, 100.0]
        r = loocv_svm_distance(D, labels; cost = cost)
        push!(svm_dist_results, (
            method = "SVM-dist (C=$cost)",
            distance = dist_name,
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy
        ))
    end
end

svm_dist_df = DataFrame(svm_dist_results)
sort!(svm_dist_df, :accuracy, rev = true)
first(svm_dist_df, 10)

#### LDA (Linear Discriminant Analysis)

> **Linear Discriminant Analysis (LDA)**
>
> **LDA** finds a linear projection of the feature space that maximizes the ratio of between-class variance to within-class variance. The projected data is then classified with a simple 1-NN rule. LDA is a classical method that works well when classes are approximately Gaussian and the number of features is not too large relative to the number of samples. It provides an interpretable low-dimensional embedding.

In [31]:
lda_results = []
for (feat_name, X) in feature_sets
    r = loocv_lda(X, labels)
    push!(lda_results, (
        method = "LDA",
        features = feat_name,
        n_correct = sum(r.predictions .== labels),
        n_total = length(labels),
        accuracy = r.accuracy
    ))
end

lda_df = DataFrame(lda_results)
sort!(lda_df, :accuracy, rev = true)
lda_df

#### Random Forest

> **Random Forest**
>
> A **Random Forest** is an ensemble of decision trees, each trained on a bootstrap sample of the data using a random subset of features. The final prediction is the majority vote across all trees. Random Forests are robust to overfitting, handle high-dimensional features well, and provide built-in feature importance estimates. They are a strong baseline for tabular data classification tasks.

In [32]:
rf_results = []
for (feat_name, X) in feature_sets
    for n_trees in [100, 500]
        r = loocv_random_forest(X, labels; n_trees = n_trees)
        m = classification_metrics(labels, r.predictions)
        push!(rf_results, (
            method = "RF (T=$n_trees)",
            features = feat_name,
            n_correct = sum(r.predictions .== labels),
            n_total = length(labels),
            accuracy = r.accuracy,
            balanced_accuracy = m.balanced_accuracy,
            macro_f1 = m.macro_f1
        ))

        rb = loocv_random_forest_balanced(X, labels; n_trees = n_trees, rng_seed = 20260223)
        mb = classification_metrics(labels, rb.predictions)
        push!(rf_results, (
            method = "Balanced RF (T=$n_trees)",
            features = feat_name,
            n_correct = sum(rb.predictions .== labels),
            n_total = length(labels),
            accuracy = rb.accuracy,
            balanced_accuracy = mb.balanced_accuracy,
            macro_f1 = mb.macro_f1
        ))
    end
end

rf_df = DataFrame(rf_results)
sort!(rf_df, :accuracy, rev = true)
first(rf_df, 12)

## Combined distance analysis

We combine the best topology-aware distance with a statistics-based distance using convex combinations: $$D_{\text{combined}}(\alpha) = \alpha \cdot D_1^* + (1 - \alpha) \cdot D_2^*$$ where $D_1^*$ and $D_2^*$ are distances normalized to $[0, 1]$.

In [33]:
stats_for_distance = zscore_normalize(sanitize_feature_matrix(stats_all))
stats_vectors_norm = [stats_for_distance[i, :] for i in axes(stats_for_distance, 1)]
D_stats = pairwise_distance(stats_vectors_norm, euclidean)

# Try combining best Rips distances with stats distance
grid_rips_w1 = combined_distance_grid_search(D_wass1_rips, D_stats, labels)
grid_rips_w2 = combined_distance_grid_search(D_wass2_rips, D_stats, labels)

println("Top 5 combinations (Rips Wass-1 + Stats):")
for r in grid_rips_w1[1:min(5, end)]
    println("  α=$(round(r.alpha, digits=1)), k=$(r.k): $(r.n_correct)/$(length(labels)) ($(round(r.accuracy * 100, digits=1))%)")
end

println("\nTop 5 combinations (Rips Wass-2 + Stats):")
for r in grid_rips_w2[1:min(5, end)]
    println("  α=$(round(r.alpha, digits=1)), k=$(r.k): $(r.n_correct)/$(length(labels)) ($(round(r.accuracy * 100, digits=1))%)")
end

Top 5 combinations (Rips Wass-1 + Stats):
  α=0.6, k=1: 54/72 (75.0%)
  α=0.7, k=1: 54/72 (75.0%)
  α=0.8, k=1: 54/72 (75.0%)
  α=0.3, k=1: 53/72 (73.6%)
  α=0.5, k=1: 52/72 (72.2%)

Top 5 combinations (Rips Wass-2 + Stats):
  α=0.6, k=3: 53/72 (73.6%)
  α=0.7, k=5: 53/72 (73.6%)
  α=0.5, k=1: 51/72 (70.8%)
  α=0.6, k=1: 51/72 (70.8%)
  α=0.8, k=5: 51/72 (70.8%)

In [34]:
# Visualize the grid search
alphas = 0.0:0.1:1.0
ks = [1, 3, 5]

acc_grid_w1 = zeros(length(alphas), length(ks))
for r in grid_rips_w1
    i = findfirst(==(r.alpha), alphas)
    j = findfirst(==(r.k), ks)
    if !isnothing(i) && !isnothing(j)
        acc_grid_w1[i, j] = r.accuracy
    end
end

acc_grid_w2 = zeros(length(alphas), length(ks))
for r in grid_rips_w2
    i = findfirst(==(r.alpha), alphas)
    j = findfirst(==(r.k), ks)
    if !isnothing(i) && !isnothing(j)
        acc_grid_w2[i, j] = r.accuracy
    end
end

p1 = heatmap(string.(ks), string.(collect(alphas)),
        acc_grid_w1,
        xlabel = "k", ylabel = "α (Rips Wass-1 weight)",
        title = "Rips Wass-1 + Stats",
        color = :Blues, clims = (0.3, 1.0))
p2 = heatmap(string.(ks), string.(collect(alphas)),
        acc_grid_w2,
        xlabel = "k", ylabel = "α (Rips Wass-2 weight)",
        title = "Rips Wass-2 + Stats",
        color = :Blues, clims = (0.3, 1.0))
plot(p1, p2, layout = (1, 2), size = (1000, 450))

## Ensemble classification

> **Ensemble methods (majority voting)**
>
> **Ensemble methods** combine predictions from multiple classifiers. In **majority voting**, each classifier casts a vote for its predicted class, and the class with the most votes wins. In **weighted voting**, each classifier’s vote is weighted by its individual accuracy, giving more influence to better classifiers. Ensembles are more robust than individual classifiers because different methods tend to make different errors.

We combine the best classifiers from each method family:

In [35]:
# Best distance-based k-NN
best_knn_row = knn_df[1, :]
D_best_knn = distances[best_knn_row.distance]
k_best = parse(Int, match(r"k=(\d)", best_knn_row.method)[1])
knn_best = loocv_knn(D_best_knn, labels; k = k_best)

# Best SVM on features
best_svm_row = svm_df[1, :]
best_svm_X = Dict(feat_name => X for (feat_name, X) in feature_sets)[best_svm_row.features]
best_svm_kernel = occursin("RBF", best_svm_row.method) ? LIBSVM.Kernel.RadialBasis : LIBSVM.Kernel.Linear
best_svm_cost = parse(Float64, match(r"C=([\d.]+)", best_svm_row.method)[1])
svm_best = loocv_svm(best_svm_X, labels; kernel = best_svm_kernel, cost = best_svm_cost)

# Best Random Forest
best_rf_row = rf_df[1, :]
best_rf_X = Dict(feat_name => X for (feat_name, X) in feature_sets)[best_rf_row.features]
best_rf_ntrees = parse(Int, match(r"T=(\d+)", best_rf_row.method)[1])
best_rf_balanced = occursin("Balanced RF", best_rf_row.method)
if best_rf_balanced
    rf_best = loocv_random_forest_balanced(best_rf_X, labels; n_trees = best_rf_ntrees, rng_seed = 20260223)
else
    rf_best = loocv_random_forest(best_rf_X, labels; n_trees = best_rf_ntrees)
end

# Best LDA
best_lda_row = lda_df[1, :]
best_lda_X = Dict(feat_name => X for (feat_name, X) in feature_sets)[best_lda_row.features]
lda_best = loocv_lda(best_lda_X, labels)

# Ensemble: majority vote
predictions_list = [knn_best.predictions, svm_best.predictions, rf_best.predictions, lda_best.predictions]
accuracies = [knn_best.accuracy, svm_best.accuracy, rf_best.accuracy, lda_best.accuracy]

ensemble_preds = ensemble_vote(predictions_list)
ensemble_acc = mean(ensemble_preds .== labels)

ensemble_preds_w = ensemble_vote(predictions_list; weights = accuracies)
ensemble_acc_w = mean(ensemble_preds_w .== labels)

println("=== Ensemble Results ===")
println("Individual classifiers:")
println("  k-NN ($(best_knn_row.distance), k=$k_best): $(round(knn_best.accuracy * 100, digits=1))%")
println("  SVM ($(best_svm_row.method), $(best_svm_row.features)): $(round(svm_best.accuracy * 100, digits=1))%")
println("  RF ($(best_rf_row.method), $(best_rf_row.features)): $(round(rf_best.accuracy * 100, digits=1))%")
println("  LDA ($(best_lda_row.features)): $(round(lda_best.accuracy * 100, digits=1))%")
println()
println("Ensemble (majority vote): $(sum(ensemble_preds .== labels))/$(length(labels)) ($(round(ensemble_acc * 100, digits=1))%)")
println("Ensemble (weighted vote): $(sum(ensemble_preds_w .== labels))/$(length(labels)) ($(round(ensemble_acc_w * 100, digits=1))%)")

=== Ensemble Results ===
Individual classifiers:
  k-NN (Cubical H1 Betti, k=3): 66.7%
  SVM (SVM (Linear, C=0.1), Multi-filtration H1): 83.3%
  RF (Balanced RF (T=100), All H0+H1 stats): 80.6%
  LDA (Multi-filtration H0+H1): 86.1%

Ensemble (majority vote): 59/72 (81.9%)
Ensemble (weighted vote): 62/72 (86.1%)

## Comprehensive comparison

In [36]:
all_results = []

# Distance-based (top 5)
for row in eachrow(first(knn_df, 5))
    push!(all_results, (
        category = "Distance-based",
        method = "$(row.method) [$(row.distance)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# SVM on distances (top 3)
for row in eachrow(first(svm_dist_df, 3))
    push!(all_results, (
        category = "Distance-based",
        method = "$(row.method) [$(row.distance)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# LDA
for row in eachrow(lda_df)
    push!(all_results, (
        category = "Feature-based",
        method = "LDA [$(row.features)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# SVM on features (top 5)
for row in eachrow(first(svm_df, 5))
    push!(all_results, (
        category = "Feature-based",
        method = "$(row.method) [$(row.features)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# Random Forest (top 3)
for row in eachrow(first(rf_df, 3))
    push!(all_results, (
        category = "Feature-based",
        method = "$(row.method) [$(row.features)]",
        accuracy = row.accuracy,
        n_correct = row.n_correct,
        n_total = row.n_total
    ))
end

# Ensembles
push!(all_results, (category = "Ensemble", method = "Majority vote (4 classifiers)",
    accuracy = ensemble_acc, n_correct = sum(ensemble_preds .== labels), n_total = length(labels)))
push!(all_results, (category = "Ensemble", method = "Weighted vote (4 classifiers)",
    accuracy = ensemble_acc_w, n_correct = sum(ensemble_preds_w .== labels), n_total = length(labels)))

# Combined distances
best_rips_comb_w1 = grid_rips_w1[1]
push!(all_results, (category = "Combined distance",
    method = "Rips Wass-1 + Stats (α=$(round(best_rips_comb_w1.alpha, digits=1)), k=$(best_rips_comb_w1.k))",
    accuracy = best_rips_comb_w1.accuracy, n_correct = best_rips_comb_w1.n_correct, n_total = length(labels)))

best_rips_comb_w2 = grid_rips_w2[1]
push!(all_results, (category = "Combined distance",
    method = "Rips Wass-2 + Stats (α=$(round(best_rips_comb_w2.alpha, digits=1)), k=$(best_rips_comb_w2.k))",
    accuracy = best_rips_comb_w2.accuracy, n_correct = best_rips_comb_w2.n_correct, n_total = length(labels)))

comparison_df = DataFrame(all_results)
sort!(comparison_df, :accuracy, rev = true)
comparison_df

## Best classifier evaluation

In [37]:
best_overall = comparison_df[1, :]
println("=== Best Method ===")
println("$(best_overall.category): $(best_overall.method)")
println("Accuracy: $(best_overall.n_correct)/$(best_overall.n_total) ($(round(best_overall.accuracy * 100, digits=1))%)")

ci = wilson_ci(best_overall.n_correct, best_overall.n_total)
println("95% Wilson CI: [$(round(ci.lower * 100, digits=1))%, $(round(ci.upper * 100, digits=1))%]")

=== Best Method ===
Feature-based: LDA [Multi-filtration H0+H1]
Accuracy: 62/72 (86.1%)
95% Wilson CI: [76.3%, 92.3%]

### Confusion matrix

In [38]:
# Use ensemble predictions for confusion matrix
final_preds = ensemble_acc_w >= ensemble_acc ? ensemble_preds_w : ensemble_preds
final_method = ensemble_acc_w >= ensemble_acc ? "Weighted ensemble" : "Majority ensemble"

cm_result = confusion_matrix(labels, final_preds)
classes = cm_result.classes

println("=== Confusion Matrix ($final_method) ===")
println("Per-class accuracy:")
for (i, cls) in enumerate(classes)
    correct = cm_result.matrix[i, i]
    total = sum(cm_result.matrix[i, :])
    println("  $(cls): $(correct)/$(total) ($(round(correct / total * 100, digits=1))%)")
end

=== Confusion Matrix (Weighted ensemble) ===
Per-class accuracy:
  Asilidae: 7/8 (87.5%)
  Bibionidae: 6/6 (100.0%)
  Ceratopogonidae: 8/8 (100.0%)
  Chironomidae: 6/8 (75.0%)
  Pelecorhynchidae: 0/2 (0.0%)
  Rhagionidae: 1/4 (25.0%)
  Sciaridae: 5/6 (83.3%)
  Simuliidae: 7/7 (100.0%)
  Tabanidae: 11/11 (100.0%)
  Tipulidae: 11/12 (91.7%)

In [39]:
heatmap(cm_result.matrix,
        xticks = (1:length(classes), classes),
        yticks = (1:length(classes), classes),
        xlabel = "Predicted", ylabel = "True",
        title = "Confusion Matrix ($final_method)",
        color = :Blues,
        clims = (0, maximum(cm_result.matrix)),
        xrotation = 45, size = (700, 600))

## Honest evaluation (Nested LOOCV)

The distance-combination nested result is unstable for this dataset. Instead, we perform an **honest nested LOOCV** for the strongest family of models (Random Forest on statistics): the outer loop holds out one sample and the inner loop tunes RF hyperparameters using only the training fold.

> **Nested cross-validation**
>
> Standard LOOCV can give optimistically biased estimates when hyperparameters are tuned on the same data. **Nested LOOCV** adds an inner cross-validation loop: for each held-out test sample, the best hyperparameters are selected using only the training fold. This provides an unbiased estimate of generalization performance.

In [40]:
nested_rf = nested_loocv_random_forest(
    X_stats_all, labels;
    n_trees_grid = [200, 500],
    max_depth_grid = [-1],
    min_samples_leaf_grid = [1, 2],
    inner_folds = 4,
    balanced = true,
    rng_seed = 20260223
)
n_correct_nested = sum(nested_rf.predictions .== labels)

println("=== Nested LOOCV Result ===")
println("Model: Balanced Random Forest (All stats)")
println("Accuracy: $(n_correct_nested)/$(length(labels)) ($(round(nested_rf.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_rf.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_rf.macro_f1 * 100, digits=1))%")

ci_nested = wilson_ci(n_correct_nested, length(labels))
println("95% Wilson CI: [$(round(ci_nested.lower * 100, digits=1))%, $(round(ci_nested.upper * 100, digits=1))%]")

=== Nested LOOCV Result ===
Model: Balanced Random Forest (All stats)
Accuracy: 56/72 (77.8%)
Balanced accuracy: 76.2%
Macro-F1: 76.2%
95% Wilson CI: [66.9%, 85.8%]

In [41]:
cm_nested = confusion_matrix(labels, nested_rf.predictions)
classes_nested = cm_nested.classes

println("Per-class accuracy (Nested LOOCV):")
for (i, cls) in enumerate(classes_nested)
    correct = cm_nested.matrix[i, i]
    total = sum(cm_nested.matrix[i, :])
    println("  $(cls): $(correct)/$(total) ($(round(correct / total * 100, digits=1))%)")
end

Per-class accuracy (Nested LOOCV):
  Asilidae: 6/8 (75.0%)
  Bibionidae: 6/6 (100.0%)
  Ceratopogonidae: 6/8 (75.0%)
  Chironomidae: 5/8 (62.5%)
  Pelecorhynchidae: 1/2 (50.0%)
  Rhagionidae: 2/4 (50.0%)
  Sciaridae: 6/6 (100.0%)
  Simuliidae: 7/7 (100.0%)
  Tabanidae: 10/11 (90.9%)
  Tipulidae: 7/12 (58.3%)

In [42]:
heatmap(cm_nested.matrix,
        xticks = (1:length(classes_nested), classes_nested),
        yticks = (1:length(classes_nested), classes_nested),
        xlabel = "Predicted", ylabel = "True",
        title = "Confusion Matrix (Nested LOOCV - Balanced RF)",
        color = :Blues,
        clims = (0, maximum(cm_nested.matrix)),
        xrotation = 45, size = (700, 600))

### Nested LOOCV for Multi-filtration SVM

> **Why is nested CV needed here?**
>
> The Multi-filtration feature matrix `X_multi` has ~991 features for only 72 samples (a ~14:1 feature-to-sample ratio). In such high-dimensional settings, SVM with RBF kernel can find separating hyperplanes even for random data. Furthermore, selecting the best kernel and cost parameter from many LOOCV runs introduces **selection bias**: the reported accuracy of the “best” configuration is upward-biased. Nested LOOCV removes this bias by selecting hyperparameters using only the training fold.

We evaluate the Multi-filtration SVM both with and without PCA dimensionality reduction:

In [43]:
# Nested LOOCV for Multi-filtration SVM (no PCA)
nested_svm_multi = nested_loocv_svm(
    X_multi, labels;
    kernels = [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear],
    costs = [0.1, 1.0, 10.0, 100.0],
    use_pca = false,
    inner_folds = 5,
    rng_seed = 20260223
)

println("=== Nested LOOCV: Multi-filtration SVM (no PCA) ===")
n_corr = sum(nested_svm_multi.predictions .== labels)
println("Accuracy: $(n_corr)/$(length(labels)) ($(round(nested_svm_multi.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_svm_multi.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_svm_multi.macro_f1 * 100, digits=1))%")

ci_svm = wilson_ci(n_corr, length(labels))
println("95% Wilson CI: [$(round(ci_svm.lower * 100, digits=1))%, $(round(ci_svm.upper * 100, digits=1))%]")

# Show which hyperparameters were selected in each fold
param_counts = Dict{String, Int}()
for p in nested_svm_multi.params
    key = "$(p.kernel), C=$(p.cost)"
    param_counts[key] = get(param_counts, key, 0) + 1
end
println("\nSelected hyperparameters across folds:")
for (k, v) in sort(collect(param_counts), by=last, rev=true)
    println("  $k: $v/$(length(labels)) folds")
end

=== Nested LOOCV: Multi-filtration SVM (no PCA) ===
Accuracy: 56/72 (77.8%)
Balanced accuracy: 70.8%
Macro-F1: 70.1%
95% Wilson CI: [66.9%, 85.8%]

Selected hyperparameters across folds:
  Linear, C=0.1: 71/72 folds
  RBF, C=10.0: 1/72 folds

In [44]:
# Nested LOOCV for Multi-filtration SVM with PCA (95% variance)
nested_svm_pca = nested_loocv_svm(
    X_multi, labels;
    kernels = [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear],
    costs = [0.1, 1.0, 10.0, 100.0],
    use_pca = true,
    variance_ratio = 0.95,
    inner_folds = 5,
    rng_seed = 20260223
)

println("=== Nested LOOCV: Multi-filtration SVM + PCA (95% var) ===")
n_corr_pca = sum(nested_svm_pca.predictions .== labels)
println("Accuracy: $(n_corr_pca)/$(length(labels)) ($(round(nested_svm_pca.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_svm_pca.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_svm_pca.macro_f1 * 100, digits=1))%")

ci_pca = wilson_ci(n_corr_pca, length(labels))
println("95% Wilson CI: [$(round(ci_pca.lower * 100, digits=1))%, $(round(ci_pca.upper * 100, digits=1))%]")

=== Nested LOOCV: Multi-filtration SVM + PCA (95% var) ===
Accuracy: 53/72 (73.6%)
Balanced accuracy: 68.6%
Macro-F1: 66.7%
95% Wilson CI: [62.4%, 82.4%]

For comparison, a simple PCA + SVM (non-nested) on the Multi-filtration features:

In [45]:
pca_svm_results = []
for kernel in [LIBSVM.Kernel.RadialBasis, LIBSVM.Kernel.Linear]
    for cost in [1.0, 10.0]
        kernel_name = kernel == LIBSVM.Kernel.RadialBasis ? "RBF" : "Linear"
        r = loocv_svm_pca(X_multi, labels;
                          variance_ratio = 0.95, kernel = kernel, cost = cost)
        push!(pca_svm_results, (
            method = "PCA+SVM ($kernel_name, C=$cost)",
            accuracy = r.accuracy,
            n_correct = sum(r.predictions .== labels),
            n_components = r.median_n_components
        ))
    end
end

pca_df = DataFrame(pca_svm_results)
sort!(pca_df, :accuracy, rev = true)
pca_df

### Permutation test

> **Permutation test for feature-based classifiers**
>
> A **permutation test** assesses whether the classifier’s accuracy is significantly better than chance. We shuffle the labels many times, recompute LOOCV accuracy each time, and measure how often the shuffled accuracy matches or exceeds the observed accuracy. If the observed accuracy is far above the null distribution, we can be confident the features contain genuine discriminative signal — even if the absolute accuracy estimate may be optimistically biased.

In [46]:
# Permutation test for Multi-filtration SVM (takes a few minutes)
perm_multi = permutation_test_svm(
    X_multi, labels;
    n_permutations = 500,
    kernel = LIBSVM.Kernel.RadialBasis,
    cost = 10.0
)

println("=== Permutation Test: Multi-filtration SVM (RBF, C=10) ===")
println("Observed LOOCV accuracy: $(round(perm_multi.observed * 100, digits=1))%")
println("Null distribution: mean=$(round(perm_multi.perm_mean * 100, digits=1))%, std=$(round(perm_multi.perm_std * 100, digits=1))%")
println("Max null accuracy: $(round(perm_multi.perm_max * 100, digits=1))%")
println("p-value: $(perm_multi.p_value)")

### Feature selection via RF importance

> **Why feature selection helps**
>
> With ~\$(size(X_multi, 2)) features and $n=72$ samples, overfitting is the main accuracy bottleneck. Selecting the top features by Random Forest impurity importance reduces dimensionality and improves generalization. **Critically**, feature selection is performed INSIDE each LOOCV fold to avoid data leakage — each fold selects features using only training data.

In [47]:
# RF with feature selection inside each fold (honest evaluation)
for top_k in [20, 30, 50]
    r_sel = loocv_rf_with_selection(
        X_multi, labels;
        n_trees_select = 500, n_trees_classify = 300,
        top_k = top_k, balanced = true, rng_seed = 20260223
    )
    n_corr_sel = sum(r_sel.predictions .== labels)
    println("RF + selection (top_k=$top_k): $(n_corr_sel)/$(length(labels)) ($(round(r_sel.accuracy * 100, digits=1))%)")
    println("  Balanced acc: $(round(r_sel.balanced_accuracy * 100, digits=1))%  Macro-F1: $(round(r_sel.macro_f1 * 100, digits=1))%")
end

RF + selection (top_k=20): 47/72 (65.3%)
  Balanced acc: 58.9%  Macro-F1: 57.2%
RF + selection (top_k=30): 52/72 (72.2%)
  Balanced acc: 66.8%  Macro-F1: 64.3%
RF + selection (top_k=50): 46/72 (63.9%)
  Balanced acc: 58.0%  Macro-F1: 56.2%

In [48]:
# Also try on the H0+H1 stats-only feature matrix (lower dimensional)
for top_k in [20, 30, 50]
    r_sel_stats = loocv_rf_with_selection(
        X_stats_all, labels;
        n_trees_select = 500, n_trees_classify = 300,
        top_k = top_k, balanced = true, rng_seed = 20260223
    )
    n_corr_sel = sum(r_sel_stats.predictions .== labels)
    println("RF + selection on stats (top_k=$top_k): $(n_corr_sel)/$(length(labels)) ($(round(r_sel_stats.accuracy * 100, digits=1))%)")
end

RF + selection on stats (top_k=20): 47/72 (65.3%)
RF + selection on stats (top_k=30): 52/72 (72.2%)
RF + selection on stats (top_k=50): 46/72 (63.9%)

### Nested LOOCV with multi-distance selection

> **Multi-distance nested LOOCV**
>
> With many distance matrices available, we need an honest way to select the best one. The inner loop evaluates all (distance, k) combinations on the training fold; the outer loop provides an unbiased accuracy estimate.

In [49]:
nested_multi_dist = nested_loocv_multi_distance(
    distances, labels;
    ks = [1, 3, 5]
)

n_corr_multi = sum(nested_multi_dist.predictions .== labels)
println("=== Nested LOOCV: Multi-distance selection ===")
println("Accuracy: $(n_corr_multi)/$(length(labels)) ($(round(nested_multi_dist.accuracy * 100, digits=1))%)")
println("Balanced accuracy: $(round(nested_multi_dist.balanced_accuracy * 100, digits=1))%")
println("Macro-F1: $(round(nested_multi_dist.macro_f1 * 100, digits=1))%")

# Show which distances were selected most often
dist_selection_counts = Dict{String, Int}()
for p in nested_multi_dist.params
    dist_selection_counts[p.distance] = get(dist_selection_counts, p.distance, 0) + 1
end
println("\nSelected distances across folds:")
for (d, v) in sort(collect(dist_selection_counts), by=last, rev=true)
    println("  $d: $v/$(length(labels)) folds")
end

=== Nested LOOCV: Multi-distance selection ===
Accuracy: 30/72 (41.7%)
Balanced accuracy: 36.6%
Macro-F1: 35.5%

Selected distances across folds:
  Cubical H1 Betti: 50/72 folds
  Rips Betti: 21/72 folds
  Cubical H1 PI: 1/72 folds

### Honest comparison summary

In [50]:
honest_results = []

# Nested RF on comprehensive stats
push!(honest_results, (
    method = "Nested LOOCV: Balanced RF (All H0+H1 stats, $(size(X_stats_all, 2)) features)",
    accuracy = nested_rf.accuracy,
    balanced_accuracy = nested_rf.balanced_accuracy,
    macro_f1 = nested_rf.macro_f1,
    n_correct = sum(nested_rf.predictions .== labels),
    n_total = length(labels),
    honest = "Yes"
))

# Nested SVM on multi-filtration (no PCA)
push!(honest_results, (
    method = "Nested LOOCV: SVM (Multi-filtration H0+H1, $(size(X_multi, 2)) features)",
    accuracy = nested_svm_multi.accuracy,
    balanced_accuracy = nested_svm_multi.balanced_accuracy,
    macro_f1 = nested_svm_multi.macro_f1,
    n_correct = sum(nested_svm_multi.predictions .== labels),
    n_total = length(labels),
    honest = "Yes"
))

# Nested SVM + PCA on multi-filtration
push!(honest_results, (
    method = "Nested LOOCV: SVM + PCA (Multi-filtration H0+H1)",
    accuracy = nested_svm_pca.accuracy,
    balanced_accuracy = nested_svm_pca.balanced_accuracy,
    macro_f1 = nested_svm_pca.macro_f1,
    n_correct = sum(nested_svm_pca.predictions .== labels),
    n_total = length(labels),
    honest = "Yes"
))

# RF with feature selection (best top_k)
for top_k in [20, 30, 50]
    r_sel = loocv_rf_with_selection(
        X_multi, labels;
        n_trees_select = 500, n_trees_classify = 300,
        top_k = top_k, balanced = true, rng_seed = 20260223
    )
    push!(honest_results, (
        method = "RF + Feature Selection (top_k=$top_k, multi H0+H1)",
        accuracy = r_sel.accuracy,
        balanced_accuracy = r_sel.balanced_accuracy,
        macro_f1 = r_sel.macro_f1,
        n_correct = sum(r_sel.predictions .== labels),
        n_total = length(labels),
        honest = "Yes (selection inside fold)"
    ))
end

# Nested multi-distance selection
push!(honest_results, (
    method = "Nested LOOCV: Multi-distance k-NN",
    accuracy = nested_multi_dist.accuracy,
    balanced_accuracy = nested_multi_dist.balanced_accuracy,
    macro_f1 = nested_multi_dist.macro_f1,
    n_correct = n_corr_multi,
    n_total = length(labels),
    honest = "Yes"
))

# Best k-NN on Wasserstein distances (no hyperparameter selection needed for k=1)
for (wname, D_wass) in [("Wass-1", D_wass1_rips), ("Wass-2", D_wass2_rips)]
    r_knn1 = loocv_knn(D_wass, labels; k = 1)
    m_knn1 = classification_metrics(labels, r_knn1.predictions)
    push!(honest_results, (
        method = "1-NN on Rips $(wname) (no tuning)",
        accuracy = r_knn1.accuracy,
        balanced_accuracy = m_knn1.balanced_accuracy,
        macro_f1 = m_knn1.macro_f1,
        n_correct = sum(r_knn1.predictions .== labels),
        n_total = length(labels),
        honest = "Yes (no hyperparams)"
    ))
end

honest_df = DataFrame(honest_results)
sort!(honest_df, :accuracy, rev = true)
honest_df

## Discussion

We applied multiple TDA filtration strategies to classify Diptera families from wing venation images. Key findings:

1.  **Multiple filtrations are complementary**: The Vietoris-Rips filtration on point-cloud samples captures the global loop structure of the wing venation. Directional height filtrations (now 8 directions) encode how topological features are spatially distributed along specific axes, the radial filtration captures the center-to-periphery organization, the EDT filtration captures vein thickness hierarchy, and cubical (grayscale sublevel-set) persistence captures intensity landscape information. Together, these views capture different geometric and topological aspects of the wing.

2.  **H0 persistence from directional/radial/EDT filtrations is informative**: While H0 is uninformative for Rips on a connected point cloud, H0 from cubical filtrations captures vein branching — when disconnected vein segments merge as the sweep progresses. This is directly related to vein count and branching patterns, a key taxonomic character for Diptera families.

3.  **8 filtration directions improve coverage**: Expanding from 4 to 8 directions (every 22.5°) captures oblique vein angles missed previously, providing finer angular resolution of the venation topology.

4.  **EDT filtration captures vein thickness**: The Euclidean Distance Transform filtration captures the vein thickness hierarchy, which is a diagnostic character (e.g., Tabanidae have thickened C and Sc veins). This provides complementary information to structural topology.

5.  **Feature selection reduces overfitting**: Random Forest feature importance-based selection (performed inside each CV fold to avoid leakage) reduces the feature-to-sample ratio dramatically, improving generalization in honest nested evaluations.

6.  **Multi-distance nested LOOCV provides honest model selection**: With many distance matrices available, the nested multi-distance evaluation selects the best (distance, k) combination in the inner loop and provides an unbiased accuracy estimate in the outer loop.

7.  **Wasserstein-2 vs Wasserstein-1**: Both Wasserstein distances are effective for comparing persistence diagrams, with W-2 penalizing large mismatches more heavily. The comparison reveals whether fine or coarse topological differences are more discriminative.

8.  **Statistical rigor**: We report LOOCV accuracy with Wilson confidence intervals, nested LOOCV for unbiased evaluation when hyperparameters are tuned, and permutation tests to verify that observed accuracy is significantly above chance level.

### Limitations

-   **Class imbalance**: Tipulidae has 12 samples while Pelecorhynchidae has only 2, which may affect some classifiers. The filtered analysis (excluding families with \< 3 samples) provides a fairer comparison.
-   Image quality and preprocessing parameters (blur, threshold) influence topological features
-   The non-nested LOOCV results for feature-based classifiers are optimistically biased due to hyperparameter selection on the evaluation data. The honest comparison table should be preferred
-   With only 72 samples, confidence intervals remain wide regardless of method

### Future work

-   Extend dataset with more specimens per family, especially underrepresented families
-   Improve imaging/segmentation quality and reevaluate image-based filtrations with less noise sensitivity
-   Apply extended persistence or zigzag persistence for richer invariants
-   XGBoost or gradient boosting classifiers for tabular feature data
-   Deep learning on persistence images or persistence diagrams directly