In [None]:
from IPython.display import Javascript
Javascript('IPython.notebook.execute_cells_below()')

In [None]:
%%html
<style>
.mpl-message,
.output_wrapper button.btn.btn-default,
.output_wrapper .ui-dialog-titlebar,
.output_prompt {
  display: none;
}
.output_latex {
    overflow: hidden;
}
</style>
<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Toggle Code"></form>

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
from sympy import *
x, n, k, alpha, l, m, r, theta, phi, Z, a_0, sigma = symbols("x n k alpha l m r theta phi Z a_0 sigma")
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as axes3d
from matplotlib.colors import LinearSegmentedColormap, Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.ticker import ScalarFormatter
import ipywidgets as wg
from IPython.display import display, Math

class quantumNumbers:
    def __init__(self, name, minmaxinit):
        self.name = name
        self.minmaxinit = minmaxinit
        self.ipywg = wg.BoundedIntText(
            value = self.minmaxinit[2],
            min = self.minmaxinit[0],
            max = self.minmaxinit[1],
            description = self.name,
            layout={
                "width": "125px"
            },
            style={
                "description_width": "50px"
            }
        )
        
    def get_wg(self):
        return self.ipywg

class printWg:
    def __init__(self):
        self.ipywg = wg.Output()
        
    def get_wg(self):
        return self.ipywg
        
