Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 167 additions & 88 deletions plots/dendrogram-basic/implementations/plotnine.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
""" pyplots.ai
dendrogram-basic: Basic Dendrogram
Library: plotnine 0.15.2 | Python 3.13.11
Quality: 91/100 | Created: 2025-12-23
Library: plotnine 0.15.3 | Python 3.14.3
Quality: 89/100 | Updated: 2026-04-05
"""

import numpy as np
import pandas as pd
from plotnine import (
aes,
annotate,
coord_cartesian,
element_blank,
element_line,
element_rect,
element_text,
geom_hline,
geom_point,
geom_segment,
geom_text,
ggplot,
guide_legend,
guides,
labs,
scale_color_manual,
scale_x_continuous,
Expand All @@ -22,112 +29,184 @@
theme_minimal,
)
from scipy.cluster.hierarchy import dendrogram, linkage
from sklearn.datasets import load_iris


# Data - Iris flower measurements (4 features for 15 samples)
# Data - Real iris flower measurements (15 samples, 5 per species)
iris = load_iris()
np.random.seed(42)
species_names = ["Setosa", "Versicolor", "Virginica"]
species_counts = dict.fromkeys(species_names, 0)
sample_labels = []
indices = np.concatenate([np.random.choice(np.where(iris.target == i)[0], 5, replace=False) for i in range(3)])
for i in indices:
name = species_names[iris.target[i]]
species_counts[name] += 1
sample_labels.append(f"{name}-{species_counts[name]}")
features = iris.data[indices]

# Hierarchical clustering with Ward's method
linkage_matrix = linkage(features, method="ward")
palette = {"Setosa": "#306998", "Versicolor": "#E8833A", "Virginica": "#55A868"}

# Extract dendrogram coordinates
dend = dendrogram(linkage_matrix, labels=sample_labels, no_plot=True)

# Track species composition of each node for branch coloring
n = len(sample_labels)
leaf_species = {lbl: lbl.rsplit("-", 1)[0] for lbl in sample_labels}
node_species = {}
for i, label in enumerate(sample_labels):
node_species[i] = {leaf_species[label]}
for i, row in enumerate(linkage_matrix):
left, right = int(row[0]), int(row[1])
node_species[n + i] = node_species[left] | node_species[right]

# Branch type label for each merge: species name if pure, "Mixed" if mixed
branch_type_labels = {"Setosa": "Setosa (pure)", "Versicolor": "Versicolor (pure)", "Virginica": "Virginica (pure)"}
merge_branch_types = []
for i in range(len(linkage_matrix)):
sp = node_species[n + i]
if len(sp) == 1:
merge_branch_types.append(branch_type_labels[next(iter(sp))])
else:
merge_branch_types.append("Mixed species")

# Simulate iris-like measurements: sepal length, sepal width, petal length, petal width
# Three species with distinct characteristics
samples_per_species = 5

labels = []
data = []

# Setosa: shorter petals, wider sepals
for i in range(samples_per_species):
labels.append(f"Setosa-{i + 1}")
data.append(
[
5.0 + np.random.randn() * 0.3, # sepal length
3.4 + np.random.randn() * 0.3, # sepal width
1.5 + np.random.randn() * 0.2, # petal length
0.3 + np.random.randn() * 0.1, # petal width
]
)

# Versicolor: medium measurements
for i in range(samples_per_species):
labels.append(f"Versicolor-{i + 1}")
data.append(
[
5.9 + np.random.randn() * 0.4, # sepal length
2.8 + np.random.randn() * 0.3, # sepal width
4.3 + np.random.randn() * 0.4, # petal length
1.3 + np.random.randn() * 0.2, # petal width
]
)

# Virginica: longer petals and sepals
for i in range(samples_per_species):
labels.append(f"Virginica-{i + 1}")
data.append(
[
6.6 + np.random.randn() * 0.5, # sepal length
3.0 + np.random.randn() * 0.3, # sepal width
5.5 + np.random.randn() * 0.5, # petal length
2.0 + np.random.randn() * 0.3, # petal width
]
)

data = np.array(data)

# Compute hierarchical clustering using Ward's method
linkage_matrix = linkage(data, method="ward")

