In [1]:
from pyspark.sql import SparkSession
spark = (
    SparkSession.builder
    .master("local[*]")
    .appName("GlobalProximityTree")
    .getOrCreate()
)

In [2]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]},
    {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9]},
    {'label': 1, 'time_series': [1.4, 2.0, 2.6, 3.2, 3.8]},
    {'label': 2, 'time_series': [2.2, 3.1, 4.0, 4.9, 5.8]},
    {'label': 2, 'time_series': [2.6, 3.2, 3.8, 4.4, 5.0]},
    {'label': 3, 'time_series': [1.2, 2.0, 2.8, 3.6, 4.4]},
    {'label': 3, 'time_series': [0.6, 1.3, 2.0, 2.7, 3.4]},
    {'label': 4, 'time_series': [9.3, 7.5, 6.7, 6.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.48, 8.6, 9.4, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.06, 7.5, 8.0, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 41.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 16.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.46, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 62.3, 2.7]},
    {'label': 4, 'time_series': [6.3, 6.5, 6.7, 61.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.8, 8.6, 9.04, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.0, 7.5, 0.80, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 1.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 1.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.6, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 2.3, 2.7]},
    {'label': 4, 'time_series': [0.9, 1.7, 1.2, 2.4, 2.8]}
]

df = spark.createDataFrame(tsdata)

In [3]:
from pyspark.sql.types import *
import numpy as np
from random import choice
import math, json, collections, itertools
from pyspark.sql import functions as F

df = df.withColumn("time_series", F.col("time_series").cast(ArrayType(DoubleType())))

---
- group the instances by their class label
- collect the grouped instances into a list
- shuffle the list
- select the first row in the shuffled list
- returned: list of row objects, one exemplar per class

In [4]:
exemplar_rows = (
    df
    .groupBy("label")               # one group per class
    .agg(F.shuffle(F.collect_list("time_series")).alias("bag"))
    .select("label", F.expr("bag[0]").alias("time_series"))   # first random element
    .collect()
)
print("broadcasted exemplars:", exemplar_rows)

broadcasted exemplars: [Row(label=1, time_series=[0.9, 1.8, 2.7, 3.6, 4.5]), Row(label=3, time_series=[0.3, 1.7, 1.46, 2.2, 2.6]), Row(label=2, time_series=[2.2, 3.1, 4.0, 4.9, 5.8]), Row(label=4, time_series=[9.3, 7.5, 6.7, 6.9, 7.1])]


---
- convert that list of exemplars into a dictionary

In [5]:
# Turn into a driver‑side dict {label: vector}
GLOBAL_EXEMPLARS = {row["label"]: row["time_series"] for row in exemplar_rows}
print("broadcasted exemplars:", GLOBAL_EXEMPLARS)

broadcasted exemplars: {1: [0.9, 1.8, 2.7, 3.6, 4.5], 3: [0.3, 1.7, 1.46, 2.2, 2.6], 2: [2.2, 3.1, 4.0, 4.9, 5.8], 4: [9.3, 7.5, 6.7, 6.9, 7.1]}


---
- create broadcast variable containing the exemplars' data and label

In [6]:
ex_bc = spark.sparkContext.broadcast(GLOBAL_EXEMPLARS)

# checking the broadcasted data
for i in ex_bc.value:
    print(i, ex_bc.value[i])

1 [0.9, 1.8, 2.7, 3.6, 4.5]
3 [0.3, 1.7, 1.46, 2.2, 2.6]
2 [2.2, 3.1, 4.0, 4.9, 5.8]
4 [9.3, 7.5, 6.7, 6.9, 7.1]


---
define functions for euclidean and dtw distance

In [8]:
def euclid(a, b):
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b)))

In [9]:
from dtaidistance import dtw

---
calculate which exemplar each instance is closest to

Euclidean distance

In [14]:
# ---------------------------------------------------------------------------
# 2.  spark ⇢ workers: tag every row with nearest exemplar -------------------
# ---------------------------------------------------------------------------
def tag_nearest_euclidian(row):
    """
    input: single row of a df or rdd
    function: measure distance of that the time series in the row to all the exemplars
    output: the best matching exemplar
    """
    vec   = row.time_series
    label = row.label
    best_id, best_dist = None, float("inf")
    for ex_label, ex_vec in ex_bc.value.items():
        d = euclid(vec, ex_vec)
        if d < best_dist:
            best_id, best_dist = ex_label, d
            print("best_id", best_id, "best_dist", best_dist, "ex_label", ex_label, "ex_vec", ex_vec)
    # We start with a single node (node_id = 0)
    return (0, best_id, label)        # (node_id, branch_id == exemplarLabel, true class)

# example usage
print(tag_nearest_euclidian(df.first()))

best_id 1 best_dist 2.2248595461286986 ex_label 1 ex_vec [0.9, 1.8, 2.7, 3.6, 4.5]
best_id 2 best_dist 1.30384048104053 ex_label 2 ex_vec [2.2, 3.1, 4.0, 4.9, 5.8]
(0, 2, 1)


DTW distance

In [15]:
# ---------------------------------------------------------------------------
# 2.  spark ⇢ workers: tag every row with nearest exemplar -------------------
# ---------------------------------------------------------------------------
def tag_nearest_dtw(row):
    vec   = row.time_series
    label = row.label
    best_id, best_dist = None, float("inf")
    for ex_label, ex_vec in ex_bc.value.items():
        d = dtw.distance(vec, ex_vec)
        if d < best_dist:
            best_id, best_dist = ex_label, d
    # We start with a single node (node_id = 0)
    return (0, best_id, label)        # (node_id, branch_id == exemplarLabel, true class)

# example usage
print(tag_nearest_euclidian(df.first()))

