# Metadata

**L1 Taxonomy** - Problem Solving & Algorithms

**L2 Taxonomy** - Machine Learning Algorithms

**Subtopic** - Scikit-learn integration for explainable k-means

**Use Case** - Create a Python solution that integrates scikit-learn’s k-means clustering with an additional explainability layer. After obtaining an initial partition, implement a greedy decision tree method that refines the clusters by selecting axis-aligned cuts based on weighted depth and cost metrics. This integration offers a practical way to deliver interpretable machine learning models for domains where auditability and transparency are priorities fileciteturn0file13.

**Programming Language** - Python

**Target Model** - o1

# Model Breaking Hints


1) **What is the initial use case?**

   The initial problem involves creating a Python solution that combines scikit-learn’s k-means clustering with an explainability layer. After obtaining clusters via k-means, it requires implementing a greedy decision tree to refine these clusters using axis-aligned cuts based on weighted depth and cost metrics, enhancing interpretability for domains prioritizing auditability and transparency.

2) **Why is the initial use case easy?**

   - It uses standard algorithms like k-means clustering and basic decision trees, which are well-documented and have extensive library support.
   - The clustering and refinement steps are performed on static data, eliminating the complexity of handling dynamic or streaming inputs.
   - The axis-aligned cuts are straightforward to implement and don't require advanced mathematical concepts.
   - There are no additional constraints like handling high-dimensional data, complex distance metrics, or multi-objective optimization.

3) **How could we make it harder?**

   - **Incorporate online clustering over streaming data**: Introduce the need to handle real-time data updates efficiently using advanced data structures like dynamic kd-trees.
   - **Handle high-dimensional data with manifold learning**: Extend the model to work with high-dimensional data while maintaining interpretability, requiring sophisticated mathematical reasoning.
   - **Introduce multi-objective optimization**: Refine clusters using a decision tree that not only minimizes variance but also maximizes interpretability and satisfies constraints like fairness and privacy.
   - **Adapt to non-Euclidean spaces with custom metric learning**: Require the algorithm to function in non-standard distance spaces, dealing with obscure corner cases challenging traditional clustering approaches.

4) **Which parameters can we change?**

   - **Data Nature**: Shift from static to streaming data to introduce the complexity of online clustering.
   - **Data Dimensionality**: Move from low-dimensional to high-dimensional data, necessitating techniques like manifold learning or topological data analysis.
   - **Distance Metrics**: Replace standard Euclidean metrics with custom, learned metrics to handle non-Euclidean spaces.
   - **Refinement Method**: Upgrade from simple greedy decision trees to multi-objective decision trees that address multiple constraints simultaneously.
   - **Constraints**: Introduce additional constraints such as fairness and privacy, which require careful consideration during the clustering and refinement processes.

5) **What can be a final hard prompt?**

   By combining these complexities, we arrive at a challenging problem that requires advanced knowledge in data structures, algorithms, and mathematical reasoning:

   "Develop a Python algorithm that performs online clustering of streaming, high-dimensional, non-Euclidean data by integrating dynamic kd-trees with manifold learning techniques. After initial clustering, refine the clusters using a multi-objective decision tree that minimizes intra-cluster variance, maximizes interpretability, and satisfies constraints like fairness and privacy, all while adapting to custom metric learning in non-standard distance spaces."

# Setup

```requirements.txt
scikit-learn>=1.2.0
numpy>=1.23.0

```


# Prompt

I want to build a Python program that uses scikit learn's k means clustering to segment a dataset into meaningful groups and then enhances this segmentation using an explainability layer. The explainability layer must apply a custom greedy decision tree algorithm that selects axis aligned cuts to refine the initial k means clusters.


**Input Format**

Three parameters will be provided as inputs:

  - data: A NumPy array of shape (n_samples, n_features), where all values are numeric.
  - k: An integer specifying the number of clusters for initial k means. It must satisfy 1 <= k <= 15.
  - max_depth: An integer indicating the maximum depth of the greedy decision tree (>= 1).

- Each row of data represents a sample, and each column represents a feature.


**Output Format**

- The program must return a dictionary with the following keys:

  - refined_labels: A list of integers of length n_samples, representing the refined cluster assignment for each sample.
  - tree_structure: A list of dictionaries, each describing a node in the decision tree with the keys:

    - depth (int): Depth of the node.
    - split_feature (int): Index of the feature used for splitting, or -1 if it is a leaf.
    - split_value (float): Value of the axis-aligned cut, or None if it is a leaf.
    - cluster_distribution (dict): Mapping from cluster label to count of samples in this node.
  - initial_centroids: A list of k lists, each being a centroid from the original k means clustering.


**Examples 1:**

