In [1]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging

from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[*]").appName("GenericRDD").getOrCreate()

# Access the SparkContext
sc = spark.sparkContext

In [2]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]},
    {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9]},
    {'label': 1, 'time_series': [1.4, 2.0, 2.6, 3.2, 3.8]},
    {'label': 2, 'time_series': [2.2, 3.1, 4.0, 4.9, 5.8]},
    {'label': 2, 'time_series': [2.6, 3.2, 3.8, 4.4, 5.0]},
    {'label': 3, 'time_series': [1.2, 2.0, 2.8, 3.6, 4.4]},
    {'label': 3, 'time_series': [0.6, 1.3, 2.0, 2.7, 3.4]},
    {'label': 4, 'time_series': [6.3, 6.5, 6.7, 6.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.8, 8.6, 9.4, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.0, 7.5, 8.0, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 1.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 1.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.6, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 2.3, 2.7]},
    {'label': 4, 'time_series': [0.9, 1.7, 1.2, 2.4, 2.8]}
]

df = spark.createDataFrame(tsdata)

# 1

In [3]:
def repartition_sparkdf(df, num_partitions):
    rdd = df.rdd
    rdd = rdd.repartition(num_partitions)
    rdd = rdd.mapPartitionsWithIndex(
            lambda idx, iter: [{**row.asDict(), "partition_id": idx} for row in iter]
        )
    return rdd

# example usage
rdd = repartition_sparkdf(df, 2)
print(rdd.getNumPartitions())  # should print 1
rdd.collect()

2


