# N-Dimensional Simplex Tree Testing

This notebook tests the SimplexTreeClassifier with data of various dimensions (3D, 4D, 5D, etc.).


In [None]:
import numpy as np
import sys
import os

current_dir = os.getcwd()
parent_dir = os.path.join(current_dir, '..', '..')
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

from sklearn.metrics import accuracy_score
from sklearn.svm import LinearSVC
from in2D.classifying.classes.simplex_tree_classifier import SimplexTreeClassifier
from in2D.classifying.classes.utilss.convexity_check import check_convexity, get_shared_vertices
from data_generators.generate_data_nd import create_simplex_vertices


## Helper function for testing any dimension


In [None]:
def test_nd_simplex_tree(dimensions, subdivision_levels=3, C=10000, EPSILON=0.30):
    print(f"\n{'='*60}")
    print(f"Testing {dimensions}D Simplex Tree")
    print(f"{'='*60}")
    
    data = np.load(f'datasets/dataset_{dimensions}d.npz')
    X, y = data['X'], data['y']
    print(f"Dataset: {X.shape[0]} points, Dimensions: {X.shape[1]}, Classes: {np.unique(y)}")
    
    root_vertices = create_simplex_vertices(dimensions)
    print(f"Root simplex has {len(root_vertices)} vertices")
    
    model = SimplexTreeClassifier(
        vertices=root_vertices,
        classifier=LinearSVC(C=C),
        subdivision_levels=subdivision_levels,
    )
    model.fit(X, y)
    print(f"Model built with {len(model.leaf_simplexes)} leaf simplices")

    weights = model.classifier.coef_[0]
    intercept = model.classifier.intercept_[0]
    min_shared = dimensions  # In ND, adjacent simplices share a (D-1)-face with D vertices
    
    nonconvex_simplex_keys = set()
    crossing_simplices = model.identify_svm_crossing_simplices()
    for i, info1 in enumerate(crossing_simplices):
        simplex1 = info1['simplex']
        for info2 in crossing_simplices[i+1:]:
            simplex2 = info2['simplex']
            shared = get_shared_vertices(simplex1, simplex2)
            if len(shared) >= min_shared:
                is_convex, avg, meeting, pt1, pt2 = check_convexity(
                    simplex1, simplex2, weights, intercept,
                    global_tree=model.tree, epsilon=EPSILON
                )
                if not is_convex:
                    nonconvex_simplex_keys.add(frozenset(simplex1.vertex_indices))
                    nonconvex_simplex_keys.add(frozenset(simplex2.vertex_indices))
    
    same_side_simplex_keys = model.find_same_side_simplices()
    combined_keys = nonconvex_simplex_keys.union(same_side_simplex_keys)
    
    print(f"\nSimplices to remove:")
    print(f"  Non-convex: {len(nonconvex_simplex_keys)}")
    print(f"  Same-side: {len(same_side_simplex_keys)}")
    print(f"  Combined: {len(combined_keys)}")
    print(f"  Leaves before: {len(model.tree.get_leaves())}")
    
    removed_count = 0
    for simplex_key in combined_keys:
        if model.tree.remove_by_leaf_key(simplex_key):
            removed_count += 1
    
    if removed_count > 0:
        model._build_node_lookup()
        model.fit(X, y)
    
    print(f"\nAfter removal:")
    print(f"  Removed: {removed_count}")
    print(f"  Leaves remaining: {len(model.tree.get_leaves())}")
    
    model.tree.print_tree(show_only_splitting_points=False)
    
    return model


In [17]:

def compute_accuracy(model, dimensions):
    data = np.load(f'datasets/dataset_{dimensions}d.npz')
    X, y = data['X'], data['y']
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred) * 100
    print(f"{dimensions}D Model Accuracy: {acc:.2f}%")
    return acc

## Test 2D

In [9]:
model_2d = test_nd_simplex_tree(dimensions=2, subdivision_levels=4)


Testing 2D Simplex Tree
Dataset: 790 points, Dimensions: 2, Classes: [-1.  1.]
Root simplex has 3 vertices
(790, 43)
Model built with 81 leaf simplices

Simplices to remove:
  Non-convex: 24
  Same-side: 39
  Combined: 63
  Leaves before: 81
(790, 43)

After removal:
  Removed: 26
  Leaves remaining: 29