best_id 1 best_dist 2.2248595461286986 ex_label 1 ex_vec [0.9, 1.8, 2.7, 3.6, 4.5]
best_id 2 best_dist 1.30384048104053 ex_label 2 ex_vec [2.2, 3.1, 4.0, 4.9, 5.8]
(0, 2, 1)


also returns the distance from instance to closest exemplar

In [16]:
def tag_nearest_euclid_debug(row):
    vec   = row.time_series
    label = row.label
    best_id, best_dist = None, float("inf")
    for ex_label, ex_vec in ex_bc.value.items():
        d = euclid(vec, ex_vec)
        if d < best_dist:
            best_id, best_dist = ex_label, d
    # ► We send the debug fields back to the driver as extra columns
    return (0,            # node_id  (root)
            best_id,      # branch_id   == exemplar chosen
            label,        # true label
            best_dist)    # DEBUG: distance to that exemplar

# example usage
print(tag_nearest_euclid_debug(df.first()))

(0, 2, 1, 1.30384048104053)


In [18]:
# ---------------------------------------------------------------------------
# 2.  spark ⇢ workers: tag every row with nearest exemplar -------------------
# ---------------------------------------------------------------------------
def tag_nearest_dtw_debug(row):
    vec   = row.time_series
    label = row.label
    best_id, best_dist = None, float("inf")
    for ex_label, ex_vec in ex_bc.value.items():
        d = dtw.distance(vec, ex_vec)
        if d < best_dist:
            best_id, best_dist = ex_label, d
    # We start with a single node (node_id = 0)
    return (0, best_id, label, best_dist)        # (node_id, branch_id == exemplarLabel, true class)

# example usage
print(tag_nearest_dtw_debug(df.first()))

(0, 2, 1, 1.2247448713915892)


---
calculate which exemplar each instance is closest to, done in parallel in each partition

In [19]:
from pyspark.sql import types as T, functions as F

schema = (T.StructType()
            .add("node_id",     T.IntegerType())
            .add("branch_id",   T.IntegerType())
            .add("true_label",  T.IntegerType())
            .add("dist_calc",   T.DoubleType()))        # debug column

tagged = df.rdd.map(tag_nearest_euclid_debug).toDF(schema)

tagged.show(truncate=False)

+-------+---------+----------+-------------------+
|node_id|branch_id|true_label|dist_calc          |
+-------+---------+----------+-------------------+
|0      |2        |1         |1.30384048104053   |
|0      |1        |1         |0.38729833462074165|
|0      |1        |1         |0.0                |
|0      |1        |1         |0.9486832980505139 |
|0      |1        |1         |0.6855654600401044 |
|0      |2        |2         |1.4662878298615183 |
|0      |2        |2         |1.3784048752090219 |
|0      |2        |2         |2.037154878746336  |
|0      |1        |3         |0.5477225575051663 |
|0      |1        |3         |1.4317821063276353 |
|0      |4        |4         |5.263078946776308  |
|0      |4        |4         |3.563705936241093  |
|0      |3        |1         |0.9239047569960877 |
|0      |1        |1         |0.9486832980505134 |
|0      |3        |1         |0.8096912991998864 |
|0      |2        |2         |1.4832396974191324 |
|0      |2        |2         |0

In [20]:
# Example: only the rows that picked exemplar 1
tagged.where(F.col("branch_id") == 1).show()

+-------+---------+----------+-------------------+
|node_id|branch_id|true_label|          dist_calc|
+-------+---------+----------+-------------------+
|      0|        1|         1|0.38729833462074165|
|      0|        1|         1|                0.0|
|      0|        1|         1| 0.9486832980505139|
|      0|        1|         1| 0.6855654600401044|
|      0|        1|         3| 0.5477225575051663|
|      0|        1|         3| 1.4317821063276353|
|      0|        1|         1| 0.9486832980505134|
|      0|        1|         3|0.38729833462074165|
|      0|        1|         1| 1.3784048752090217|
|      0|        1|         1| 0.9746794344808964|
|      0|        1|         3| 0.3872983346207414|
+-------+---------+----------+-------------------+



node_id – which node of the tree the row is currently at (starts at 0, the root).

branch_id – ID of the exemplar (here equal to its class label) to which this row is closest.

true_label – the actual class label of the row.

In [None]:
def histogram(tagged_df, open_nodes, debug=False):
    """
    3a — Count #rows for every (node_id, branch_id, true_label) among the open nodes.
    Returns a tiny nested dict:
      { node_id: { branch_id: { true_label: count, … }, … }, … }
    """
    if not open_nodes:
        return {}

    # filter to only the node_ids we're actually expanding
    filtered = tagged_df.where(F.col("node_id").isin(open_nodes))
    counts   = (
      filtered
        .groupBy("node_id", "branch_id", "true_label")
        .count()
    )

    if debug:
        print(">>> FILTERED ROWS FOR THESE open_nodes:", open_nodes)
        filtered.show(2, truncate=False)
        print(">>> AGGREGATED COUNTS:")
        counts.show(2, truncate=False)
        counts.printSchema()

    # collect to driver (should be very small!)
    hist = {}
    for r in counts.collect():
        hist.setdefault(r.node_id, {}) \
            .setdefault(r.branch_id, {})[r.true_label] = r["count"]
    return hist

# example usage
open_nodes = [0] # root node
print("open_nodes", open_nodes)
hist = histogram(tagged, open_nodes, debug=True)
print("histogram", hist)


open_nodes [0]
>>> FILTERED ROWS FOR THESE open_nodes: [0]
+-------+---------+----------+-------------------+
|node_id|branch_id|true_label|dist_calc          |
+-------+---------+----------+-------------------+
|0      |2        |1         |1.30384048104053   |
|0      |1        |1         |0.38729833462074165|
+-------+---------+----------+-------------------+
only showing top 2 rows

