In [13]:
from typing import Callable
import plotly.graph_objects as go
import numpy as np


def draw_function(f: Callable, graph_range: tuple[float]):
    # vectorize function to call it with numpy arrays
    f_vector = np.vectorize(f)

    # Create a grid of x and y values with high resolution
    x = np.linspace(*graph_range, 150)
    y = np.linspace(*graph_range, 150)
    X, Y = np.meshgrid(x, y)

    # Calculate Z values
    Z = f_vector(X, Y)

    Z_clipped = np.clip(Z, *graph_range)

    # Create the 3D surface plot
    fig = go.Figure(data=[go.Surface(z=Z_clipped, x=X, y=Y)])

    # Update layout
    fig.update_layout(
        scene=dict(zaxis=dict(range=np.array(graph_range) * 0.99)),
    )

    # Show the plot
    fig.show()

In [14]:
def my_function(x: float, y: float) -> float:
    return x**2 - y**3


draw_function(my_function, (-5, 5))