In [29]:
import numpy as np
import plotly.graph_objects as go
import ipywidgets as widgets
import time

# -------------------------------
# Data + Helpers
# -------------------------------
np.random.seed(42)
n = 300

def generate_rotated_data(angle=30.0, var_parallel=5.0, var_perp=0.5, n_samples=n):
    u = np.sqrt(var_parallel) * np.random.randn(n_samples)
    v = np.sqrt(var_perp) * np.random.randn(n_samples)
    data = np.vstack([u, v])
    theta = np.deg2rad(angle)
    R = np.array([[np.cos(theta), -np.sin(theta)],
                  [np.sin(theta),  np.cos(theta)]])
    return (R @ data).T

def projection_error(data, alpha):
    a_rad = np.deg2rad(alpha)
    perp_vec = np.array([-np.sin(a_rad), np.cos(a_rad)])
    coords = data @ perp_vec
    return np.mean(coords**2)

def iterative_search(data, step_size=5):
    current_alpha = 0
    trajectory = []
    while True:
        err = projection_error(data, current_alpha)
        left = projection_error(data, (current_alpha-step_size) % 180)
        right = projection_error(data, (current_alpha+step_size) % 180)
        trajectory.append((current_alpha, err))
        if err <= left and err <= right:
            break
        current_alpha = (current_alpha - step_size) % 180 if left < right else (current_alpha + step_size) % 180
    return trajectory

# -------------------------------
# Robust scaling (for sensitivity) + RPCA (IALM)
# -------------------------------
def _col_mad(x):
    med = np.median(x, axis=0)
    mad = np.median(np.abs(x - med), axis=0)
    scale = 1.4826 * mad
    scale[scale < 1e-8] = 1.0  # avoid divide-by-zero; degenerate dims
    return med, scale

def _rpca_ialm(Xz, lam=0.06, mu=None, max_iter=1000, tol=1e-7):
    """RPCA on standardized data Xz -> returns Lz, Sz."""
    m, d = Xz.shape
    normX = np.linalg.norm(Xz, 'fro')
    if normX == 0:
        return np.zeros_like(Xz), np.zeros_like(Xz)

    Y = Xz / normX
    Lz = np.zeros_like(Xz)
    Sz = np.zeros_like(Xz)

    if mu is None:
        denom = np.sum(np.abs(Xz))
        mu = (m * d) / (4.0 * max(denom, 1e-8))

    for _ in range(max_iter):
        # SVT for L
        U, s, VT = np.linalg.svd(Xz - Sz + (1/mu)*Y, full_matrices=False)
        s_thr = np.maximum(s - 1/mu, 0)
        Lz = U @ np.diag(s_thr) @ VT

        # Soft-threshold for S (λ acts here)
        R = Xz - Lz + (1/mu)*Y
        Sz = np.sign(R) * np.maximum(np.abs(R) - lam/mu, 0)

        # Dual update
        Z = Xz - Lz - Sz
        Y += mu * Z

        if np.linalg.norm(Z, 'fro') / normX < tol:
            break

    return Lz, Sz

def robust_pca_scaled(X, lam=0.06, max_iter=1000, tol=1e-7):
    """
    Standardize columns (robustly) before RPCA so λ has clear effect.
    Return L, S in original units and Sz in standardized units (for outlier test/heatmap).
    """
    med, scale = _col_mad(X)
    Xz = (X - med) / scale  # robust standardization (broadcast)
    Lz, Sz = _rpca_ialm(Xz, lam=lam, max_iter=max_iter, tol=tol)
    # Back to original units: X = L + S
    L = Lz * scale + med
    S = Sz * scale
    return L, S, Sz  # Sz drives the heatmap and "S ≠ 0" test

# -------------------------------
# Setup Plot
# -------------------------------
data0 = generate_rotated_data(40, 5, 0.5)
data = data0.copy()
alphas = np.linspace(0, 180, 181)
errors = [projection_error(data, a) for a in alphas]

fig = go.FigureWidget()

