In [67]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging

from pyspark.sql import SparkSession
import os
from pyspark.sql import SparkSession
from data_ingestion import DataIngestion
from preprocessing import Preprocessor
from prediction_manager import PredictionManager
from local_model_manager import LocalModelManager
from evaluation import Evaluator
from utilities import show_compact
import time
import json
from random import sample
from dtaidistance import dtw

## ---


In [68]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[*]").appName("GenericRDD").getOrCreate()

# Access the SparkContext
sc = spark.sparkContext

# ---

In [69]:
data = [
    {"label": 1, "time_series": [1.0, 2.1, 3.2, 4.3, 5.4]},
    {"label": 2, "time_series": [2.0, 3.1, 4.2, 5.3, 6.4]},
    {"label": 3, "time_series": [3.0, 4.1, 5.2, 6.3, 7.4]},
    {"label": 4, "time_series": [4.0, 5.1, 6.2, 7.3, 8.4]},
    {"label": 1, "time_series": [1.5, 2.6, 3.7, 4.8, 5.9]},
    {"label": 2, "time_series": [2.5, 3.6, 4.7, 5.8, 6.9]},
    {"label": 3, "time_series": [3.5, 4.6, 5.7, 6.8, 7.9]},
    {"label": 4, "time_series": [4.5, 5.6, 6.7, 7.8, 8.9]}
]

rdd = sc.parallelize(data)

In [70]:
rdd = rdd.repartition(2)
rdd.getNumPartitions()

2

In [71]:
def print_partition_rows(index, iterator):
    # Add partition index to each row
    return [(index, row) for row in iterator]

# Use mapPartitionsWithIndex to include partition index
partitioned_rdd = rdd.mapPartitionsWithIndex(print_partition_rows)

# Collect and print the rows along with their partition index
for partition_index, row in partitioned_rdd.collect():
    print(f"Partition {partition_index}: {row}")

Partition 0: {'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4]}
Partition 0: {'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4]}
Partition 0: {'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9]}
Partition 0: {'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9]}
Partition 0: {'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9]}
Partition 1: {'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4]}
Partition 1: {'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4]}
Partition 1: {'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9]}


# adding exemplar column

In [72]:
def sample_and_add_column(iterator):
    partition_data = list(iterator)
    sampled_element = sample(partition_data, 1)[0]['time_series']
    return iter([{**row, "exemplar": sampled_element} for row in partition_data])

rdd_with_sampled_column = rdd.mapPartitions(sample_and_add_column)

# Collect and print the updated RDD
for row in rdd_with_sampled_column.collect():
    print(row)

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9]}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9]}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9]}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9]}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9]}
{'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4], 'exemplar': [2.0, 3.1, 4.2, 5.3, 6.4]}
{'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4], 'exemplar': [2.0, 3.1, 4.2, 5.3, 6.4]}
{'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9], 'exemplar': [2.0, 3.1, 4.2, 5.3, 6.4]}


# calculating DTW distance using time series and exemplar columns

In [73]:
# def calc_dtw_distance(iterator):
#     partition_data = list(iterator)
#     time_series = partition_data['time_series']
#     exemplar = partition_data['exemplar']
#     dtw_distance = dtw.distance(time_series, exemplar)
#     return iter([{**row, "dtw_distance": dtw_distance} for row in partition_data])

def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row['time_series']
        exemplar = row['exemplar']
        
        dtw_distance = dtw.distance(time_series, exemplar)
        
        updated_row = {**row, "dtw_distance": dtw_distance}
        updated_rows.append(updated_row)
    return iter(updated_rows)

rdd_with_dtw = rdd_with_sampled_column.mapPartitions(calc_dtw_distance)
for row in rdd_with_dtw.collect():
    print(row)

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9], 'dtw_distance': 1.1180339887498947}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9], 'dtw_distance': 2.2671568097509267}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9], 'dtw_distance': 1.42828568570857}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9], 'dtw_distance': 0.0}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplar': [2.5, 3.6, 4.7, 5.8, 6.9], 'dtw_distance': 1.42828568570857}
{'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4], 'exemplar': [2.0, 3.1, 4.2, 5.3, 6.4], 'dtw_distance': 1.42828568570857}
{'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4], 'exemplar': [2.0, 3.1, 4.2, 5.3, 6.4], 'dtw_distance': 0.0}
{'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9], 'exemplar': [2.0, 3.1, 4.2, 5.3, 6.4], 'dtw_distance': 4.08533964316309

---

# WORKS FOR ANY NUM OF PARTITIONS AND EXEMPLARS

In [74]:
def create_sample_and_add_column_function(num_exemplars):
    def sample_and_add_column(iterator):
        partition_data = list(iterator)
        exemplars = []
        for row in sample(partition_data, min(num_exemplars, len(partition_data))):
            exemplars.append(row['time_series'])
        return iter([{**row, "exemplars": exemplars} for row in partition_data])
    return sample_and_add_column

# Example usage
num_exemplars = 2

chosen_exemplars = create_sample_and_add_column_function(num_exemplars)
rdd_with_exemplar_column = rdd.mapPartitions(chosen_exemplars)

for row in rdd_with_exemplar_column.collect():
    print(row)

print(f'\nrdd num partitions: {rdd_with_exemplar_column.getNumPartitions()}')

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]]}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]]}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]]}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]]}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]]}
{'label': 1, 'time_series': [1.0, 2.1, 3.2, 4.3, 5.4], 'exemplars': [[2.0, 3.1, 4.2, 5.3, 6.4], [1.0, 2.1, 3.2, 4.3, 5.4]]}
{'label': 2, 'time_series': [2.0, 3.1, 4.2, 5.3, 6.4], 'exemplars': [[2.0, 3.1, 4.2, 5.3, 6.4], [1.0, 2.1, 3.2, 4.3, 5.4]]}
{'label': 4, 'time_series': [4.5, 5.6, 6.7, 7.8, 8.9], 'exemplars': [[2.0, 3.1, 4.2, 5.3, 6.4], [1.0, 2.1, 3.2, 4.3, 5.4]]}