└── [0] vertices: [0, 1, 2]
    ├── [1] vertices: [1, 2, 3]
    │   ├── [2] vertices: [2, 3, 4]
    │   │   ├── vertices: [3, 4, 5]
    │   │   ├── vertices: [2, 4, 5]
    │   │   └── vertices: [2, 3, 5]
    │   ├── [6] vertices: [1, 3, 4]
    │   │   ├── vertices: [3, 4, 9]
    │   │   ├── vertices: [1, 4, 9]
    │   │   └── vertices: [1, 3, 9]
    │   └── [10] vertices: [1, 2, 4]
    │       ├── vertices: [2, 4, 13]
    │       ├── vertices: [1, 4, 13]
    │       └── vertices: [1, 2, 13]
    ├── [14] vertices: [0, 2, 3]
    │   ├── [15] vertices: [2, 3, 17]
    │   │   ├── vertices: [3, 17, 18]
    │   │   ├── vertices: [2, 17, 18]
    │   │   └── vertices: [2, 3, 18]
    │   ├── [1

In [15]:
compute_accuracy(model_2d, 2)

2D Model Accuracy: 99.49%


99.49367088607595

## Test 3D

In [18]:
model_3d = test_nd_simplex_tree(dimensions=3, subdivision_levels=4)


Testing 3D Simplex Tree
Dataset: 759 points, Dimensions: 3, Classes: [-1.  1.]
Root simplex has 4 vertices
(759, 89)
Model built with 256 leaf simplices

Simplices to remove:
  Non-convex: 131
  Same-side: 80
  Combined: 211
  Leaves before: 256
(759, 89)

After removal:
  Removed: 63
  Leaves remaining: 67
└── [0] vertices: [0, 1, 2, 3]
    ├── [1] vertices: [1, 2, 3, 4]
    │   ├── [2] vertices: [2, 3, 4, 5]
    │   │   ├── vertices: [3, 4, 5, 6]
    │   │   ├── vertices: [2, 4, 5, 6]
    │   │   ├── vertices: [2, 3, 5, 6]
    │   │   └── vertices: [2, 3, 4, 6]
    │   ├── [7] vertices: [1, 3, 4, 5]
    │   │   ├── vertices: [3, 4, 5, 11]
    │   │   ├── vertices: [1, 4, 5, 11]
    │   │   ├── vertices: [1, 3, 5, 11]
    │   │   └── vertices: [1, 3, 4, 11]
    │   ├── [12] vertices: [1, 2, 4, 5]
    │   │   ├── vertices: [2, 4, 5, 16]
    │   │   ├── vertices: [1, 4, 5, 16]
    │   │   ├── vertices: [1, 2, 5, 16]
    │   │   └── vertices: [1, 2, 4, 16]
    │   └── [17] vertices: [1,

In [19]:
compute_accuracy(model_3d, 3)

3D Model Accuracy: 94.33%


94.33465085638998

## Test 4D


In [20]:
model_4d = test_nd_simplex_tree(dimensions=4, subdivision_levels=4)


Testing 4D Simplex Tree
Dataset: 751 points, Dimensions: 4, Classes: [-1.  1.]
Root simplex has 5 vertices
(751, 161)
Model built with 625 leaf simplices

Simplices to remove:
  Non-convex: 483
  Same-side: 60
  Combined: 543
  Leaves before: 625
(751, 161)

After removal:
  Removed: 125
  Leaves remaining: 125
└── [0] vertices: [0, 1, 2, 3, 4]
    ├── [1] vertices: [1, 2, 3, 4, 5]
    │   ├── [2] vertices: [2, 3, 4, 5, 6]
    │   │   ├── vertices: [3, 4, 5, 6, 7]
    │   │   ├── vertices: [2, 4, 5, 6, 7]
    │   │   ├── vertices: [2, 3, 5, 6, 7]
    │   │   ├── vertices: [2, 3, 4, 6, 7]
    │   │   └── vertices: [2, 3, 4, 5, 7]
    │   ├── [8] vertices: [1, 3, 4, 5, 6]
    │   │   ├── vertices: [3, 4, 5, 6, 13]
    │   │   ├── vertices: [1, 4, 5, 6, 13]
    │   │   ├── vertices: [1, 3, 5, 6, 13]
    │   │   ├── vertices: [1, 3, 4, 6, 13]
    │   │   └── vertices: [1, 3, 4, 5, 13]
    │   ├── [14] vertices: [1, 2, 4, 5, 6]
    │   │   ├── vertices: [2, 4, 5, 6, 19]
    │   │   ├── ver

In [21]:
compute_accuracy(model_4d, 4)

4D Model Accuracy: 72.70%


72.70306258322236

In [22]:
model_5d = test_nd_simplex_tree(dimensions=5, subdivision_levels=4)


Testing 5D Simplex Tree
Dataset: 706 points, Dimensions: 5, Classes: [-1.  1.]
Root simplex has 6 vertices
(706, 265)
Model built with 1296 leaf simplices

Simplices to remove:
  Non-convex: 1250
  Same-side: 0
  Combined: 1250
  Leaves before: 1296
(706, 265)

After removal:
  Removed: 216
  Leaves remaining: 216
└── [0] vertices: [0, 1, 2, 3, 4, 5]
    ├── [1] vertices: [1, 2, 3, 4, 5, 6]
    │   ├── [2] vertices: [2, 3, 4, 5, 6, 7]
    │   │   ├── vertices: [3, 4, 5, 6, 7, 8]
    │   │   ├── vertices: [2, 4, 5, 6, 7, 8]
    │   │   ├── vertices: [2, 3, 5, 6, 7, 8]
    │   │   ├── vertices: [2, 3, 4, 6, 7, 8]
    │   │   ├── vertices: [2, 3, 4, 5, 7, 8]
    │   │   └── vertices: [2, 3, 4, 5, 6, 8]
    │   ├── [9] vertices: [1, 3, 4, 5, 6, 7]
    │   │   ├── vertices: [3, 4, 5, 6, 7, 15]
    │   │   ├── vertices: [1, 4, 5, 6, 7, 15]
    │   │   ├── vertices: [1, 3, 5, 6, 7, 15]
    │   │   ├── vertices: [1, 3, 4, 6, 7, 15]
    │   │   ├── vertices: [1, 3, 4, 5, 7, 15]
    │   │   └──

In [None]:
compute_accuracy(model_5d, 5)

## Test with Real Dataset: Phoneme (5D)

In [23]:
def test_phoneme(subdivision_levels=3, C=1000, EPSILON=0.30):
    print(f"\n{'='*60}")
    print(f"Testing Phoneme Dataset (5D)")
    print(f"{'='*60}")
    
    data = np.load('datasets/dataset_phoneme.npz')
    X, y = data['X'], data['y']
    print(f"Dataset: {X.shape[0]} points, Dimensions: {X.shape[1]}")
    print(f"Class distribution: {dict(zip(*np.unique(y, return_counts=True)))}")
    
    dimensions = X.shape[1]
    root_vertices = create_simplex_vertices(dimensions)
    print(f"Root simplex has {len(root_vertices)} vertices")
    
    model = SimplexTreeClassifier(
        vertices=root_vertices,
        classifier=LinearSVC(C=C),
        subdivision_levels=subdivision_levels,
    )
    model.fit(X, y)
    print(f"Model built with {len(model.leaf_simplexes)} leaf simplices")
    
    y_pred = model.predict(X)
    acc = accuracy_score(y, y_pred) * 100
    print(f"Accuracy: {acc:.2f}%")
    
    weights = model.classifier.coef_[0]
    intercept = model.classifier.intercept_[0]
    min_shared = dimensions
    
    nonconvex_simplex_keys = set()
    crossing_simplices = model.identify_svm_crossing_simplices()
    for i, info1 in enumerate(crossing_simplices):
        simplex1 = info1['simplex']
        for info2 in crossing_simplices[i+1:]:
            simplex2 = info2['simplex']
            shared = get_shared_vertices(simplex1, simplex2)
            if len(shared) >= min_shared:
                is_convex, avg, meeting, pt1, pt2 = check_convexity(
                    simplex1, simplex2, weights, intercept,
                    global_tree=model.tree, epsilon=EPSILON
                )
                if not is_convex:
                    nonconvex_simplex_keys.add(frozenset(simplex1.vertex_indices))
                    nonconvex_simplex_keys.add(frozenset(simplex2.vertex_indices))
    
    same_side_simplex_keys = model.find_same_side_simplices()
    combined_keys = nonconvex_simplex_keys.union(same_side_simplex_keys)
    
    print(f"\nSimplices to remove:")
    print(f"  Non-convex: {len(nonconvex_simplex_keys)}")
    print(f"  Same-side: {len(same_side_simplex_keys)}")
    print(f"  Combined: {len(combined_keys)}")
    print(f"  Leaves before: {len(model.tree.get_leaves())}")
    
    removed_count = 0
    for simplex_key in combined_keys:
        if model.tree.remove_by_leaf_key(simplex_key):
            removed_count += 1
    
    if removed_count > 0:
        model._build_node_lookup()
        model.fit(X, y)
        y_pred_after = model.predict(X)
        acc_after = accuracy_score(y, y_pred_after) * 100
    else:
        acc_after = acc
    
    print(f"\nAfter removal:")
    print(f"  Removed: {removed_count}")
    print(f"  Leaves remaining: {len(model.tree.get_leaves())}")
    print(f"  Accuracy after: {acc_after:.2f}%")
    
    return model

model_phoneme = test_phoneme(subdivision_levels=5)


Testing Phoneme Dataset (5D)
Dataset: 5404 points, Dimensions: 5
Class distribution: {-1.0: 3818, 1.0: 1586}
Root simplex has 6 vertices
(5404, 1561)
Model built with 7776 leaf simplices
Accuracy: 73.63%


KeyboardInterrupt: 