In [1]:
from src.recurrence.protree.data.stream_generators import Sine, Plane, RandomTree
from src.recurrence.protree.data.river_generators import Sea, Rbf, Stagger, Mixed
import random
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from river import drift as river_drift
from sklearn.neural_network import MLPClassifier
from src.DDM.BinaryErrorDriftDescriptor import BinaryErrorDriftDescriptor

In [2]:
import plotly.io as pio
pio.renderers.default = "browser"

In [3]:
num_windows = 60
window_length = 1000
ds = Mixed(drift_position=[20_000, 30_000, 45_000], seed = random.randint(1,10000))

# Generate all windows
all_x = []
all_y = []

for i in range(num_windows):
    x_block, y_block = zip(*ds.take(window_length))
    all_x.extend(x_block)
    all_y.extend(y_block)

# Convert to DataFrame and Series
# The x_block items are dictionaries with feature names as keys
X = pd.DataFrame(all_x)
y = pd.Series(all_y, name='target')

X_np = X.to_numpy() if hasattr(X, "to_numpy") else X
y_np = y.to_numpy() if hasattr(y, "to_numpy") else y

In [4]:
def generate_error_stream(X_np, y_np, model):
    """
    Runs test-then-train once and returns the error stream and predictions.
    """
    classes = np.unique(y_np)

    error_stream = []
    predictions = []

    # Initialize model
    model.partial_fit(X_np[0].reshape(1, -1), [y_np[0]], classes=classes)
    error_stream.append(1)  # Add the first sample as if it was an error

    for i in range(1, len(X_np)):
        x_i = X_np[i].reshape(1, -1)
        y_true = y_np[i]

        # Predict
        y_pred = model.predict(x_i)[0]
        predictions.append(y_pred)

        # Binary error
        error_stream.append(int(y_pred != y_true))

        # Learn
        model.partial_fit(x_i, [y_true])


    return error_stream, predictions


def run_drift_detection(
    error_stream,
    detector_type,
    warning_grace_period,
    rate_calculation_sample_size,
    lookback_method='gradient',
    lookforward_method='peak',
    confidence_level=None
):
    # Create detector
    if detector_type == "DDM":
        detector = river_drift.binary.DDM()
    elif detector_type == "EDDM":
        detector = river_drift.binary.EDDM()
    elif detector_type == "FHDDM":
        detector = river_drift.binary.FHDDM(confidence_level=confidence_level)
    elif detector_type == "HDDM_A":
        detector = river_drift.binary.HDDM_A()
    elif detector_type == "HDDM_W":
        detector = river_drift.binary.HDDM_W()

    drift_descriptor = BinaryErrorDriftDescriptor(
        warning_grace_period=warning_grace_period,
        rate_calculation_sample_size=rate_calculation_sample_size,
        ddm=detector,
        lookback_method=lookback_method,
        lookforward_method=lookforward_method
    )

    drift_descriptions = []

    for i, error in enumerate(error_stream):
        drift_descriptor.update(error)

        if drift_descriptor.drift_detected:
            drift = drift_descriptor.last_detected_drift
            drift.detected_at = i
            drift_descriptions.append(drift)

    # Post-process to find actual drift ends
    drift_descriptions = drift_descriptor.post_process_drift_ends(drift_descriptions)

    return drift_descriptions