```python
Input:
  data = np.array([[1.0, 2.0], [1.1, 1.9], [9.0, 8.8], [9.1, 9.0]])
  k = 2
  max_depth = 2

Output:
  {
    "refined_labels": [0, 0, 1, 1],
    "tree_structure": [
      {"depth": 0, "split_feature": 0, "split_value": 5.0, "cluster_distribution": {0: 2, 1: 2}},
      {"depth": 1, "split_feature": -1, "split_value": None, "cluster_distribution": {0: 2}},
      {"depth": 1, "split_feature": -1, "split_value": None, "cluster_distribution": {1: 2}}
    ],
    "initial_centroids": [[1.05, 1.95], [9.05, 8.9]]
  }
```

**Example 2:**

```python
Input:
  data = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
  k = 1
  max_depth = 1

Output:
  {
    "refined_labels": [0, 0, 0, 0],
    "tree_structure": [],
    "initial_centroids": [[0.5, 0.5]]
  }
```

# Requirements

**Explicit and Implicit Points:**

  - The refined_labels must reflect the updated groupings after applying the decision tree.
  - If the decision tree reaches max_depth or no split improves separation, stop splitting.
  - Nodes in tree_structure must appear in depth first order.
  - If a node is a leaf, it must have split_feature: -1 and split_value: None.
  - The cluster_distribution must include all clusters observed in the node.
  - Use axis aligned splits only. Do not use multi dimensional splits.

**Solution Expectations:**

  - The model should cluster the data using k means.
  - The explainability layer should simulate how decisions refine the clusters.

**Function Signature:**

```python
def explainable_kmeans(data: np.ndarray, k: int, max_depth: int) -> dict:
```

**Classes or Structures**

```python
class TreeNode:
    def __init__(self, depth: int):
        self.depth: int
        self.split_feature: int
        self.split_value: float
        self.cluster_distribution: Dict[int, int]
```


**Edge Case Behavior**

- Empty dataset: Raise ValueError "Input data is empty".
- All identical points: Return 1 cluster, empty tree.
- k > unique samples: Raise ValueError "k exceeds number of unique samples".
- max_depth = 0: Tree is never built; return initial clusters.
- 1D feature space: Tree can still split based on 1 axis.
- Floating point instability: Ensure centroids and splits use np.float64 for precision.
- Dataset with NaNs: Raise ValueError "Input contains NaNs".
- High dimensional data >=100 features: Tree should consider only top 10 most variant axes.
- Overlapping clusters: Tree should try to separate but may retain overlap.
- Small sample set n < k: Raise ValueError "Not enough samples for k clusters".

**Constraints**

- You must not use eval() or exec() anywhere in the code.
- You must ensure that every feature considered for splitting has at least 2 unique values in the node.
- All calculations for cost must be in float64 precision. Do not cast to lower precision types.
- Use only numpy, scikit learn, and typing libraries. Do not use pandas or any visualization package.
- You must write your own function to calculate cluster distribution within a node. Do not use collections.Counter.
- All inputs must be validated. Invalid types or out of bound values must raise clear ValueError messages with exact reasons.


In [None]:
# code

import numpy as np
from sklearn.cluster import KMeans
from typing import Dict, List, Tuple, Optional


class TreeNode:
    """Represents a node in the decision tree."""

    def __init__(self, depth: int):
        self.depth = depth
        self.split_feature = -1  # -1 for leaf nodes
        self.split_value = None  # None for leaf nodes
        self.cluster_distribution = {}
        self.left_child = None
        self.right_child = None
        self.sample_indices = []


def explainable_kmeans(data: np.ndarray, k: int, max_depth: int) -> dict:
    """
    Performs k-means clustering with explainable decision tree refinement.
    """
    # Input validation
    _validate_inputs(data, k, max_depth)

    # Handle edge case: max_depth = 0
    if max_depth == 0:
        kmeans_result = _perform_kmeans(data, k)
        return {
            "refined_labels": kmeans_result['labels'].tolist(),
            "tree_structure": [],
            "initial_centroids": kmeans_result['centroids'].tolist()
        }

    # Perform initial k-means clustering
    kmeans_result = _perform_kmeans(data, k)
    initial_labels = kmeans_result['labels']
    initial_centroids = kmeans_result['centroids']

    # Build decision tree for explainability
    root_node = _build_decision_tree(data, initial_labels, max_depth)

    # Extract refined labels from tree
    refined_labels = _extract_refined_labels(data, root_node, len(data))

    # Convert tree to required format
    tree_structure = _convert_tree_to_structure(root_node)

    return {
        "refined_labels": refined_labels.tolist(),
        "tree_structure": tree_structure,
        "initial_centroids": initial_centroids.tolist()
    }