>>> AGGREGATED COUNTS:
+-------+---------+----------+-----+
|node_id|branch_id|true_label|count|
+-------+---------+----------+-----+
|0      |1        |1         |7    |
|0      |2        |1         |1    |
+-------+---------+----------+-----+
only showing top 2 rows

root
 |-- node_id: integer (nullable = true)
 |-- branch_id: integer (nullable = true)
 |-- true_label: integer (nullable = true)
 |-- count: long (nullable = false)

histogram {0: {1: {1: 7, 3: 4}, 2: {1: 1, 2: 8}, 4: {4: 10, 1: 1}, 3: {1: 4, 3: 3, 2: 1, 4: 2}}}


^^^ at node 0:

Branch 1 contains 7 instances of label 1 and 4 instances of label 3

Branch 2 contains 1 instance of label 1 and 8 instances of label 2

Branch 4 contains 10 instances of label 4 and 1 instance of label 1

Branch 3 contains 4 instances of label 1, 3 instances of label 3, 1 instance of label 2, and 2 instances of label 4

[branch 3 = the instances who were closest to an exemplar with label 3]

In [None]:
def split_node(nid, branches, tree):
    """
    3b — Given one node-id and its branch histogram, append children
    to the global `tree` dict and return the set of *impure* children
    that must be processed in the next layer.
    """
    next_open = set()

    for br_id, cls_count in branches.items():
        if len(cls_count) == 1:                 # ——— pure → leaf
            pred = next(iter(cls_count))
            leaf = TreeNode(
                node_id=len(tree),
                parent_id=nid,
                split_on=None,
                is_leaf=True,
                prediction=pred,
                children={}
            )
            tree[leaf.node_id]      = leaf
            tree[nid].children[br_id] = leaf.node_id

        else:                                   # ——— impure → internal
            child_id = len(tree)
            twin = TreeNode(child_id, nid, None, False, None, {})
            tree[child_id] = twin
            tree[nid].children[br_id] = child_id
            next_open.add(child_id)

    # mark parent as decided
    tree[nid] = tree[nid]._replace(split_on="nearest_exemplar")
    return next_open

In [19]:
def calculate_gini(labels):
    if not labels:
        return 0.0
    counts = collections.Counter(labels)
    total = sum(counts.values())
    return 1.0 - sum((cnt/total)**2 for cnt in counts.values())

In [20]:
def split_node_gini(nid, branches, tree, depth, max_depth, min_samples):
    # 1) compute parent Gini
    parent_counts = collections.Counter()
    for bc in branches.values():
        parent_counts.update(bc)
    labels_parent = list(parent_counts.elements())
    gini_parent  = calculate_gini(labels_parent)

    # 2) stop if too few or too deep
    if len(labels_parent) < min_samples or depth >= max_depth:
        majority = parent_counts.most_common(1)[0][0]
        tree[nid] = tree[nid]._replace(
            is_leaf=True,
            prediction=majority,
            split_on=None,
            gini_parent=gini_parent
        )
        return set()

    # 3) stop if already pure
    if gini_parent == 0.0:
        tree[nid] = tree[nid]._replace(
            is_leaf=True,
            prediction=labels_parent[0],
            split_on=None,
            gini_parent=gini_parent
        )
        return set()

    # 4) compute weighted‑child Gini
    total = len(labels_parent)
    weighted = 0.0
    for bc in branches.values():
        child_labels = sum(([lbl]*cnt for lbl,cnt in bc.items()), [])
        weighted += (len(child_labels)/total) * calculate_gini(child_labels)

    # 5) stop if no improvement
    if weighted >= gini_parent:
        majority = parent_counts.most_common(1)[0][0]
        tree[nid] = tree[nid]._replace(
            is_leaf=True,
            prediction=majority,
            split_on=None,
            gini_parent=gini_parent
        )
        return set()

    # 6) otherwise create children
    next_open = set()
    tree[nid] = tree[nid]._replace(
        split_on="nearest_exemplar",
        gini_parent=gini_parent
    )
    for br_id, bc in branches.items():
        child_labels = sum(([lbl]*cnt for lbl,cnt in bc.items()), [])
        if calculate_gini(child_labels) == 0.0:
            # pure → leaf
            leaf = TreeNode(
                node_id    = len(tree),
                parent_id  = nid,
                split_on   = None,
                is_leaf    = True,
                prediction = child_labels[0],
                children   = {},
                gini_parent=gini_parent
            )
            tree[leaf.node_id] = leaf
            tree[nid].children[br_id] = leaf.node_id
        else:
            # impure → internal
            cid = len(tree)
            internal = TreeNode(
                node_id    = cid,
                parent_id  = nid,
                split_on   = None,  # will be set when we visit it
                is_leaf    = False,
                prediction = None,
                children   = {},
                gini_parent=gini_parent
            )
            tree[cid] = internal
            tree[nid].children[br_id] = cid
            next_open.add(cid)

    return next_open

In [21]:
def push_rows_down(tagged_df, tree, open_nodes, next_open, schema):
    """
    3 c — Rows that belonged to a parent now get the child node‑id
    (narrow map, no shuffle).  Returns a **new** DataFrame.
    """
    if not next_open:
        return tagged_df                       # nothing to change

    mapping = {                                 # (parent,branch) ➜ child
        (nid, br): child_id
        for nid in open_nodes
        for br, child_id in tree[nid].children.items()
        if child_id in next_open
    }
    bc = spark.sparkContext.broadcast(mapping)
    
    def _push(r):
        key   = (r.node_id, r.branch_id)
        new_n = bc.value.get(key, r.node_id)

        return (
            new_n,               # node_id
            r.time_series,       # carry the series through
            r.branch_id,
            r.true_label,
            r.dist_calc
        )

    return tagged_df.rdd.map(_push).toDF(schema)