# Extract dendrogram coordinates using scipy (no_plot=True returns coordinates only)
dend = dendrogram(linkage_matrix, labels=labels, no_plot=True)
# Map dendrogram order to linkage order via merge heights
height_to_merge = {}
for i, h in enumerate(linkage_matrix[:, 2]):
height_to_merge.setdefault(round(h, 10), []).append(i)

# Convert dendrogram coordinates to segment data for plotnine
# icoord contains x coords (pairs of 4 for each merge)
# dcoord contains y coords (pairs of 4 for each merge)
# Build segment dataframe
segments = []
color_threshold = 0.7 * max(linkage_matrix[:, 2])

for xs, ys in zip(dend["icoord"], dend["dcoord"], strict=True):
# Each merge has 4 points forming a U-shape: [x1, x2, x3, x4], [y1, y2, y3, y4]
# We need 3 segments: left vertical, horizontal, right vertical

# Determine color based on height (merge distance)
merge_height = max(ys)
if merge_height > color_threshold:
color = "#306998" # Python Blue for high-level merges
h = round(max(ys), 10)
if h in height_to_merge and height_to_merge[h]:
merge_idx = height_to_merge[h].pop(0)
btype = merge_branch_types[merge_idx]
else:
color = "#FFD43B" # Python Yellow for low-level merges

segments.append({"x": xs[0], "xend": xs[1], "y": ys[0], "yend": ys[1], "color": color})
segments.append({"x": xs[1], "xend": xs[2], "y": ys[1], "yend": ys[2], "color": color})
segments.append({"x": xs[2], "xend": xs[3], "y": ys[2], "yend": ys[3], "color": color})
btype = "Mixed species"
segments.append({"x": xs[0], "xend": xs[1], "y": ys[0], "yend": ys[1], "branch_type": btype})
segments.append({"x": xs[1], "xend": xs[2], "y": ys[1], "yend": ys[2], "branch_type": btype})
segments.append({"x": xs[2], "xend": xs[3], "y": ys[2], "yend": ys[3], "branch_type": btype})

segments_df = pd.DataFrame(segments)

# Create label data using the actual leaf positions from dendrogram
# dend['leaves'] gives the order, and x positions are at 5, 15, 25, ... (spacing of 10)
leaf_positions = [(i + 1) * 10 - 5 for i in range(len(dend["ivl"]))]
ivl = dend["ivl"] # Reordered labels from dendrogram
label_df = pd.DataFrame({"x": leaf_positions, "label": ivl, "y": [-0.8] * len(ivl)})
# Leaf labels with species-based coloring
n_leaves = len(dend["ivl"])
leaf_positions = [(i + 1) * 10 - 5 for i in range(n_leaves)]
leaf_labels = dend["ivl"]
leaf_btypes = [branch_type_labels[leaf_species[lbl]] for lbl in leaf_labels]
label_df = pd.DataFrame({"x": leaf_positions, "label": leaf_labels, "y": [0.0] * n_leaves, "branch_type": leaf_btypes})

# Ordered category for consistent legend
category_order = ["Setosa (pure)", "Versicolor (pure)", "Virginica (pure)", "Mixed species"]
color_map = {
"Setosa (pure)": palette["Setosa"],
"Versicolor (pure)": palette["Versicolor"],
"Virginica (pure)": palette["Virginica"],
"Mixed species": "#888888",
}
segments_df["branch_type"] = pd.Categorical(segments_df["branch_type"], categories=category_order, ordered=True)
label_df["branch_type"] = pd.Categorical(label_df["branch_type"], categories=category_order, ordered=True)

# Merge node points - highlight where clusters join (plotnine geom_point layer)
merge_nodes = []
for xs, ys, btype in zip(dend["icoord"], dend["dcoord"], merge_branch_types, strict=True):
cx = (xs[1] + xs[2]) / 2
cy = max(ys)
merge_nodes.append({"x": cx, "y": cy, "branch_type": btype})
merge_df = pd.DataFrame(merge_nodes)
merge_df["branch_type"] = pd.Categorical(merge_df["branch_type"], categories=category_order, ordered=True)

# Key merge threshold: where Setosa separates from the rest
setosa_sep_height = linkage_matrix[-2, 2]
threshold_df = pd.DataFrame({"yintercept": [setosa_sep_height]})

# Plot
y_max = max(linkage_matrix[:, 2]) * 1.08
x_min = min(segments_df["x"].min(), segments_df["xend"].min())
x_max = max(segments_df["x"].max(), segments_df["xend"].max())
x_pad = (x_max - x_min) * 0.06