rdd num

In [75]:
def calc_dtw_distance(iterator):
    partition_data = list(iterator)
    updated_rows = []
    
    for row in partition_data:
        time_series = row['time_series']
        exemplars = row['exemplars']
        
        # Calculate DTW distances for each exemplar
        dtw_distances = [dtw.distance(time_series, exemplar) for exemplar in exemplars]
        
        # Add each DTW distance as a separate column
        updated_row = {**row}
        for i, dtw_distance in enumerate(dtw_distances):
            updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
        
        updated_rows.append(updated_row)
    
    return iter(updated_rows)

# Example usage
rdd_with_dtw = rdd_with_exemplar_column.mapPartitions(calc_dtw_distance)
for row in rdd_with_dtw.collect():
    print(row)

print(f'\nrdd num partitions: {rdd_with_dtw.getNumPartitions()}')

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 1.1180339887498947, 'dtw_distance_exemplar_2': 1.42828568570857}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 2.2671568097509267, 'dtw_distance_exemplar_2': 0.0}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 1.42828568570857, 'dtw_distance_exemplar_2': 4.085339643163099}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 0.0, 'dtw_distance_exemplar_2': 2.2671568097509267}
{'label': 3, 'time_series': [3.5, 4.6, 5.7, 6.8, 7.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 1.4282856857

In [76]:
# not sure if this is needed

def assign_closest_exemplar(iterator):
    partition_data = list(iterator)

    for row in partition_data:
        # Check if there are DTW distances for exemplars
        exemplar_distances = {key: value for key, value in row.items() if key.startswith("dtw_distance_exemplar_")}
        
        if exemplar_distances:
            # Find the exemplar with the smallest DTW distance
            closest_exemplar = min(exemplar_distances, key=exemplar_distances.get)
            
            # Assign the closest exemplar to the row
            row["closest exemplar"] = closest_exemplar

    return iter(partition_data)

# Example usage
rdd_with_classification = rdd_with_dtw.mapPartitions(assign_closest_exemplar)
for row in rdd_with_classification.collect():
    print(row)

{'label': 3, 'time_series': [3.0, 4.1, 5.2, 6.3, 7.4], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 1.1180339887498947, 'dtw_distance_exemplar_2': 1.42828568570857, 'closest exemplar': 'dtw_distance_exemplar_1'}
{'label': 4, 'time_series': [4.0, 5.1, 6.2, 7.3, 8.4], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 2.2671568097509267, 'dtw_distance_exemplar_2': 0.0, 'closest exemplar': 'dtw_distance_exemplar_2'}
{'label': 1, 'time_series': [1.5, 2.6, 3.7, 4.8, 5.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 1.42828568570857, 'dtw_distance_exemplar_2': 4.085339643163099, 'closest exemplar': 'dtw_distance_exemplar_1'}
{'label': 2, 'time_series': [2.5, 3.6, 4.7, 5.8, 6.9], 'exemplars': [[2.5, 3.6, 4.7, 5.8, 6.9], [4.0, 5.1, 6.2, 7.3, 8.4]], 'dtw_distance_exemplar_1': 0.0, 'dtw_distance_exemplar_2': 2.2671568097509267, 'closest exemplar': 'd

In [77]:
def calculate_partition_gini(iterator):
    labels = [row['label'] for row in iterator]

    label_counts_dict = {}
    for label in labels:
        if label in label_counts_dict:
            label_counts_dict[label] += 1
        else:
            label_counts_dict[label] = 1
    
    total = sum(label_counts_dict.values())
    proportion_sqrd_values = [(count / total) ** 2 for count in label_counts_dict.values()]
    gini_impurity = 1 - sum(proportion_sqrd_values)
    
    return iter([gini_impurity])

In [78]:
# Example usage
gini_rdd = rdd_with_classification.mapPartitions(calculate_partition_gini)

# Collect and print the Gini impurity for each partition
i=0
for gini in gini_rdd.collect():
    print(f'gini impurity of partition {i+1}: {gini}')
    i+=1

gini impurity of partition 1: 0.72
gini impurity of partition 2: 0.6666666666666667


### trying splitting code

In [79]:
# tsdata = [
#     {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0], 'closest_exemplar': 'exemplar_1'},
#     {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9], 'closest_exemplar': 'exemplar_2'},
#     {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5], 'closest_exemplar': 'exemplar_1'},
#     {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2], 'closest_exemplar': 'exemplar_2'},
#     {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'closest_exemplar': 'exemplar_1'},
#     {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9], 'closest_exemplar': 'exemplar_2'},
#     {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0], 'closest_exemplar': 'exemplar_1'},
#     {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5], 'closest_exemplar': 'exemplar_2'},
#     {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5], 'closest_exemplar': 'exemplar_1'},
#     {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5], 'closest_exemplar': 'exemplar_2'}
# ]

# ts_rdd = sc.parallelize(tsdata)
# ts_rdd = ts_rdd.repartition(2)

# print(f'ts_rdd num partitions: {ts_rdd.getNumPartitions()}')

# ts_rdd_gini = ts_rdd.mapPartitions(calculate_partition_gini)
# # Collect and print the Gini impurity for each partition
# i=0
# for gini in ts_rdd_gini.collect():
#     print(f'gini impurity of partition {i+1}: {gini}')
#     i+=1

In [80]:
# closestto1_yes = [row['label'] for row in tsdata if row['closest_exemplar'] == 'exemplar_1']
# closestto1_no = [row['label'] for row in tsdata if row['closest_exemplar'] != 'exemplar_1']

# print(closestto1_yes)
# print(closestto1_no)

In [84]:
def calculate_gini(labels):
    label_counts = {}
    for label in labels:
        label_counts[label] = label_counts.get(label, 0) + 1
    total = sum(label_counts.values())
    gini = 1 - sum((count / total) ** 2 for count in label_counts.values())
    return gini

def evaluate_splits_within_partition(iterator):
    partition_data = list(iterator)
    
    # If the partition is empty, return an empty iterator
    if not partition_data:
        return iter([])
    
    # Get all unique exemplar names in the partition
    unique_exemplars = set(row['closest exemplar'] for row in partition_data)
    
    results = []
    
    # Loop through each exemplar to evaluate splits
    for exemplar_name in unique_exemplars:
        # Split the data based on the current exemplar
        yes_split = [row for row in partition_data if row['closest exemplar'] == exemplar_name]
        no_split = [row for row in partition_data if row['closest exemplar'] != exemplar_name]
        
        # Calculate metrics for the split (e.g., Gini impurity)
        yes_labels = [row['label'] for row in yes_split]
        no_labels = [row['label'] for row in no_split]
        
        yes_gini = calculate_gini(yes_labels)
        no_gini = calculate_gini(no_labels)
        
        # Store the results for this split
        results.append({
            'exemplar': exemplar_name,
            'yes_gini': yes_gini,
            'no_gini': no_gini,
            'yes_split_size': len(yes_split),
            'no_split_size': len(no_split)
        })
    
    # Return the results as an iterator
    return iter(results)

In [85]:
# example usage
ts_rdd_split_results = rdd_with_classification.mapPartitions(evaluate_splits_within_partition)
# Collect and print the results
for result in ts_rdd_split_results.collect():
    print(result)

{'exemplar': 'dtw_distance_exemplar_1', 'yes_gini': 0.6666666666666667, 'no_gini': 0.5, 'yes_split_size': 3, 'no_split_size': 2}
{'exemplar': 'dtw_distance_exemplar_2', 'yes_gini': 0.5, 'no_gini': 0.6666666666666667, 'yes_split_size': 2, 'no_split_size': 3}
{'exemplar': 'dtw_distance_exemplar_1', 'yes_gini': 0.5, 'no_gini': 0.0, 'yes_split_size': 2, 'no_split_size': 1}
{'exemplar': 'dtw_distance_exemplar_2', 'yes_gini': 0.0, 'no_gini': 0.5, 'yes_split_size': 1, 'no_split_size': 2}


In [36]:
# def split_within_partition(iterator):
#     partition_data = list(iterator)
#     exemplar_name = partition_data[0]['closest_exemplar'] # randomly chosen exemplar name
    
#     # Split the data within the partition
#     yes_split = [row for row in partition_data if row['closest_exemplar'] == exemplar_name]
#     no_split = [row for row in partition_data if row['closest_exemplar'] != exemplar_name]
    
#     # Add a flag to indicate which split the row belongs to
#     yes_split = [{**row, 'split': 'yes'} for row in yes_split]
#     no_split = [{**row, 'split': 'no'} for row in no_split]
    
#     # Combine the splits and return as an iterator
#     return iter(yes_split + no_split)

In [37]:
# # example usage
# split_rdd = ts_rdd.mapPartitions(split_within_partition)

# # Collect and print the results
# for row in split_rdd.collect():
#     print(row)

{'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5], 'closest_exemplar': 'exemplar_1', 'split': 'yes'}
{'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2], 'closest_exemplar': 'exemplar_1', 'split': 'yes'}
{'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0], 'closest_exemplar': 'exemplar_1', 'split': 'yes'}
{'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2], 'closest_exemplar': 'exemplar_2', 'split': 'no'}
{'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9], 'closest_exemplar': 'exemplar_2', 'split': 'no'}
{'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5], 'closest_exemplar': 'exemplar_2', 'split': 'no'}
{'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0], 'closest_exemplar': 'exemplar_1', 'split': 'yes'}
{'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5], 'closest_exemplar': 'exemplar_1', 'split': 'yes'}
{'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9], 'closest_exemplar': 'exemplar_2', 'split': 'no'}
{'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5], 'clo

###running petrus stuff real quick

In [None]:
import pickle
import pandas as pd
import numpy as np
from pyspark.sql import DataFrame
from aeon.classification.distance_based import ProximityTree, ProximityForest
import logging
from random import sample
from dtaidistance import dtw

class GlobalModelManager:
    def __init__(self):
        self.num_exemplars = 3
        self.num_partitions = 2

    def train(self, df: DataFrame):
        rdd = self.partition_data(df)
        rdd = rdd.repartition(self.num_partitions)

        choose_exemplars = self.choose_exemplars_function(self.num_exemplars)
        rdd_with_exemplar_column = rdd.mapPartitions(choose_exemplars)

        rdd_with_dtw = rdd_with_exemplar_column.mapPartitions(self.calc_dtw_distance)

        rdd_with_closest_exemplar = rdd_with_dtw.mapPartitions(self.assign_closest_exemplar)

        rdd_with_gini = rdd_with_closest_exemplar.mapPartitions(self.calculate_partition_gini) # gini impurity before splitting

        rdd_splits = rdd_with_closest_exemplar.mapPartitions(self.evaluate_splits_within_partition)

        return rdd_splits.collect()

    def partition_data(self, df: DataFrame) -> DataFrame:
        rdd = df.rdd
        repartitioned_rdd = rdd.repartition(self.num_partitions)
        return repartitioned_rdd

    def choose_exemplars_function(self, num_exemplars):
        def choose_exemplars(iterator):
            partition_data = list(iterator)
            exemplars = []
            for row in sample(partition_data, min(num_exemplars, len(partition_data))):
                exemplars.append(row['time_series'])
            return iter([{**row.asDict(), "exemplars": exemplars} for row in partition_data])
        return choose_exemplars
    
    def calc_dtw_distance(self, iterator):
        partition_data = list(iterator)
        updated_rows = []
        
        for row in partition_data:
            time_series = row['time_series']
            exemplars = row['exemplars']
            
            dtw_distances = [dtw.distance(time_series, exemplar) for exemplar in exemplars]
            
            updated_row = {**row}

            for i, dtw_distance in enumerate(dtw_distances):
                updated_row[f"dtw_distance_exemplar_{i+1}"] = dtw_distance
            
            updated_rows.append(updated_row)
        
        return iter(updated_rows)
    
    def assign_closest_exemplar(self, iterator):
        partition_data = list(iterator)

        for row in partition_data:
            # Check if there are DTW distances for exemplars
            exemplar_distances = {key: value for key, value in row.items() if key.startswith("dtw_distance_exemplar_")}
            
            if exemplar_distances:
                # Find the exemplar with the smallest DTW distance
                closest_exemplar = min(exemplar_distances, key=exemplar_distances.get)
                
                # Assign the closest exemplar to the row
                row["closest exemplar"] = closest_exemplar

        return iter(partition_data)
    
    def calculate_partition_gini(self, iterator):
        partition_data = list(iterator)
        labels = [row['label'] for row in partition_data]

        # Calculate Gini impurity for the partition
        label_counts_dict = {}
        for label in labels:
            label_counts_dict[label] = label_counts_dict.get(label, 0) + 1

        total = sum(label_counts_dict.values())
        proportion_sqrd_values = [(count / total) ** 2 for count in label_counts_dict.values()]
        gini_impurity = 1 - sum(proportion_sqrd_values)

        # Add Gini impurity to each row in the partition
        updated_rows = []
        for row in partition_data:
            updated_row = {**row, "partition_gini": gini_impurity}
            updated_rows.append(updated_row)

        return iter(updated_rows)
    
    def calculate_gini(self, labels):
        label_counts = {}
        for label in labels:
            label_counts[label] = label_counts.get(label, 0) + 1
        total = sum(label_counts.values())
        gini = 1 - sum((count / total) ** 2 for count in label_counts.values())
        return gini

    def evaluate_splits_within_partition(self, iterator):
        partition_data = list(iterator)
        
        # If the partition is empty, return an empty iterator
        if not partition_data:
            return iter([])
        
        partition_gini = self.calculate_partition_gini(partition_data)
        
        # Get all unique exemplar names in the partition
        unique_exemplars = set(row['closest exemplar'] for row in partition_data)
        
        updated_rows = []
        
        # Loop through each exemplar to evaluate splits
        for exemplar_name in unique_exemplars:
            # Split the data based on the current exemplar
            yes_split = [row for row in partition_data if row['closest exemplar'] == exemplar_name]
            no_split = [row for row in partition_data if row['closest exemplar'] != exemplar_name]
            
            # Calculate metrics for the split (e.g., Gini impurity)
            yes_labels = [row['label'] for row in yes_split]
            no_labels = [row['label'] for row in no_split]
            
            yes_gini = self.calculate_gini(yes_labels)
            no_gini = self.calculate_gini(no_labels)
            
            # Add the split information to each row in the partition
            for row in partition_data:
                updated_row = {
                    **row,
                    "before_split_partition_gini": partition_gini,
                    "split_exemplar": exemplar_name,
                    "yes_gini": yes_gini,
                    "no_gini": no_gini,
                    "yes_split_size": len(yes_split),
                    "no_split_size": len(no_split)
                }
                updated_rows.append(updated_row)
    
        # Return the updated rows as an iterator
        return iter(updated_rows)

In [97]:
tsdata = [
    {'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0]},
    {'label': 1, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 1, 'time_series': [0.9, 1.8, 2.7, 3.6, 4.5]},
    {'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9]},
    {'label': 1, 'time_series': [0.8, 1.7, 2.5, 3.2, 4.0]},
    {'label': 2, 'time_series': [2.1, 3.3, 4.5, 5.7, 6.9]},
    {'label': 2, 'time_series': [3.0, 3.8, 4.6, 5.4, 6.2]},
    {'label': 2, 'time_series': [3.3, 4.1, 4.9, 5.7, 6.5]},
    {'label': 3, 'time_series': [0.5, 1.5, 2.5, 3.5, 4.5]},
    {'label': 3, 'time_series': [2.0, 2.5, 3.0, 3.5, 4.0]},
    {'label': 4, 'time_series': [5.5, 6.6, 7.7, 8.8, 9.9]},
    {'label': 4, 'time_series': [6.1, 6.2, 6.3, 6.4, 6.5]},
    {'label': 1, 'time_series': [0.7, 1.3, 1.9, 2.5, 3.1]},
    {'label': 1, 'time_series': [1.1, 2.1, 3.1, 4.1, 5.1]},
    {'label': 1, 'time_series': [0.6, 1.2, 1.8, 2.4, 3.0]},
    {'label': 2, 'time_series': [2.4, 3.5, 4.6, 5.7, 6.8]},
    {'label': 2, 'time_series': [1.9, 2.8, 3.7, 4.6, 5.5]},
    {'label': 3, 'time_series': [1.0, 1.8, 2.6, 3.4, 4.2]},
    {'label': 4, 'time_series': [6.0, 7.0, 8.0, 9.0, 10.0]},
    {'label': 1, 'time_series': [1.3, 2.3, 3.3, 4.3, 5.3]}
]

df = spark.createDataFrame(tsdata)
df = df.repartition(2)

In [64]:
df.show(5)

+-----+--------------------+
|label|         time_series|
+-----+--------------------+
|    1|[1.0, 1.8, 2.6, 3...|
|    1|[0.9, 1.8, 2.7, 3...|
|    2|[2.1, 3.3, 4.5, 5...|
|    3|[2.0, 2.5, 3.0, 3...|
|    2|[3.3, 4.1, 4.9, 5...|
+-----+--------------------+
only showing top 5 rows



In [102]:
global_model = GlobalModelManager()
results = global_model.train(df)
for result in results:
    print(result)

{'label': 1, 'time_series': [1.2, 2.4, 3.6, 4.8, 6.0], 'exemplars': [[0.7, 1.3, 1.9, 2.5, 3.1], [0.8, 1.7, 2.5, 3.2, 4.0], [1.5, 2.1, 2.7, 3.3, 3.9]], 'dtw_distance_exemplar_1': 3.4741905532080417, 'dtw_distance_exemplar_2': 2.2847319317591723, 'dtw_distance_exemplar_3': 2.3622023622035435, 'closest exemplar': 'dtw_distance_exemplar_2', 'before_split_partition_gini': <list_iterator object at 0x000001ECAC73C2B0>, 'split_exemplar': 'dtw_distance_exemplar_1', 'yes_gini': 0.0, 'no_gini': 0.6666666666666667, 'yes_split_size': 1, 'no_split_size': 9}
{'label': 1, 'time_series': [1.5, 2.1, 2.7, 3.3, 3.9], 'exemplars': [[0.7, 1.3, 1.9, 2.5, 3.1], [0.8, 1.7, 2.5, 3.2, 4.0], [1.5, 2.1, 2.7, 3.3, 3.9]], 'dtw_distance_exemplar_1': 1.2, 'dtw_distance_exemplar_2': 0.8426149773176358, 'dtw_distance_exemplar_3': 0.0, 'closest exemplar': 'dtw_distance_exemplar_3', 'before_split_partition_gini': <list_iterator object at 0x000001ECAD34EAD0>, 'split_exemplar': 'dtw_distance_exemplar_1', 'yes_gini': 0.0, 'n