# Traces:
# 0: Inliers (blue solid)  – from L (RPCA) or X (Standard / initial)
# 1: Outliers (open circle) – from S (shown at X)
# 2: α-axis (red dashed)
# 3: ⟂ axis (green dashed)
# 4: Error curve
# 5: Current α line
# 6: Current error point
# 7: Binary heatmap of |Sz|>eps  (blue=0, yellow=1)

fig.add_scatter(x=[], y=[], mode="markers",
                marker=dict(color="blue", size=7),
                name="Inliers (L)")

fig.add_scatter(x=[], y=[], mode="markers",
                marker=dict(symbol="circle-open", size=8, line=dict(width=2, color="black")),
                name="Outliers (S ≠ 0, plotted at X)")

fig.add_scatter(x=[], y=[], mode="lines",
                line=dict(color="red", dash="dash"), name="α-axis")

fig.add_scatter(x=[], y=[], mode="lines",
                line=dict(color="green", dash="dash"), name="⊥ axis")

fig.add_scatter(x=alphas.tolist(), y=[float(v) for v in errors],
                mode="lines", line=dict(color="gray"),
                name="Error curve", xaxis="x2", yaxis="y2")

fig.add_scatter(x=[], y=[], mode="lines",
                line=dict(color="red", dash="dash"),
                name="Current α", xaxis="x2", yaxis="y2")

fig.add_scatter(x=[], y=[], mode="markers",
                marker=dict(color="red", size=10),
                name="Current error", xaxis="x2", yaxis="y2")

# --- Binary heatmap placeholder (0=blue, 1=yellow), no colorbar ---
binary_zero = np.zeros((data.shape[0], data.shape[1]), dtype=int)
fig.add_trace(go.Heatmap(
    z=binary_zero,
    zmin=0, zmax=1,
    x=["x1","x2"],
    y=list(range(data.shape[0])),
    showscale=False,
    colorscale=[ [0.0, "#1f77b4"], [0.5, "#1f77b4"], [0.5, "#FFD700"], [1.0, "#FFD700"] ],
    name="S nonzero mask",
    xaxis="x3", yaxis="y3"
))

# --- Seed the plot with the raw data so it's visible before Recompute ---
with fig.batch_update():
    fig.data[0].x = data[:,0].tolist()        # show raw X as blue points initially
    fig.data[0].y = data[:,1].tolist()
    fig.layout.title = "Data ready — press Recompute to run PCA / RPCA"

fig.update_layout(
    width=1200, height=600,
    xaxis=dict(domain=[0.00, 0.45], title="x", scaleanchor="y", scaleratio=1),
    yaxis=dict(domain=[0.05, 0.95], title="y"),
    # Right panel TOP: error curve
    xaxis2=dict(domain=[0.55, 1.00], title="α (degrees)"),
    yaxis2=dict(domain=[0.55, 0.95], title="Projection error (MSE)"),
    # Right panel BOTTOM: binary heatmap of S≠0
    xaxis3=dict(domain=[0.55, 1.00], title="dimension"),
    yaxis3=dict(domain=[0.05, 0.50], title="sample index"),
    margin=dict(l=40, r=20, t=50, b=40)
)

# -------------------------------
# Widgets
# -------------------------------
x_input = widgets.FloatText(description="x:")
y_input = widgets.FloatText(description="y:")
add_btn = widgets.Button(description="Add point", button_style="info")

# NEW: random outliers controls
add_outliers_btn = widgets.Button(description="Add random outliers", button_style="warning")
outlier_k_slider = widgets.IntSlider(value=15, min=1, max=200, step=1, description="k outliers")
outlier_strength_slider = widgets.FloatSlider(value=5.0, min=1.0, max=10.0, step=0.5, description="strength")

recompute_btn = widgets.Button(description="Recompute", button_style="success")
reset_btn = widgets.Button(description="Reset", button_style="danger")

method_dropdown = widgets.Dropdown(
    options=['Standard PCA', 'Robust PCA'],
    value='Standard PCA',
    description='Method:',
)

