In [None]:
# Imports
from typing import *
import ase
import e3nn_jax as e3nn
import jax
import jax.numpy as jnp
import plotly.graph_objects as go
import plotly.subplots

import ml_collections
import logging
import ase.io
import ase.visualize
import PIL
import numpy as np
import matplotlib.pyplot as plt
import os
import seaborn as sns

logging.getLogger().setLevel(logging.INFO)

import sys
sys.path.append("..")

import analyses.analysis as analysis
import analyses.visualize_atom_removals as visualize_atom_removals
import configs.nequip
import train as train
import input_pipeline_tf as input_pipeline_tf
import models as models

In [None]:
def spherical_harmonics_as_signals(l: int) -> Iterable[e3nn.SphericalSignal]:
    """Yields the spherical harmonics of degree l as a sequence of e3nn.SphericalSignal objects for each m such that -l <= m <= l."""
    res = (50, 49)
    for m in range(-l, l + 1):
        coeffs = e3nn.IrrepsArray(e3nn.s2_irreps(l)[-1], jnp.asarray([1. if md == m else 0. for md in range(-l, l + 1)]))
        yield e3nn.to_s2grid(coeffs, *res, quadrature="soft", p_val=1, p_arg=-1)

def plot_spherical_harmonics(l: int) -> go.Figure:
    """Plots the spherical harmonics of degree l on a single row of subplots with one column for each m such that -l <= m <= l."""
    fig = plotly.subplots.make_subplots(rows=1, cols=2*l + 1, specs=[[{'type': 'surface'} for _ in range(2*l + 1)]], subplot_titles=[r"$\huge{Y^2_{" + str(m) + r"}(\theta, \phi)}$" for m in range(-l, l + 1)])
    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=3.25, y=3.25, z=3.25)
    )

    for index, sig in enumerate(spherical_harmonics_as_signals(l), start=1):
        fig.add_trace(go.Surface(sig.plotly_surface(scale_radius_by_amplitude=True), colorscale='plasma', cmax=2, cmin=-2, showscale=(index == 2*l + 1), colorbar=dict(lenmode='fraction', len=0.5, thickness=20)), row=1, col=index)
        fig.layout[f"scene{index}"].camera = camera
        fig.layout[f"scene{index}"].xaxis.showticklabels = False
        fig.layout[f"scene{index}"].yaxis.showticklabels = False
        fig.layout[f"scene{index}"].zaxis.showticklabels = False
        fig.layout[f"scene{index}"].xaxis.title = ""
        fig.layout[f"scene{index}"].yaxis.title = ""
        fig.layout[f"scene{index}"].zaxis.title = ""
        #fig.layout.autosize = True

        # fig.layout[f"scene{index}"].yaxis.automargin = True
        # fig.layout[f"scene{index}"].zaxis.automargin = True

    fig.update_layout(title=r"$\huge{\text{Spherical Harmonics} \ Y^2(\theta, \phi)}$", title_x=0.5)
    fig.update_layout(font_family="Serif", font_size=32)
    fig.update_annotations(yshift=-50, font_size=32)
    fig.update_layout(margin=dict(l=10,r=10,b=0,t=40))
    return fig

plot_spherical_harmonics(l=2)

In [None]:
fig = plot_spherical_harmonics(l=2)
fig.write_image("spherical_harmonics_l=2.pdf", width=1000, height=400)