In [1]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging
from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw
import collections
from pprint import pprint
import random
import collections
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.window import Window

In [2]:
spark = SparkSession.builder.master("local[*]").appName("testingglobal").getOrCreate()
sc = spark.sparkContext

In [3]:
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType, ArrayType
import random
import collections

class GlobalProxTree:
    def __init__(self, spark, max_depth=5, min_samples=5):
        """
        Initialize the Time Series Decision Tree
        
        Parameters:
        -----------
        spark : SparkSession
            The Spark session to use
        max_depth : int
            Maximum depth of the tree
        min_samples : int
            Minimum number of samples required to split a node
        """
        self.spark = spark
        self.max_depth = max_depth
        self.min_samples = min_samples
        
        # Define the schema for tagged dataframe
        self.tagged_schema = StructType([
            StructField("node_id",     IntegerType(),           False),
            StructField("time_series", ArrayType(DoubleType()), False),
            StructField("branch_id",   IntegerType(),           False),
            StructField("true_label",  IntegerType(),           False),
            StructField("dist_calc",   DoubleType(),            False),
        ])
        
        # Define the TreeNode structure
        self.TreeNode = collections.namedtuple(
            "TreeNode",
            "node_id parent_id split_on is_leaf prediction children gini_parent".split()
        )
        
        # Initialize the tree with a root node
        self.tree = {
            0: self.TreeNode(
                node_id=0,
                parent_id=None,
                split_on=None,
                is_leaf=False,
                prediction=None,
                children={},
                gini_parent=None
            )
        }
    
    def _convert_to_time_series_format(self, df):
        """
        Convert wide dataframe (with each feature in its own column) to a dataframe 
        with a single array column containing all features
        
        Parameters:
        -----------
        df : Spark DataFrame
            Wide DataFrame with feature columns and label column
            
        Returns:
        --------
        Spark DataFrame
            DataFrame with 'time_series' and 'label' columns
        """
        # Check if 'time_series' column already exists
        if 'time_series' in df.columns:
            print("DataFrame already has 'time_series' column, no conversion needed")
            return df
            
        # Get all column names except 'label'
        feature_cols = [col for col in df.columns if col != 'label']
        
        print(f"Converting {len(feature_cols)} feature columns to 'time_series' array")
        
        # Use array() function to combine columns
        # This is more efficient than a UDF
        ts_df = df.select(
            F.array(*[F.col(c) for c in feature_cols]).alias("time_series"),
            df["label"].cast(IntegerType())  # Ensure label is an integer
        )
        
        # Show sample of converted data
        print("Sample of converted DataFrame:")
        ts_df.show(2, truncate=False)
        
        return ts_df
        
    def fit(self, df):
        """
        Fit the decision tree on the dataframe
        
        Parameters:
        -----------
        df : Spark DataFrame
            DataFrame with feature columns and 'label' column
            
        Returns:
        --------
        self : GlobalProxTree
            The fitted tree
        """
        # First, convert to time_series format if needed
        df = self._convert_to_time_series_format(df)
        
        # Initialize assignment dataframe with all rows at the root node
        assign_df = (
            df
            .withColumn("node_id", F.lit(0).cast(IntegerType()))
            .withColumn("true_label", F.col("label").cast(IntegerType()))
            .select("node_id", "time_series", "true_label")
            .cache()
        )
            
        
        open_nodes = {0}
        
        for depth in range(self.max_depth):
            # If no nodes to expand, stop
            if not open_nodes:
                break
                
            # 1. Sample local exemplars

            # collect all time_series per (node,label) as a list column
            expl_df = (
                assign_df
                .filter(F.col("node_id").isin(open_nodes))
                .groupBy("node_id","true_label")
                .agg(F.collect_list("time_series").alias("all_series"))
                .collect()
            )
            exemplars = {
                (row.node_id, row.true_label): random.choice(row.all_series)
                for row in expl_df
            }
            
            bc_exemplars = self.spark.sparkContext.broadcast(exemplars)
            # Use a Python function in map instead of a lambda to avoid serializing the class
            def tag_with_exemplars(node_id, time_series, true_label):
                exemplars = bc_exemplars.value
                best_branch, best_dist = None, float("inf")

                for (n, ex_lbl), ex_ts in exemplars.items():
                    if n != node_id:
                        continue
                    try:
                        import dtw
                        d = dtw.distance(time_series, ex_ts)
                    except ImportError:
                        d = sum((a - b) ** 2 for a, b in zip(time_series, ex_ts)) ** 0.5
                    if d < best_dist:
                        best_branch, best_dist = ex_lbl, d

                if best_branch is None:
                    # fallback
                    labels = [lbl for (n, lbl), _ in exemplars.items() if n == node_id]
                    best_branch = labels[0] if labels else 0

                return (node_id, time_series, best_branch, true_label, best_dist)
                        
            # 2. Tag every row with nearest exemplar
            tagged = (
                assign_df.rdd
                .map(lambda r: tag_with_exemplars(r.node_id, r.time_series, r.true_label))
                .toDF(self.tagged_schema)
                .cache()
            )
            
            # 3. Build histogram & decide splits
            hist = self._build_histogram(tagged, open_nodes, debug=(depth == 0))
            if not hist:
                print(f"No more splits at depth={depth}, stopping.")
                bc_ex.unpersist()
                tagged.unpersist()
                break
                
            # Print debugging information
            print(f"\n=== depth={depth}, open_nodes={open_nodes} ===")
            import pprint
            pprint.pprint(hist)
            
            # Decide splits
            next_open = set()
            for nid in open_nodes:
                if nid in hist:  # Check if the node has any data points
                    children = self._split_node_gini(
                        nid,
                        hist[nid],
                        depth,
                        self.max_depth,
                        self.min_samples
                    )
                    next_open |= children  # Union of new internal nodes
            
            # 4. Push rows down to child node_ids
            old = assign_df
            assign_df = self._push_rows_down(tagged, open_nodes, next_open)
            old.unpersist()
            
            open_nodes = next_open  # Next layer
            
            # Filter to only rows at open nodes
            assign_df = assign_df.where(F.col("node_id").isin(next_open))
            bc_exemplars.unpersist()  # Unpersist the broadcast variable
            
        return self
    
    def _build_histogram(self, tagged_df, open_nodes, debug=False):
        """
        Count rows for every (node_id, branch_id, true_label) among the open nodes.
        
        Parameters:
        -----------
        tagged_df : Spark DataFrame
            DataFrame containing tagged rows
        open_nodes : set
            Set of node IDs to process
        debug : bool
            Whether to print debug information
            
        Returns:
        --------
        dict : Nested dictionary with counts
            { node_id: { branch_id: { true_label: count, ... }, ... }, ... }
        """
        if not open_nodes:
            return {}
            
        # Filter to only the node_ids we're actually expanding
        filtered = tagged_df.where(F.col("node_id").isin(open_nodes))
        counts = (
            filtered
            .groupBy("node_id", "branch_id", "true_label")
            .count()
        )
        
        if debug:
            print(">>> FILTERED ROWS FOR THESE open_nodes:", open_nodes)
            filtered.show(2, truncate=False)
            print(">>> AGGREGATED COUNTS:")
            counts.show(5, truncate=False)
        
        # Build histogram directly from collected rows    
        hist = {}
        for r in counts.collect():
            node_id = r["node_id"]
            branch_id = r["branch_id"]
            true_label = r["true_label"]
            count = r["count"]
            
            if node_id not in hist:
                hist[node_id] = {}
            if branch_id not in hist[node_id]:
                hist[node_id][branch_id] = {}
            hist[node_id][branch_id][true_label] = count
                
        return hist
    
    def _split_node_gini(self, node_id, branches, depth, max_depth, min_samples):
        """
        Decide whether to split a node based on Gini impurity
        
        Parameters:
        -----------
        node_id : int
            ID of the node to potentially split
        branches : dict
            Dictionary of branches for this node
        depth : int
            Current depth in the tree
        max_depth : int
            Maximum allowed depth
        min_samples : int
            Minimum samples required to split
            
        Returns:
        --------
        set : Set of newly created internal nodes
        """
        new_nodes = set()
        
        # Skip if we're at max depth
        if depth >= max_depth - 1:
            self.tree[node_id] = self.tree[node_id]._replace(is_leaf=True)
            
            # Find most common class for prediction
            all_counts = {}
            for branch_id, class_counts in branches.items():
                for label, count in class_counts.items():
                    all_counts[label] = all_counts.get(label, 0) + count
            
            prediction = max(all_counts.items(), key=lambda x: x[1])[0] if all_counts else None
            self.tree[node_id] = self.tree[node_id]._replace(prediction=prediction)
            
            print(f"Node {node_id} is now a leaf with prediction {prediction}")
            return new_nodes
            
        # Create a child node for each branch
        next_node_id = max(self.tree.keys()) + 1
        
        for branch_id in branches:
            # Compute total counts for this branch
            branch_total = sum(count for label, count in branches[branch_id].items())
            
            # Skip branches with too few samples
            if branch_total < min_samples:
                print(f"Branch {branch_id} has only {branch_total} samples, skipping")
                continue
                
            # Create a new child node
            child_id = next_node_id
            next_node_id += 1
            
            # Create the child node
            self.tree[child_id] = self.TreeNode(
                node_id=child_id,
                parent_id=node_id,
                split_on=None,  # No further split yet
                is_leaf=False,  # Will be determined in next iteration
                prediction=None,  # Will be set if it becomes a leaf
                children={},
                gini_parent=None  # Calculate if needed
            )
            
            # Add child to parent's children
            self.tree[node_id].children[branch_id] = child_id
            new_nodes.add(child_id)
            
            print(f"Created child node {child_id} for branch {branch_id}")
            
        return new_nodes
    
    def _push_rows_down(self, tagged_df, old_nodes, new_nodes):
        """
        Push rows down from parent nodes to child nodes
        
        Parameters:
        -----------
        tagged_df : Spark DataFrame
            DataFrame containing tagged rows
        old_nodes : set
            Set of current node IDs
        new_nodes : set
            Set of new node IDs
            
        Returns:
        --------
        Spark DataFrame : Updated assignment DataFrame
        """
        # Create a map from (parent_id, branch_id) to child_id
        parent_branch_to_child = {}
        
        for nid in old_nodes:
            if nid in self.tree:
                for branch_id, child_id in self.tree[nid].children.items():
                    parent_branch_to_child[(nid, branch_id)] = child_id
        
        # Create a closure over the mapping
        mapping = parent_branch_to_child  # Local variable to avoid serializing self
        
        def map_to_child(node_id, branch_id):
            key = (node_id, branch_id)
            if key in mapping:
                return mapping[key]
            return node_id  # Keep at same node if no mapping
        
        # Create a UDF from the closure
        map_to_child_udf = F.udf(map_to_child, IntegerType())
        
        # Apply the mapping
        pushed_df = tagged_df.withColumn(
            "node_id",
            map_to_child_udf(F.col("node_id"), F.col("branch_id"))
        ).select("node_id", "time_series", "true_label").cache()
        
        return pushed_df
    
    def predict(self, df):
        """
        Make predictions using the trained tree
        
        Parameters:
        -----------
        df : Spark DataFrame
            DataFrame with feature columns or 'time_series' column
            
        Returns:
        --------
        Spark DataFrame : DataFrame with predictions
        """
        # First, convert to time_series format if needed
        df = self._convert_to_time_series_format(df)
        
        # Simplified predict implementation
        def traverse_tree(time_series):
            """Traverse the tree for a single time series instance"""
            node_id = 0  # Start at root
            
            while node_id in self.tree and not self.tree[node_id].is_leaf:
                # If we're at a leaf, return its prediction
                if self.tree[node_id].is_leaf:
                    return self.tree[node_id].prediction
                
                # Find distances to all exemplars for this node's children
                min_dist = float('inf')
                best_branch = None
                
                # Simplified based on the exemplar approach
                for branch_id in self.tree[node_id].children:
                    # Using branch_id directly as a proxy for distance
                    # In a full implementation, we would use actual exemplars here
                    if branch_id == time_series[0]:  # Simplified matching
                        best_branch = branch_id
                        break
                
                # Default to first child if no match
                if best_branch is None and self.tree[node_id].children:
                    best_branch = list(self.tree[node_id].children.keys())[0]
                
                # If we have a branch to follow, go there
                if best_branch is not None and best_branch in self.tree[node_id].children:
                    node_id = self.tree[node_id].children[best_branch]
                else:
                    # If we can't go further, use current node's prediction (or default)
                    break
            
            # Return prediction if node exists and is a leaf
            if node_id in self.tree and self.tree[node_id].is_leaf:
                return self.tree[node_id].prediction
            
            # Default prediction (most common class among all data)
            return 1  # Default prediction
        
        # This is a placeholder - in reality you would implement a full traversal
        print("Prediction functionality is simplified - modify for real use")
        return df.withColumn("prediction", F.lit(1))  # Return default prediction
    
    def print_tree(self):
        """
        Print a representation of the tree
        
        Returns:
        --------
        str : String representation of the tree
        """
        def print_node(node_id, depth=0):
            node = self.tree[node_id]
            indent = "  " * depth
            
            if node.is_leaf:
                return f"{indent}Node {node_id}: LEAF, prediction={node.prediction}\n"
            
            result = f"{indent}Node {node_id}: internal\n"
            
            for branch_id, child_id in sorted(node.children.items()):
                result += f"{indent}  Branch {branch_id} -> {child_id}\n"
                result += print_node(child_id, depth + 1)
                
            return result
        
        return print_node(0)  # Start at root

