# **Sensor Vibration Data Exploration and Interactive Visualization**

This notebook presents the final outputs for exploring and visualizing vibration sensor data. (See also, `visualization_step_by_step.ipynb` for more on the development process.)

It includes two main components:
- **Summary Statistics**: Aggregate statistics comparing "good" and "bad" examples across key metrics, such as mean and standard deviation of vibration signals.
- **Interactive Visualization**: A dynamic Plotly dashboard that allows users to select any "bad" example from a dropdown menu and view the corresponding time-series vibration data compared side-by-side with a matching "good" example.

The goal of this notebook is to provide an intuitive, interactive overview of the vibration patterns and enable quick comparisons between good and bad sensor behaviors.


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, asc
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, clear_output
from pyspark.sql.functions import count, mean, stddev
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.notebook import tqdm

## **Start the Spark session and Connect to Delta Lake**

In [None]:
# 🔁 Start a new Spark session (same Delta config as before)
spark = SparkSession.builder \
    .appName("Delta Lake Summary") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()

In [None]:
# Path to data
DELTA_PATH = "../data/delta"  # Update if needed

In [None]:
# Load the Delta table
df = spark.read.format("delta").load(DELTA_PATH)

## **Basic Summary Statistics**

In [None]:
# Get all unique examples (bad and good)
example_keys = (
    df.select("machine_id", "month", "year", "operation", "example_no", "label")
      .distinct()
      .collect()
)

In [None]:

# Helper function to process one example
def compute_stats(key):
    machine_id = key["machine_id"]
    month = key["month"]
    year = key["year"]
    operation = key["operation"]
    example_no = key["example_no"]
    label = key["label"]

    example_df = df.filter(
        (col("machine_id") == machine_id) &
        (col("month") == month) &
        (col("year") == year) &
        (col("operation") == operation) &
        (col("example_no") == example_no) &
        (col("label") == label)
    )

    try:
        stats = (
            example_df.agg(
                count("*").alias("n_rows"),
                mean("x").alias("mean_x"),
                mean("y").alias("mean_y"),
                mean("z").alias("mean_z"),
                stddev("x").alias("sd_x"),
                stddev("y").alias("sd_y"),
                stddev("z").alias("sd_z")
            ).collect()[0]
        )

        row_data = {
            "label": label,
            "n_rows": stats["n_rows"],
            "length_seconds": stats["n_rows"] / 2000,
            "mean_x": stats["mean_x"],
            "mean_y": stats["mean_y"],
            "mean_z": stats["mean_z"],
            "sd_x": stats["sd_x"],
            "sd_y": stats["sd_y"],
            "sd_z": stats["sd_z"]
        }

    except Exception as e:
        row_data = {"label": label, "error": str(e)}

    del example_df
    gc.collect()

    return row_data

# Run in parallel with tqdm progress bar
rows = []
with ThreadPoolExecutor(max_workers=4) as executor:
    futures = [executor.submit(compute_stats, key) for key in example_keys]
    for f in tqdm(as_completed(futures), total=len(futures)):
        rows.append(f.result())

In [None]:
# Create pandas DataFrame from all examples
example_stats_df = pd.DataFrame(rows)

In [None]:
# Compute average stats by label
summary = {
    "stat": ["Number of examples", "Avg rows per example", "Avg length (sec)",
             "Avg mean x", "Avg mean y", "Avg mean z",
             "Avg sd x", "Avg sd y", "Avg sd z"],
    "good": [],
    "bad": []
}

for label in ["good", "bad"]:
    subset = example_stats_df[example_stats_df["label"] == label]
    summary["good" if label == "good" else "bad"].extend([
        len(subset),
        subset["n_rows"].mean(),
        subset["length_seconds"].mean(),
        subset["mean_x"].mean(),
        subset["mean_y"].mean(),
        subset["mean_z"].mean(),
        subset["sd_x"].mean(),
        subset["sd_y"].mean(),
        subset["sd_z"].mean()
    ])

# Create a DataFrame for summary statistics
summary_df = pd.DataFrame(summary)

In [None]:
# Round the summary DataFrame
# First, identify rows for "number of examples" (leave as integer)
summary_df.loc[summary_df["stat"] == "number of examples", ["good", "bad"]] = summary_df.loc[
    summary_df["stat"] == "number of examples", ["good", "bad"]
].round(0).astype(int)

# Then round all other stats to 1 decimal place
summary_df.loc[summary_df["stat"] != "number of examples", ["good", "bad"]] = summary_df.loc[
    summary_df["stat"] != "number of examples", ["good", "bad"]
].round(1)

In [None]:
# Display the summary DataFrame
display(summary_df)

