This simply contains an example of a hard-coded proximity forest. 

In [1]:
import numpy as np
import random
from collections import Counter

# Simple distance functions
def euclidean(ts1, ts2):
    return np.linalg.norm(np.array(ts1) - np.array(ts2))

def manhattan(ts1, ts2):
    return np.sum(np.abs(np.array(ts1) - np.array(ts2)))

# Pool of distance measures
DISTANCE_MEASURES = [euclidean, manhattan]

# Node class for the Proximity Tree
class ProximityNode:
    def __init__(self, measure=None, exemplars=None, branches=None, label=None):
        self.measure = measure        # distance function
        self.exemplars = exemplars    # one exemplar per class
        self.branches = branches      # child nodes {class_label: ProximityNode}
        self.label = label            # class label if this is a leaf

# The Proximity Tree learner
class ProximityTree:
    def __init__(self, r=2):
        self.r = r  # number of candidate splits
        self.root = None

    def fit(self, X, y):
        data = list(zip(X, y))
        self.root = self._build_tree(data)

    def _build_tree(self, data):
        labels = [label for _, label in data]
        # Base case: pure node
        if len(set(labels)) == 1:
            return ProximityNode(label=labels[0])

        best_split = None
        best_impurity = -float('inf')

        # Try r candidate splits
        for _ in range(self.r):
            measure = random.choice(DISTANCE_MEASURES)
            class_exemplars = self._select_random_exemplars(data)
            split = self._split_data(data, measure, class_exemplars)

            gini_gain = self._gini_gain(labels, split)
            if gini_gain > best_impurity:
                best_impurity = gini_gain
                best_split = (measure, class_exemplars, split)

        # Create child branches
        measure, exemplars, split_data = best_split
        branches = {}
        for exemplar, subset in split_data.items():
            if subset:  # If a branch has data
                branches[exemplar] = self._build_tree(subset)

        return ProximityNode(measure=measure, exemplars=exemplars, branches=branches)

    def _select_random_exemplars(self, data):
        class_dict = {}
        for ts, label in data:
            if label not in class_dict:
                class_dict[label] = []
            class_dict[label].append(ts)
        # Random exemplar for each class
        return {label: random.choice(ts_list) for label, ts_list in class_dict.items()}

    def _split_data(self, data, measure, exemplars):
        split = {label: [] for label in exemplars}
        for ts, label in data:
            # Assign to closest exemplar
            distances = {label_: measure(ts, ex) for label_, ex in exemplars.items()}
            closest_label = min(distances, key=distances.get)
            split[closest_label].append((ts, label))
        return split

    def _gini_gain(self, parent_labels, split):
        def gini(labels):
            count = Counter(labels)
            probs = [c / len(labels) for c in count.values()]
            return 1 - sum(p**2 for p in probs)

        parent_gini = gini(parent_labels)
        total = len(parent_labels)
        weighted_child_gini = sum(
            (len(branch) / total) * gini([label for _, label in branch])
            for branch in split.values() if branch
        )
        return parent_gini - weighted_child_gini

    def predict(self, ts):
        return self._predict_node(ts, self.root)

    def _predict_node(self, ts, node):
        if node.label is not None:
            return node.label

        # Find closest exemplar
        distances = {label: node.measure(ts, ex) for label, ex in node.exemplars.items()}
        closest_label = min(distances, key=distances.get)
        return self._predict_node(ts, node.branches[closest_label])

In [2]:
class ProximityForest:
    def __init__(self, n_trees=10, r=5):
        self.n_trees = n_trees
        self.r = r
        self.trees = [ProximityTree(r) for _ in range(n_trees)]

    def fit(self, X, y):
        for tree in self.trees:
            tree.fit(X, y)

    def predict(self, ts):
        predictions = [tree.predict(ts) for tree in self.trees]
        return Counter(predictions).most_common(1)[0][0]

In [3]:
from sklearn.model_selection import train_test_split
import pandas as pd

df = pd.read_csv('../../fulldataset_ECG5000.csv', delimiter=',')

X = df.drop('label', axis=1).values
y = df['label'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f'Lengths:\n  X_train:{len(X_train)}\n  y_train:{len(y_train)}\n  X_test:{len(X_test)}\n  y_test:{len(y_test)}')

Lengths:
  X_train:4000
  y_train:4000
  X_test:1000
  y_test:1000


In [4]:
# Initialize and train a Proximity Forest
forest = ProximityForest(n_trees=5, r=20)
forest.fit(X_train, y_train)

# Predict and print results
for i, ts in enumerate(X_test):
    predicted = forest.predict(ts)
    #print(f"Test Time Series {i+1}: Predicted Class = {predicted}")

In [5]:
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# Predict on the test set
y_pred = [forest.predict(ts) for ts in X_test]

# Print Accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")

# Print Confusion Matrix
print("\nConfusion Matrix:")
print(confusion_matrix(y_test, y_pred))

# Detailed Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred))

Accuracy: 94.60%

Confusion Matrix:
[[584   0   2   0   0]
 [  4 332   5  10   2]
 [  1   5  10   2   0]
 [  1  16   0  19   2]
 [  2   1   1   0   1]]

Classification Report:
              precision    recall  f1-score   support

           1       0.99      1.00      0.99       586
           2       0.94      0.94      0.94       353
           3       0.56      0.56      0.56        18
           4       0.61      0.50      0.55        38
           5       0.20      0.20      0.20         5

    accuracy                           0.95      1000
   macro avg       0.66      0.64      0.65      1000
weighted avg       0.94      0.95      0.94      1000

