In [15]:
import numpy as np
import plotly.graph_objects as go
from scipy.spatial.distance import mahalanobis

In [16]:
# Function to calculate Mahalanobis distance manually
def mahalanobis_distance(x, mean, inv_cov):
    """The idea behind the inv_cov is to use the inverse of the covariance matrix to transform the data into a standard normal distribution."""
    diff = x - mean
    dist = np.sqrt(diff.T @ inv_cov @ diff)
    return dist


# Generate synthetic dataset
np.random.seed(42)
mean = np.array([0, 0])
cov_matrix = np.array([[2.0, 0.5], [0.5, 1.0]])
data = np.random.multivariate_normal(mean, cov_matrix, size=100)

# Calculate the inverse of the covariance matrix
inv_cov_matrix = np.linalg.inv(cov_matrix)

# Calculate Mahalanobis distances for each point
distances = np.array(
    [mahalanobis_distance(point, mean, inv_cov_matrix) for point in data]
)

# Define boundary values for contour plot
x_range = np.linspace(-5, 5, 100)
y_range = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x_range, y_range)
Z = np.array(
    [
        mahalanobis_distance(np.array([x, y]), mean, inv_cov_matrix)
        for x, y in zip(X.flatten(), Y.flatten())
    ]
).reshape(X.shape)

In [17]:
# Plotting with Plotly
fig = go.Figure()

# Add points
fig.add_trace(
    go.Scatter(
        x=data[:, 0],
        y=data[:, 1],
        mode="markers",
        marker=dict(
            color=distances,
            colorscale="Viridis",
            colorbar=dict(title="Mahalanobis Distance"),
        ),
        name="Data Points",
    )
)

# Add contour for Mahalanobis distance boundary
fig.add_trace(
    go.Contour(
        x=x_range,
        y=y_range,
        z=Z,
        contours=dict(start=1, end=5, size=0.5, coloring="lines"),
        line_smoothing=0.85,
        colorscale="Jet",
        showscale=False,
        name="Distance Boundary",
    )
)

# Update layout
fig.update_layout(
    title="Mahalanobis Distance Visualization",
    xaxis_title="X-axis",
    yaxis_title="Y-axis",
    showlegend=False,
    width=700,
    height=700,
)

# Show plot
fig.show()

In [43]:
# Example of 2D contour plot
import plotly.graph_objects as go

fig = go.Figure(
    data=go.Contour(
        z=[
            [10, 10.625, 12.5, 15.625, 20],
            [5.625, 6.25, 8.125, 11.25, 15.625],
            [2.5, 3.125, 5.0, 8.125, 12.5],
            [0.625, 1.25, 3.125, 6.25, 10.625],
            [0, 0.625, 2.5, 5.625, 10],
        ],
        x=[-9, -6, -5, -3, -1],  # horizontal axis
        y=[0, 1, 4, 5, 7],  # vertical axis
        # contours=dict(start=1, end=5, size=0.5, coloring="lines"),
    ),
)
fig.show()