[{'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0], 'partition_id': 0},
 {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9], 'partition_id': 0},
 {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2], 'partition_id': 0},
 {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5], 'partition_id': 0},
 {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5], 'partition_id': 0},
 {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0], 'partition_id': 0},
 {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9], 'partition_id': 0},
 {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5], 'partition_id': 0},
 {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1], 'partition_id': 0},
 {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1], 'partition_id': 0},
 {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5], 'partition_id': 0},
 {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'partition_id': 0},
 {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0], 'partition_id': 0},
 {'label': 

# 2

In [4]:
def choose_exemplars(iterator):
    partition_data = list(iterator)
    if not partition_data:
        return iter([])
    
    # Group data by class
    grouped_data_by_class = {}
    for row in partition_data:
        label = row.get('label')
        if label is not None:
            grouped_data_by_class.setdefault(label, []).append(row)
    
    # Select one exemplar per class and track exemplar-to-label mapping
    chosen_exemplars = []
    exemplar_labels = {}  # Map to track exemplar labels
    
    for label, instances in grouped_data_by_class.items():
        if instances:  # Ensure there are instances for the class
            exemplar = sample(instances, 1)[0]
            chosen_exemplars.append((exemplar['time_series'], label))  # Store tuple of (time_series, label)
    
    # Remove chosen exemplars from the working data
    exemplar_time_series = [ex[0] for ex in chosen_exemplars]
    filtered_partition = [
        row for row in partition_data
        if row['time_series'] not in exemplar_time_series
    ]
    
    # Return rows with individual exemplar columns and their labels
    result = []
    for row in filtered_partition:
        new_row = {**row}
        # Add each exemplar and its label as columns
        for i, (exemplar, label) in enumerate(chosen_exemplars, 1):
            new_row[f"exemplar_{i}"] = exemplar
            new_row[f"exemplar_{i}_label"] = label
        result.append(new_row)
    
    return iter(result)

# example usage
rdd_with_exemplars = rdd.mapPartitions(choose_exemplars)
rdd_with_exemplars.collect()

[{'label': 1,
  'time_series': [0.8, 1.7, 2.5, 3.2, 4.0],
  'partition_id': 0,
  'exemplar_1': [1.1, 2.1, 3.1, 4.1, 5.1],
  'exemplar_1_label': 1,
  'exemplar_2': [2.1, 3.3, 4.5, 5.7, 6.9],
  'exemplar_2_label': 2,
  'exemplar_3': [2.0, 2.5, 3.0, 3.5, 4.0],
  'exemplar_3_label': 3,
  'exemplar_4': [6.3, 6.5, 6.7, 6.9, 7.1],
  'exemplar_4_label': 4},
 {'label': 2,
  'time_series': [3.0, 3.8, 4.6, 5.4, 6.2],
  'partition_id': 0,
  'exemplar_1': [1.1, 2.1, 3.1, 4.1, 5.1],
  'exemplar_1_label': 1,
  'exemplar_2': [2.1, 3.3, 4.5, 5.7, 6.9],
  'exemplar_2_label': 2,
  'exemplar_3': [2.0, 2.5, 3.0, 3.5, 4.0],
  'exemplar_3_label': 3,
  'exemplar_4': [6.3, 6.5, 6.7, 6.9, 7.1],
  'exemplar_4_label': 4},
 {'label': 2,
  'time_series': [3.3, 4.1, 4.9, 5.7, 6.5],
  'partition_id': 0,
  'exemplar_1': [1.1, 2.1, 3.1, 4.1, 5.1],
  'exemplar_1_label': 1,
  'exemplar_2': [2.1, 3.3, 4.5, 5.7, 6.9],
  'exemplar_2_label': 2,
  'exemplar_3': [2.0, 2.5, 3.0, 3.5, 4.0],
  'exemplar_3_label': 3,
  'exemplar_4

# 3

In [5]:
def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row.get('time_series', [])
        
        # Check for individual exemplar columns (exemplar_1, exemplar_2, etc.)
        exemplar_columns = {k: v for k, v in row.items() if k.startswith('exemplar_') and isinstance(v, list)}
        
        if not exemplar_columns:
            # Try to get exemplars from the 'exemplars' list if individual columns aren't found
            exemplars = row.get('exemplars', [])
            if not exemplars:
                continue  # Skip if no exemplars found
            
            # Calculate DTW distances for each exemplar in the list
            updated_row = {**row}
            for i, exemplar in enumerate(exemplars):
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
            
            updated_rows.append(updated_row)
        else:
            # Calculate DTW distances for each exemplar column
            updated_row = {**row}
            for col_name, exemplar in exemplar_columns.items():
                # Extract index from column name (e.g., "exemplar_1" -> "1")
                idx = col_name.split('_')[1]
                dtw_distance = dtw.distance(time_series, exemplar)
                updated_row[f"dtw_distance_exemplar_{idx}"] = dtw_distance
            
            updated_rows.append(updated_row)
    
    return iter(updated_rows)

# example usage
rdd_with_dtw = rdd_with_exemplars.mapPartitions(calc_dtw_distance)
rdd_with_dtw.collect()[:3]

[{'label': 1,
  'time_series': [0.8, 1.7, 2.5, 3.2, 4.0],
  'partition_id': 0,
  'exemplar_1': [1.1, 2.1, 3.1, 4.1, 5.1],
  'exemplar_1_label': 1,
  'exemplar_2': [2.1, 3.3, 4.5, 5.7, 6.9],
  'exemplar_2_label': 2,
  'exemplar_3': [2.0, 2.5, 3.0, 3.5, 4.0],
  'exemplar_3_label': 3,
  'exemplar_4': [6.3, 6.5, 6.7, 6.9, 7.1],
  'exemplar_4_label': 4,
  'dtw_distance_exemplar_1': 1.2806248474865694,
  'dtw_distance_exemplar_2': 3.683748091278773,
  'dtw_distance_exemplar_3': 1.2884098726725126,
  'dtw_distance_exemplar_4': 9.707213812418061},
 {'label': 2,
  'time_series': [3.0, 3.8, 4.6, 5.4, 6.2],
  'partition_id': 0,
  'exemplar_1': [1.1, 2.1, 3.1, 4.1, 5.1],
  'exemplar_1_label': 1,
  'exemplar_2': [2.1, 3.3, 4.5, 5.7, 6.9],
  'exemplar_2_label': 2,
  'exemplar_3': [2.0, 2.5, 3.0, 3.5, 4.0],
  'exemplar_3_label': 3,
  'exemplar_4': [6.3, 6.5, 6.7, 6.9, 7.1],
  'exemplar_4_label': 4,
  'dtw_distance_exemplar_1': 2.463736998950984,
  'dtw_distance_exemplar_2': 1.284523257866513,
  'dtw_

# 4

In [6]:
def assign_closest_exemplar(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        # Get all DTW distances
        dtw_distances = {k: v for k, v in row.items() if k.startswith('dtw_distance_exemplar_')}
        
        if not dtw_distances:
            # Create a simplified row without exemplar columns
            simplified_row = {k: v for k, v in row.items() 
                             if not k.startswith('exemplar_')}
            updated_rows.append(simplified_row)
            continue
        
        # Find the closest exemplar based on the minimum DTW distance
        closest_exemplar_key = min(dtw_distances, key=dtw_distances.get)
        min_distance = dtw_distances[closest_exemplar_key]
        
        # Extract exemplar number from the key (e.g., "dtw_distance_exemplar_1" -> "1")
        exemplar_num = closest_exemplar_key.split('_')[-1]
        
        # Get the corresponding exemplar time series data and its original label
        exemplar_key = f'exemplar_{exemplar_num}'
        exemplar_label_key = f'exemplar_{exemplar_num}_label'
        exemplar_time_series = row.get(exemplar_key, None)
        exemplar_original_label = row.get(exemplar_label_key, None)
        
        # Create a new row without the DTW distance columns and exemplar columns
        updated_row = {k: v for k, v in row.items() 
                      if not k.startswith('dtw_distance_exemplar_') and not k.startswith('exemplar_')}
        
        # Add information about the closest exemplar
        updated_row['closest_exemplar_id'] = closest_exemplar_key
        updated_row['closest_exemplar_data'] = exemplar_time_series
        updated_row['closest_exemplar_original_label'] = exemplar_original_label
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

# Example usage
rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
rdd_with_closest_exemplar.collect()

[{'label': 1,
  'time_series': [0.8, 1.7, 2.5, 3.2, 4.0],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_1',
  'closest_exemplar_data': [1.1, 2.1, 3.1, 4.1, 5.1],
  'closest_exemplar_original_label': 1},
 {'label': 2,
  'time_series': [3.0, 3.8, 4.6, 5.4, 6.2],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_2',
  'closest_exemplar_data': [2.1, 3.3, 4.5, 5.7, 6.9],
  'closest_exemplar_original_label': 2},
 {'label': 2,
  'time_series': [3.3, 4.1, 4.9, 5.7, 6.5],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_2',
  'closest_exemplar_data': [2.1, 3.3, 4.5, 5.7, 6.9],
  'closest_exemplar_original_label': 2},
 {'label': 3,
  'time_series': [0.5, 1.5, 2.5, 3.5, 4.5],
  'partition_id': 0,
  'closest_exemplar_id': 'dtw_distance_exemplar_1',
  'closest_exemplar_data': [1.1, 2.1, 3.1, 4.1, 5.1],
  'closest_exemplar_original_label': 1},
 {'label': 4,
  'time_series': [5.5, 6.6, 7.7, 8.8, 9.9],
  'partition_id': 0,
  'closest_exempl

# 5

In [7]:
def calculate_gini(labels):
    if not labels:
        return 0
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values()) if total > 0 else 0
    return gini

def evaluate_splits_within_partition(index, iterator):
    partition_data = list(iterator)
    results = []
    
    # Calculate Gini impurity before splitting
    labels = [row.get('label') for row in partition_data if row.get('label') is not None]
    before_split_gini = calculate_gini(labels)
    
    # Get all unique exemplars in the partition
    unique_exemplars = set(
        row['closest_exemplar_id'] for row in partition_data
        if row.get('closest_exemplar_id') is not None
    )
    
    # Create a mapping of exemplar_id to exemplar_data and exemplar_label
    exemplar_data_map = {}
    exemplar_label_map = {}
    
    for row in partition_data:
        exemplar_id = row.get('closest_exemplar_id')
        if exemplar_id and exemplar_id not in exemplar_data_map:
            exemplar_data_map[exemplar_id] = row.get('closest_exemplar_data')
            # Use the original exemplar label instead of the row's label
            exemplar_label_map[exemplar_id] = row.get('closest_exemplar_original_label')
    
    # Evaluate all possible splits
    for exemplar_id in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [r for r in partition_data if r.get('closest_exemplar_id') == exemplar_id]
        no_split = [r for r in partition_data if r.get('closest_exemplar_id') != exemplar_id]
        
        # Calculate Gini for each daughter node
        yes_labels = [r.get('label') for r in yes_split if r.get('label') is not None]
        no_labels = [r.get('label') for r in no_split if r.get('label') is not None]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Calculate weighted Gini after split
        total_size = len(yes_split) + len(no_split)
        weighted_gini = (yes_gini * len(yes_split) / total_size + no_gini * len(no_split) / total_size) if total_size > 0 else float('inf')
        
        # Calculate Gini reduction
        gini_reduction = before_split_gini - weighted_gini if total_size > 0 else float('-inf')
        
        # Get the exemplar data and label for this exemplar_id
        exemplar_data = exemplar_data_map.get(exemplar_id)
        exemplar_label = exemplar_label_map.get(exemplar_id)
        
        # Add this split evaluation to results
        results.append({
            "partition_id": index,
            "exemplar_id": exemplar_id,
            "exemplar_data": exemplar_data,
            "exemplar_label": exemplar_label,
            "gini_reduction": gini_reduction   
        })
    
    return iter(results)

# Example usage
rdd_with_splits = rdd_with_closest_exemplar.mapPartitionsWithIndex(evaluate_splits_within_partition)
rdd_with_splits.collect()

[{'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_4',
  'exemplar_data': [6.3, 6.5, 6.7, 6.9, 7.1],
  'exemplar_label': 4,
  'gini_reduction': 0.16625},
 {'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_1',
  'exemplar_data': [1.1, 2.1, 3.1, 4.1, 5.1],
  'exemplar_label': 1,
  'gini_reduction': 0.08738095238095245},
 {'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_3',
  'exemplar_data': [2.0, 2.5, 3.0, 3.5, 4.0],
  'exemplar_label': 3,
  'gini_reduction': 0.0009340659340658641},
 {'partition_id': 0,
  'exemplar_id': 'dtw_distance_exemplar_2',
  'exemplar_data': [2.1, 3.3, 4.5, 5.7, 6.9],
  'exemplar_label': 2,
  'gini_reduction': 0.11735294117647066},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_1',
  'exemplar_data': [0.9, 1.4, 1.9, 2.4, 2.9],
  'exemplar_label': 1,
  'gini_reduction': 0.0},
 {'partition_id': 1,
  'exemplar_id': 'dtw_distance_exemplar_2',
  'exemplar_data': [2.4, 3.5, 4.6, 5.7, 6.8],
  'exemplar_label': 2,
  'gini_redu

# ---

In [None]:
# global_proximity_tree.py
#
# Build ONE proximity‑tree in a truly distributed way.
# ----------------------------------------------------
# 1.  Driver picks one exemplar per class  (global!)
# 2.  Broadcast the exemplar list (a few KB)
# 3.  Workers tag every row with its nearest exemplar
# 4.  Spark counts      (node_id, branch_id, class)  →   n
# 5.  Driver decides the best split, creates children
# 6.  Repeat steps 2‑5 until all nodes are pure or max_depth reached
#
# NOTE: for brevity we use plain Euclidean distance; swap dtw.distance if you like.
#       The code runs on the same SparkSession as your previous script.
from random import sample
import json
from pyspark.sql import SparkSession
import collections

from pyspark.sql import functions as F
from pyspark.sql.types import *
import numpy as np
from random import choice
import math, json, collections, itertools


tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]},
    {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9]},
    {'label': 1, 'time_series': [1.4, 2.0, 2.6, 3.2, 3.8]},
    {'label': 2, 'time_series': [2.2, 3.1, 4.0, 4.9, 5.8]},
    {'label': 2, 'time_series': [2.6, 3.2, 3.8, 4.4, 5.0]},
    {'label': 3, 'time_series': [1.2, 2.0, 2.8, 3.6, 4.4]},
    {'label': 3, 'time_series': [0.6, 1.3, 2.0, 2.7, 3.4]},
    {'label': 4, 'time_series': [6.3, 6.5, 6.7, 6.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.8, 8.6, 9.4, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.0, 7.5, 8.0, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 1.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 1.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.6, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 2.3, 2.7]},
    {'label': 4, 'time_series': [0.9, 1.7, 1.2, 2.4, 2.8]}
]

# ---------------------------------------------------------------------------
# helper: tiny distance function (Euclidean) --------------------------------
# ---------------------------------------------------------------------------
def euclid(a, b):
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b)))

# ---------------------------------------------------------------------------
# 0.  create Spark session & DataFrame --------------------------------------
# ---------------------------------------------------------------------------
spark = (
    SparkSession.builder
    .appName("GlobalProximityTree")
    .getOrCreate()
)

tsdata = [...]        # <‑‑ paste the 34‑item list exactly as in your post
df = spark.createDataFrame(tsdata)        # cols: label, time_series

# Make sure the time_series column is an array<double>
df = df.withColumn("time_series", F.col("time_series").cast(ArrayType(DoubleType())))

# ---------------------------------------------------------------------------
# 1.  driver picks ONE exemplar per class (cheap: input is small) -----------
# ---------------------------------------------------------------------------
exemplar_rows = (
    df
    .groupBy("label")               # one group per class
    .agg(F.shuffle(F.collect_list("time_series")).alias("bag"))
    .select("label", F.expr("bag[0]").alias("time_series"))   # first random element
    .collect()
)

# Turn into a driver‑side dict {label: vector}
GLOBAL_EXEMPLARS = {row["label"]: row["time_series"] for row in exemplar_rows}
print("broadcasted exemplars:", GLOBAL_EXEMPLARS)

ex_bc = spark.sparkContext.broadcast(GLOBAL_EXEMPLARS)   # a few KB

# ---------------------------------------------------------------------------
# 2.  spark ⇢ workers: tag every row with nearest exemplar -------------------
# ---------------------------------------------------------------------------
def tag_nearest(row):
    vec   = row.time_series
    label = row.label
    best_id, best_dist = None, float("inf")
    for ex_label, ex_vec in ex_bc.value.items():
        d = euclid(vec, ex_vec)
        if d < best_dist:
            best_id, best_dist = ex_label, d
    # We start with a single node (node_id = 0)
    return (0, best_id, label)        # (node_id, branch_id == exemplarLabel, true class)

schema = StructType().add("node_id", IntegerType()) \
                     .add("branch_id", IntegerType()) \
                     .add("true_label", IntegerType())

tagged = df.rdd.map(tag_nearest).toDF(schema)
# ┌ node_id ┬ branch_id ┬ true_label ┐
# └      0  ┴     1     ┴     1     ┘  etc.

# ---------------------------------------------------------------------------
# 3.  iterate breadth‑first until tree finished ------------------------------
# ---------------------------------------------------------------------------
TreeNode  = collections.namedtuple(
    "TreeNode",
    "node_id parent_id split_on is_leaf prediction children".split()
)

tree = {0: TreeNode(0, None, None, False, None, {})}   # root placeholder
open_nodes = {0}
max_depth  = 3

for depth in range(max_depth):
    if not open_nodes:
        break

    # 3a.  count class distribution per (node_id, branch_id)
    counts = (
        tagged
        .where(F.col("node_id").isin(list(open_nodes)))
        .groupBy("node_id", "branch_id", "true_label")
        .count()
    )

    # bring tiny histogram to driver
    hist = {}
    for r in counts.collect():
        hist.setdefault(r.node_id, {}).setdefault(r.branch_id, {})[r.true_label] = r["count"]

    # 3b.  decide splits on the driver
    next_open = set()
    for nid in open_nodes:
        branches = hist.get(nid, {})
        # compute parent gini
        total_per_class = collections.Counter()
        for br in branches.values():
            total_per_class.update(br)
        n_total = sum(total_per_class.values())
        gini_parent = 1 - sum((c/n_total)**2 for c in total_per_class.values())

        # the split is already fixed: one branch per exemplar;
        # we only need to mark children as leaf or internal
        for br_id, cls_count in branches.items():
            n = sum(cls_count.values())
            if len(cls_count) == 1:          # pure → leaf
                pred = next(iter(cls_count))
                leaf = TreeNode(node_id=len(tree),
                                parent_id=nid,
                                split_on=None,
                                is_leaf=True,
                                prediction=pred,
                                children={})
                tree[leaf.node_id] = leaf
                tree[nid].children[br_id] = leaf.node_id
            else:                            # impure → internal node
                child_id = len(tree)
                twin  = TreeNode(child_id, nid, None, False, None, {})
                tree[child_id] = twin
                tree[nid].children[br_id] = child_id
                next_open.add(child_id)

        # mark the current node as decided (its split is “nearest exemplar”)
        tree[nid] = tree[nid]._replace(split_on="nearest_exemplar")

    # 3c.  update tagged DataFrame with new node_id for impure children
    if next_open:
        mapping = {old: new for nid in open_nodes
                             for br, new in tree[nid].children.items()
                             if new in next_open}

        # broadcast the dict { (parentId,branchId) : childId }
        map_bc = spark.sparkContext.broadcast(mapping)

        # update node_id column (narrow map, no shuffle)
        def push_down(r):
            key = (r.node_id, r.branch_id)
            new_nid = map_bc.value.get(key, r.node_id)
            return (new_nid, r.branch_id, r.true_label)

        tagged = tagged.rdd.map(push_down).toDF(schema)

    open_nodes = next_open    # loop
# ---------------------------------------------------------------------------
# 4.  pretty‑print the tree --------------------------------------------------
# ---------------------------------------------------------------------------
def show(nid, indent=""):
    node = tree[nid]
    if node.is_leaf:
        print(f"{indent}Leaf ⇒ predict {node.prediction}")
    else:
        print(f"{indent}Node {nid}: split = nearest exemplar")
        for br, child in node.children.items():
            print(f"{indent}  if nearest == class‑{br} exemplar →")
            show(child, indent + "      ")

print("\nFinal tree")
show(0)

spark.stop()


PySparkTypeError: [CANNOT_INFER_SCHEMA_FOR_TYPE] Can not infer schema for type: `ellipsis`.

: 

In [3]:
from pyspark.sql import SparkSession
spark = (
    SparkSession.builder
    .master("local[*]")
    .appName("GlobalProximityTree")
    .getOrCreate()
)

In [4]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]},
    {'label': 1, 'time_series': [0.9, 1.4, 1.9, 2.4, 2.9]},
    {'label': 1, 'time_series': [1.4, 2.0, 2.6, 3.2, 3.8]},
    {'label': 2, 'time_series': [2.2, 3.1, 4.0, 4.9, 5.8]},
    {'label': 2, 'time_series': [2.6, 3.2, 3.8, 4.4, 5.0]},
    {'label': 3, 'time_series': [1.2, 2.0, 2.8, 3.6, 4.4]},
    {'label': 3, 'time_series': [0.6, 1.3, 2.0, 2.7, 3.4]},
    {'label': 4, 'time_series': [6.3, 6.5, 6.7, 6.9, 7.1]},
    {'label': 4, 'time_series': [7.0, 7.8, 8.6, 9.4, 10.2]},
    {'label': 4, 'time_series': [6.5, 7.0, 7.5, 8.0, 8.5]},
    {'label': 1, 'time_series': [0.5, 1.0, 1.5, 2.0, 2.5]},
    {'label': 2, 'time_series': [0.6, 1.4, 1.3, 2.1, 2.5]},
    {'label': 3, 'time_series': [0.3, 1.7, 1.6, 2.2, 2.6]},
    {'label': 4, 'time_series': [0.2, 1.9, 1.6, 2.3, 2.7]},
    {'label': 4, 'time_series': [0.9, 1.7, 1.2, 2.4, 2.8]}
]