In [4]:
config = {
    "databricks_data_path": "/mnt/2025-team6/fulldataset_ECG5000.csv",
    "local_data_path": "/fulldataset_ECG5000.csv",
    "label_col": "label",
    "data_percentage": 1.0,
    "min_number_iterarations": 2,

    "local_model_config": {
        "test_local_model" : True,
        "num_partitions": 10,  
        "tree_params": {
            "n_splitters": 5,  # Matches ProximityTree default
            "max_depth": None,  
            "min_samples_split": 5,  # From ProximityTree default
            "random_state": 123
            },
        "forest_params": {
            "random_state": 123,
            "n_jobs": -1  # Use all available cores
            }
    },
    "global_model_config": {
        "test_local_model" : False,
        "num_partitions": 10
    }
}

ingestion_config = {
                "data_path":r"D:\repos\BigData-main\BigData-1\fulldataset_ECG5000.csv",
                "data_percentage": config.get("data_percentage", 0.5)
}

In [5]:
ingestion = DataIngestion(spark=spark, config=ingestion_config)
preprocessor = Preprocessor(config=config)

In [6]:
# load + preprocess data
df = ingestion.load_data()
df = preprocessor.run_preprocessing(df)

Data Path: D:\repos\BigData-main\BigData-1\fulldataset_ECG5000.csv
Loading 100.0% of data
Data size: 5000