# Plot using plotnine's native geom_segment
plot = (
ggplot()
+ geom_segment(aes(x="x", xend="xend", y="y", yend="yend", color="color"), data=segments_df, size=1.8)
+ geom_text(aes(x="x", y="y", label="label"), data=label_df, angle=45, ha="right", va="top", size=9)
+ scale_color_manual(values={"#306998": "#306998", "#FFD43B": "#FFD43B"}, guide=None)
+ scale_x_continuous(breaks=[], expand=(0.12, 0.05))
+ scale_y_continuous(expand=(0.25, 0.02))
+ labs(x="Sample", y="Distance (Ward)", title="dendrogram-basic · plotnine · pyplots.ai")
# Dendrogram branches - thicker for HD visibility
+ geom_segment(aes(x="x", xend="xend", y="y", yend="yend", color="branch_type"), data=segments_df, size=2.2)
# Threshold line using idiomatic geom_hline
+ geom_hline(aes(yintercept="yintercept"), data=threshold_df, linetype="dashed", color="#AAAAAA", size=0.8)
# Threshold annotation using plotnine annotate
+ annotate(
"text",
x=x_max - x_pad,
y=setosa_sep_height + 0.35,
label="Setosa separates",
size=13,
color="#555555",
fontstyle="italic",
ha="right",
)
# Intermixing annotation - data storytelling for Versicolor/Virginica
+ annotate(
"text",
x=x_max - x_pad,
y=linkage_matrix[-1, 2] * 0.55,
label="Versicolor & Virginica intermixed",
size=12,
color="#888888",
fontstyle="italic",
ha="right",
)
# Leaf labels - larger for readability
+ geom_text(
aes(x="x", y="y", label="label", color="branch_type"),
data=label_df,
angle=45,
ha="right",
va="top",
size=13,
nudge_y=-0.3,
show_legend=False,
)
# Merge node markers - emphasize join points
+ geom_point(aes(x="x", y="y", color="branch_type"), data=merge_df, size=3.5, show_legend=False)
+ scale_color_manual(values=color_map, name="Branch Type")
+ guides(color=guide_legend(override_aes={"size": 4, "alpha": 1}))
+ scale_x_continuous(breaks=[], expand=(0.04, 0))
+ scale_y_continuous(breaks=np.arange(0, y_max, 2).tolist(), expand=(0.10, 0))
+ coord_cartesian(xlim=(x_min - x_pad, x_max + x_pad), ylim=(-2.5, y_max))
+ labs(
x="",
y="Ward Linkage Distance",
title="Iris Species Clustering · dendrogram-basic · plotnine · pyplots.ai",
subtitle="Hierarchical clustering of 15 iris samples using Ward's minimum variance method",
)
+ theme_minimal()
+ theme(
figure_size=(16, 9),
text=element_text(size=14),
axis_title=element_text(size=20),
axis_text=element_text(size=16),
text=element_text(size=14, family="sans-serif"),
axis_title_x=element_blank(),
axis_title_y=element_text(size=20, margin={"r": 12}),
axis_text=element_text(size=16, color="#444444"),
axis_text_x=element_blank(),
plot_title=element_text(size=24),
axis_ticks_major_x=element_blank(),
plot_title=element_text(size=24, weight="bold", margin={"b": 4}),
plot_subtitle=element_text(size=15, color="#666666", margin={"b": 12}),
plot_background=element_rect(fill="#FAFAFA", color="none"),
panel_background=element_rect(fill="#FAFAFA", color="none"),
panel_grid_major_x=element_blank(),
panel_grid_minor_x=element_blank(),
panel_grid_major_y=element_line(alpha=0.3, linetype="dashed"),
panel_grid_minor_y=element_blank(),
panel_grid_major_y=element_line(alpha=0.2, size=0.5, color="#CCCCCC"),
legend_title=element_text(size=16, weight="bold"),
legend_text=element_text(size=14),
legend_position="right",
legend_background=element_rect(fill="#FAFAFA", color="#DDDDDD", size=0.5),
legend_key=element_rect(fill="none", color="none"),
plot_margin=0.02,
)
)

plot.save("plot.png", dpi=300)
# Save with tight layout
fig = plot.draw()
fig.savefig("plot.png", dpi=300, bbox_inches="tight")
Loading
Loading