In [1]:
import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
src_dir = os.path.join(project_root, 'src')
sys.path.append(src_dir)
import spectra as spectra
from utils.plotters import visualize_geometry, colorplot
from utils.geometries import trigonal_plane, tetrahedron
import jax.numpy as jnp
import plotly.subplots as sp

In [2]:
def geometry_given_angle(angle):
    """
    Generate a two point geometry with a specified bond angle.
    
    Args:
        angle (float): The bond angle in degrees.
        
    Returns:
        jnp.ndarray: The geometry of the atoms.
    """
    return jnp.array([[1, 0, 0], [jnp.cos(angle), jnp.sin(angle), 0]])

In [3]:
# lmaxes = list(range(2, 6))
lmaxes = [2, 3, 4, 6]
# angles = [360/i for i in range(2, 6)][::-1]  # Reverse order of angles to go from smallest to largest
angles = [60, 90, 120, 180]

# Create subplot figure with 3D scenes
fig = sp.make_subplots(
    rows=len(angles), 
    cols=len(lmaxes),
    subplot_titles=[f'l = {l}' for l in lmaxes],  # Only show titles for first row
    specs=[[{'type': 'scene'} for _ in range(len(lmaxes))] for _ in range(len(angles))],
    vertical_spacing=0.02,
    horizontal_spacing=0.02
)

# Add traces for each geometry and lmax combination
for i, angle in enumerate(angles):
    # Add angle annotation to the left of each row
    fig.add_annotation(
        text=f'{int(round(angle))}°',  # Round to nearest integer
        xref='paper',
        yref='paper',
        x=-0.05,
        y=(1 - (2*i + 1)/(2*len(angles))),  # Center vertically in row
        showarrow=False,
        font=dict(size=14)
    )
    
    for j, lmax in enumerate(lmaxes):
        geom = geometry_given_angle(jnp.deg2rad(angle))  # Convert angle to radians
        subfig = visualize_geometry(geom, lmax=lmax)
        
        # Add all traces from subfigure to main figure
        for trace in subfig.data:
            trace.showscale = False  # Remove colorbar
            fig.add_trace(trace, row=i+1, col=j+1)
            
        # Update layout for each subplot
        fig.update_scenes(
            dict(
                xaxis=dict(
                    title='',
                    showticklabels=False,
                    showgrid=False,
                    zeroline=False,
                    backgroundcolor='rgba(255,255,255,255)',
                    range=[-2.5, 2.5]
                ),
                yaxis=dict(
                    title='',
                    showticklabels=False, 
                    showgrid=False,
                    zeroline=False,
                    backgroundcolor='rgba(255,255,255,255)',
                    range=[-2.5, 2.5]
                ),
                zaxis=dict(
                    title='',
                    showticklabels=False,
                    showgrid=False,
                    zeroline=False,
                    backgroundcolor='rgba(255,255,255,255)',
                    range=[-2.5, 2.5]
                ),
                bgcolor='rgba(255,255,255,255)',
                aspectmode='cube',
                camera=dict(
                    eye=dict(x=0, y=0, z=0.55)
                )
            ),
            row=i+1, col=j+1
        )

# Update overall layout
fig.update_layout(
    height=200*len(angles),
    width=200*len(lmaxes),
    showlegend=False,
    plot_bgcolor='rgba(255,255,255,255)',
    paper_bgcolor='rgba(255,255,255,255)', 
    margin=dict(l=60, r=0, t=30, b=0)  # Increased left margin to accommodate angle labels
)

fig.write_image("bandwidth.png")