def _validate_inputs(data: np.ndarray, k: int, max_depth: int) -> None:
    """Validates all input parameters according to constraints."""

    # Check data type and format
    if not isinstance(data, np.ndarray):
        raise ValueError("Data must be a numpy array")

    if data.size == 0:
        raise ValueError("Input data is empty")

    if data.ndim != 2:
        raise ValueError("Data must be 2-dimensional")

    # Check for NaNs
    if np.isnan(data).any():
        raise ValueError("Input contains NaNs")

    # Validate k parameter
    if not isinstance(k, int) or k < 1 or k > 15:
        raise ValueError("k must be an integer between 1 and 15")

    # Check unique samples constraint
    unique_samples = len(np.unique(data, axis=0))
    if k > unique_samples:
        raise ValueError("k exceeds number of unique samples")

    # Check minimum samples constraint
    if len(data) < k:
        raise ValueError("Not enough samples for k clusters")

    # Validate max_depth
    if not isinstance(max_depth, int) or max_depth < 0:
        raise ValueError("max_depth must be a non-negative integer")


def _perform_kmeans(data: np.ndarray, k: int) -> Dict:
    """Performs k-means clustering with float64 precision."""

    # Ensure float64 precision
    data_float64 = data.astype(np.float64)

    # Handle single cluster case
    if k == 1:
        labels = np.zeros(len(data), dtype=int)
        centroids = np.mean(data_float64, axis=0, keepdims=True)
        return {'labels': labels, 'centroids': centroids}

    # Perform k-means clustering
    kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
    labels = kmeans.fit_predict(data_float64)
    centroids = kmeans.cluster_centers_

    return {'labels': labels, 'centroids': centroids}


def _build_decision_tree(data: np.ndarray, labels: np.ndarray, max_depth: int) -> TreeNode:
    """Builds a greedy decision tree for cluster refinement."""

    root_indices = np.arange(len(data))
    root_node = TreeNode(depth=0)
    root_node.sample_indices = root_indices
    root_node.cluster_distribution = _calculate_cluster_distribution(labels[root_indices])

    # Build tree recursively
    _build_tree_recursive(root_node, data, labels, max_depth)

    return root_node


def _build_tree_recursive(node: TreeNode, data: np.ndarray, labels: np.ndarray, max_depth: int) -> None:
    """Recursively builds the decision tree using greedy axis-aligned splits."""

    # Stop if max depth reached
    if node.depth >= max_depth:
        return

    # Stop if node is pure or has insufficient samples
    if len(set(labels[node.sample_indices])) <= 1 or len(node.sample_indices) <= 1:
        return

    # Find best split
    best_split = _find_best_split(data, labels, node.sample_indices)

    if best_split is None:
        return  # No valid split found

    feature_idx, split_value, left_indices, right_indices = best_split

    # Set node split information
    node.split_feature = feature_idx
    node.split_value = float(split_value)

    # Create child nodes
    node.left_child = TreeNode(depth=node.depth + 1)
    node.left_child.sample_indices = left_indices
    node.left_child.cluster_distribution = _calculate_cluster_distribution(labels[left_indices])

    node.right_child = TreeNode(depth=node.depth + 1)
    node.right_child.sample_indices = right_indices
    node.right_child.cluster_distribution = _calculate_cluster_distribution(labels[right_indices])

    # Recursively build subtrees
    _build_tree_recursive(node.left_child, data, labels, max_depth)
    _build_tree_recursive(node.right_child, data, labels, max_depth)


def _find_best_split(data: np.ndarray, labels: np.ndarray, sample_indices: np.ndarray) -> Optional[Tuple]:
    """Finds the best axis-aligned split for the given samples."""

    node_data = data[sample_indices]
    node_labels = labels[sample_indices]

    # For high dimensional data, consider only top 10 most variant features
    features_to_consider = _select_features_to_consider(node_data)

    best_split = None
    best_score = float('inf')

    for feature_idx in features_to_consider:
        feature_values = node_data[:, feature_idx]

        # Check if feature has at least 2 unique values
        unique_values = np.unique(feature_values)
        if len(unique_values) < 2:
            continue

        # Try splits at midpoints between consecutive unique values
        for i in range(len(unique_values) - 1):
            split_value = (unique_values[i] + unique_values[i + 1]) / 2.0

            # Create split
            left_mask = feature_values <= split_value
            right_mask = ~left_mask

            # Skip if either side is empty
            if not np.any(left_mask) or not np.any(right_mask):
                continue

            # Calculate split quality using weighted impurity
            score = _calculate_split_score(node_labels, left_mask, right_mask)

            if score < best_score:
                best_score = score
                left_indices = sample_indices[left_mask]
                right_indices = sample_indices[right_mask]
                best_split = (feature_idx, split_value, left_indices, right_indices)

    return best_split


def _select_features_to_consider(data: np.ndarray) -> List[int]:
    """Selects features to consider for splitting, limiting to top 10 for high-dimensional data."""

    n_features = data.shape[1]

    # For high dimensional data (>=100 features), select top 10 most variant
    if n_features >= 100:
        variances = np.var(data, axis=0)
        top_features = np.argsort(variances)[-10:]
        return top_features.tolist()

    return list(range(n_features))