# λ slider (directly used inside RPCA soft-thresholding)
default_lambda = 1 / np.sqrt(max(n, 2))  # ≈ 0.0577 for n=300
lambda_slider = widgets.FloatLogSlider(
    value=float(default_lambda), base=10,
    min=-3,  # 0.001
    max=0,   # 1.0
    step=0.01,
    description='λ (RPCA):',
    readout_format=".4f",
    continuous_update=False
)

# -------------------------------
# Button callbacks
# -------------------------------
def add_point(b):
    global data
    new_point = np.array([[x_input.value, y_input.value]])
    data = np.vstack([data, new_point])
    # update visible blue points immediately & extend heatmap y range
    with fig.batch_update():
        fig.data[0].x = data[:,0].tolist()
        fig.data[0].y = data[:,1].tolist()
        fig.data[7].y = list(range(data.shape[0]))
        fig.layout.title = f"Point added ({x_input.value:.2f}, {y_input.value:.2f}). Press Recompute to update PCA / RPCA."

def add_random_outliers(b):
    """Add k outliers far from the current cloud, controlled by 'strength'."""
    global data
    k = int(outlier_k_slider.value)
    strength = float(outlier_strength_slider.value)

    # Use robust center & scale to size the cloud
    center = np.median(data, axis=0)
    # radius ~ max distance from center
    R = float(np.max(np.linalg.norm(data - center, axis=1)))
    if R < 1e-6:
        R = 1.0

    angles = 2*np.pi*np.random.rand(k)
    radii = (1.5 + np.random.rand(k)) * strength * R  # far from cloud
    noise = 0.05 * R * np.random.randn(k, 2)          # tiny jitter
    ring = np.column_stack([np.cos(angles), np.sin(angles)]) * radii[:, None]
    points = center + ring + noise

    data = np.vstack([data, points])

    with fig.batch_update():
        fig.data[0].x = data[:,0].tolist()
        fig.data[0].y = data[:,1].tolist()
        fig.data[7].y = list(range(data.shape[0]))  # resize heatmap rows
        fig.layout.title = f"Added {k} random outliers (strength={strength:.1f}). Press Recompute."

def recompute(b):
    global data
    method = method_dropdown.value
    lam = float(lambda_slider.value)

    if method == "Robust PCA":
        # RPCA with robust column scaling -> L, S in original units; Sz standardized
        L, S, Sz = robust_pca_scaled(data, lam=lam)

        # Outliers flagged ONLY by non-zero entries of Sz (tiny eps)
        eps = 1e-6
        outlier_mask = np.any(np.abs(Sz) > eps, axis=1)

        # Each sample appears exactly once:
        L_inliers = L[~outlier_mask]          # blue (at L)
        X_outliers = data[outlier_mask]       # open circle (at X)

        clean_for_axis = L if len(L) else data
        title_left = f"Robust PCA (λ={lam:.4f}) • Outliers={int(outlier_mask.sum())}"

        all_errors = [projection_error(clean_for_axis, a) for a in alphas]
        traj = iterative_search(clean_for_axis)

        # --- Binary mask for S: 1 where |Sz|>eps, else 0
        heat_mask = (np.abs(Sz) > eps).astype(int)
    else:
        # Standard PCA: all points shown as inliers at X; S = 0
        L_inliers = data
        X_outliers = np.empty((0,2))
        clean_for_axis = data
        title_left = "Standard PCA • Outliers=0"

        all_errors = [projection_error(clean_for_axis, a) for a in alphas]
        traj = iterative_search(clean_for_axis)

        heat_mask = np.zeros((data.shape[0], data.shape[1]), dtype=int)

    # Long axes for clarity
    Lmax = float(np.max(np.abs(clean_for_axis)) * 6.0)

    for step, (a, e) in enumerate(traj):
        theta = np.deg2rad(a)
        axis_vec = np.array([np.cos(theta), np.sin(theta)])
        perp_vec = np.array([-np.sin(theta), np.cos(theta)])
        with fig.batch_update():
            # 0: blue inliers
            fig.data[0].x = L_inliers[:,0].tolist() if len(L_inliers) else []
            fig.data[0].y = L_inliers[:,1].tolist() if len(L_inliers) else []
            # 1: open-circle outliers
            fig.data[1].x = X_outliers[:,0].tolist() if len(X_outliers) else []
            fig.data[1].y = X_outliers[:,1].tolist() if len(X_outliers) else []
            # 2–3: axes
            fig.data[2].x, fig.data[2].y = [-Lmax*axis_vec[0], Lmax*axis_vec[0]], [-Lmax*axis_vec[1], Lmax*axis_vec[1]]
            fig.data[3].x, fig.data[3].y = [-Lmax*perp_vec[0], Lmax*perp_vec[0]], [-Lmax*perp_vec[1], Lmax*perp_vec[1]]
            # 4–6: error panel
            fig.data[4].x, fig.data[4].y = alphas.tolist(), [float(v) for v in all_errors]
            fig.data[5].x, fig.data[5].y = [a, a], [0, float(max(all_errors))*1.05]
            fig.data[6].x, fig.data[6].y = [a], [float(e)]
            # 7: binary heatmap (0=blue, 1=yellow)
            fig.data[7].z = heat_mask
            fig.data[7].x = ["x1", "x2"]
            fig.data[7].y = list(range(data.shape[0]))
            fig.layout.title = f"{title_left} • Step {step+1}: α={a:.1f}°, Error={e:.3f}"

        time.sleep(0.3)