Repartitioning to 10 workers


In [7]:
# The original data is in wide format
print("Original DataFrame:")

print(f"DataFrame shape: {len(df.columns)} columns, {df.count()} rows")

# Limit to fewer rows for quicker testing
test_df = df.limit(50)
print(f"Test DataFrame shape: {len(test_df.columns)} columns, {test_df.count()} rows")

Original DataFrame:
DataFrame shape: 141 columns, 5000 rows
Test DataFrame shape: 141 columns, 50 rows


In [8]:
# from pyspark.sql.window import Window
# from pyspark.sql.functions import row_number, rand, col

# # ─── balance to 50 samples per class ─────────────────────────
# w = Window.partitionBy("label").orderBy(rand())
# df = (
#     df
#     .withColumn("rn", row_number().over(w))
#     .filter(col("rn") <= 50)
#     .drop("rn")
# )
# # now df has up to 50 rows for each label

# # ─── inspect! ────────────────────────────────────────────────
# print("Balanced DataFrame:")
# df.printSchema()
# print(f"DataFrame shape: {len(df.columns)} columns, {df.count()} rows")

In [9]:
# ─── split into train/test ──────────────────────────────────
train_df, test_df = df.randomSplit([0.8, 0.2], seed=1234)

In [10]:
def randomSplit_stratified_via_sampleBy(df, label_col, weights=[0.8, 0.2], seed=123):
    
    """
    Splits a Spark DataFrame into train/test sets based on partition-Preserves per‑class proportions

    """
    
    assert abs(sum(weights) - 1.0) < 1e-6 # ensure that our weights must sum to 1.0
    train_frac = weights[0]

    # figure out all the distinct label values 
    labels = [row[label_col] for row in df.select(label_col) 
                                            .distinct()  # build a tiny DataFrame of unique labels
                                            .collect() # brings the list to the driver
                                            ]

    # build a dict: each label -> same fraction
    fractions = {dict_lbl: train_frac for dict_lbl in labels}

    # sample train set: Use Spark’s native stratified sampler
    train_df = df.stat.sampleBy(label_col, fractions, seed) # map‑side sampling per key, jno shuffle
    # everything else is test
    test_df  = df.join(train_df, on=df.columns, how="left_anti") # one shuffles to get the rest of the data
    return train_df, test_df 