In [22]:
import pandas as pd

def hist_to_dataframe(hist):
    rows = []
    for node_id, branches in hist.items():
        for br_id, cls_dict in branches.items():
            for lbl, cnt in cls_dict.items():
                rows.append((node_id, br_id, lbl, cnt))
    return pd.DataFrame(rows,
                        columns=["node_id", "branch_id", "true_label", "count"])

In [23]:
from pprint import pprint
import random, collections

assign_df = df.rdd \
    .map(lambda r: (0, r.time_series, r.label)) \
    .toDF(["node_id","time_series","true_label"]) \
    .cache()
#
# Schema for the “tagged” DataFrame

tagged_schema = StructType([
    StructField("node_id",     IntegerType(),           False),
    StructField("time_series", ArrayType(DoubleType()), False),
    StructField("branch_id",   IntegerType(),           False),
    StructField("true_label",  IntegerType(),           False),
    StructField("dist_calc",   DoubleType(),            False),
])

TreeNode  = collections.namedtuple(
    "TreeNode",
    "node_id parent_id split_on is_leaf prediction children gini_parent".split()
)

tree = {
    0: TreeNode(
        node_id         = 0,
        parent_id       = None,
        split_on        = None,
        is_leaf         = False,
        prediction      = None,
        children        = {},
        gini_parent     = None
        
    )
}
open_nodes = {0}
max_depth  = 5
min_samples =5


for depth in range(max_depth):

    # ---------------------------------------------------------- 3 a
    if not open_nodes:                # nothing more to grow → stop
        break
    
 # ──────────── 1) SAMPLE local exemplars ────────────
    exemplars = {}
    # pull out only the rows at the currently open nodes
    for row in assign_df.filter(F.col("node_id").isin(open_nodes)).collect():
        node, series, lbl = row
        exemplars.setdefault((node, lbl), []).append(series)
    
    # for each (node, class) pick one exemplar at random
    exemplars = { k: random.choice(v) for k, v in exemplars.items() }
    ex_bc = spark.sparkContext.broadcast(exemplars)

    # ──────────── 2) TAG every row with nearest exemplar ────────────
    def tag_row(r):
        nid, vec, true_lbl = r
        best_branch, best_dist = None, float("inf")
        for (node_key, ex_lbl), ex_vec in ex_bc.value.items():
            if node_key != nid: 
                continue
            d = dtw.distance(vec, ex_vec)
            if d < best_dist:
                best_branch, best_dist = ex_lbl, d
        if best_branch is None:
            raise RuntimeError(f"No exemplar for node {nid}")
        return (nid, vec, best_branch, true_lbl, best_dist)
        
  
    
    tagged = (
        assign_df.rdd
        .map(tag_row)
        .toDF(tagged_schema)
        .cache()
    )

    # ──────────── 3) BUILD histogram & decide splits ────────────
    #  Build your histogram, split_node_gini, push_rows_down exactly as before,
    #    but operating on `tagged`.
    
    hist = histogram(tagged, open_nodes, debug=(depth == 0))
    if not hist:
        print(f"No more splits at depth={depth}, stopping.")
        break

    # ----- (optional) pretty print for debugging ---------------
    import pprint
    print(f"\n=== depth={depth}, open_nodes={open_nodes} ===")
    pprint.pprint(hist)
    df_hist = hist_to_dataframe(hist)      # convert **once**
    display(df_hist)                       # or  print(df_hist)

    # ----------------------------------------------------------  decide splits
    next_open = set()
    for nid in open_nodes:
        children = split_node_gini(
            nid,
            hist[nid],      # branches for this node
            tree,
            depth,
            max_depth,
            min_samples
        )
        next_open |= children              # union of new internal nodes

    # ──────────── 4) PUSH rows down to child node_ids ────────────
    
    assign_df = (
        push_rows_down(tagged, tree, open_nodes, next_open, tagged_schema)
        .select("node_id","time_series","true_label")
        .cache()
    )
                
    
    open_nodes = next_open                # next layer
    
    assign_df = assign_df.where(F.col("node_id").isin(next_open)).cache()


>>> FILTERED ROWS FOR THESE open_nodes: {0}
+-------+-------------------------+---------+----------+------------------+
|node_id|time_series              |branch_id|true_label|dist_calc         |
+-------+-------------------------+---------+----------+------------------+
|0      |[1.2, 2.4, 3.6, 4.8, 6.0]|2        |1         |2.0199009876724157|
|0      |[1.0, 1.8, 2.6, 3.4, 4.2]|3        |1         |1.1224972160321824|
+-------+-------------------------+---------+----------+------------------+
only showing top 2 rows

>>> AGGREGATED COUNTS:
+-------+---------+----------+-----+
|node_id|branch_id|true_label|count|
+-------+---------+----------+-----+
|0      |2        |1         |1    |
|0      |3        |1         |7    |
+-------+---------+----------+-----+
only showing top 2 rows

root
 |-- node_id: integer (nullable = false)
 |-- branch_id: integer (nullable = false)
 |-- true_label: integer (nullable = false)
 |-- count: long (nullable = false)