def reset_plot(b):
    global data
    data = data0.copy()
    with fig.batch_update():
        # repopulate the blue trace with raw data
        fig.data[0].x, fig.data[0].y = data[:,0].tolist(), data[:,1].tolist()
        # Clear outliers, axes, and the current alpha/error
        for i in [1,2,3,5,6]:
            fig.data[i].x, fig.data[i].y = [], []
        # Reset error curve for original data
        fig.data[4].x = alphas.tolist()
        fig.data[4].y = [float(projection_error(data, a)) for a in alphas]
        # Reset heatmap to zero/blue with correct size
        fig.data[7].z = np.zeros((data.shape[0], data.shape[1]), dtype=int)
        fig.data[7].y = list(range(data.shape[0]))
        fig.layout.title = "Data reset — press Recompute to run PCA / RPCA"

add_btn.on_click(add_point)
add_outliers_btn.on_click(add_random_outliers)
recompute_btn.on_click(recompute)
reset_btn.on_click(reset_plot)

# -------------------------------
# Expose UI
# -------------------------------
ui_plotly = widgets.VBox([
    widgets.HBox([x_input, y_input, add_btn, add_outliers_btn, recompute_btn, reset_btn]),
    widgets.HBox([method_dropdown, lambda_slider, outlier_k_slider, outlier_strength_slider])
])
display(fig, ui_plotly)


FigureWidget({
    'data': [{'marker': {'color': 'blue', 'size': 7},
              'mode': 'markers',
              'name': 'Inliers (L)',
              'type': 'scatter',
              'uid': '89e9c010-a031-4299-be79-e6c26803e26b',
              'x': [1.2276296861580742, 0.01777647715448856, 0.7697839759880035,
                    ..., 0.9712489571128977, 1.5943137757894184,
                    0.9070181027370388],
              'y': [0.26488953686131694, -0.5021657900034965, 1.3357240390051466,
                    ..., -0.0884852904615885, 0.9276777220670097,
                    1.1093504339222986]},
             {'marker': {'line': {'color': 'black', 'width': 2}, 'size': 8, 'symbol': 'circle-open'},
              'mode': 'markers',
              'name': 'Outliers (S ≠ 0, plotted at X)',
              'type': 'scatter',
              'uid': '74948836-b17f-4f23-a106-0c10d1859a98',
              'x': [],
              'y': []},
             {'line': {'color': 'red', 'dash': 'dash'},
  

VBox(children=(HBox(children=(FloatText(value=0.0, description='x:'), FloatText(value=0.0, description='y:'), …