In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
def plot_grid_search_results(results, model_type="MLP"):
    """
    Plots validation losses for each architecture tested during grid search.

    Parameters
    ----------
    results : list of tuples
        Output from grid_search_* function, formatted as [({params_dict}, val_loss), ...].
        Each params_dict must contain 'hidden_layers', 'lr', and 'batch_size'.
    model_type : str, optional
        Model type for plot title (e.g., 'MLP', 'Autoencoder', etc.)
    """
    if not results:
        print("⚠️ No results to plot.")
        return

    # Prepare data for plotting
    architectures, losses = [], []
    for params, loss in results:
        arch = "-".join(str(x) for x in params.get('hidden_layers', []))
        label = f"{arch}\nLR={params.get('lr')}, B={params.get('batch_size')}"
        architectures.append(label)
        losses.append(loss)

    # Sort by validation loss
    sorted_indices = sorted(range(len(losses)), key=lambda i: losses[i])
    architectures = [architectures[i] for i in sorted_indices]
    losses = [losses[i] for i in sorted_indices]

    # Plot
    plt.figure(figsize=(12, 6))
    bars = plt.bar(range(len(losses)), losses, color='skyblue')
    plt.xticks(range(len(losses)), architectures, rotation=90)
    plt.ylabel("Validation Loss (MSE)")
    plt.title(f"Grid Search Results: {model_type} Architectures vs Validation Loss")
    plt.tight_layout()

    # Highlight best model
    bars[0].set_color('limegreen')
    plt.show()

    print(f"✅ Best model: {architectures[0]} with val_loss = {losses[0]:.6f}")
    return architectures, losses

In [None]:
def plot_3d_actual_vs_predicted(y_test_real, y_pred_real, title="3D Tip Position: Actual vs Predicted"):
    """
    Plots a 3D scatter of actual vs predicted data.

    Parameters
    ----------
    y_test_real : np.ndarray
        Ground-truth 3D positions (de-standardized).
    y_pred_real : np.ndarray
        Predicted 3D positions (de-standardized).
    title : str, optional
        Plot title.
    """

    # 3D scatter plot
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    ax.scatter(
        y_test_real[:, 0], y_test_real[:, 1], y_test_real[:, 2],
        c='blue', label='Actual', s=15, alpha=0.5
    )
    ax.scatter(
        y_pred_real[:, 0], y_pred_real[:, 1], y_pred_real[:, 2],
        c='red', label='Predicted', s=15, alpha=0.5
    )

    ax.set_xlabel('X (mm)')
    ax.set_ylabel('Y (mm)')
    ax.set_zlabel('Z (mm)')
    ax.set_title(title)
    ax.legend()
    plt.tight_layout()
    plt.show()

    # Compute simple residual metric
    residuals = np.linalg.norm(y_pred_real - y_test_real, axis=1)
    mean_residual = np.mean(residuals)
    print(f"Mean residual distance: {mean_residual:.3f} mm")

    return residuals, y_pred_real, y_test_real

In [None]:
def compute_3d_metrics(y_pred_real, y_test_real, print_results=True):
    """
    Computes per-axis RMSE and MAE, as well as 3D Euclidean error statistics.

    Parameters
    ----------
    y_pred_real : np.ndarray
        Predicted (de-standardized) outputs, shape (N, 3).
    y_test_real : np.ndarray
        True (de-standardized) outputs, shape (N, 3).
    print_results : bool, optional
        Whether to print the metrics.

    Returns
    -------
    metrics : dict
        Dictionary containing RMSE, MAE, mean and std of 3D errors.
    """
    # Residuals
    residuals = y_pred_real - y_test_real

    # Per-axis RMSE and MAE
    rmse = np.sqrt(np.mean(residuals**2, axis=0))
    mae = np.mean(np.abs(residuals), axis=0)

    # 3D Euclidean error per sample
    euclidean_error = np.linalg.norm(residuals, axis=1)
    mean_3d_error = np.mean(euclidean_error)
    std_3d_error = np.std(euclidean_error)

    metrics = {
        "RMSE": rmse,
        "MAE": mae,
        "Mean_3D_Error": mean_3d_error,
        "Std_3D_Error": std_3d_error,
        "Per_Sample_Euclidean_Error": euclidean_error
    }

    if print_results:
        print(f"RMSE [x, y, z]: {rmse}")
        print(f"MAE  [x, y, z]: {mae}")
        print(f"Mean 3D deviation: {mean_3d_error:.4f} ± {std_3d_error:.4f}")

    return metrics

In [None]:
def plot_residuals_and_errors(y_test_real, y_pred_real):
    """
    Creates a 4-panel plot:
      - Residuals vs Actual for X, Y, Z
      - Histogram of 3D Euclidean errors

    Parameters
    ----------
    y_test_real : np.ndarray
        Ground truth values (N, 3)
    y_pred_real : np.ndarray
        Predicted values (N, 3)
    """
    # Compute residuals and errors
    residuals = y_pred_real - y_test_real
    euclidean_error = np.linalg.norm(residuals, axis=1)
    titles = ['X', 'Y', 'Z']

    # 4 subplots: 3 for residuals + 1 for histogram
    fig, axes = plt.subplots(1, 4, figsize=(18, 4))

    for i, ax in enumerate(axes[:3]):
        ax.scatter(y_test_real[:, i], 
                   residuals[:, i],
                   s=8, 
                   alpha=0.6, 
                   color='teal')
        ax.axhline(0, color='black', 
                   linestyle='--')
        ax.set_xlabel(f"True {titles[i]} (mm)")
        ax.set_ylabel("Residual (Pred - True)")
        ax.set_title(f"{titles[i]} residuals vs actual")
        ax.grid(True)

    # Histogram of 3D Euclidean errors
    axes[3].hist(euclidean_error, 
                 bins=25, 
                 color='salmon', 
                 alpha=0.8)
    axes[3].set_xlabel("3D Euclidean Error (mm)")
    axes[3].set_ylabel("Count")
    axes[3].set_title("Distribution of 3D Errors")
    axes[3].grid(True)

    plt.tight_layout()
    plt.show()

    mean_error = np.mean(euclidean_error)
    std_error = np.std(euclidean_error)
    print(f"Mean 3D deviation: {mean_error:.4f} ± {std_error:.4f}")

    return residuals, euclidean_error