def _calculate_split_score(labels: np.ndarray, left_mask: np.ndarray, right_mask: np.ndarray) -> float:
    """Calculates weighted impurity score for a split using float64 precision."""

    total_samples = len(labels)
    left_samples = np.sum(left_mask)
    right_samples = np.sum(right_mask)

    # Calculate weighted gini impurity
    left_weight = left_samples / total_samples
    right_weight = right_samples / total_samples

    left_gini = _calculate_gini_impurity(labels[left_mask])
    right_gini = _calculate_gini_impurity(labels[right_mask])

    # Return weighted impurity as float64
    weighted_impurity = np.float64(left_weight * left_gini + right_weight * right_gini)
    return weighted_impurity


def _calculate_gini_impurity(labels: np.ndarray) -> float:
    """Calculates Gini impurity for a set of labels with float64 precision."""

    if len(labels) == 0:
        return 0.0

    # Calculate cluster distribution
    cluster_dist = _calculate_cluster_distribution(labels)
    total_samples = len(labels)

    # Calculate Gini impurity
    gini = np.float64(1.0)
    for count in cluster_dist.values():
        probability = np.float64(count) / np.float64(total_samples)
        gini -= probability * probability

    return gini


def _calculate_cluster_distribution(labels: np.ndarray) -> Dict[int, int]:
    """Calculates cluster distribution without using collections.Counter."""

    distribution = {}
    for label in labels:
        label_int = int(label)
        if label_int in distribution:
            distribution[label_int] += 1
        else:
            distribution[label_int] = 1

    return distribution


def _extract_refined_labels(data: np.ndarray, root_node: TreeNode, n_samples: int) -> np.ndarray:
    """Extracts refined cluster labels from the decision tree."""

    refined_labels = np.zeros(n_samples, dtype=int)

    # Assign cluster labels based on tree structure
    _assign_labels_recursive(root_node, refined_labels, 0)

    return refined_labels


def _assign_labels_recursive(node: TreeNode, labels: np.ndarray, cluster_id: int) -> int:
    """Recursively assigns cluster labels to samples based on tree structure."""

    if node.split_feature == -1:  # Leaf node
        # Assign current cluster_id to all samples in this leaf
        for idx in node.sample_indices:
            labels[idx] = cluster_id
        return cluster_id + 1
    else:
        # Process left subtree first, then right subtree
        cluster_id = _assign_labels_recursive(node.left_child, labels, cluster_id)
        cluster_id = _assign_labels_recursive(node.right_child, labels, cluster_id)
        return cluster_id


def _convert_tree_to_structure(root_node: TreeNode) -> List[Dict]:
    """Converts tree to required output format in depth-first order."""

    if root_node is None:
        return []

    structure = []
    _traverse_tree_depth_first(root_node, structure)
    return structure


def _traverse_tree_depth_first(node: TreeNode, structure: List[Dict]) -> None:
    """Traverses tree in depth-first order and builds structure list."""

    # Add current node to structure
    node_dict = {
        "depth": node.depth,
        "split_feature": node.split_feature,
        "split_value": node.split_value,
        "cluster_distribution": node.cluster_distribution
    }
    structure.append(node_dict)

    # Traverse children in depth-first order
    if node.left_child is not None:
        _traverse_tree_depth_first(node.left_child, structure)

    if node.right_child is not None:
        _traverse_tree_depth_first(node.right_child, structure)



In [None]:
# tests
import unittest
import numpy as np
from main import explainable_kmeans

