In [18]:
from typing import Callable
import plotly.graph_objects as go
import numpy as np


def draw_function(f: Callable, graph_range: tuple[float]):
    # vectorize function to call it with numpy arrays
    f_vector = np.vectorize(f)

    # Create a grid of x and y values with high resolution
    x = np.linspace(*graph_range, 150)
    y = np.linspace(*graph_range, 150)
    X, Y = np.meshgrid(x, y)

    # Calculate Z values
    Z = f_vector(X, Y)

    # Create the 3D surface plot
    fig = go.Figure(
        data=[go.Surface(z=Z, x=X, y=Y, cmin=graph_range[0], cmax=graph_range[1])]
    )

    # update layout properties
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
        autosize=False,
        width=800,
        height=600,
        scene=dict(zaxis=dict(range=graph_range)),
    )

    # Show the plot
    fig.show()

In [19]:
def my_function(x: float, y: float) -> float:
    return x**2 - y**3


draw_function(my_function, (-5, 5))