class App:    
    # GUI ###################################################
    def __init__(self):
        # GUI  
        self.n = quantumNumbers(r"$n$",[1,10,1]).get_wg()
        self.l = quantumNumbers(r"$l$",[0,0,0]).get_wg()
        self.m = quantumNumbers(r"$\left| m \right|$",[0,0,0]).get_wg()
        self.Z = quantumNumbers(r"$Z$",[1,180,1]).get_wg()
        
        self.pm = wg.RadioButtons(
            options = ["+","-"],
            description = r"$Y_l^m \pm Y_l^{-m}$",
            layout = wg.Layout(display="none",width="125px"))

        self.plotButton = wg.Button(
            description="Generate Plots",
            layout = {
                "width": "125px"
            },
            style = {
                "button_color": "lightgreen"
            }
        )
        
        self.nPoints = wg.IntSlider(
            orientation="vertical",
            value = 2000,
            min=500,
            max=10000,
            description="# of points",
            layout={
                "height": "170px"
            }
        )
        
        self.resetButton = wg.Button(
            description="Reset",
            layout = {
                "width": "100px"
            },
            style = {
                "button_color": "coral"
            }
        )
        
        self.topLBox = wg.VBox(
            [self.n,self.l,self.m,self.pm,self.Z,self.nPoints,
             self.plotButton,self.resetButton],
            layout={
                "width": "160px",
                "display": "flex",
                "flex_flow": "column wrap",
                "align_items": "center"
            }
        )
        
        self.plotPsi = wg.Output(
            layout={
                "height": "100%"
            }
        )

        self.topRBox = wg.HBox(
            [self.plotPsi],
            layout={
                "width": "100%",
                "display": "flex",
                "flex_flow": "column",
                "align_items": "center"
            }
        )
        
        self.topBox = wg.HBox([self.topLBox,self.topRBox])
        
        self.plotY,self.plotR = wg.Output(),wg.Output()
        self.botLeftBox = wg.HBox(
            [self.plotY],
            layout={
                "width": "50%",
                "display": "flex",
                "flex_flow": "column",
                "align_items": "center"
            }
        )
        self.botRightBox = wg.HBox(
            [self.plotR],
            layout={
                "width": "50%",
                "display": "flex",
                "flex_flow": "column",
                "align_items": "center"
            }
        )
        
        
        self.botBox = wg.VBox(
            [self.botLeftBox,self.botRightBox],
            layout={
                "width": "100%",
                "display": "flex",
                "flex_flow": "row",
                "align_items": "center"
            }
        )
        
        self.psiOut,self.YOut,self.ROut = printWg().get_wg(),printWg().get_wg(),printWg().get_wg()
        self.tBox = wg.VBox(
            [self.psiOut,self.YOut,self.ROut],
            layout={
                "width": "100%",
                "display": "flex",
                "flex_flow": "column",
                "align_items": "center"
            }
        )
        
        self.ghLogo = wg.HTML(
            value="<div><a href=\"https://github.com/Noah-Burns/Orbitals\"><img src=\"data/ghmark.png\"></a></div>"
        )
        self.docNB = wg.HTML(
            value="<div><a href=\"https://mybinder.org/v2/gh/Noah-Burns/Orbitals/main?filepath=Orbitals%20Documented.ipynb\">Documented notebook</a></div>"
        )
        self.fBox = wg.VBox(
            [self.ghLogo,self.docNB],
            layout={
                "width": "100%",
                "display": "flex",
                "flex_flow": "column",
                "align_items": "center"
            }
        )
        
        def update_l_bounds(change):
            self.l.max = change["new"]-1
        self.n.observe(update_l_bounds, "value")
        
        def update_m_bounds(change):
            self.m.max = change["new"]
        self.l.observe(update_m_bounds, "value")
        
        def update_pm(change):
            if change["new"] > 0 and change["old"] == 0:
                self.pm.layout.display = "flex" # show radio buttons
            elif change["new"] == 0 and change["old"] > 0:
                self.pm.layout.display = "none" # hide radio buttons
        self.m.observe(update_pm, "value")
        
        
        self.plotButton.on_click(self.plotter)
        self.resetButton.on_click(self.reset)
        self.GUI = wg.VBox([self.topBox,self.botBox,self.tBox,self.fBox])
    
    # Math #############################################################
    
    def generate_functions(self):
        
        Legendre = 1/(2**k*factorial(k)) * ((x**2-1)**k).diff((x,k))
        ALegendre = (1-x**2)**(abs(alpha)/2) * (Legendre.subs(k,sigma)).diff((x,abs(alpha)))
        
        def P(l,m):
            return ALegendre.subs([(sigma,l),(alpha,m)]).doit().subs(x,cos(theta)).simplify().subs(((sin(theta))**2)**(1/2),sin(theta))
        
        FinalLs = []
        for i in range(3):
            for j in range(i+1):
                FinalLs.append(P(i,j))
        
        def Y(l,m):
            return (sqrt( ((2*l+1)/(4*pi)) * factorial(l-abs(m))/factorial(l+abs(m)) ) * exp(I*m*phi)) * P(l,m)
        
        def real_Y(l,mag_m,pm):
            if mag_m == 0:
                return Y(l,mag_m)
            if pm[0] == "-":
                return simplify((1/(sqrt(2)*I)) * (Y(l,mag_m) - Y(l,-mag_m)))
            return simplify((1/sqrt(2)) * (Y(l,mag_m) + Y(l,-mag_m))) 
        
        def Laguerre(k,alpha):
            if k > 1:
                return expand(((-x + 2*k + alpha - 1)*Laguerre(k-1,alpha) - (k + alpha - 1)*Laguerre(k-2,alpha))/k)
            elif k == 1:
                return 1 + alpha - x
            else:
                return 1
            
        def subsLaguerre(n,l,Z):
            res = Laguerre(n-l-1,2*l+1)
            if res == 1:
                return res
            else:
                return res.subs(x,2*Z*r/(n*a_0))
            
        def R(n,l,Z):
            return nsimplify((2/n)**(l+3/2)) * (Z/a_0)**(l+3/2) * sqrt(factorial(n-l-1)/(2*n*factorial(n+l))) * subsLaguerre(n,l,Z) * r**l * exp(-Z*r/(n*a_0))  
        
        def Psi(n,l,m,Z,*pm):
            if pm:
                return real_Y(l,m,pm[0])*R(n,l,Z)
            return real_Y(l,m)*R(n,l,Z)      
        
        return [Psi(self.n.value, self.l.value, self.m.value, self.Z.value, self.pm.value),
                real_Y(self.l.value, self.m.value, self.pm.value),
                R(self.n.value, self.l.value, self.Z.value)]
        
    # Plots #############################################################

    def plotter(self,*args):
        funcs = self.generate_functions()   
        plt.close("all")
        
        def randrange(nPoints, vmin, vmax):  
            return (vmax - vmin)*np.random.rand(nPoints) + vmin
        
        self.plotPsi.clear_output()
        with self.plotPsi:
            f = lambdify(
                [r,theta,phi],
                (funcs[0].subs(a_0,1))**2,
                "numpy")

            r_Range = (10+20*(self.n.value-1))/2
            # in units of a_0.
            # there's probably some smarter way to do this

            fig = plt.figure(figsize=(5,4), num = " ", facecolor="grey")
            ax = fig.add_subplot(1,1,1, projection='3d', facecolor="grey")

            # Redefine the inferno color map with scaled alpha (transparency)
            nColors = 256
            cArray = plt.get_cmap("inferno")(range(nColors))
            cArray[:,-1] = np.linspace(0.0,1.0,nColors)
            mapObj = LinearSegmentedColormap.from_list(
                name="inferno_alpha",colors=cArray)
            plt.register_cmap(cmap=mapObj)

            rs = randrange(self.nPoints.value, 0, r_Range)
            thetas = randrange(self.nPoints.value, 0, np.pi)
            phis = randrange(self.nPoints.value, 0, 2*np.pi)

            xs = rs * np.sin(thetas) * np.cos(phis)
            ys = rs * np.sin(thetas) * np.sin(phis)
            zs = rs * np.cos(thetas)

            ax.scatter(
                xs, ys, zs, marker="o",
                c=f(rs,thetas,phis), cmap = "inferno_alpha"
            )
            
            fmt = ScalarFormatter(useMathText=True)
            fmt.set_powerlimits((0, 0))
            fig.colorbar(
                ScalarMappable(cmap="inferno_alpha",norm=Normalize(0,max(f(rs,thetas,phis)))),
                ax = ax,
                orientation = "vertical", pad = .15,
                label = r"$\left|\psi\right|^2$ $(a_0^{-3})$",
                format = fmt
            )
            
            lDesignation = "spdfghiklm"
            if not self.m.value:
                pmv = ""
            else:
                pmv = f"^{self.pm.value}"
            ax.set_title(f"$f(r,\\theta,\\phi)=\\left|\\psi_"
                         f'{{{self.n.value},{self.l.value},{self.m.value}{pmv}}}'
                         f"\\right|^2$ for $Z = {self.Z.value}$\n"
                         f"{self.n.value}{lDesignation[self.l.value]} Atomic Orbital",
                         y = 1.0,
                         pad = -18
            )
            ax.set_xlim(-r_Range,r_Range)
            ax.set_ylim(-r_Range,r_Range)
            ax.set_zlim(-r_Range,r_Range)
            ax.set_box_aspect((1,1,1))
            ax.view_init(5,-45)

            ax.set_xlabel(r'$x$ $(a_0)$')
            ax.set_ylabel(r'$y$ $(a_0)$')
            ax.set_zlabel(r'$z$ $(a_0)$')
            plt.tight_layout()
        
        
        self.plotY.clear_output()
        with self.plotY:
            f = lambdify([theta,phi],funcs[1],"numpy")
    
            t, p = np.linspace(0, np.pi, 100), np.linspace(0, 2*np.pi, 100)
            THETA, PHI = np.meshgrid(t, p)
            R = abs(f(THETA,PHI)) 
            X = R * np.sin(THETA) * np.cos(PHI)
            Y = R * np.sin(THETA) * np.sin(PHI)
            Z = R * np.cos(THETA)
        
            fig = plt.figure(figsize=(4,3), num = "   ")
            ax = fig.add_subplot(1,1,1, projection='3d')
            plot = ax.plot_surface(
                X, Y, Z, rstride=1, cstride=1, cmap=plt.get_cmap('jet'),
                linewidth=0, antialiased=False, alpha=0.5)
            
            if self.m.value == 0:
                ax.set_title(f"Real Spherical Harmonic\n$Y_{{{self.l.value}}}^{{0}}$")
            else:
                ax.set_title(
                    f"Real Spherical Harmonic\n$\\frac{{1}}{{\\sqrt{{2}}"
                    f"{'i' if self.pm.value == '-' else ''}"
                    f"}}\\left(Y_{{{self.l.value}}}^{{{self.m.value}}}"
                    f"{self.pm.value}"
                    f"Y_{{{self.l.value}}}^{{-{self.m.value}}}\\right)$"
                 )
            
            ax.set_xlim(-np.amax(R),np.amax(R))
            ax.set_ylim(-np.amax(R),np.amax(R))
            ax.set_zlim(-np.amax(R),np.amax(R))
            ax.view_init(5,-45)
            ax.set_box_aspect((1,1,1))

            ax.set_xlabel(r'$x$')
            ax.set_ylabel(r'$y$')
            ax.set_zlabel(r'$z$')
            plt.tight_layout()
            
            
        self.plotR.clear_output()
        with self.plotR:
            f = lambdify(r,r**2*(funcs[2].subs([(Z,1),(a_0,1)]))**2,"numpy")
    
            r_Range = 10+20*(self.n.value-1)
            # how far to plot up to (in units of a_0)

            radius = np.linspace(0, r_Range, 1000)
            R_val = f(radius)

            fig = plt.figure(figsize=(4,3), num = "  ")
            ax = fig.add_subplot(1,1,1)
            ax.plot(radius,R_val)
            plt.axvspan(2*self.Z.value*self.n.value**2, r_Range, facecolor='gray', alpha=0.15)

            ax.set_title(f"$f(r) = r^2R_{{{self.n.value},{self.l.value}}}^2$ for $Z = {self.Z.value}$\n")
            ax.set_xlabel(r'$r/a_0$')
            ax.set_ylabel(r'$f(r)$ $(a_0^{{-1}})$')
            plt.tight_layout()
        
            self.update_outs(funcs)
    
    def update_outs(self,funcs):
        if not self.m.value:
            pmv = ""
        else:
            pmv = f"^{self.pm.value}"
            if self.pm.value == "+":
                istr = ""
            else:
                istr = "i"
        
        prescripts = [
            f"\\psi_{{{self.n.value},{self.l.value},{self.m.value}{pmv}}}(r,\\theta,\\phi)",
            f"R_{{{self.n.value},{self.l.value}}}(r)"
        ]
        if not self.m.value:
            prescripts.insert(1,f"Y_{self.l.value}^0(\\theta,\\phi)")
        else:
            prescripts.insert(1,f"Y_{self.l.value}^{{{self.m.value}{pmv}}}(\\theta,\\phi)="
                f"\\frac{{1}}{{\sqrt{{2}}{istr}}}"
                f"(Y_{self.l.value}^{self.m.value}{self.pm.value}"
                f"Y_{self.l.value}^{{-{self.m.value}}})"
            )
            
        i=0    
        for opw in [self.psiOut, self.YOut, self.ROut]:
            opw.clear_output()
            with opw:
                display(Math(f"${prescripts[i]} = {latex(funcs[i])}$"))
            i+=1
            
    
    def reset(self,*args):
        for opw in [self.plotPsi, self.plotY, self.plotR,
                    self.psiOut, self.YOut, self.ROut]:
            opw.clear_output()
        self.n.value = 1
        self.pm.value = "+"
        self.nPoints.value = 2000
            

runApp = App()
runApp.GUI