class TestExplainableKMeans(unittest.TestCase):

    def test_simple_two_clusters(self):
        data = np.array([[1.0, 2.0], [1.1, 1.9], [9.0, 8.8], [9.1, 9.0]])
        result = explainable_kmeans(data, k=2, max_depth=2)
        self.assertEqual(len(result["refined_labels"]), 4)
        self.assertEqual(len(result["initial_centroids"]), 2)
        self.assertTrue(all(isinstance(node, dict) for node in result["tree_structure"]))

    def test_single_cluster_no_split(self):
        data = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
        result = explainable_kmeans(data, k=1, max_depth=1)
        self.assertEqual(result["tree_structure"], [
            {
        "depth": 0,
        "split_feature": -1,
        "split_value": None,
        "cluster_distribution": {0: 4}
            }
        ])
        self.assertEqual(result["refined_labels"], [0, 0, 0, 0])

    def test_empty_data(self):
        with self.assertRaises(ValueError) as ctx:
            explainable_kmeans(np.array([]).reshape(0, 2), k=1, max_depth=1)
        self.assertIn("Input data is empty", str(ctx.exception))

    def test_all_identical_points(self):
        data = np.array([[1.0, 1.0]] * 5)
        k = 1
        max_depth = 2
        result = explainable_kmeans(data, k, max_depth)

        expected_structure = [
            {
                "depth": 0,
                "split_feature": -1,
                "split_value": None,
                "cluster_distribution": {0: 5}
            }
        ]

        self.assertEqual(result["tree_structure"], expected_structure)

    def test_k_greater_than_unique(self):
        data = np.array([[1, 2], [1, 2], [1, 2]])
        with self.assertRaises(ValueError) as ctx:
            explainable_kmeans(data, k=5, max_depth=2)
        self.assertIn("k exceeds number of unique samples", str(ctx.exception))

    def test_max_depth_zero(self):
        data = np.random.rand(10, 2)
        result = explainable_kmeans(data, k=3, max_depth=0)
        self.assertEqual(result["tree_structure"], [])
        self.assertEqual(len(result["initial_centroids"]), 3)

    def test_single_feature(self):
        data = np.array([[1], [2], [9], [10]])
        result = explainable_kmeans(data, k=2, max_depth=2)
        self.assertEqual(len(result["refined_labels"]), 4)
        self.assertTrue(isinstance(result["tree_structure"], list))

    def test_nan_input(self):
        data = np.array([[1.0, 2.0], [np.nan, 3.0]])
        with self.assertRaises(ValueError) as ctx:
            explainable_kmeans(data, k=2, max_depth=1)
        self.assertIn("Input contains NaNs", str(ctx.exception))

    def test_high_dimensional_input(self):
        data = np.random.rand(50, 150)
        result = explainable_kmeans(data, k=4, max_depth=3)
        self.assertEqual(len(result["initial_centroids"]), 4)

    def test_k_equals_1(self):
        data = np.random.rand(20, 5)
        result = explainable_kmeans(data, k=1, max_depth=2)
        self.assertEqual(len(set(result["refined_labels"])), 1)

    def test_depth_limit_enforced(self):
        data = np.random.rand(100, 2)
        result = explainable_kmeans(data, k=3, max_depth=1)
        for node in result["tree_structure"]:
            self.assertLessEqual(node["depth"], 1)

    def test_axis_aligned_only(self):
        data = np.random.rand(10, 3)
        result = explainable_kmeans(data, k=2, max_depth=2)
        for node in result["tree_structure"]:
            self.assertTrue(-1 <= node["split_feature"] < 3 or node["split_feature"] == -1)

    def test_feature_with_one_unique_value(self):
        data = np.array([[1, 2], [1, 3], [1, 4]])
        result = explainable_kmeans(data, k=1, max_depth=2)
        self.assertTrue(isinstance(result["refined_labels"], list))

    def test_correct_centroid_precision(self):
        data = np.random.rand(5, 2)
        result = explainable_kmeans(data, k=2, max_depth=1)
        for centroid in result["initial_centroids"]:
            self.assertTrue(all(isinstance(v, float) for v in centroid))

    def test_non_integer_k(self):
        with self.assertRaises(ValueError) as ctx:
            explainable_kmeans(np.random.rand(10, 2), k="3", max_depth=2)
        self.assertIn("k must be an integer", str(ctx.exception))

    def test_non_integer_max_depth(self):
        with self.assertRaises(ValueError) as ctx:
            explainable_kmeans(np.random.rand(10, 2), k=3, max_depth="2")
        self.assertIn("max_depth must be a non-negative integer", str(ctx.exception))

    def test_invalid_data_type(self):
        with self.assertRaises(ValueError) as ctx:
            explainable_kmeans("not an array", k=3, max_depth=2)
        self.assertIn("Data must be a numpy array", str(ctx.exception))

    def test_tree_structure_content(self):
        data = np.random.rand(10, 2)
        result = explainable_kmeans(data, k=2, max_depth=2)
        for node in result["tree_structure"]:
            self.assertIn("depth", node)
            self.assertIn("split_feature", node)
            self.assertIn("split_value", node)
            self.assertIn("cluster_distribution", node)

    def test_cluster_distribution_is_dict(self):
        data = np.random.rand(20, 3)
        result = explainable_kmeans(data, k=3, max_depth=2)
        for node in result["tree_structure"]:
            self.assertTrue(isinstance(node["cluster_distribution"], dict))

    def test_node_leaf_condition(self):
        data = np.random.rand(10, 2)
        result = explainable_kmeans(data, k=2, max_depth=1)
        for node in result["tree_structure"]:
            if node["split_feature"] == -1:
                self.assertIsNone(node["split_value"])

    def test_cluster_label_range(self):
        data = np.random.rand(12, 2)
        result = explainable_kmeans(data, k=3, max_depth=2)
        self.assertTrue(set(result["refined_labels"]).issubset(set(range(3))))

    def test_no_split_when_all_same_cluster(self):
        data = np.random.rand(10, 2)
        result = explainable_kmeans(data, k=1, max_depth=3)
        self.assertEqual(len(set(result["refined_labels"])), 1)

    def test_random_large_sample(self):
        data = np.random.rand(1000, 2)
        result = explainable_kmeans(data, k=5, max_depth=3)
        self.assertEqual(len(result["refined_labels"]), 1000)

    def test_only_top_10_features_considered(self):
        data = np.random.rand(200, 100)
        result = explainable_kmeans(data, k=4, max_depth=2)
        used_features = {node["split_feature"] for node in result["tree_structure"] if node["split_feature"] != -1}
        self.assertLessEqual(len(used_features), 10)

    def test_refined_label_length_matches_data(self):
        data = np.random.rand(75, 5)
        result = explainable_kmeans(data, k=4, max_depth=2)
        self.assertEqual(len(result["refined_labels"]), len(data))

    def test_refined_label_type(self):
        data = np.random.rand(20, 2)
        result = explainable_kmeans(data, k=3, max_depth=1)
        self.assertTrue(all(isinstance(lbl, int) for lbl in result["refined_labels"]))

    def test_node_ordering_depth_first(self):

        data = np.array([
            [1.0, 2.0],
            [1.1, 2.1],
            [9.0, 8.9],
            [9.1, 9.0],
            [5.0, 5.0],
        ])
        k = 2
        max_depth = 2

        result = explainable_kmeans(data, k, max_depth)
        tree = result["tree_structure"]

        # Validate preorder depth-first traversal (parent before children)
        stack = []
        last_depth = -1
        for node in tree:
            depth = node["depth"]
            # Depth can increase by 1 (child), or decrease (backtracking), or stay same (sibling)
            # Ensure no jumps > 1 in depth without a proper return
            if depth > last_depth + 1:
                self.fail(f"Invalid depth transition from {last_depth} to {depth}")
            last_depth = depth

        # Optional: ensure root node exists
        self.assertEqual(tree[0]["depth"], 0)

    def test_split_feature_valid_range(self):
        data = np.random.rand(20, 4)
        result = explainable_kmeans(data, k=3, max_depth=2)
        for node in result["tree_structure"]:
            if node["split_feature"] != -1:
                self.assertTrue(0 <= node["split_feature"] < 4)

    def test_all_clusters_present_in_distribution(self):
        data = np.random.rand(30, 2)
        result = explainable_kmeans(data, k=3, max_depth=2)
        for node in result["tree_structure"]:
            total = sum(node["cluster_distribution"].values())
            self.assertTrue(total > 0)

    def test_precision_stability(self):
        data = np.random.rand(100, 2).astype(np.float64)
        result = explainable_kmeans(data, k=5, max_depth=2)
        for node in result["tree_structure"]:
            self.assertTrue(isinstance(node["split_value"], (float, type(None))))

    def test_not_enough_samples_for_k(self):
        data = np.random.rand(2, 2)
        with self.assertRaises(ValueError) as ctx:
            explainable_kmeans(data, k=5, max_depth=1)
        self.assertIn("k exceeds number of unique samples", str(ctx.exception))