df = spark.createDataFrame(tsdata)

In [8]:
from pyspark.sql.types import *
import numpy as np
from random import choice
import math, json, collections, itertools
from pyspark.sql import functions as F

df = df.withColumn("time_series", F.col("time_series").cast(ArrayType(DoubleType())))

In [11]:
exemplar_rows = (
    df
    .groupBy("label")               # one group per class
    .agg(F.shuffle(F.collect_list("time_series")).alias("bag"))
    .select("label", F.expr("bag[0]").alias("time_series"))   # first random element
    .collect()
)
print("broadcasted exemplars:", exemplar_rows)

broadcasted exemplars: [Row(label=1, time_series=[1.1, 2.1, 3.1, 4.1, 5.1]), Row(label=2, time_series=[2.1, 3.3, 4.5, 5.7, 6.9]), Row(label=3, time_series=[0.3, 1.7, 1.6, 2.2, 2.6]), Row(label=4, time_series=[7.0, 7.8, 8.6, 9.4, 10.2])]


In [12]:
# Turn into a driver‑side dict {label: vector}
GLOBAL_EXEMPLARS = {row["label"]: row["time_series"] for row in exemplar_rows}
print("broadcasted exemplars:", GLOBAL_EXEMPLARS)

broadcasted exemplars: {1: [1.1, 2.1, 3.1, 4.1, 5.1], 2: [2.1, 3.3, 4.5, 5.7, 6.9], 3: [0.3, 1.7, 1.6, 2.2, 2.6], 4: [7.0, 7.8, 8.6, 9.4, 10.2]}