=== depth=0, open_nodes={0} ===
{0

Unnamed: 0,node_id,branch_id,true_label,count
0,0,2,1,1
1,0,2,2,7
2,0,2,4,1
3,0,3,1,7
4,0,3,3,4
5,0,3,2,1
6,0,4,4,9
7,0,4,1,1
8,0,1,1,4
9,0,1,3,3



=== depth=1, open_nodes={1, 2, 3, 4} ===
{1: {1: {1: 1, 2: 3}, 2: {2: 4}, 4: {4: 1}},
 2: {1: {1: 3, 3: 1}, 2: {2: 1}, 3: {1: 4, 3: 3}},
 3: {1: {1: 1, 4: 7}, 4: {4: 2}},
 4: {1: {1: 3, 3: 1}, 2: {1: 1, 2: 1, 4: 1}, 3: {3: 2}, 4: {4: 1}}}


Unnamed: 0,node_id,branch_id,true_label,count
0,1,1,1,1
1,1,1,2,3
2,1,2,2,4
3,1,4,4,1
4,2,3,1,4
5,2,3,3,3
6,2,1,1,3
7,2,1,3,1
8,2,2,2,1
9,3,1,4,7



=== depth=2, open_nodes={5, 8, 9, 11, 13, 15} ===
{5: {1: {1: 1, 2: 2}, 2: {2: 1}},
 8: {1: {1: 3, 3: 1}, 3: {1: 1, 3: 2}},
 9: {1: {1: 3}, 3: {3: 1}},
 11: {1: {1: 1}, 4: {4: 7}},
 13: {1: {1: 3}, 3: {3: 1}},
 15: {1: {1: 1}, 2: {2: 1}, 4: {4: 1}}}


Unnamed: 0,node_id,branch_id,true_label,count
0,8,1,1,3
1,8,1,3,1
2,8,3,1,1
3,8,3,3,2
4,5,1,1,1
5,5,1,2,2
6,5,2,2,1
7,9,1,1,3
8,9,3,3,1
9,11,4,4,7



=== depth=3, open_nodes={17, 18} ===
{17: {1: {1: 1}, 3: {1: 2, 3: 1}}, 18: {1: {1: 1, 3: 1}, 3: {3: 1}}}


Unnamed: 0,node_id,branch_id,true_label,count
0,17,3,1,2
1,17,3,3,1
2,17,1,1,1
3,18,1,1,1
4,18,1,3,1
5,18,3,3,1


In [41]:
print("Final tree:")
pprint.pprint(tree)

Final tree:
{0: TreeNode(node_id=0, parent_id=None, split_on='nearest_exemplar', is_leaf=False, prediction=None, children={2: 1, 3: 2, 4: 3, 1: 4}, gini_parent=0.7364663890541344),
 1: TreeNode(node_id=1, parent_id=0, split_on='nearest_exemplar', is_leaf=False, prediction=None, children={1: 5, 2: 6, 4: 7}, gini_parent=0.37037037037037024),
 2: TreeNode(node_id=2, parent_id=0, split_on='nearest_exemplar', is_leaf=False, prediction=None, children={3: 8, 1: 9, 2: 10}, gini_parent=0.5416666666666666),
 3: TreeNode(node_id=3, parent_id=0, split_on='nearest_exemplar', is_leaf=False, prediction=None, children={1: 11, 4: 12}, gini_parent=0.17999999999999994),
 4: TreeNode(node_id=4, parent_id=0, split_on='nearest_exemplar', is_leaf=False, prediction=None, children={1: 13, 3: 14, 2: 15, 4: 16}, gini_parent=0.7),
 5: TreeNode(node_id=5, parent_id=1, split_on=None, is_leaf=True, prediction=2, children={}, gini_parent=0.375),
 6: TreeNode(node_id=6, parent_id=1, split_on=None, is_leaf=True, predic

In [51]:
# 1) Suppose your test data is in the same format as tsdata:
test_data = [
    # Class 1 (≈0→5)
    {'label': 1, 'time_series': [0.55, 1.15, 1.75, 2.35, 2.95]},
    {'label': 1, 'time_series': [1.05, 2.05, 3.05, 4.05, 5.05]},
    {'label': 1, 'time_series': [0.10, 1.00, 2.10, 3.00, 4.10]},
    {'label': 1, 'time_series': [0.80, 1.60, 2.40, 3.20, 4.00]},
    {'label': 1, 'time_series': [0.30, 1.20, 2.10, 3.00, 3.90]},

    # Class 2 (≈2→7)
    {'label': 2, 'time_series': [2.05, 3.25, 4.45, 5.65, 6.85]},
    {'label': 2, 'time_series': [1.95, 2.95, 3.95, 4.95, 5.95]},
    {'label': 2, 'time_series': [2.20, 3.20, 4.20, 5.20, 6.20]},
    {'label': 2, 'time_series': [2.50, 3.50, 4.50, 5.50, 6.50]},
    {'label': 2, 'time_series': [1.80, 2.80, 3.80, 4.80, 5.80]},

    # Class 3 (≈0.5→4)
    {'label': 3, 'time_series': [0.45, 1.45, 2.45, 3.45, 4.45]},
    {'label': 3, 'time_series': [0.95, 1.95, 2.95, 3.95, 4.95]},
    {'label': 3, 'time_series': [0.60, 1.60, 2.60, 3.60, 4.60]},
    {'label': 3, 'time_series': [0.30, 1.30, 2.30, 3.30, 4.30]},
    {'label': 3, 'time_series': [0.75, 1.75, 2.75, 3.75, 4.75]},

    # Class 4 (≈5.5→10)
    {'label': 4, 'time_series': [5.55, 6.55, 7.55, 8.55, 9.55]},
    {'label': 4, 'time_series': [6.05, 6.55, 7.05, 7.55, 8.05]},
    {'label': 4, 'time_series': [5.80, 6.80, 7.80, 8.80, 9.80]},
    {'label': 4, 'time_series': [5.20, 6.20, 7.20, 8.20, 9.20]},
    {'label': 4, 'time_series': [6.30, 7.30, 8.30, 9.30, 10.30]},
]

def predict_root(series):
    # pick the class whose exemplar at node 0 is closest
    return min(
        GLOBAL_EXEMPLARS.keys(),
        key=lambda lbl: dtw.distance(series, GLOBAL_EXEMPLARS[lbl])
    )
    
# 2) Run predictions and tally accuracy:
correct = 0
for row in test_data:
    true = row['label']
    pred = predict_root(row['time_series'])
    print(f"True: {true:>2}  →  Pred: {pred}")
    if pred == true:
        correct += 1

print(f"\nAccuracy: {correct}/{len(test_data)} = {correct/len(test_data):.2%}")


True:  1  →  Pred: 1
True:  1  →  Pred: 4
True:  1  →  Pred: 4
True:  1  →  Pred: 4
True:  1  →  Pred: 4
True:  2  →  Pred: 4
True:  2  →  Pred: 4
True:  2  →  Pred: 4
True:  2  →  Pred: 4
True:  2  →  Pred: 4
True:  3  →  Pred: 4
True:  3  →  Pred: 4
True:  3  →  Pred: 4
True:  3  →  Pred: 4
True:  3  →  Pred: 4
True:  4  →  Pred: 4
True:  4  →  Pred: 4
True:  4  →  Pred: 4
True:  4  →  Pred: 4
True:  4  →  Pred: 4

Accuracy: 6/20 = 30.00%


In [24]:
tagged.printSchema()
tagged.select("node_id", "branch_id", "true_label").summary("min", "max").show()


root
 |-- node_id: integer (nullable = false)
 |-- time_series: array (nullable = false)
 |    |-- element: double (containsNull = true)
 |-- branch_id: integer (nullable = false)
 |-- true_label: integer (nullable = false)
 |-- dist_calc: double (nullable = false)

+-------+-------+---------+----------+
|summary|node_id|branch_id|true_label|
+-------+-------+---------+----------+
|    min|     17|        1|         1|
|    max|     18|        3|         3|
+-------+-------+---------+----------+



In [25]:
import pandas as pd

def hist_to_dataframe(hist):
    rows = []
    for node_id, branches in hist.items():
        for br_id, cls_dict in branches.items():
            for lbl, cnt in cls_dict.items():
                rows.append((node_id, br_id, lbl, cnt))
    return pd.DataFrame(rows,
                        columns=["node_id", "branch_id", "true_label", "count"])

df_hist = hist_to_dataframe(histogram(tagged, open_nodes))
display(df_hist)          # in a notebook – or print(df_hist)

Unnamed: 0,node_id,branch_id,true_label,count


In [22]:
!pip install graphviz

Collecting graphviz
  Downloading graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)
Downloading graphviz-0.20.3-py3-none-any.whl (47 kB)
Installing collected packages: graphviz
Successfully installed graphviz-0.20.3


In [23]:
# ------------------------------------------------------------------
# helper: pretty‑print / draw the learned tree
# ------------------------------------------------------------------
def show_tree(tree: dict, root_id: int = 0, graphviz: bool = False,
              file_name: str = "proximity_tree") -> None:
    """
    Parameters
    ----------
    tree       : the dict that the loop filled  {node_id : TreeNode}
    root_id    : normally 0
    graphviz   : draw a .png with Graphviz in addition to console print
    file_name  : base name for the .dot and .png files
    """
    # ---- 1.  ASCII print -----------------------------------------
    def _ascii(nid: int, indent: str = ""):
        node = tree[nid]
        if node.is_leaf:
            print(f"{indent}*[{nid}]  LEAF  → predict class {node.prediction}")
        else:
            print(f"{indent}*[{nid}]  split = nearest_exemplar")
            for br_id, child_id in node.children.items():
                print(f"{indent}  ├─ branch {br_id}")
                _ascii(child_id, indent + "  │   ")

    print("\n===== Proximity‑Tree (depth ≤ {}) =====".format(max_depth))
    _ascii(root_id)
    print("========================================\n")

    # ---- 2.  optional Graphviz (.png) -----------------------------
    if graphviz:
        try:
            import graphviz                              # pip install graphviz
            dot = graphviz.Digraph(format="png")
            for nid, node in tree.items():
                label = f"{nid}\\nleaf→{node.prediction}" if node.is_leaf \
                        else str(nid)
                shape = "box" if node.is_leaf else "ellipse"
                dot.node(str(nid), label, shape=shape)

            for nid, node in tree.items():
                for br, child in node.children.items():
                    dot.edge(str(nid), str(child), label=str(br))

            dot.render(file_name, cleanup=True)
            print(f"Graphviz output written to {file_name}.png")

        except ImportError:
            print("‑‑ Graphviz not installed; skipped picture generation ‑‑")


In [24]:
# … your breadth‑first loop finished …

show_tree(tree, graphviz=False)             # just ASCII
# show_tree(tree, graphviz=True)            # ASCII + PNG (needs `pip install graphviz`)


NameError: name 'tree' is not defined

In [29]:
def print_tree_with_points(tree, tagged_df, max_rows=5):
    """
    Pretty‑print the tree and, under each node, show up to `max_rows`
    example rows that sit in that node (taken from the *current* tagged DF).
    """

    # ------------------------------------------------------------------ #
    # 1. helper to fetch rows of one node (returns a list of dicts)       #
    # ------------------------------------------------------------------ #
    def sample_rows(node_id, n=max_rows):
        rows = (
            tagged_df.where(F.col("node_id") == node_id)
                     .limit(n)                         # <‑‑ avoid huge output
                     .collect()
        )
        return [r.asDict() for r in rows]

    # ------------------------------------------------------------------ #
    # 2. breadth‑first walk & print                                       #
    # ------------------------------------------------------------------ #
    queue = collections.deque([(0, 0, None)])  # (node_id, depth, via_branch)

    while queue:
        nid, depth, via_br = queue.popleft()
        node = tree[nid]

        indent = "  " * depth
        branch_info = f"[from branch {via_br}]" if via_br is not None else ""
        leaf_flag   = "🌿" if node.is_leaf else ""
        print(f"{indent}{leaf_flag}Node {nid} {branch_info}")

        # ------------- show a few rows ---------------------------------
        for row in sample_rows(nid):
            print(f"{indent}   · {row}")

        # ------------- enqueue children -------------------------------
        for br_id, child_id in sorted(node.children.items()):
            queue.append((child_id, depth + 1, br_id))

# -----------------------------------------------------------------------
# call it right after the training loop (when `tagged` & `tree` exist)
print_tree_with_points(tree, tagged, max_rows=3)


Node 0 
  Node 4 [from branch 1]
  Node 1 [from branch 2]
  Node 2 [from branch 3]
  Node 3 [from branch 4]
    🌿Node 13 [from branch 1]
    🌿Node 15 [from branch 2]
    🌿Node 14 [from branch 3]
    🌿Node 16 [from branch 4]
    🌿Node 5 [from branch 1]
    🌿Node 6 [from branch 2]
    🌿Node 7 [from branch 4]
    🌿Node 9 [from branch 1]
    🌿Node 10 [from branch 2]
    Node 8 [from branch 3]
    Node 11 [from branch 1]
    🌿Node 12 [from branch 4]
      🌿Node 17 [from branch 1]
         · {'node_id': 17, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'branch_id': 3, 'true_label': 1, 'dist_calc': 0.0}
         · {'node_id': 17, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0], 'branch_id': 1, 'true_label': 1, 'dist_calc': 0.0}
         · {'node_id': 17, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'branch_id': 3, 'true_label': 3, 'dist_calc': 0.0}
      🌿Node 18 [from branch 3]
         · {'node_id': 18, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9], 'branch_id': 1, 'true_label': 1, 'dist_calc': 0.0}
         

