In [None]:
import numpy as np
from numpy.polynomial.hermite import hermval
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tkinter import *
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

# Define global variables for the plots and canvas objects
canvas_3d = None
ax1 = None
canvas_contour = None
ax2 = None
m2_entry = None
n2_entry = None

class LaserBeam:
    def __init__(self, wavelength, beam_width, length):
        self.wavelength = wavelength
        self.beam_width = beam_width
        self.length = length

    def w_0(self):
        return np.sqrt(self.length * self.wavelength / (np.pi))

    def wz(self, z):
        return self.w_0() * np.sqrt(1 + (z / self.length)**2)

    def hermite(self, x, n):
        w_z = self.wz(0)
        coeffs = np.zeros(n+1)
        coeffs[n] = 1
        return hermval(x / (w_z / np.sqrt(2)), coeffs)

    def intensity(self, x, y, I_0, m, n):
        w_0 = self.w_0()
        w_z = self.wz(0)
        Hm_x = self.hermite(x, m)
        Hn_y = self.hermite(y, n)
        return I_0 * (w_0 / w_z)**2 * (Hm_x * Hn_y)**2 * np.exp(-2 * (x**2 + y**2) / w_z**2)
    
def update_plot(event=None):
    m1 = m_slider.get()
    n1 = n_slider.get()
    sum_profiles = sum_var.get() == 1

    # Determine whether to use the second set of m and n values
    if sum_profiles:
        m2 = m2_slider.get()
        n2 = n2_slider.get()
    else:
        m2 = m1  # Default to m1 if not summing
        n2 = n1  # Default to n1 if not summing

    plot_graphs(
        float(wavelength_entry.get()),
        float(beam_width_entry.get()),
        float(length_entry.get()),
        m1, n1, m2, n2, sum_profiles
    )

def toggle_m2_n2_sliders():
    # This function toggles the visibility of m2 and n2 sliders
    sum_profiles = sum_var.get() == 1
    if sum_profiles:
        m2_slider.pack(fill=X)
        n2_slider.pack(fill=X)
        update_plot()  # Update the plot immediately after showing the sliders
    else:
        m2_slider.pack_forget()
        n2_slider.pack_forget()
        update_plot() 
    
def sum_intensity(event=None):
    m1 = m_slider.get()
    n1 = n_slider.get()
    m2 = m2_entry.get()
    n2 = n2_entry.get()
    plot_graphs(
        float(wavelength_entry.get()),
        float(beam_width_entry.get()),
        float(length_entry.get()),
        m1, n1,
        int(m2), int(n2))

def plot_graphs(wavelength, beam_width, length, m1, n1, m2, n2, sum_profiles):
    global canvas_3d, ax1, canvas_contour, ax2
    
    laser = LaserBeam(wavelength, beam_width, length)
    x = np.linspace(-1e-3, 1e-3, 1001)
    y = np.linspace(-1e-3, 1e-3, 1001)
    X, Y = np.meshgrid(x, y)
    Z = laser.intensity(X, Y, 1, m1, n1)  # Initialize Z with the first intensity profile

    # If sum_profiles is True, add the second intensity profile to Z
    if sum_profiles:
        Z += laser.intensity(X, Y, 1, m2, n2)

    # Check if the canvas objects exist, create them if they don't, or clear them if they do
    if canvas_3d is None or canvas_contour is None:
        fig_3d = Figure(dpi=100)
        ax1 = fig_3d.add_subplot(111, projection='3d')
        canvas_3d = FigureCanvasTkAgg(fig_3d, master=plot_frame)
        widget_3d = canvas_3d.get_tk_widget()
        widget_3d.pack(fill=BOTH, expand=True)

        fig_contour = Figure(dpi=100)
        ax2 = fig_contour.add_subplot(111)
        canvas_contour = FigureCanvasTkAgg(fig_contour, master=plot_frame)
        widget_contour = canvas_contour.get_tk_widget()
        widget_contour.pack(fill=BOTH, expand=True)
    else:
        # Clear the previous plots
        ax1.cla()
        ax2.cla()

    # Plot the new data
    ax1.plot_surface(X, Y, Z, cmap='viridis')
    cp = ax2.contourf(X, Y, Z, cmap='viridis')

    # Draw the canvas
    canvas_3d.draw()
    canvas_contour.draw()


# GUI application code
root = Tk()
root.title("Laser Beam Intensity Visualization")

input_frame = Frame(root)
input_frame.pack(side=LEFT, fill=Y)

plot_frame = Frame(root)
plot_frame.pack(side=RIGHT, fill=BOTH, expand=True)

# Entry widgets for user input
Label(input_frame, text="Wavelength (m)").pack()
wavelength_entry = Entry(input_frame)
wavelength_entry.pack()
wavelength_entry.insert(0, "632.8e-9")

Label(input_frame, text="Beam Width (m)").pack()
beam_width_entry = Entry(input_frame)
beam_width_entry.pack()
beam_width_entry.insert(0, "3e-6")

Label(input_frame, text="Length (m)").pack()
length_entry = Entry(input_frame)
length_entry.pack()
length_entry.insert(0, "1")

m_slider = Scale(input_frame, from_=0, to=15, orient=HORIZONTAL, label="m (Order)", command=update_plot)
m_slider.pack(fill=X)
m_slider.set(0)  # Default value

n_slider = Scale(input_frame, from_=0, to=15, orient=HORIZONTAL, label="n (Order)", command=update_plot)
n_slider.pack(fill=X)
n_slider.set(1)

# Create sliders for m2 and n2 but do not pack them yet
m2_slider = Scale(input_frame, from_=0, to=15, orient=HORIZONTAL, label="m2 (Order)", command=update_plot)
n2_slider = Scale(input_frame, from_=0, to=15, orient=HORIZONTAL, label="n2 (Order)", command=update_plot)


# Add a variable and checkbox for summing the intensities
sum_var = IntVar()
sum_check = Checkbutton(input_frame, text="Sum Intensities", variable=sum_var, command=toggle_m2_n2_sliders)
sum_check.pack()

root.mainloop()
