In [1]:
import numpy as np
import torch
import plotly.graph_objects as go

from toy_functions import toy_polynomial_2d, six_hump_camel

In [2]:
def plot_pes_and_gradient(pes_fn, cutoff=10, logscale=False):
    # Create a grid of points
    npoints = 100
    x = torch.linspace(-2, 2, npoints)
    y = torch.linspace(-2, 2, npoints)
    X, Y = torch.meshgrid(x, y, indexing='ij')
    points = torch.stack([X.flatten(), Y.flatten()], dim=1)
    
    # Compute PES and gradient norm
    pes, grad_norm = pes_fn(points)
    Z_pes = pes.reshape(npoints, npoints).detach().numpy()
    Z_grad = grad_norm.reshape(npoints, npoints).detach().numpy()
    X_np = X.numpy()
    Y_np = Y.numpy()

    # Mask values above cutoff
    mask = Z_pes < cutoff
    Z_pes_masked = np.ma.masked_where(~mask, Z_pes)
    Z_grad_masked = np.ma.masked_where(~mask, Z_grad)

    # Plot PES surface
    z_pes_plot = np.log10(Z_pes_masked) if logscale else Z_pes_masked
    fig_pes = go.Figure(data=[go.Surface(x=X_np, y=Y_np, z=z_pes_plot)])
    fig_pes.update_layout(
        title=f'Potential Energy Surface (Values < {cutoff})' + (' (log scale)' if logscale else ''),
        scene=dict(
            xaxis_title='x',
            yaxis_title='y',
            zaxis_title='log10(PES)' if logscale else 'PES'
        ),
        width=800,
        height=600
    )
    fig_pes.write_image("plots/pes_surface.png")
    fig_pes.show()

    # Plot gradient norm surface
    z_grad_plot = np.log10(Z_grad_masked) if logscale else Z_grad_masked
    fig_grad = go.Figure(data=[go.Surface(x=X_np, y=Y_np, z=z_grad_plot)])
    fig_grad.update_layout(
        title=f'Gradient Norm Surface (PES Values < {cutoff})' + (' (log scale)' if logscale else ''),
        scene=dict(
            xaxis_title='x',
            yaxis_title='y',
            zaxis_title='log10(Gradient Norm)' if logscale else 'Gradient Norm'
        ),
        width=800,
        height=600
    )
    fig_grad.write_image("plots/gradient_surface.png")
    fig_grad.show()

SyntaxError: invalid syntax (1149590441.py, line 18)

In [9]:
plot_pes_and_gradient(toy_polynomial_2d, logscale=True)

In [None]:
from plotting import plot_potential_energy_surface
plot_potential_energy_surface(toy_polynomial_2d)

In [None]:
plot_pes_and_gradient(six_hump_camel)