## **Interactive dashboard to view the "bad" examples and "good" examples side by side**

In [None]:
# Get metadata for the first 36 distinct "bad" examples
bad_meta = (
    df.filter(col("label") == "bad")
      .select("machine_id", "month", "year", "operation")
      .distinct()
      .orderBy("year", "month", "operation", "machine_id")
      .limit(36)
      .collect()
)

# Convert to pandas DataFrame
bad_meta_df = pd.DataFrame([row.asDict() for row in bad_meta])


In [None]:
bad_meta_df["example_number"] = bad_meta_df.index + 1

# Display or use for plotting
bad_meta_df.head(5)

In [None]:
# Create dropdown
dropdown = widgets.Dropdown(
    options=[(f"Example {row['example_number']}", idx) for idx, row in bad_meta_df.iterrows()],
    description='Example:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

# Output widget
output = widgets.Output()

# Callback
def update_plot(change):
    with output:
        clear_output(wait=True)
        print("Loading...")

    with output:
        clear_output(wait=True)

        # Select metadata
        selected = bad_meta_df.iloc[change['new']]
        machine_id = selected["machine_id"]
        month = selected["month"]
        year = selected["year"]
        operation = selected["operation"]

        # Get "bad" data
        bad_df = df.filter(
            (col("machine_id") == machine_id) &
            (col("month") == month) &
            (col("year") == year) &
            (col("operation") == operation) &
            (col("example_no") == "000") &
            (col("label") == "bad")
        )
        bad_rows = bad_df.collect()
        x_bad = [row["x"] for row in bad_rows]
        y_bad = [row["y"] for row in bad_rows]
        z_bad = [row["z"] for row in bad_rows]
        t_bad = np.linspace(0, len(x_bad)/2000, len(x_bad))  # 2000 Hz

        # Get "good" data
        good_df = df.filter(
            (col("machine_id") == machine_id) &
            (col("month") == month) &
            (col("year") == year) &
            (col("operation") == operation) &
            (col("example_no") == "000") &
            (col("label") == "good")
        )
        good_rows = good_df.collect()
        x_good = [row["x"] for row in good_rows]
        y_good = [row["y"] for row in good_rows]
        z_good = [row["z"] for row in good_rows]
        t_good = np.linspace(0, len(x_good)/2000, len(x_good))

        # Use "bad" axis ranges
        x_min, x_max = min(x_bad), max(x_bad)
        y_min, y_max = min(y_bad), max(y_bad)
        z_min, z_max = min(z_bad), max(z_bad)
        t_min, t_max = min(t_bad), max(t_bad)

        # Create subplot
        fig = make_subplots(rows=3, cols=2,
                            shared_xaxes=False,
                            subplot_titles=("X (Bad)", "X (Good)", 
                                            "Y (Bad)", "Y (Good)",
                                            "Z (Bad)", "Z (Good)"))

        # X axis
        fig.add_trace(go.Scatter(x=t_bad, y=x_bad, mode="lines"), row=1, col=1)
        fig.add_trace(go.Scatter(x=t_good, y=x_good, mode="lines"), row=1, col=2)

        # Y axis
        fig.add_trace(go.Scatter(x=t_bad, y=y_bad, mode="lines"), row=2, col=1)
        fig.add_trace(go.Scatter(x=t_good, y=y_good, mode="lines"), row=2, col=2)

        # Z axis
        fig.add_trace(go.Scatter(x=t_bad, y=z_bad, mode="lines"), row=3, col=1)
        fig.add_trace(go.Scatter(x=t_good, y=z_good, mode="lines"), row=3, col=2)

        # Axis syncing
        fig.update_layout(
            height=900,
            title_text=f"Bad vs Good Vibration Signal – Example {selected['example_number']}",
            showlegend=False,
            xaxis=dict(range=[t_min, t_max]),
            xaxis2=dict(range=[t_min, t_max]),
            xaxis3=dict(range=[t_min, t_max]),
            xaxis4=dict(range=[t_min, t_max]),
            xaxis5=dict(range=[t_min, t_max]),
            xaxis6=dict(range=[t_min, t_max]),
            yaxis=dict(range=[x_min, x_max]),
            yaxis2=dict(range=[x_min, x_max]),
            yaxis3=dict(range=[y_min, y_max]),
            yaxis4=dict(range=[y_min, y_max]),
            yaxis5=dict(range=[z_min, z_max]),
            yaxis6=dict(range=[z_min, z_max])
        )

        fig.show()

# Connect callback
dropdown.observe(update_plot, names='value')

# Display widgets
display(dropdown, output)
update_plot({'new': 0})  # Show initial example