In [None]:
import collections
import pyspark.sql.functions as F


# ╔══════════════════════════════════════════════════════════════════╗
# ║  Pretty‑print the tree AND show a few sample rows per (node,     ║
# ║  branch) so you can see exactly which points flow where.         ║
# ╚══════════════════════════════════════════════════════════════════╝
def print_tree_with_points(tree, tagged_df, max_rows=5, show_branch_rows=True):
    """
    Print every node top‑down; under each node print up to `max_rows`
    sample rows that currently sit *inside that node*.

    If `show_branch_rows=True` we also indent once more and show a few
    rows *per branch* before they flow into their child node – handy to
    verify the “routing” in an internal node.
    """

    # ───────────────────────────────────────────────────────────
    # helpers
    # ───────────────────────────────────────────────────────────
    def sample_rows(node_id, branch_id=None, n=max_rows):
        """
        Return at most `n` Row‑dicts that are inside *node_id*.
        If `branch_id` is not None, we additionally filter that branch.
        """
        cond = (F.col("node_id") == node_id)
        if branch_id is not None:
            cond &= (F.col("branch_id") == branch_id)

        return (tagged_df
                .where(cond)
                .limit(n)              # avoid huge output
                .collect())

    # ───────────────────────────────────────────────────────────
    # breadth‑first traversal (queue holds tuples)
    # (node_id , depth , incoming_branch_id)
    # ───────────────────────────────────────────────────────────
    queue = collections.deque([(0, 0, None)])

    while queue:
        nid, depth, via_branch = queue.popleft()
        node = tree[nid]

        # ---------- headline for the node ----------
        indent  = "  " * depth
        leaf_fl = "🌿" if node.is_leaf else "├"
        heading = f"{indent}{leaf_fl} node {nid}"
        if via_branch is not None:
            heading += f"  (arrived via branch {via_branch})"
        if node.is_leaf:
            heading += f"  ➜ predict label {node.prediction}"
        print(heading)

        # ---------- some sample rows that live in *this* node ----------
        for r in sample_rows(nid):
            print(f"{indent}     · {r.asDict()}")

        # ---------- enqueue / optionally show per‑branch samples ----------
        for br_id, child_id in sorted(node.children.items()):
            if show_branch_rows and not node.is_leaf:
                sub_indent = indent + "  "
                print(f"{sub_indent}branch {br_id} → child {child_id}")
                for r in sample_rows(nid, br_id):
                    print(f"{sub_indent}   · {r.asDict()}")
            queue.append((child_id, depth + 1, br_id))

        print()            # blank line between siblings for readability