In [23]:
ex_bc = spark.sparkContext.broadcast(GLOBAL_EXEMPLARS)
for i in ex_bc.value:
    print(i, ex_bc.value[i]) # a few KB


1 [1.1, 2.1, 3.1, 4.1, 5.1]
2 [2.1, 3.3, 4.5, 5.7, 6.9]
3 [0.3, 1.7, 1.6, 2.2, 2.6]
4 [7.0, 7.8, 8.6, 9.4, 10.2]


In [18]:
def euclid(a, b):
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(a, b)))

In [None]:
from dtaidistance import dtw

In [None]:
# ---------------------------------------------------------------------------
# 2.  spark ⇢ workers: tag every row with nearest exemplar -------------------
# ---------------------------------------------------------------------------
def tag_nearest_euclidian(row):
    vec   = row.time_series
    label = row.label
    best_id, best_dist = None, float("inf")
    for ex_label, ex_vec in ex_bc.value.items():
        d = dtw(vec, ex_vec)
        if d < best_dist:
            best_id, best_dist = ex_label, d
    # We start with a single node (node_id = 0)
    return (0, best_id, label)        # (node_id, branch_id == exemplarLabel, true class)

In [None]:
# ---------------------------------------------------------------------------
# 2.  spark ⇢ workers: tag every row with nearest exemplar -------------------
# ---------------------------------------------------------------------------
def tag_nearest_dtw(row):
    vec   = row.time_series
    label = row.label
    best_id, best_dist = None, float("inf")
    for ex_label, ex_vec in ex_bc.value.items():
        d = euclid(vec, ex_vec)
        if d < best_dist:
            best_id, best_dist = ex_label, d
    # We start with a single node (node_id = 0)
    return (0, best_id, label)        # (node_id, branch_id == exemplarLabel, true class)