<a href="https://colab.research.google.com/github/alisterpage/CHEM2410-Jupyter-Notebooks/blob/main/hydrogen_orbitals.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The Hydrogen Orbitals

In [1]:
%matplotlib inline
#%pip install pyvista sympy panel ipyvtklink
!pip install -q pyvista plotly ipywidgets

from google.colab import output
output.enable_custom_widget_manager()

import numpy as np
import plotly.graph_objects as go
from pyvista import examples
import ipywidgets as widgets
from IPython.display import display, clear_output
import time


In [8]:
#@title 💻🧪Hydrogen Orbital Plotter

# Output widget
plot_output = widgets.Output()

# code based on PyVista examples: https://docs.pyvista.org/examples/99-advanced/atomic_orbitals.html#sphx-glr-download-examples-99-advanced-atomic-orbitals-py

# Plotting function
def plot_hydrogen_orbital(n, l, m, x_range, y_range, z_range):
    plot_output.clear_output()
    with plot_output:
        try:
            if not (n > 0 and 0 <= l < n and -l <= m <= l):
                print("Invalid quantum numbers: ensure 0 ≤ l < n and -l ≤ m ≤ l")
                return

            # Load orbital data
            orbital = examples.load_hydrogen_orbital(n, l, m, zoom_fac=1.0)
            prob = np.abs(orbital['real_wf']) ** 2
            prob /= prob.sum()

            # Sample and jitter
            rng = np.random.default_rng(seed=0)
            indices = rng.choice(orbital.n_points, 30000, p=prob)
            points = orbital.points[indices]
            points += rng.random(points.shape) - 0.5

            # Apply clipping filters
            x_min, x_max = x_range
            y_min, y_max = y_range
            z_min, z_max = z_range
            mask = (
                (points[:, 0] >= x_min) & (points[:, 0] <= x_max) &
                (points[:, 1] >= y_min) & (points[:, 1] <= y_max) &
                (points[:, 2] >= z_min) & (points[:, 2] <= z_max)
            )
            filtered_points = points[mask]

            if filtered_points.shape[0] == 0:
                print("No points in specified range.")
                return

            # Color by phase
            phases = orbital['real_wf'][indices][mask]
            colors = np.where(phases < 0, 'red', 'blue')

            # Create interactive plot
            fig = go.Figure(data=go.Scatter3d(
                x=filtered_points[:, 0],
                y=filtered_points[:, 1],
                z=filtered_points[:, 2],
                mode='markers',
                marker=dict(
                    size=2.5,
                    color=colors,
                    opacity=0.3
                )
            ))

            fig.update_layout(
              scene=dict(
                  xaxis=dict(
                      range=[-30, 30],
                      showbackground=False,
                      showgrid=False,
                      zeroline=False,
                      title='x (a.u.)'
                  ),
                  yaxis=dict(
                      range=[-30, 30],
                      showbackground=False,
                      showgrid=False,
                      zeroline=False,
                      title='y (a.u.)'
                  ),
                  zaxis=dict(
                      range=[-30, 30],
                      showbackground=False,
                      showgrid=False,
                      zeroline=False,
                      title='z (a.u.)'
                  ),
                  aspectmode='cube',
                  bgcolor='white'  # background of the scene (outside the box)
              ),
              margin=dict(l=0, r=0, b=0, t=80),
              paper_bgcolor='white',  # removes outer paper background
              plot_bgcolor='white'    # removes inner plot background
          )

            fig.show()

        except Exception as e:
            print(f"Error: {e}")

# Quantum number widgets
n_input = widgets.BoundedIntText(value=1, min=1, max=10, description='n:')
l_input = widgets.BoundedIntText(value=0, min=0, max=9, description='l:')
m_input = widgets.BoundedIntText(value=0, min=-9, max=9, description='m:')

# Range sliders
x_slider = widgets.FloatRangeSlider(
    value=[-40, 40], min=-40, max=40, step=1,
    description='X range:', continuous_update=False
)
y_slider = widgets.FloatRangeSlider(
    value=[-40, 40], min=-40, max=40, step=1,
    description='Y range:', continuous_update=False
)
z_slider = widgets.FloatRangeSlider(
    value=[-40, 40], min=-40, max=40, step=1,
    description='Z range:', continuous_update=False
)

# Button
plot_button = widgets.Button(description="Plot Orbital")

def on_plot_click(b):
    # Small delay to ensure slider values propagate
    time.sleep(0.05)

    # Copy values *after* widgets have stabilized
    n = n_input.value
    l = l_input.value
    m = m_input.value
    x_range = tuple(x_slider.value)
    y_range = tuple(y_slider.value)
    z_range = tuple(z_slider.value)

    plot_hydrogen_orbital(n, l, m, x_range, y_range, z_range)

plot_button.on_click(on_plot_click)

# UI Layout
ui = widgets.VBox([
    n_input,
    l_input,
    m_input,
    x_slider,
    y_slider,
    z_slider,
    plot_button,
    plot_output
])

display(ui)

VBox(children=(BoundedIntText(value=1, description='n:', max=10, min=1), BoundedIntText(value=0, description='…