In [5]:
def visualization_and_correction(drift_descriptions, error_rate, rate_calculation_sample_size, visualize = True):
    corrected_descriptions = []

    if visualize:
        fig = go.Figure()

        # Add error rate line
        fig.add_trace(go.Scatter(
            x=list(range(rate_calculation_sample_size,
                            rate_calculation_sample_size + len(error_rate))),
            y=error_rate,
            mode='lines',
            name='Error Rate',
            line=dict(color='royalblue', width=2),
            hovertemplate='Index: %{x}<br>Error Rate: %{y:.3f}<extra></extra>'
        ))

    # Add drift annotations using ACTUAL drift start and end
    for idx, drift in enumerate(drift_descriptions):
        # Use the actual drift start index
        if hasattr(drift, 'drift_start_index') and drift.drift_start_index is not None:
            start = drift.drift_start_index
        else:
            start = max(0, drift.detected_at - drift.drift_duration)

        # Use the actual drift end index (peak or recovery point)
        if hasattr(drift, 'drift_end_index') and drift.drift_end_index is not None:
            end = drift.drift_end_index
        else:
            end = drift.detected_at

        # Update end to the index with maximum error rate
        start_idx = start - rate_calculation_sample_size
        end_idx = end - rate_calculation_sample_size
        start_idx = max(0, min(start_idx, len(error_rate) - 1))
        end_idx = max(0, min(end_idx, len(error_rate) - 1))

        if start_idx <= end_idx:
            max_idx_in_range = start_idx + np.argmax(error_rate[start_idx:end_idx + 1])
            end = max_idx_in_range + rate_calculation_sample_size
            drift.drift_end_index = end

        detection = drift.detected_at

        # Make sure they're within the error_rate bounds
        start_plot = max(rate_calculation_sample_size,
                            min(start, rate_calculation_sample_size + len(error_rate) - 1))
        end_plot = max(rate_calculation_sample_size,
                        min(end, rate_calculation_sample_size + len(error_rate) - 1))

        # Calculate error_rate array indices
        start_idx = start - rate_calculation_sample_size
        end_idx = end - rate_calculation_sample_size
        detection_idx = detection - rate_calculation_sample_size
        start_idx = max(0, min(start_idx, len(error_rate) - 1))
        end_idx = max(0, min(end_idx, len(error_rate) - 1))
        detection_idx = max(0, min(detection_idx, len(error_rate) - 1))

        description = {
            'start': start_plot,
            'end': end_plot,
            'duration': end_plot - start_plot,
            'start_error_rate': error_rate[start_idx],
            'end_error_rate': error_rate[end_idx],
            'error_rate_increase': error_rate[end_idx] - error_rate[start_idx]
        }
        corrected_descriptions.append(description)

        if visualize:
            # Add drift region as shaded area (from start to end)
            fig.add_vrect(
                x0=start_plot,
                x1=end_plot,
                fillcolor="rgba(255, 107, 53, 0.2)",
                layer="below",
                line_width=0,
            )

            # Add drift line from start to end
            fig.add_trace(go.Scatter(
                x=[start_plot, end_plot],
                y=[error_rate[start_idx], error_rate[end_idx]],
                mode='lines+markers',
                name=f'Drift {idx+1}',
                line=dict(color='#FF6B35', width=3, dash='dash'),
                marker=dict(size=10, color='#FF6B35', symbol=['circle', 'x']),
                hovertemplate=(
                    f'<b>Drift {idx+1}</b><br>'
                    f'Start Index: {start}<br>'
                    f'Peak/End Index: {end}<br>'
                    f'Duration: {drift.drift_duration}<br>'
                    f'Error Rate at Start: {error_rate[start_idx]:.3f}<br>'
                    f'Error Rate at Peak: {error_rate[end_idx]:.3f}<br>'
                    f'<extra></extra>'
                )
            ))
    if visualize:
        # Update layout
        fig.update_layout(
            title=f'Drift Detection using',
            xaxis_title='Data Point Index',
            yaxis_title='Error Rate',
            hovermode='closest',
            height=600,
            showlegend=True,
            legend=dict(
                yanchor="top",
                y=0.99,
                xanchor="left",
                x=0.01
            )
        )
        fig.show()
        
    return corrected_descriptions

In [6]:
def process_DDM(X, y, detector, warning_grace_period = 4, lookback_method: str = 'gradient',
    lookforward_method: str = 'recovery', model=None, rate_calculation_sample_size = 100, visualize = False):
    if model is None:
        model = MLPClassifier(
            hidden_layer_sizes=(10,),
            max_iter=1,
            random_state=42
        )

    X_np = X.to_numpy() if hasattr(X, "to_numpy") else X
    y_np = y.to_numpy() if hasattr(y, "to_numpy") else y

    error_stream, predictions = generate_error_stream(X_np, y_np, model)
    error_array = np.array(error_stream)
    error_rate = [
        np.mean(error_array[i:i + rate_calculation_sample_size])
        for i in range(len(error_stream) - rate_calculation_sample_size + 1)
    ]

    drift_descriptions = run_drift_detection(error_stream, detector, warning_grace_period, rate_calculation_sample_size, lookback_method, lookforward_method)
    corrected_descriptions = visualization_and_correction(drift_descriptions, error_rate, rate_calculation_sample_size, visualize)

    return corrected_descriptions

In [7]:
process_DDM(X, y, "DDM", visualize=True)

[{'start': 19828,
  'end': np.int64(20109),
  'duration': np.int64(281),
  'start_error_rate': np.float64(0.05),
  'end_error_rate': np.float64(0.94),
  'error_rate_increase': np.float64(0.8899999999999999)}]