train_df, test_df = randomSplit_stratified_via_sampleBy(df, label_col = "label", weights=[0.8, 0.2], seed=123)

In [11]:
# Example usage of the updated GlobalProxTree with the sample data

# First, create the tree with the desired parameters
tree = GlobalProxTree(spark, max_depth=15, min_samples=3)

# The original data is in wide format
print("Original DataFrame:")
#df.printSchema()  # Show the schema to confirm structure
print(f"DataFrame shape: {len(train_df.columns)} columns, {df.count()} rows")



# Now we can directly fit the tree on the wide DataFrame
# The conversion will happen automatically inside the fit method
try:
    print("\nFitting tree on wide DataFrame...")
    tree.fit(train_df)
    
    print("\nTree structure:")
    print(tree.print_tree())
except Exception as e:
    print(f"Error while fitting tree: {e}")
    import traceback
    traceback.print_exc()

print("\nTree fitting complete!")



Original DataFrame:
DataFrame shape: 141 columns, 5000 rows

Fitting tree on wide DataFrame...
Converting 140 feature columns to 'time_series' array
Sample of converted DataFrame:
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [12]:
for node_id, node in tree.tree.items():
    print(f"Node {node_id}: is_leaf={node.is_leaf}, prediction={node.prediction}, children={node.children}")