In [31]:
print_tree_with_points(tree, tagged, max_rows=3)


├ node 0
  branch 1 → child 4
  branch 2 → child 1
  branch 3 → child 2
  branch 4 → child 3

  ├ node 4  (arrived via branch 1)
    branch 1 → child 13
    branch 2 → child 15
    branch 3 → child 14
    branch 4 → child 16

  ├ node 1  (arrived via branch 2)
    branch 1 → child 5
    branch 2 → child 6
    branch 4 → child 7

  ├ node 2  (arrived via branch 3)
    branch 1 → child 9
    branch 2 → child 10
    branch 3 → child 8

  ├ node 3  (arrived via branch 4)
    branch 1 → child 11
    branch 4 → child 12

    🌿 node 13  (arrived via branch 1)  ➜ predict label 1

    🌿 node 15  (arrived via branch 2)  ➜ predict label 2

    🌿 node 14  (arrived via branch 3)  ➜ predict label 3

    🌿 node 16  (arrived via branch 4)  ➜ predict label 4

    🌿 node 5  (arrived via branch 1)  ➜ predict label 2

    🌿 node 6  (arrived via branch 2)  ➜ predict label 2

    🌿 node 7  (arrived via branch 4)  ➜ predict label 4

    🌿 node 9  (arrived via branch 1)  ➜ predict label 1

    🌿 node 10  (arr

In [32]:
def print_tree_sideways(tree, node_id=0, indent=0, gap=6):
    """
    Recursively prints the whole Proximity‑Tree sideways.
    ─  root is on the *left*, leaves on the *right*.
    ─  ‘In‑order’ = visit left‑half children ▸ node ▸ right‑half children
      (if you have more than two children we simply take the lower‑index
       half as the “left” side).

    Parameters
    ----------
    tree     dict {node_id ➜ TreeNode}
    node_id  int   id of the current node (start with the root, 0)
    indent   int   how many spaces to shift this subtree to the right
    gap      int   horizontal spacing between levels (tweak to taste)
    """
    node = tree[node_id]

    # sort branch labels so the picture is deterministic
    children = sorted(node.children.items())           # [(br, child_id) …]

    # **split** children into a ‘left’ and ‘right’ half for in‑order
    mid = len(children) // 2
    left  = children[:mid]
    right = children[mid:]

    # ---------- print RIGHT half first (so it ends up *above* the parent)
    for br, c_id in right:
        print_tree_sideways(tree, c_id, indent + gap, gap)

    # ---------- print the current node
    label = f"[{node.node_id}]"
    if node.is_leaf:
        label += f"⟶{node.prediction}"
    print(" " * indent + label)

    # ---------- print LEFT half afterwards (goes *below* the parent)
    for br, c_id in left:
        print_tree_sideways(tree, c_id, indent + gap, gap)


In [33]:
print_tree_sideways(tree)        # `tree` is the dict you built earlier


            [10]⟶2
                  [18]⟶3
            [8]
                  [17]⟶1
      [2]
            [9]⟶1
            [12]⟶4
      [3]
                  [19]⟶4
            [11]
                  [20]⟶1
[0]
            [14]⟶3
            [16]⟶4
      [4]
            [13]⟶1
            [15]⟶2
            [6]⟶2
            [7]⟶4
      [1]
            [5]⟶2


In [34]:
import math
from collections import deque

# ------------------------------------------------------------
# pretty‑printer  (root on top, breadth‑first levels)
# ------------------------------------------------------------
def print_tree_topdown(tree, root_id=0, gap=3):
    """
    Nicely print a Proximity‑Tree in 2‑D (root at top).

    Parameters
    ----------
    tree     dict { node_id -> TreeNode }
    root_id  int   where to start (default = 0)
    gap      int   horizontal space between sibling sub‑trees
    """

    # 1.  Gather each level with a simple BFS
    levels = []                       # list of [node_id ...] per depth
    q = deque([(root_id, 0)])         # (node_id, depth)
    while q:
        nid, d = q.popleft()
        if len(levels) <= d:
            levels.append([])
        levels[d].append(nid)

        # enqueue children in *branch order* for stability
        for _, cid in sorted(tree[nid].children.items()):
            q.append((cid, d+1))

    # 2.  Determine the printable label for every node once
    labels = {}
    for nid, node in tree.items():
        if node.is_leaf:
            labels[nid] = f"[{nid}→{node.prediction}]"
        else:
            labels[nid] = f"[{nid}]"

    # 3.  Compute width of the bottom layer → overall canvas width
    bottom = levels[-1]
    max_label_len = max(len(labels[n]) for n in tree)
    cell_w = max_label_len + gap                    # width per “slot”
    width  = len(bottom) * cell_w

    # 4.  Print each level centred in its allotted range
    for depth, layer in enumerate(levels):
        n_slots = 2 ** depth                        # binary tree assumption
        slot_w  = width // n_slots                  # width per logical slot

        line = ""
        for nid in layer:
            pos  = levels[depth].index(nid)         # index in this layer
            cell = labels[nid].center(slot_w)
            line += cell
        print(line.rstrip())

    # 5.  Optional: print an underline for visual separation
    print("=" * width)
print_tree_topdown(tree)          # <-- just call it


                [0]
       [4]               [1]               [2]               [3]
  [13→1]   [15→2]   [14→3]   [16→4]   [5→2]    [6→2]    [7→4]    [9→1]    [10→2]    [8]      [11]    [12→4]
[17→1][18→3][20→1][19→4]


In [35]:
def print_pyramid(tree, root_id=0):
    """
    Show the tree breadth‑first with the root centred on the first line.

    Example for a 3‑layer tree:
        (0)
      (1) (2)
     (3)(4)(5)

    Only node‑ids are displayed for brevity; adapt as you like.
    """
    # --- collect nodes per depth ------------------------------------------
    by_depth = collections.defaultdict(list)
    queue    = collections.deque([(root_id, 0)])          # (node_id, depth)

    while queue:
        nid, depth = queue.popleft()
        by_depth[depth].append(nid)

        # enqueue children (order by branch‑id just for nicer output)
        for br_id in sorted(tree[nid].children):
            queue.append((tree[nid].children[br_id], depth + 1))

    # --- pretty‑print ------------------------------------------------------
    max_depth   = max(by_depth)
    width_first = len(" ".join(f"({n})" for n in by_depth[0]))  # width of top line

    for d in range(max_depth + 1):
        nodes   = " ".join(f"({n})" for n in by_depth[d])
        pad     = " " * ((width_first - len(nodes)) // 2)
        print(f"depth {d}: {pad}{nodes}")

# -------------------------------------------------------------------------
# call it after the training loop
print_pyramid(tree)


depth 0: (0)
depth 1: (4) (1) (2) (3)
depth 2: (13) (15) (14) (16) (5) (6) (7) (9) (10) (8) (11) (12)
depth 3: (17) (18) (20) (19)