if __name__ == "__main__":
    unittest.main(argv=[''], exit=False, verbosity= 2)

test_all_clusters_present_in_distribution (__main__.TestExplainableKMeans.test_all_clusters_present_in_distribution) ... ok
test_all_identical_points (__main__.TestExplainableKMeans.test_all_identical_points) ... ok
test_axis_aligned_only (__main__.TestExplainableKMeans.test_axis_aligned_only) ... ok
test_cluster_distribution_is_dict (__main__.TestExplainableKMeans.test_cluster_distribution_is_dict) ... ok
test_cluster_label_range (__main__.TestExplainableKMeans.test_cluster_label_range) ... ok
test_correct_centroid_precision (__main__.TestExplainableKMeans.test_correct_centroid_precision) ... ok
test_depth_limit_enforced (__main__.TestExplainableKMeans.test_depth_limit_enforced) ... ok
test_empty_data (__main__.TestExplainableKMeans.test_empty_data) ... ok
test_feature_with_one_unique_value (__main__.TestExplainableKMeans.test_feature_with_one_unique_value) ... ok
test_high_dimensional_input (__main__.TestExplainableKMeans.test_high_dimensional_input) ... ok
test_invalid_data_type (__

# Model Breaking Proof

#### Model Breaking Task URL: <https://rlhf-v3.turing.com/prompt/44ec3a3f-1a5a-4c14-be48-0d1fc51131bb?origin=https%3A%2F%2Flabeling-ms.turing.com&redirect_url=https%3A%2F%2Flabeling-ms.turing.com%2Fconversations%2F44735%2Fview>

#### Model code:

```python
import numpy as np
from typing import Dict, List
from sklearn.cluster import KMeans

class TreeNode:
    def __init__(self, depth: int):
        self.depth = depth
        self.split_feature = -1      # -1 means leaf
        self.split_value = None      # None means leaf
        self.cluster_distribution: Dict[int, int] = {}
        self.majority_label = None   # Used to assign a final label when leaf
        self.left_child = None
        self.right_child = None

def explainable_kmeans(data: np.ndarray, k: int, max_depth: int) -> dict:
    """
    Clusters the data into k clusters using KMeans, then refines/segments
    those assignments via a custom greedy (axis-aligned) decision tree.

    Parameters
    ----------
    data : np.ndarray
        2D array of shape (n_samples, n_features).
    k : int
        Number of clusters for the initial KMeans. Must be 1 <= k <= 15.
    max_depth : int
        Maximum depth of the decision tree. Must be >= 1.

    Returns
    -------
    dict
        A dictionary with keys:
          - refined_labels: List[int] of length n_samples with the updated cluster/segment assignment.
          - tree_structure: List[Dict], each dict describes a node in depth-first order:
                {
                  "depth": int,
                  "split_feature": int,
                  "split_value": float,
                  "cluster_distribution": Dict[int, int]
                }
          - initial_centroids: List[List[float]] of shape (k, n_features).
    """

    # ---------------------------- Input Validation ----------------------------

    if not isinstance(data, np.ndarray):
        raise ValueError("Input data must be a NumPy array.")

    if data.size == 0:
        raise ValueError("Input data is empty")

    if np.isnan(data).any():
        raise ValueError("Input contains NaNs")

    if len(data.shape) != 2:
        raise ValueError("Input data must be 2-dimensional")

    n_samples, n_features = data.shape

    if not isinstance(k, int) or k < 1 or k > 15:
        raise ValueError("k must be an integer between 1 and 15")

    if not isinstance(max_depth, int) or max_depth < 1:
        raise ValueError("max_depth must be an integer >= 1")

    # Number of unique samples check
    unique_samples = np.unique(data, axis=0).shape[0]
    if k > unique_samples:
        raise ValueError("k exceeds number of unique samples")

    # If all points identical AND k > 1, KMeans will fail to create distinct clusters
    # but to align with the requirement, let's allow KMeans to handle that scenario.
    # If the user specifically wants "All identical => return 1 cluster, empty tree",
    # they must pass k=1 in that scenario.

    # If n_samples < k as a second protective check
    if n_samples < k:
        raise ValueError("Not enough samples for k clusters")

    # ----------------------------- KMeans Clustering -----------------------------

    # Perform KMeans clustering
    # We use a fixed random_state=0 for reproducibility.
    # Note: n_init changed in sklearn >= 1.4; we set n_init explicitly to avoid warnings.
    kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto', max_iter=300)
    kmeans.fit(data)
    initial_centroids = kmeans.cluster_centers_.astype(np.float64)
    cluster_labels = kmeans.labels_

    # If k=1, there's only one cluster; no split is possible, or if all points are identical
    # we effectively have a single cluster. Then either no tree or an empty tree is returned.
    if k == 1 or unique_samples == 1:
        # All refined labels are the same
        return {
            "refined_labels": [0] * n_samples,
            "tree_structure": [],
            "initial_centroids": initial_centroids.tolist()
        }

    # ----------------------------- Precompute Feature Set -----------------------------
    # If high dimensional (>= 100 features), select the top 10 most variant features globally.
    if n_features >= 100:
        variances = np.var(data, axis=0)  # shape (n_features,)
        # get indices of top 10 features
        top_10_idx = np.argsort(variances)[::-1][:10]
        candidate_features = set(top_10_idx.tolist())
    else:
        # consider all features
        candidate_features = set(range(n_features))

    # ----------------------- Helper Functions for Decision Tree -----------------------

    def get_cluster_distribution(labels_subset: np.ndarray) -> Dict[int, int]:
        """
        Build a dictionary cluster_label -> count for the given labels_subset.
        """
        dist = {}
        for lab in labels_subset:
            if lab not in dist:
                dist[lab] = 0
            dist[lab] += 1
        return dist

    def gini_impurity(dist: Dict[int, int]) -> float:
        """
        Compute Gini impurity given a cluster distribution dictionary.
        Gini = 1 - sum_{all clusters} (p_i^2),
        where p_i = dist[cluster_label] / total.
        """
        total = sum(dist.values())
        if total == 0:
            # No samples => no impurity
            return 0.0
        gini = 1.0
        for count in dist.values():
            p = count / total
            gini -= p ** 2
        return gini

    # Build a recursive decision tree
    def build_tree(data_indices: np.ndarray, depth: int) -> TreeNode:
        """
        Recursively build the decision tree up to max_depth, selecting
        the best single-feature split (axis-aligned) to improve cluster purity.
        """
        node = TreeNode(depth)
        node_labels = cluster_labels[data_indices]  # cluster assignments at this node
        dist = get_cluster_distribution(node_labels)
        node.cluster_distribution = dist
        node_gini = gini_impurity(dist)

        # majority label in current node
        node.majority_label = max(dist.keys(), key=lambda c: dist[c])

        # Stopping conditions
        # 1) If depth == max_depth, stop.
        # 2) If all examples are from the same cluster (pure node).
        # 3) If no split can reduce impurity (no improvement).
        if depth >= max_depth:
            return node
        if len(dist) == 1:
            # Already pure, no need to split
            return node

        best_gain = 0.0
        best_feature = -1
        best_value = None
        best_left_indices = None
        best_right_indices = None

        # For each feature in candidate_features, try all possible splits
        for f in candidate_features:
            # Values of this feature for the samples in data_indices
            feature_values = data[data_indices, f]

            # If there's only 1 unique value among these samples, skip
            unique_vals = np.unique(feature_values)
            if len(unique_vals) < 2:
                continue

            # For an axis-aligned cut, we try midpoints between consecutive unique values
            # Sort unique_vals and consider midpoints
            unique_vals_sorted = np.sort(unique_vals)

            # Potential cut points: midpoints between consecutive unique values
            # e.g., mid = (u[i] + u[i+1]) / 2.0
            for i in range(len(unique_vals_sorted) - 1):

                # <Issue>: <Prompt requires all splits use np.float64 precision; 'cut' is float by default, not explicitly cast.>
                cut = 0.5 * (unique_vals_sorted[i] + unique_vals_sorted[i+1])

                left_mask = feature_values <= cut
                right_mask = feature_values > cut

                left_indices = data_indices[left_mask]
                right_indices = data_indices[right_mask]

                # If either child is empty, skip
                if len(left_indices) == 0 or len(right_indices) == 0:
                    continue

                # Calculate gini for children
                left_dist = get_cluster_distribution(cluster_labels[left_indices])
                right_dist = get_cluster_distribution(cluster_labels[right_indices])
                left_gini = gini_impurity(left_dist)
                right_gini = gini_impurity(right_dist)
                n_left = len(left_indices)
                n_right = len(right_indices)
                n_total = n_left + n_right
                weighted_child_gini = (n_left / n_total) * left_gini + (n_right / n_total) * right_gini

                gain = node_gini - weighted_child_gini
                if gain > best_gain:
                    best_gain = gain
                    best_feature = f
                    best_value = cut
                    best_left_indices = left_indices
                    best_right_indices = right_indices

        # If we found a beneficial split, apply it
        if best_gain > 1e-15:  # some small epsilon
            node.split_feature = best_feature
            node.split_value = best_value

            # Build children
            node.left_child = build_tree(best_left_indices, depth + 1)
            node.right_child = build_tree(best_right_indices, depth + 1)

        return node

    # ----------------------- Build the Tree from the Whole Dataset -----------------------
    root = build_tree(np.arange(n_samples), depth=0)

    # ----------------------- Extract Tree Structure in Depth-First Order -----------------------
    tree_structure = []

    def dfs_tree_structure(node: TreeNode):
        if node is None:
            return
        tree_structure.append({
            "depth": node.depth,
            "split_feature": node.split_feature,
            "split_value": node.split_value,
            "cluster_distribution": dict(sorted(node.cluster_distribution.items()))
        })
        if node.split_feature != -1:
            dfs_tree_structure(node.left_child)
            dfs_tree_structure(node.right_child)

    # Collect the tree nodes if the root is non-leaf
    if root.split_feature != -1:
        dfs_tree_structure(root)
    else:
        # root is a leaf => no splits
        # In that case, the tree_structure has just one node or is empty?
        # The examples show that if no splits are done, we might return an empty list.
        # But the question's examples suggest if all labels remain the same, tree_structure is empty.
        # We'll check if the root is truly a leaf. If so, that typically means no splits.
        # The requirement states "tree_structure must appear in depth first order."
        # If no splits are made, the user example (like k=1) uses an empty list.
        # So we follow that convention.

        # <Issue>: <Prompt requires that leaf nodes appear in tree_structure with split_feature=-1 and split_value=None. This code discards root if it's a leaf.>
        tree_structure = []

    
    # ----------------------- Assign Refined Labels Using the Built Tree -----------------------

    # We'll define a function that traverses the tree from the root
    # for a given sample row until it reaches a leaf node, returning the node's majority_label.
    def classify(sample: np.ndarray, node: TreeNode):
        if node.split_feature == -1:
            # Leaf node
            return node.majority_label
        else:
            if sample[node.split_feature] <= node.split_value:
                return classify(sample, node.left_child)
            else:
                return classify(sample, node.right_child)

    refined_labels = []
    for i in range(n_samples):
        refined_labels.append(classify(data[i], root))

    # ----------------------- Return the Required Dictionary -----------------------
    return {
        "refined_labels": refined_labels,
        "tree_structure": tree_structure,
        "initial_centroids": initial_centroids.tolist()
    }
```