Node 0: is_leaf=False, prediction=None, children={1: 1, 2: 2, 4: 3, 5: 4, 3: 5}
Node 1: is_leaf=False, prediction=None, children={5: 6, 4: 7, 1: 8}
Node 2: is_leaf=False, prediction=None, children={3: 9, 2: 10, 4: 11, 5: 12}
Node 3: is_leaf=False, prediction=None, children={1: 13, 4: 14, 2: 15, 5: 16, 3: 17}
Node 4: is_leaf=False, prediction=None, children={1: 18, 2: 19, 5: 20, 4: 21, 3: 22}
Node 5: is_leaf=False, prediction=None, children={4: 23, 5: 24, 2: 25, 3: 26}
Node 6: is_leaf=False, prediction=None, children={5: 27, 3: 28, 1: 29}
Node 7: is_leaf=False, prediction=None, children={1: 30}
Node 8: is_leaf=False, prediction=None, children={1: 31}
Node 9: is_leaf=False, prediction=None, children={2: 32, 4: 33, 3: 34}
Node 10: is_leaf=False, prediction=None, children={4: 35, 2: 36, 1: 37, 5: 38, 3: 39}
Node 11: is_leaf=False, prediction=None, children={4: 40, 2: 41, 1: 42, 3: 43}
Node 12: is_leaf=False, prediction=None, children={}
Node 13: is_leaf=False, prediction=None, children={2:

In [13]:


from pyspark.ml.evaluation import MulticlassClassificationEvaluator
test_df = test_df

from pyspark.sql.types import DoubleType

# 2) Run predict() on your held‑out test_df (must have a "label" column):
pred_df = tree.predict(test_df)

# Cast both label and prediction to Double so the evaluator is happy
pred_df = (
    pred_df
    .withColumn("prediction", F.col("prediction").cast(DoubleType()))
    .withColumn("label",      F.col("label")     .cast(DoubleType()))
)

# Show a few predictions vs. truth
pred_df.select("label", "prediction").show(10)

# 3a) Manual accuracy via a “correct” flag + avg:
acc_manual = (
    pred_df
    .withColumn("correct", F.when(F.col("prediction") == F.col("label"), 1).otherwise(0))
    .agg(F.avg("correct").alias("accuracy"))
    .collect()[0]["accuracy"]
)
print(f"Manual accuracy = {acc_manual:.3f}")

# 3b) Or use the built‑in evaluator:
evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="accuracy"
)
acc_evaluator = evaluator.evaluate(pred_df)
print(f"Evaluator accuracy = {acc_evaluator:.3f}")


Converting 140 feature columns to 'time_series' array
Sample of converted DataFrame:
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [14]:
import pyspark.sql.functions as F
from pyspark.sql.types import ArrayType, DoubleType

def convert_using_udf(df):
    """Alternative method using a UDF"""
    
    feature_cols = [col for col in df.columns if col != 'label']
    
    # Define a UDF to combine values into an array
    @F.udf(returnType=ArrayType(DoubleType()))
    def combine_features(*cols):
        return [float(c) if c is not None else 0.0 for c in cols]
    
    # Apply the UDF to create time_series column
    udf_ts_df = df.select(
        combine_features(*feature_cols).alias("time_series"),
        "label"
    )
    
    return udf_ts_df

# Convert using the UDF method
udf_ts_df = convert_using_udf(df)

print("\nConverted DataFrame (UDF method):")
udf_ts_df.show(5, truncate=False)


Converted DataFrame (UDF method):
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------