#### Crystal Compute (Single Atom Basis), Random Crystal Structure Generator

In [1]:
import sys, re, os, time, psutil
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import norm
from scipy.stats import multivariate_normal
from scipy.spatial.transform import Rotation as R
from scipy.spatial.distance import cdist
import plotly.graph_objs as go
import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

class ComputeCrystal:
    def __init__(self, lattice_params):
        self.a, self.b, self.c, self.alpha, self.beta, self.gamma = lattice_params

        # - Compute Real Space Cartesian Vectors
        self.a1, self.a2, self.a3 = self.compute_realspacevecs()
        self.a1, self.a2, self.a3 = self.clean_data([self.a1, self.a2, self.a3])

        # - Compute the Volume of a Unit Cell
        self.volume = self.compute_volume()

        # - Compute Reciprocal Space Vectors
        self.b1, self.b2, self.b3 = self.compute_recipvecs()
        self.b1, self.b2, self.b3 = self.clean_data([self.b1, self.b2, self.b3])

    def initialize(self, Mhkl):
        self.Mhkl = Mhkl

        # - Computed (hkl) Coordinate Positions (Lattice Points)
        self.hkl_coords = self.compute_coords()
        
        # Create DataFrame to store (qx, qy, qz) values and corresponding (h, k, l) indices
        data = {
            'qx': self.hkl_coords[:, 0],
            'qy': self.hkl_coords[:, 1],
            'qz': self.hkl_coords[:, 2],
            'hkl': [str(tuple(hkl)) for hkl in self.hkl_indices]
        }
        self.df_coords = pd.DataFrame(data)

        # Group by coordinates and concatenate all corresponding (h, k, l) indices
        self.df_coords_grouped = self.df_coords.groupby(['qx', 'qy', 'qz'])['hkl'].apply(lambda x: ', '.join(x)).reset_index()

    # - Compute the real-space orthogonal basis vectors.
    def compute_realspacevecs(self):
        alpha = np.radians(self.alpha)
        beta = np.radians(self.beta)
        gamma = np.radians(self.gamma)

        a1 = np.array([self.a, 0, 0])
        a2 = np.array([self.b * np.cos(gamma), self.b * np.sin(gamma), 0])
        z_value = 1 - np.cos(beta)**2 - ((np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma))**2
        z_value = max(z_value, 0)
        a3 = np.array([self.c * np.cos(beta),
                       self.c * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma),
                       self.c * np.sqrt(z_value)])

        return a1, a2, a3

    def compute_volume(self):
        return np.dot(self.a1, np.cross(self.a2, self.a3))

    def compute_recipvecs(self):
        b1 = 2 * np.pi * np.cross(self.a2, self.a3) / self.volume
        b2 = 2 * np.pi * np.cross(self.a3, self.a1) / self.volume
        b3 = 2 * np.pi * np.cross(self.a1, self.a2) / self.volume

        return b1, b2, b3

    # def compute_coords(self):
    #     Mhkl = int(self.Mhkl)
    #     hkl_coords = []
    #     hkl_indices = []

    #     for h in range(Mhkl):
    #         for k in range(Mhkl):
    #             for l in range(Mhkl):
    #                 qx = h * self.b1[0] + k * self.b2[0] + l * self.b3[0]
    #                 qy = h * self.b1[1] + k * self.b2[1] + l * self.b3[1]
    #                 qz = h * self.b1[2] + k * self.b2[2] + l * self.b3[2]
    #                 hkl_coords.append((qx, qy, qz))  # Save only the (qx, qy, qz) coordinates
    #                 hkl_indices.append((h, k, l))  # Save the hkl indices separately

    #     self.hkl_indices = np.array(hkl_indices)  # Store the hkl indices as a new attribute
    #     return np.array(hkl_coords)

    def compute_coords(self):
        hkl_coords = []
        hkl_indices = []

        # Iterate through the range of -Mhkl to +Mhkl for each h, k, l
        for h in range(-self.Mhkl, self.Mhkl + 1):
            for k in range(-self.Mhkl, self.Mhkl + 1):
                for l in range(-self.Mhkl, self.Mhkl + 1):
                    qx = h * self.b1[0] + k * self.b2[0] + l * self.b3[0]
                    qy = h * self.b1[1] + k * self.b2[1] + l * self.b3[1]
                    qz = h * self.b1[2] + k * self.b2[2] + l * self.b3[2]
                    hkl_coords.append((qx, qy, qz))  # Save only the (qx, qy, qz) coordinates
                    hkl_indices.append((h, k, l))    # Save the hkl indices separately

        self.hkl_indices = np.array(hkl_indices)  # Store the hkl indices as a new attribute
        return np.array(hkl_coords)

    def clean_data(self, data):
        return [np.where(np.isclose(vec, 0, atol=1e-10), 0, vec) for vec in data]

class GenCrystal(ComputeCrystal):
    def __init__(self, lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        super().__init__(lattice_params)
        self.initialize(Mhkl)
        self.sigma_r = sigma_r
        self.sigma_theta = sigma_theta
        self.sigma_phi = sigma_phi
        self.m = m
        self.lattice_params = lattice_params

class RandomCrystal(ComputeCrystal):
    def __init__(self, lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        self.crystal_systems = {
            'triclinic': self.triclinic,
            'monoclinic': self.monoclinic,
            'orthorhombic': self.orthorhombic,
            'tetragonal': self.tetragonal,
            'trigonal': self.trigonal,
            'hexagonal': self.hexagonal,
            'cubic': self.cubic
        }

        # Select a random crystal system
        system = np.random.choice(list(self.crystal_systems.keys()))

        # Generate lattice parameters for the selected system
        lattice_params = self.crystal_systems[system]()
        
        print(f"Selected Crystal System: {system}")

        super().__init__(lattice_params)
        self.initialize(Mhkl)
        self.sigma_r = sigma_r
        self.sigma_theta = sigma_theta
        self.sigma_phi = sigma_phi
        self.m = m
        self.lattice_params = lattice_params

    @staticmethod
    def triclinic():
        a, b, c = np.round(np.random.uniform(1, 10, 3), 2)
        alpha, beta, gamma = np.round(np.random.uniform(20, 160, 3), 2)
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def monoclinic():
        a, b, c = np.round(np.random.uniform(1, 10, 3), 2)
        alpha, gamma = 90, 90
        beta = np.round(np.random.uniform(20, 160), 2)
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def orthorhombic():
        a, b, c = np.round(np.random.uniform(1, 10, 3), 2)
        alpha, beta, gamma = 90, 90, 90
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def tetragonal():
        a = np.round(np.random.uniform(1, 10), 2)
        c = np.round(np.random.uniform(1, 10), 2)
        alpha, beta, gamma = 90, 90, 90
        return a, a, c, alpha, beta, gamma

    @staticmethod
    def trigonal():
        a = np.round(np.random.uniform(1, 10), 2)
        alpha = np.round(np.random.uniform(20, 140), 2)
        return a, a, a, alpha, alpha, alpha

    @staticmethod
    def hexagonal():
        a = np.round(np.random.uniform(1, 10), 2)
        c = np.round(np.random.uniform(1, 10), 2)
        return a, a, c, 90, 90, 120

    @staticmethod
    def cubic():
        a = np.round(np.random.uniform(1, 10), 2)
        return a, a, a, 90, 90, 90

class PlotCrystal():
    def __init__(self, crystal, m):
        self.crystal = crystal
        self.m = m
        # self.crystal.generate_hkl_coords()  # generate hkl_coords
        self.create_pixel_space()

    # - Estimate the time order complexity for smearing a crystal structure.
    def estimate_time_complexity(self):
        """Estimate the time complexity based on system specs and current usage."""

        # Get the number of CPUs available
        num_cpus = os.cpu_count()

        # Get the CPU usage
        cpu_usage = psutil.cpu_percent()

        # Get the available memory
        avail_mem = psutil.virtual_memory().available

        # Estimate the time complexity based on the number of CPUs, CPU usage, and available memory
        # This is a very rough estimate and is unlikely to be accurate
        time_complexity = (self.Mhkl**3 * self.m**3) / (num_cpus * (1 - cpu_usage/100) * avail_mem)

        return time_complexity
    
    # - Generate Ewald Sphere Pixel Space
    def create_pixel_space(self):
        # Create the 3D pixel space with size m x m x m
        self.pixel_space = np.zeros((self.m, self.m, self.m))

    # - Convert Cartesian (qx, qy, qz) Coordinates to Spherical Coordinates
    def cart_to_sph(self, qx, qy, qz):
        qr = np.sqrt(qx**2 + qy**2 + qz**2)
        qtheta = np.arctan2(np.sqrt(qx**2 + qy**2), qz)
        qphi = np.arctan2(qy, qx)
        return qr, qtheta, qphi

    def plot_panel(self):
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

        qx_values = self.hkl_coords[:, 0]
        qy_values = self.hkl_coords[:, 1]
        qz_values = self.hkl_coords[:, 2]

        axs[0].scatter(qx_values, qz_values, color='blue')
        axs[1].scatter(qy_values, qz_values, color='green')
        axs[2].scatter(np.sqrt(qx_values**2 + qy_values**2), qz_values, color='red')

        axs[0].grid()
        axs[1].grid()
        axs[2].grid()

        axs[0].set_xlabel(r'$q_x \, (\AA^{-1})$')
        axs[0].set_ylabel(r'$q_z \, (\AA^{-1})$')
        axs[1].set_xlabel(r'$q_y \, (\AA^{-1})$')
        axs[1].set_ylabel(r'$q_z \, (\AA^{-1})$')
        axs[2].set_xlabel(r'$q_{xy} \, (\AA^{-1})$')
        axs[2].set_ylabel(r'$q_z \, (\AA^{-1})$')

        plt.tight_layout()
        plt.show()

    def plot_qx_qz(self):
        plt.figure()
        qx_values = self.hkl_coords[:, 0]
        qz_values = self.hkl_coords[:, 2]
        plt.scatter(qx_values, qz_values, color='blue')
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qy_qz(self):
        plt.figure()
        qy_values = self.hkl_coords[:, 1]
        qz_values = self.hkl_coords[:, 2]
        plt.scatter(qy_values, qz_values, color='green')
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qxy_qz(self):
        plt.figure()
        qx_values = self.hkl_coords[:, 0]
        qy_values = self.hkl_coords[:, 1]
        qz_values = self.hkl_coords[:, 2]
        plt.scatter(np.sqrt(qx_values**2 + qy_values**2), qz_values, color='red')
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def gen_gaussian(self, q_vec):
        # Define the variance for each direction in 3D
        variance = [self.sigma_r**2, self.sigma_theta**2, self.sigma_phi**2]
        
        # Convert Cartesian coordinates to spherical
        qr, qtheta, qphi = self.cart_to_sph(*q_vec)
        
        # Create a grid of coordinates spanning the pixel space
        grid = np.indices((self.m, self.m, self.m)).reshape(3,-1).T

        # Generate a Gaussian centered at each peak
        gauss = multivariate_normal.pdf(grid, mean=[qr, qtheta, qphi], cov=variance)

        return gauss.reshape((self.m, self.m, self.m))

    # - Gaussian Convolution Methods + Plotting
    def gaussian(self, x, mu, sigma):
        return norm.pdf(x, mu, sigma)

    def convoluted_gaussian(self, q_vec):
        qx, qy, qz = q_vec
        qr, qtheta, qphi = self.cart_to_sph(qx, qy, qz)
        gauss_r = self.gaussian(qr, qr, self.sigma_r)
        gauss_theta = self.gaussian(qtheta, qtheta, self.sigma_theta)
        gauss_phi = self.gaussian(qphi, qphi, self.sigma_phi)
        return gauss_r * gauss_theta * gauss_phi
    
    def smear_peaks(self):
        for hkl, q_vec in self.hkl_coords:
            gauss = self.gen_gaussian(q_vec)
            self.pixel_space += gauss

    def plot_qx_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qx, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qy_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qy, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qxy_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_panel_convolution(self):
        plt.figure(figsize=(15,5))
        
        plt.subplot(131)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qx, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        
        plt.subplot(132)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qy, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.subplot(133)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.tight_layout()
        plt.show()

    def plot_3D(self):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            ax.scatter(qx, qy, qz, c=intensity)

        ax.set_xlabel('Qx')
        ax.set_ylabel('Qy')
        ax.set_zlabel('Qz')
        plt.show()
    
    def plot_image(self, sigma):
        # Estimate the time complexity
        est_time_complexity = self.estimate_time_complexity()
        print(f"Estimated time complexity: {est_time_complexity}")

        # Ask the user to proceed
        proceed = input("Proceed ([y]/n)? ")
        if proceed.lower() != 'n':
            start_time = time.time()

            # Iterate through the points in the Cartesian coordinate system
            for hkl, q_vec in self.hkl_coords:
                gauss = self.gen_gaussian(q_vec)
                self.pixel_space += gauss

            # Define the threshold for plotting
            threshold = 0.5

            # Get the indices where the Gaussian is above the threshold
            ind = np.argwhere(self.pixel_space > threshold)

            # Plot the 3D pixel space
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            ax.scatter(ind[:,0], ind[:,1], ind[:,2], c='r', marker='o')
            plt.show()
        
            end_time = time.time()
            total_time = end_time - start_time
            print(f"Total computation time: {total_time} seconds")

    def plot_pixel_space(self):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        pixel_space_reshaped = self.pixel_space.reshape(self.m**3)
        x, y, z = np.indices((self.m, self.m, self.m)).reshape(3, -1)
        c = pixel_space_reshaped
        img = ax.scatter(x, y, z, c=c, cmap=plt.hot())
        fig.colorbar(img)
        plt.show()
    
    def create_image_plane(self):
        self.image_plane = ImagePlane()  # instantiate the ImagePlane class
        try:
            self.image_plane.load_data(self.hkl_coords)  # load the data using the ImagePlane's load_data method
        except AttributeError:
            print("hkl_coords attribute not found. Make sure it is defined in child classes.")
    
    def plotly_plot_panel(self):
        # Extract values from the grouped DataFrame
        qx_values = self.df_coords_grouped['qx']
        qy_values = self.df_coords_grouped['qy']
        qz_values = self.df_coords_grouped['qz']
        hover_text = self.df_coords_grouped['hkl']

        # Create subplots with 3 columns
        fig = make_subplots(rows=1, cols=3)

        # Add scatter plot for (qx, qz)
        fig.add_trace(go.Scatter(x=qx_values, y=qz_values, mode='markers', text=hover_text), row=1, col=1)
        fig.update_xaxes(title_text=r'$q_x \, (\AA^{-1})$', row=1, col=1, range=[0, max(qx_values)])
        fig.update_yaxes(title_text=r'$q_z \, (\AA^{-1})$', row=1, col=1, range=[0, max(qz_values)])

        # Add scatter plot for (qy, qz)
        fig.add_trace(go.Scatter(x=qy_values, y=qz_values, mode='markers', text=hover_text), row=1, col=2)
        fig.update_xaxes(title_text=r'$q_y \, (\AA^{-1})$', row=1, col=2, range=[0, max(qy_values)])
        fig.update_yaxes(title_text=r'$q_z \, (\AA^{-1})$', row=1, col=2, range=[0, max(qz_values)])

        # Add scatter plot for (qxy, qz)
        qxy_values = np.sqrt(qx_values**2 + qy_values**2)
        fig.add_trace(go.Scatter(x=qxy_values, y=qz_values, mode='markers', text=hover_text), row=1, col=3)
        fig.update_xaxes(title_text=r'$q_{xy} \, (\AA^{-1})$', row=1, col=3, range=[0, max(qxy_values)])
        fig.update_yaxes(title_text=r'$q_z \, (\AA^{-1})$', row=1, col=3, range=[0, max(qz_values)])

        fig.show()

'''
    def plotly_plot_panel(self):
        # Extract values from the grouped DataFrame
        qx_values = self.df_coords_grouped['qx']
        qy_values = self.df_coords_grouped['qy']
        qz_values = self.df_coords_grouped['qz']
        hover_text = self.df_coords_grouped['hkl']

        fig = make_subplots(rows=1, cols=3)

        # Extract the data you need for the plot
        qx_values = self.hkl_coords[:, 0]
        qy_values = self.hkl_coords[:, 1]
        qz_values = self.hkl_coords[:, 2]
        hkl_indices = self.crystal.hkl_indices # Assuming hkl_indices is an attribute of the crystal object

        for i, (x_values, label) in enumerate([(qx_values, 'q_x'), (qy_values, 'q_y'), (np.sqrt(qx_values**2 + qy_values**2), 'q_xy')]):
            # Create hover text for each point
            hover_text = [f"(h:{h}, k:{k}, l:{l})" for h, k, l in hkl_indices]
            
            fig.add_trace(go.Scatter(x=x_values, y=qz_values, mode='markers', hoverinfo="text", text=hover_text), row=1, col=i+1)
            fig.update_xaxes(title_text=f'{label} (A^(-1))', row=1, col=i+1)
            fig.update_yaxes(title_text='q_z (A^(-1))', row=1, col=i+1)

        # # Create hover text for each point
        # hover_text = [f"(h:{h}, k:{k}, l:{l})" for h, k, l in hkl_indices]

        #         # Create a layout with three subplots
        # fig = make_subplots(rows=1, cols=3,
        #                     subplot_titles=(r'q_x v. q_z', r'q_y v. q_z', r'q_xy v. q_z'))

        # # Create scatter plots
        # fig.add_trace(go.Scatter(x=qx_values, y=qz_values, mode='markers', marker=dict(color='blue'), hoverinfo="text", text=hover_text), row=1, col=1)
        # fig.add_trace(go.Scatter(x=qy_values, y=qz_values, mode='markers', marker=dict(color='green'), hoverinfo="text", text=hover_text), row=1, col=2)
        # fig.add_trace(go.Scatter(x=np.sqrt(qx_values**2 + qy_values**2), y=qz_values, mode='markers', marker=dict(color='red'), hoverinfo="text", text=hover_text), row=1, col=3)

        # Update axes labels and range to only include positive values
        fig.update_layout(height=500, width=1500)
        fig.update_xaxes(title_text='q_x (A^(-1))', range=[0, max(qx_values)], row=1, col=1)
        fig.update_yaxes(title_text='q_z (A^(-1))', range=[0, max(qz_values)], row=1, col=1)
        fig.update_xaxes(title_text='q_y (A^(-1))', range=[0, max(qy_values)], row=1, col=2)
        fig.update_yaxes(title_text='q_z (A^(-1))', range=[0, max(qz_values)], row=1, col=2)
        fig.update_xaxes(title_text='q_xy (A^(-1))', range=[0, max(np.sqrt(qx_values**2 + qy_values**2))], row=1, col=3)
        fig.update_yaxes(title_text='q_z (A^(-1))', range=[0, max(qz_values)], row=1, col=3)

        # fig.update_xaxes(title_text=r'$q_x \, (\AA^{-1})$', row=1, col=1)
        # fig.update_yaxes(title_text=r'$q_z \, (\AA^{-1})$', row=1, col=1)
        # fig.update_xaxes(title_text=r'$q_y \, (\AA^{-1})$', row=1, col=2)
        # fig.update_yaxes(title_text=r'$q_z \, (\AA^{-1})$', row=1, col=2)
        # fig.update_xaxes(title_text=r'$q_{xy} \, (\AA^{-1})$', range=[0, max(np.sqrt(qx_values**2 + qy_values**2))], row=1, col=3)
        # fig.update_yaxes(title_text=r'$q_z \, (\AA^{-1})$', range=[0, max(qz_values)], row=1, col=3)
            
        # Plot the figure
        fig.show()
'''     

class PlotGenCrystal(PlotCrystal):
    def __init__(self, lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        self.crystal = GenCrystal(lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m)
        self.lattice_params = self.crystal.lattice_params
        self.m = m
        self.create_pixel_space()

        # Access attributes from GenCrystal (and thereby ComputeCrystal)
        self.a = self.crystal.a
        self.b = self.crystal.b
        self.c = self.crystal.c
        self.alpha = self.crystal.alpha
        self.beta = self.crystal.beta
        self.gamma = self.crystal.gamma

        self.a1 = self.crystal.a1
        self.a2 = self.crystal.a2
        self.a3 = self.crystal.a3

        self.volume = self.crystal.volume

        self.b1 = self.crystal.b1
        self.b2 = self.crystal.b2
        self.b3 = self.crystal.b3

        self.Mhkl = self.crystal.Mhkl

        self.hkl_coords = self.crystal.hkl_coords.copy()  # Assuming hkl_coords is a numpy array or a list
        self.hkl_indices = self.crystal.hkl_indices.copy()  # Access the hkl_indices from the GenCrystal object
        self.initialize_df_coords()

    def create_image_plane(self):
        # Initialize ImagePlane with initial points
        self.image_plane = ImagePlane(np.array([]))
        # Load data
        self.image_plane.load_data(self.hkl_coords, self.a1, self.a2, self.a3)

    def initialize_df_coords(self):
        # Creating a DataFrame with (qx, qy, qz) and (h, k, l) values
        data = []
        for i, (qx, qy, qz) in enumerate(self.hkl_coords):
            h, k, l = self.hkl_indices[i]
            data.append((qx, qy, qz, h, k, l))

        self.df_coords = pd.DataFrame(data, columns=['qx', 'qy', 'qz', 'h', 'k', 'l'])
        
        # Define the custom aggregation function
        def aggregate_hkl(x):
            return list(zip(self.df_coords.loc[x.index, 'h'], self.df_coords.loc[x.index, 'k'], self.df_coords.loc[x.index, 'l']))

        # Grouping by (qx, qy, qz) and aggregating (h, k, l) using the custom aggregation function
        self.df_coords_grouped = self.df_coords.groupby(['qx', 'qy', 'qz']).agg(
            hkl_values=pd.NamedAgg(column='h', aggfunc=aggregate_hkl)
        ).reset_index()

    # def initialize_df_coords(self):
    #     # Creating a DataFrame with (qx, qy, qz) and (h, k, l) values
    #     data = []
    #     for i, (qx, qy, qz) in enumerate(self.hkl_coords):
    #         h, k, l = self.hkl_indices[i]
    #         data.append((qx, qy, qz, h, k, l))
        
    #     self.df_coords = pd.DataFrame(data, columns=['qx', 'qy', 'qz', 'h', 'k', 'l'])
        
    #     # Grouping by (qx, qy, qz) and aggregating (h, k, l)
    #     self.df_coords_grouped = self.df_coords.groupby(['qx', 'qy', 'qz']).agg(
    #         hkl_values=pd.NamedAgg(column='h', aggfunc=lambda x: list(zip(x, self.df_coords.loc[x.index, 'k'], self.df_coords.loc[x.index, 'l'])))
    #     ).reset_index()

        
class PlotRandCrystal(PlotCrystal):
    def __init__(self, lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        self.lattice_params = None
        self.crystal = RandomCrystal(lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m)
        self.lattice_params = self.crystal.lattice_params

        # Access attributes from GenCrystal (and thereby ComputeCrystal)
        self.a = self.crystal.a
        self.b = self.crystal.b
        self.c = self.crystal.c
        self.alpha = self.crystal.alpha
        self.beta = self.crystal.beta
        self.gamma = self.crystal.gamma

        self.a1 = self.crystal.a1
        self.a2 = self.crystal.a2
        self.a3 = self.crystal.a3

        self.volume = self.crystal.volume

        self.b1 = self.crystal.b1
        self.b2 = self.crystal.b2
        self.b3 = self.crystal.b3

        self.Mhkl = self.crystal.Mhkl

        self.hkl_coords = self.crystal.hkl_coords.copy()  # Assuming hkl_coords is a numpy array or a list

class ImagePlane:
    def __init__(self, hkl_coords=None):
        self.points = hkl_coords
        self.fig_3d = None
        self.fig_2d = None
        self.scatter = None
        self.plane = None
        self.planes = []  # Store references to generated planes
        self.slider_a = None
        self.slider_b = None
        self.slider_c = None
        self.slider_d = None
        self.slider_thickness = None
        self.btn_level_parallel = None
        self.btn_level_perpendicular = None

    def update_plane(self, a, b, c, d):
        x = np.linspace(0, 1, 10)
        y = np.linspace(0, 1, 10)
        X, Y = np.meshgrid(x, y)
        Z = (-a * X - b * Y - d) / c

        return go.Surface(x=X, y=Y, z=Z, opacity=0.5, showscale=False)

    def add_plane(self, a, b, c, d, offset):
        x = np.linspace(0, 1, 10)
        y = np.linspace(0, 1, 10)
        X, Y = np.meshgrid(x, y)
        Z = (-a * X - b * Y - (d + offset)) / c

        plane = go.Surface(x=X, y=Y, z=Z, opacity=0.5, showscale=False)
        self.fig_3d.add_trace(plane)
        self.planes.append(plane)  # Store reference to the generated plane

    def delete_plane(self):
        plane = self.planes.pop()  # Remove reference to the last generated plane
        self.fig_3d.data.remove(plane)

    def update(self, change):
        a = self.slider_a.value
        b = self.slider_b.value
        c = self.slider_c.value
        d = self.slider_d.value
        plane_thickness = self.slider_thickness.value
        new_plane = self.update_plane(a, b, c, d)
        self.fig_3d.data[1].x = new_plane.x
        self.fig_3d.data[1].y = new_plane.y
        self.fig_3d.data[1].z = new_plane.z

        # Calculate intersection points
        if np.isclose(a, 0) and np.isclose(c, 0):  # This means the plane is parallel to the XZ plane
            intersection_points = self.points[np.abs(self.points[:, 1] - (-d/b)) <= plane_thickness]
        else:
            plane_distance = np.abs(a * self.points[:, 0] + b * self.points[:, 1] + c * self.points[:, 2] + d) / np.sqrt(
                a ** 2 + b ** 2 + c ** 2)
            intersection_points = self.points[plane_distance <= plane_thickness]

        # Calculate the thickness threshold from the origin plane
        origin_plane_distance = np.abs(a * self.points[:, 0] + b * self.points[:, 1] + c * self.points[:, 2] + d) / np.sqrt(
            a ** 2 + b ** 2 + c ** 2)
        origin_plane_thickness = np.max(origin_plane_distance)

        # Update the scatter plot with intersection points
        self.fig_2d.data[0].x = intersection_points[:, 0]
        self.fig_2d.data[0].y = intersection_points[:, 1]

        # Generate or delete visual plane objects based on the thickness
        existing_planes = len(self.planes)
        thickness_range = self.slider_thickness.max - self.slider_thickness.min
        thickness_increment = thickness_range * 0.1  # Generate new plane objects every 10% of the thickness slider range

        target_planes = min(
            existing_planes, int(np.ceil((origin_plane_thickness - self.slider_thickness.min) / thickness_increment))
        )

        if target_planes > existing_planes:  # Generate additional plane objects
            offset_factor = (self.slider_thickness.value - self.slider_thickness.min) / thickness_range
            for _ in range(target_planes - existing_planes):
                offset = offset_factor * origin_plane_thickness
                self.add_plane(a, b, c, d, offset)

        elif target_planes < existing_planes:  # Delete excess plane objects
            for _ in range(existing_planes - target_planes):
                self.delete_plane()

    def level_parallel(self, _):
        # Level the plane parallel to the xy-plane about its center
        self.slider_a.value = 0
        self.slider_b.value = 0
        self.slider_c.value = 1
        self.slider_d.value = -np.mean(self.points[:, 2])

    def level_perpendicular(self, _):
        # Level the plane parallel to the xz-plane about its center
        self.slider_a.value = 0
        self.slider_b.value = 1
        self.slider_c.value = 0
        self.slider_d.value = -np.mean(self.points[:, 1])

    def load_data(self, hkl_coords, a1, a2, a3):
        hkl_coords = np.asarray(hkl_coords)
        a1 = np.asarray(a1)
        a2 = np.asarray(a2)
        a3 = np.asarray(a3)

        # Check if the shapes are compatible
        if hkl_coords.shape[1] != 3 or a1.shape != (3,) or a2.shape != (3,) or a3.shape != (3,):
            raise ValueError("Invalid shapes of input arrays")

        # Convert hkl coordinates to Cartesian coordinates
        self.points = hkl_coords[:, 0][:, np.newaxis] * a1 + hkl_coords[:, 1][:, np.newaxis] * a2 + hkl_coords[:, 2][:, np.newaxis] * a3

        # Rescale the plane to match the data range
        self.scale_plane()

        # Calculate the range of the data
        data_range = np.max(self.points, axis=0) - np.min(self.points, axis=0)

        # Create a 3D scatter plot
        self.scatter = go.Scatter3d(x=self.points[:, 0], y=self.points[:, 1], z=self.points[:, 2], mode='markers',
                                    marker=dict(size=5))

        # Initial plane parameters
        a, b, c, d = 1, -1, 1, 0

        self.plane = self.update_plane(a, b, c, d)

        # Create the initial 3D plot
        self.fig_3d = go.FigureWidget(data=[self.scatter, self.plane])
        self.fig_3d.layout.title = "3D Plot with Plane"
        self.fig_3d.layout.width = 800
        self.fig_3d.layout.height = 600

        # Create the 2D scatter plot for intersection points
        self.fig_2d = go.FigureWidget(data=[go.Scatter(x=[], y=[], mode='markers')])
        self.fig_2d.layout.title = "Intersection Points"
        self.fig_2d.layout.xaxis.title = 'X'
        self.fig_2d.layout.yaxis.title = 'Y'

        # Create sliders for plane parameters
        self.slider_a = widgets.FloatSlider(min=-1, max=1, step=0.01, value=a, description='a (X coefficient)')
        self.slider_b = widgets.FloatSlider(min=-1, max=1, step=0.01, value=b, description='b (Y coefficient)')
        self.slider_c = widgets.FloatSlider(min=-1, max=1, step=0.01, value=c, description='c (Z coefficient)')
        self.slider_d = widgets.FloatSlider(min=-1, max=1, step=0.01, value=d, description='d (Constant)')

        # Update the thickness slider range
        self.slider_thickness = widgets.FloatSlider(min=0.01, max=np.max(data_range), step=0.01, value=0.05, description='Thickness')

        # Create spring-action buttons for leveling the plane
        self.btn_level_parallel = widgets.Button(description="Level Parallel to XY-Plane")
        self.btn_level_perpendicular = widgets.Button(description="Level Perpendicular to XY-Plane")

        # Add event handlers to the buttons
        self.btn_level_parallel.on_click(self.level_parallel)
        self.btn_level_perpendicular.on_click(self.level_perpendicular)

        # Add the observer to the sliders
        self.slider_a.observe(self.update, names='value')
        self.slider_b.observe(self.update, names='value')
        self.slider_c.observe(self.update, names='value')
        self.slider_d.observe(self.update, names='value')
        self.slider_thickness.observe(self.update, names='value')

        # Display the interactive plot and sliders
        display(widgets.HBox([self.fig_3d, self.fig_2d]))
        display(widgets.VBox([widgets.Label('Plane Parameters:'), self.slider_a, self.slider_b, self.slider_c, self.slider_d,
                              self.slider_thickness, self.btn_level_parallel, self.btn_level_perpendicular]))

    def scale_plane(self):
        min_vals = np.min(self.points, axis=0)
        max_vals = np.max(self.points, axis=0)
        self.points = (self.points - min_vals) / (max_vals - min_vals)

    def level_parallel(self, _):
        # Level the plane parallel to the xy-plane about its center
        self.slider_a.value = 0
        self.slider_b.value = 0
        self.slider_c.value = 1
        self.slider_d.value = -np.mean(self.points[:, 2])

    def level_perpendicular(self, _):
        # Level the plane parallel to the xz-plane about its center
        self.slider_a.value = 0
        self.slider_b.value = 1
        self.slider_c.value = 0
        self.slider_d.value = -np.mean(self.points[:, 1])
        

In [2]:
import numpy as np
import matplotlib.pyplot as plt

# Define lattice parameters
lattice_params = [4, 2, 3, 80, 90, 90]  # Example for a tetragonal lattice
Mhkl = 5
sigma_r = 0.1
sigma_theta = 0.1
sigma_phi = 0.1
m = 100

# - Generate a pattern from input parameters.
crystal = PlotGenCrystal(lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m)

# Create image plane using the hkl coordinates of the GenCrystal object
crystal.create_image_plane()


HBox(children=(FigureWidget({
    'data': [{'marker': {'size': 5},
              'mode': 'markers',
          …

VBox(children=(Label(value='Plane Parameters:'), FloatSlider(value=1.0, description='a (X coefficient)', max=1…

In [2]:
%matplotlib widget

# Define lattice parameters
lattice_params = [1, 2, 1, 60, 90, 90]  # Example for a tetragonal lattice
Mhkl = 3
sigma_r = 0.1
sigma_theta = 0.1
sigma_phi = 0.1
m = 100

# - Generate a pattern from input parameters.
crystal = PlotGenCrystal(lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m)
# crystal.plot_panel()
crystal.plotly_plot_panel()
# crystal.plot_pixel_space()
print("Lattice parameters: ", crystal.lattice_params)

# - Generate a pattern from randomized parameters.
# randcrystal = PlotRandCrystal(lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m)
# randcrystal.plot_panel()
# print("Lattice parameters: ", randcrystal.lattice_params)

# Access some attributes
# print("Real Space Cartesian Vectors: ", [crystal.a1, crystal.a2, crystal.a3])
# print("Volume of Unit Cell: ", crystal.volume)
# print("Reciprocal Space Vectors: ", [crystal.b1, crystal.b2, crystal.b3])
# print("Computed (hkl) Coordinate Positions: ", crystal.hkl_coords)

KeyError: 'hkl'

In [35]:
indsize = np.size(crystal.hkl_indices)
print (indsize)
coordsize = np.size(crystal.hkl_coords)
print (coordsize)

81
81


#### (C) Random GIWAXS Image Generator

#### (D) Trained Convolutional Neural Net I: Peak Recognition

#### (E) Trained Convolutional Neural Net II: Space Group/Symmetry Classifier

##### Old Code Snippets

In [None]:
# - Old 1
import numpy as np
import scipy as sp
import os, re, gc
import matplotlib
import matplotlib.pyplot as plt

class CrystalStructureGenerator:
    def __init__(self):
        self.crystal_systems = {
            'triclinic': self.triclinic,
            'monoclinic': self.monoclinic,
            'orthorhombic': self.orthorhombic,
            'tetragonal': self.tetragonal,
            'trigonal': self.trigonal,
            'hexagonal': self.hexagonal,
            'cubic': self.cubic
        }

    def generate_random_lattice_parameters(self):
        # Select a random crystal system
        system = np.random.choice(list(self.crystal_systems.keys()))
        # Generate lattice parameters for the selected system
        return system, self.crystal_systems[system]()

    @staticmethod
    def triclinic():
        a, b, c = np.random.uniform(1, 10, 3)  # in angstrom
        alpha, beta, gamma = np.random.uniform(20, 160, 3)  # in degrees
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def monoclinic():
        a, b, c = np.random.uniform(1, 10, 3)
        alpha, gamma = 90, 90
        beta = np.random.uniform(20, 160)
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def orthorhombic():
        a, b, c = np.random.uniform(1, 10, 3)
        alpha, beta, gamma = 90, 90, 90
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def tetragonal():
        a, b = np.random.uniform(1, 10, 2)
        c = np.random.uniform(1, 10)
        alpha, beta, gamma = 90, 90, 90
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def trigonal():
        a = np.random.uniform(1, 10)
        alpha = np.random.uniform(20, 160)
        return a, a, a, alpha, alpha, alpha

    @staticmethod
    def hexagonal():
        a = np.random.uniform(1, 10)
        c = np.random.uniform(1, 10)
        return a, a, c, 90, 90, 120

    @staticmethod
    def cubic():
        a = np.random.uniform(1, 10)
        return a, a, a, 90, 90, 90

class ComputeCrystal:
    def __init__(self, lattice_params):
        
        self.a, self.b, self.c, self.alpha, self.beta, self.gamma = lattice_params
        self.lattice_params = self.randomize_latticeparams()

        # - Computed Real Space Cartesian Vectors
        self.a1, self.a2, self.a3 = self.compute_realspacevecs()

        # - Compute the Volume of a Unit Cell
        self.volume = self.compute_volume()

        # - Computed Reciprocal Space Vectors
        self.b1, self.b2, self.b3 = self.compute_recipvecs()

        # - Computed (hkl) Coordinate Positions
        self.hkl_coords = self.compute_coords()
        # pass

    # - Compute the real-space orthogonal basis vectors.
    def compute_realspacevecs(self):
        alpha = np.radians(self.alpha)
        beta = np.radians(self.beta)
        gamma = np.radians(self.gamma)

        a1 = np.array([self.a, 0, 0])
        a2 = np.array([self.b * np.cos(gamma), self.b * np.sin(gamma), 0])
        a3 = np.array([self.c * np.cos(beta),
                       self.c * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma),
                       self.c * np.sqrt(1 - np.cos(beta)**2 - ((np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma))**2)])

        return a1, a2, a3
    
    def compute_volume(self):
        return np.dot(self.a1, np.cross(self.a2, self.a3))

    def compute_recipvecs(self):
        b1 = 2 * np.pi * np.cross(self.a2, self.a3) / self.volume
        b2 = 2 * np.pi * np.cross(self.a3, self.a1) / self.volume
        b3 = 2 * np.pi * np.cross(self.a1, self.a2) / self.volume

        return b1, b2, b3

    # - Determine the crystal system from the input set of lattice vectors.
    def determine_crystal_system(self):
        if self.a == self.b == self.c and self.alpha == self.beta == self.gamma == 90:
            return self.cubic
        elif self.a == self.b != self.c and self.alpha == self.beta == 90 and self.gamma == 120:
            return self.hexagonal
        elif self.a == self.b != self.c and self.alpha == self.beta == self.gamma == 90:
            return self.tetragonal
        elif self.a != self.b != self.c and self.alpha == self.beta == self.gamma == 90:
            return self.orthorhombic
        elif self.a != self.b != self.c and self.alpha == self.gamma == 90 and self.beta != 90:
            return self.monoclinic
        elif self.a == self.b == self.c and self.alpha == self.beta == self.gamma != 90:
            return self.rhombohedral
        else:
            return self.triclinic

    # - (1) Cubic Crystal System
    def cubic(self):
        return self.a, self.a, self.a, 90, 90, 90

    # - (2) Hexagonal Crystal System
    def hexagonal(self):
        return self.a, self.a, self.c, 90, 90, 120

    # - (3) Tetragonal Crystal System
    def tetragonal(self):
        return self.a, self.a, self.c, 90, 90, 90

    # - (4) Orthorhombic Crystal System
    def orthorhombic(self):
        return self.a, self.b, self.c, 90, 90, 90

    # - (5) Monoclinic Crystal System
    def monoclinic(self):
        return self.a, self.b, self.c, 90, self.beta, 90

    # - (6) Triclinic Crystal System
    def triclinic(self):
        return self.a, self.b, self.c, self.alpha, self.beta, self.gamma

    # - (7) Rhombohedral Crystal System
    def rhombohedral(self):
        return self.a, self.a, self.a, self.alpha, self.alpha, self.alpha
    
# usage
csg = CrystalStructureGenerator()
system, lattice_parameters = csg.generate_random_lattice_parameters()
print("Crystal system: ", system)
print("Lattice constants a, b, c: ", lattice_parameters[:3])
print("Lattice angles alpha, beta, gamma: ", lattice_parameters[3:])

cc = ComputeCrystal(lattice_parameters)
print("Real space vectors: ", cc.a1, cc.a2, cc.a3)
print("Volume: ", cc.volume)
print("Reciprocal space vectors: ", cc.b1, cc.b2, cc.b3)


In [None]:
# - Old 2
import numpy as np
import scipy as sp
import os, re, gc
import matplotlib
import matplotlib.pyplot as plt

class ComputeCrystal:
    def __init__(self, lattice_params, Mhkl):
        self.a, self.b, self.c, self.alpha, self.beta, self.gamma = lattice_params
        self.Mhkl = Mhkl

        # - Computed Real Space Cartesian Vectors
        self.a1, self.a2, self.a3 = self.compute_realspacevecs()

        # - Compute the Volume of a Unit Cell
        self.volume = self.compute_volume()

        # - Computed Reciprocal Space Vectors
        self.b1, self.b2, self.b3 = self.compute_recipvecs()

        # - Computed (hkl) Coordinate Positions (Lattice Points)
        self.hkl_coords = self.compute_coords()

    # - Compute the real-space orthogonal basis vectors.
    def compute_realspacevecs(self):
        alpha = np.radians(self.alpha)
        beta = np.radians(self.beta)
        gamma = np.radians(self.gamma)

        a1 = np.array([self.a, 0, 0])
        a2 = np.array([self.b * np.cos(gamma), self.b * np.sin(gamma), 0])
        # Compute value under the square root
        z_value = 1 - np.cos(beta)**2 - ((np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma))**2
        # If z_value is negative, set it to zero
        z_value = max(z_value, 0)
        a3 = np.array([self.c * np.cos(beta),
                    self.c * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma),
                    self.c * np.sqrt(z_value)])

        return a1, a2, a3

    def compute_volume(self):
        return np.dot(self.a1, np.cross(self.a2, self.a3))

    def compute_recipvecs(self):
        b1 = 2 * np.pi * np.cross(self.a2, self.a3) / self.volume
        b2 = 2 * np.pi * np.cross(self.a3, self.a1) / self.volume
        b3 = 2 * np.pi * np.cross(self.a1, self.a2) / self.volume

        return b1, b2, b3

    def compute_coords(self):
        Mhkl = int(self.Mhkl)
        hkl_coords = []

        for h in range(Mhkl):
            for k in range(Mhkl):
                for l in range(Mhkl):
                    qx = h * self.b1[0] + k * self.b2[0] + l * self.b3[0]
                    qy = h * self.b1[1] + k * self.b2[1] + l * self.b3[1]
                    qz = h * self.b1[2] + k * self.b2[2] + l * self.b3[2]
                    hkl_coords.append((qx, qy, qz))

        return np.array(hkl_coords)

class RandomCrystal(ComputeCrystal):
    def __init__(self, Mhkl):
        self.crystal_systems = {
            'triclinic': self.triclinic,
            'monoclinic': self.monoclinic,
            'orthorhombic': self.orthorhombic,
            'tetragonal': self.tetragonal,
            'trigonal': self.trigonal,
            'hexagonal': self.hexagonal,
            'cubic': self.cubic
        }

        # Select a random crystal system
        system = np.random.choice(list(self.crystal_systems.keys()))
        # Generate lattice parameters for the selected system
        lattice_params = self.crystal_systems[system]()
        super().__init__(lattice_params, Mhkl)
        
    @staticmethod
    def triclinic():
        a, b, c = np.random.uniform(1, 10, 3)  # in angstrom
        alpha, beta, gamma = np.random.uniform(20, 160, 3)  # in degrees
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def monoclinic():
        a, b, c = np.random.uniform(1, 10, 3)
        alpha, gamma = 90, 90
        beta = np.random.uniform(20, 160)
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def orthorhombic():
        a, b, c = np.random.uniform(1, 10, 3)
        alpha, beta, gamma = 90, 90, 90
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def tetragonal():
        a = np.random.uniform(1, 10)
        c = np.random.uniform(1, 10)
        alpha, beta, gamma = 90, 90, 90
        return a, a, c, alpha, beta, gamma

    @staticmethod
    def trigonal():
        a = np.random.uniform(1, 10)
        alpha = np.random.uniform(20, 160)
        return a, a, a, alpha, alpha, alpha

    @staticmethod
    def hexagonal():
        a = np.random.uniform(1, 10)
        c = np.random.uniform(1, 10)
        return a, a, c, 90, 90, 120

    @staticmethod
    def cubic():
        a = np.random.uniform(1, 10)
        return a, a, a, 90, 90, 90


In [None]:
# - Old 3
import numpy as np
import scipy as sp
import os, re, gc
import matplotlib
import matplotlib.pyplot as plt

class ComputeCrystal:
    def __init__(self, lattice_params, Mhkl):
        self.a, self.b, self.c, self.alpha, self.beta, self.gamma = lattice_params
        self.Mhkl = Mhkl

        # - Computed Real Space Cartesian Vectors
        self.a1, self.a2, self.a3 = self.compute_realspacevecs()

        # - Compute the Volume of a Unit Cell
        self.volume = self.compute_volume()

        # - Computed Reciprocal Space Vectors
        self.b1, self.b2, self.b3 = self.compute_recipvecs()

        # - Computed (hkl) Coordinate Positions (Lattice Points)
        self.hkl_coords = self.compute_coords()

    # - Compute the real-space orthogonal basis vectors.
    def compute_realspacevecs(self):
        alpha = np.radians(self.alpha)
        beta = np.radians(self.beta)
        gamma = np.radians(self.gamma)

        a1 = np.array([self.a, 0, 0])
        a2 = np.array([self.b * np.cos(gamma), self.b * np.sin(gamma), 0])
        # Compute value under the square root
        z_value = 1 - np.cos(beta)**2 - ((np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma))**2
        # If z_value is negative, set it to zero
        z_value = max(z_value, 0)
        a3 = np.array([self.c * np.cos(beta),
                    self.c * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma),
                    self.c * np.sqrt(z_value)])

        return a1, a2, a3

    def compute_volume(self):
        return np.dot(self.a1, np.cross(self.a2, self.a3))

    def compute_recipvecs(self):
        b1 = 2 * np.pi * np.cross(self.a2, self.a3) / self.volume
        b2 = 2 * np.pi * np.cross(self.a3, self.a1) / self.volume
        b3 = 2 * np.pi * np.cross(self.a1, self.a2) / self.volume

        return b1, b2, b3

def compute_coords(self):
    Mhkl = int(self.Mhkl)
    hkl_coords = []

    for h in range(Mhkl):
        for k in range(Mhkl):
            for l in range(Mhkl):
                qx = h * self.b1[0] + k * self.b2[0] + l * self.b3[0]
                qy = h * self.b1[1] + k * self.b2[1] + l * self.b3[1]
                qz = h * self.b1[2] + k * self.b2[2] + l * self.b3[2]
                hkl_coords.append(((h, k, l), (qx, qy, qz)))

    return np.array(hkl_coords)

class RandomCrystal(ComputeCrystal):
    def __init__(self, Mhkl):
        self.crystal_systems = {
            'triclinic': self.triclinic,
            'monoclinic': self.monoclinic,
            'orthorhombic': self.orthorhombic,
            'tetragonal': self.tetragonal,
            'trigonal': self.trigonal,
            'hexagonal': self.hexagonal,
            'cubic': self.cubic
        }

        # Select a random crystal system
        system = np.random.choice(list(self.crystal_systems.keys()))
        # Generate lattice parameters for the selected system
        lattice_params = self.crystal_systems[system]()
        super().__init__(lattice_params, Mhkl)
        
    @staticmethod
    def triclinic():
        a, b, c = np.random.uniform(1, 10, 3)  # in angstrom
        alpha, beta, gamma = np.random.uniform(20, 160, 3)  # in degrees
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def monoclinic():
        a, b, c = np.random.uniform(1, 10, 3)
        alpha, gamma = 90, 90
        beta = np.random.uniform(20, 160)
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def orthorhombic():
        a, b, c = np.random.uniform(1, 10, 3)
        alpha, beta, gamma = 90, 90, 90
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def tetragonal():
        a = np.random.uniform(1, 10)
        c = np.random.uniform(1, 10)
        alpha, beta, gamma = 90, 90, 90
        return a, a, c, alpha, beta, gamma

    @staticmethod
    def trigonal():
        a = np.random.uniform(1, 10)
        alpha = np.random.uniform(20, 140)
        return a, a, a, alpha, alpha, alpha

    @staticmethod
    def hexagonal():
        a = np.random.uniform(1, 10)
        c = np.random.uniform(1, 10)
        return a, a, c, 90, 90, 120

    @staticmethod
    def cubic():
        a = np.random.uniform(1, 10)
        return a, a, a, 90, 90, 90


In [None]:
# - Old 4
'''
    # def plot_image(self, sigma):
    #     # Iterate through the points in the Cartesian coordinate system
    #     for hkl, q_vec in self.hkl_coords:
    #         gauss = self.gen_gaussian(q_vec)
    #         self.pixel_space += gauss

    #     # Plot the 3D pixel space
    #     fig = plt.figure()
    #     ax = fig.add_subplot(111, projection='3d')
    #     ax.scatter(self.pixel_space[:,:,0], self.pixel_space[:,:,1], self.pixel_space[:,:,2], c='r', marker='o')
    #     plt.show()

    # def gen_gaussian(self, m, q_vec, x_std_dev, y_std_dev):
    #     qx, qy, qz = q_vec
    #     x = np.linspace(-m/2, m/2, m)
    #     y = np.linspace(0, m, m)
    #     X, Y = np.meshgrid(x, y)
    #     qr, qtheta, qphi = self.cart_to_sph(qx, qy, qz)
    #     rot = R.from_rotvec(qphi * np.array([0, 0, 1]))
    #     rot_q_vec = rot.apply([qx, qy, qz])
    #     cov = [[x_std_dev**2, 0], [0, y_std_dev**2]]
    #     gauss = multivariate_normal.pdf(np.dstack((X, Y)), mean=rot_q_vec[0:2], cov=cov)
    #     return gauss
 '''


In [None]:
# - Old 5
import sys, re, os, time, psutil
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import norm
from scipy.stats import multivariate_normal
from scipy.spatial.transform import Rotation as R
from scipy.spatial.distance import cdist

# Crystal computation class
class ComputeCrystal:
    def __init__(self, lattice_params):
        self.a, self.b, self.c, self.alpha, self.beta, self.gamma = lattice_params

        # - Compute Real Space Cartesian Vectors
        self.a1, self.a2, self.a3 = self.compute_realspacevecs()
        self.a1, self.a2, self.a3 = self.clean_data([self.a1, self.a2, self.a3])

        # - Compute the Volume of a Unit Cell
        self.volume = self.compute_volume()

        # - Compute Reciprocal Space Vectors
        self.b1, self.b2, self.b3 = self.compute_recipvecs()
        self.b1, self.b2, self.b3 = self.clean_data([self.b1, self.b2, self.b3])

    def initialize(self, Mhkl):
        self.Mhkl = Mhkl

        # - Computed (hkl) Coordinate Positions (Lattice Points)
        self.hkl_coords = self.compute_coords()
        pass

    def compute_realspacevecs(self):
        alpha = np.radians(self.alpha)
        beta = np.radians(self.beta)
        gamma = np.radians(self.gamma)

        a1 = np.array([self.a, 0, 0])
        a2 = np.array([self.b * np.cos(gamma), self.b * np.sin(gamma), 0])
        z_value = 1 - np.cos(beta)**2 - ((np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma))**2
        z_value = max(z_value, 0)
        a3 = np.array([self.c * np.cos(beta),
                       self.c * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma),
                       self.c * np.sqrt(z_value)])

        return a1, a2, a3

    def compute_volume(self):
        return np.dot(self.a1, np.cross(self.a2, self.a3))

    def compute_recipvecs(self):
        b1 = 2 * np.pi * np.cross(self.a2, self.a3) / self.volume
        b2 = 2 * np.pi * np.cross(self.a3, self.a1) / self.volume
        b3 = 2 * np.pi * np.cross(self.a1, self.a2) / self.volume

        return b1, b2, b3

    def compute_coords(self):
        Mhkl = int(self.Mhkl)
        hkl_coords = []

        for h in range(Mhkl):
            for k in range(Mhkl):
                for l in range(Mhkl):
                    qx = h * self.b1[0] + k * self.b2[0] + l * self.b3[0]
                    qy = h * self.b1[1] + k * self.b2[1] + l * self.b3[1]
                    qz = h * self.b1[2] + k * self.b2[2] + l * self.b3[2]
                    hkl_coords.append(((h, k, l), (qx, qy, qz)))

        return np.array(hkl_coords)

    def clean_data(self, data):
        return [np.where(np.isclose(vec, 0, atol=1e-10), 0, vec) for vec in data]

# Crystal plotting class
class PlotCrystal(ComputeCrystal):
    def __init__(self, lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        super().__init__(lattice_params, Mhkl)
        self.sigma_r = sigma_r
        self.sigma_theta = sigma_theta
        self.sigma_phi = sigma_phi
        self.m = m
        self.create_pixel_space()

    # - Estimate the time order complexity for smearing a crystal structure.
    def estimate_time_complexity(self):
        """Estimate the time complexity based on system specs and current usage."""

        # Get the number of CPUs available
        num_cpus = os.cpu_count()

        # Get the CPU usage
        cpu_usage = psutil.cpu_percent()

        # Get the available memory
        avail_mem = psutil.virtual_memory().available

        # Estimate the time complexity based on the number of CPUs, CPU usage, and available memory
        # This is a very rough estimate and is unlikely to be accurate
        time_complexity = (self.Mhkl**3 * self.m**3) / (num_cpus * (1 - cpu_usage/100) * avail_mem)

        return time_complexity
    
    # - Generate Ewald Sphere Pixel Space
    def create_pixel_space(self):
        # Create the 3D pixel space with size m x m x m
        self.pixel_space = np.zeros((self.m, self.m, self.m))

    # - Convert Cartesian (qx, qy, qz) Coordinates to Spherical Coordinates
    def cart_to_sph(self, qx, qy, qz):
        qr = np.sqrt(qx**2 + qy**2 + qz**2)
        qtheta = np.arctan2(np.sqrt(qx**2 + qy**2), qz)
        qphi = np.arctan2(qy, qx)
        return qr, qtheta, qphi

    # - Plot Generated Points in a 3-Panel
    def plot_panel(self):
        plt.figure(figsize=(15,5))

        plt.subplot(131)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qx, qz, color = 'blue')
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.subplot(132)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qy, qz, color = 'green')
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.subplot(133)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, color = 'red')
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.tight_layout()
        plt.show()

    def plot_qx_qz(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qx, qz, color='blue')
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qy_qz(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qy, qz, color = 'green')
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qxy_qz(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, color = 'red')
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def gen_gaussian(self, q_vec):
        # Define the variance for each direction in 3D
        variance = [self.sigma_r**2, self.sigma_theta**2, self.sigma_phi**2]
        
        # Convert Cartesian coordinates to spherical
        qr, qtheta, qphi = self.cart_to_sph(*q_vec)
        
        # Create a grid of coordinates spanning the pixel space
        grid = np.indices((self.m, self.m, self.m)).reshape(3,-1).T

        # Generate a Gaussian centered at each peak
        gauss = multivariate_normal.pdf(grid, mean=[qr, qtheta, qphi], cov=variance)

        return gauss.reshape((self.m, self.m, self.m))

    # - Gaussian Convolution Methods + Plotting
    def gaussian(self, x, mu, sigma):
        return norm.pdf(x, mu, sigma)

    def convoluted_gaussian(self, q_vec):
        qx, qy, qz = q_vec
        qr, qtheta, qphi = self.cart_to_sph(qx, qy, qz)
        gauss_r = self.gaussian(qr, qr, self.sigma_r)
        gauss_theta = self.gaussian(qtheta, qtheta, self.sigma_theta)
        gauss_phi = self.gaussian(qphi, qphi, self.sigma_phi)
        return gauss_r * gauss_theta * gauss_phi
    
    def plot_qx_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qx, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qy_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qy, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qxy_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_panel_convolution(self):
        plt.figure(figsize=(15,5))
        
        plt.subplot(131)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qx, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        
        plt.subplot(132)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qy, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.subplot(133)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.tight_layout()
        plt.show()

    def plot_3D(self):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            ax.scatter(qx, qy, qz, c=intensity)

        ax.set_xlabel('Qx')
        ax.set_ylabel('Qy')
        ax.set_zlabel('Qz')
        plt.show()
    
    def plot_image(self, sigma):
        # Estimate the time complexity
        est_time_complexity = self.estimate_time_complexity()
        print(f"Estimated time complexity: {est_time_complexity}")

        # Ask the user to proceed
        proceed = input("Proceed ([y]/n)? ")
        if proceed.lower() != 'n':
            start_time = time.time()

            # Iterate through the points in the Cartesian coordinate system
            for hkl, q_vec in self.hkl_coords:
                gauss = self.gen_gaussian(q_vec)
                self.pixel_space += gauss

            # Define the threshold for plotting
            threshold = 0.5

            # Get the indices where the Gaussian is above the threshold
            ind = np.argwhere(self.pixel_space > threshold)

            # Plot the 3D pixel space
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            ax.scatter(ind[:,0], ind[:,1], ind[:,2], c='r', marker='o')
            plt.show()
        
            end_time = time.time()
            total_time = end_time - start_time
            print(f"Total computation time: {total_time} seconds")

# Your specific classes for generated and randomly generated crystals
class PlotGenCrystal(PlotCrystal):
    def __init__(self, lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        super().__init__(lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m)

class PlotRandCrystal(PlotCrystal):
    def __init__(self, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        # Generate random lattice parameters here
        lattice_params = [np.random.uniform(1, 5), np.random.uniform(1, 5), np.random.uniform(1, 5),
                          np.random.uniform(60, 120), np.random.uniform(60, 120), np.random.uniform(60, 120)]
        super().__init__(lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m)


In [None]:
# - Old 6
import sys, re, os, time, psutil
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import norm
from scipy.stats import multivariate_normal
from scipy.spatial.transform import Rotation as R
from scipy.spatial.distance import cdist

class ComputeCrystal:
    def __init__(self, lattice_params):
        self.a, self.b, self.c, self.alpha, self.beta, self.gamma = lattice_params

        # - Compute Real Space Cartesian Vectors
        self.a1, self.a2, self.a3 = self.compute_realspacevecs()
        self.a1, self.a2, self.a3 = self.clean_data([self.a1, self.a2, self.a3])

        # - Compute the Volume of a Unit Cell
        self.volume = self.compute_volume()

        # - Compute Reciprocal Space Vectors
        self.b1, self.b2, self.b3 = self.compute_recipvecs()
        self.b1, self.b2, self.b3 = self.clean_data([self.b1, self.b2, self.b3])

    def initialize(self, Mhkl):
        self.Mhkl = Mhkl

        # - Computed (hkl) Coordinate Positions (Lattice Points)
        self.hkl_coords = self.compute_coords()

    # - Compute the real-space orthogonal basis vectors.
    def compute_realspacevecs(self):
        alpha = np.radians(self.alpha)
        beta = np.radians(self.beta)
        gamma = np.radians(self.gamma)

        a1 = np.array([self.a, 0, 0])
        a2 = np.array([self.b * np.cos(gamma), self.b * np.sin(gamma), 0])
        z_value = 1 - np.cos(beta)**2 - ((np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma))**2
        z_value = max(z_value, 0)
        a3 = np.array([self.c * np.cos(beta),
                       self.c * (np.cos(alpha) - np.cos(beta) * np.cos(gamma)) / np.sin(gamma),
                       self.c * np.sqrt(z_value)])

        return a1, a2, a3

    def compute_volume(self):
        return np.dot(self.a1, np.cross(self.a2, self.a3))

    def compute_recipvecs(self):
        b1 = 2 * np.pi * np.cross(self.a2, self.a3) / self.volume
        b2 = 2 * np.pi * np.cross(self.a3, self.a1) / self.volume
        b3 = 2 * np.pi * np.cross(self.a1, self.a2) / self.volume

        return b1, b2, b3

    def compute_coords(self):
        Mhkl = int(self.Mhkl)
        hkl_coords = []

        for h in range(Mhkl):
            for k in range(Mhkl):
                for l in range(Mhkl):
                    qx = h * self.b1[0] + k * self.b2[0] + l * self.b3[0]
                    qy = h * self.b1[1] + k * self.b2[1] + l * self.b3[1]
                    qz = h * self.b1[2] + k * self.b2[2] + l * self.b3[2]
                    hkl_coords.append(((h, k, l), (qx, qy, qz)))

        return np.array(hkl_coords)

    def clean_data(self, data):
        return [np.where(np.isclose(vec, 0, atol=1e-10), 0, vec) for vec in data]

class GenCrystal(ComputeCrystal):
    def __init__(self, lattice_params, Mhkl):
        super().__init__(lattice_params)
        self.initialize(Mhkl)

class RandomCrystal(ComputeCrystal):
    def __init__(self, Mhkl):
        self.crystal_systems = {
            'triclinic': self.triclinic,
            'monoclinic': self.monoclinic,
            'orthorhombic': self.orthorhombic,
            'tetragonal': self.tetragonal,
            'trigonal': self.trigonal,
            'hexagonal': self.hexagonal,
            'cubic': self.cubic
        }

        # Select a random crystal system
        system = np.random.choice(list(self.crystal_systems.keys()))
        # Generate lattice parameters for the selected system
        lattice_params = self.crystal_systems[system]()
        
        # print lattice parameters
        print(f"Selected system: {system}")
        # print(f"Lattice parameters: {lattice_params}")

        super().__init__(lattice_params)
        self.initialize(Mhkl)

    @staticmethod
    def triclinic():
        a, b, c = np.round(np.random.uniform(1, 10, 3), 2)
        alpha, beta, gamma = np.round(np.random.uniform(20, 160, 3), 2)
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def monoclinic():
        a, b, c = np.round(np.random.uniform(1, 10, 3), 2)
        alpha, gamma = 90, 90
        beta = np.round(np.random.uniform(20, 160), 2)
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def orthorhombic():
        a, b, c = np.round(np.random.uniform(1, 10, 3), 2)
        alpha, beta, gamma = 90, 90, 90
        return a, b, c, alpha, beta, gamma

    @staticmethod
    def tetragonal():
        a = np.round(np.random.uniform(1, 10), 2)
        c = np.round(np.random.uniform(1, 10), 2)
        alpha, beta, gamma = 90, 90, 90
        return a, a, c, alpha, beta, gamma

    @staticmethod
    def trigonal():
        a = np.round(np.random.uniform(1, 10), 2)
        alpha = np.round(np.random.uniform(20, 140), 2)
        return a, a, a, alpha, alpha, alpha

    @staticmethod
    def hexagonal():
        a = np.round(np.random.uniform(1, 10), 2)
        c = np.round(np.random.uniform(1, 10), 2)
        return a, a, c, 90, 90, 120

    @staticmethod
    def cubic():
        a = np.round(np.random.uniform(1, 10), 2)
        return a, a, a, 90, 90, 90

class PlotCrystal(RandomCrystal):
    def __init__(self, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        super().__init__(Mhkl)
        self.sigma_r = sigma_r
        self.sigma_theta = sigma_theta
        self.sigma_phi = sigma_phi
        self.m = m
        self.create_pixel_space()

    # - Estimate the time order complexity for smearing a crystal structure.
    def estimate_time_complexity(self):
        """Estimate the time complexity based on system specs and current usage."""

        # Get the number of CPUs available
        num_cpus = os.cpu_count()

        # Get the CPU usage
        cpu_usage = psutil.cpu_percent()

        # Get the available memory
        avail_mem = psutil.virtual_memory().available

        # Estimate the time complexity based on the number of CPUs, CPU usage, and available memory
        # This is a very rough estimate and is unlikely to be accurate
        time_complexity = (self.Mhkl**3 * self.m**3) / (num_cpus * (1 - cpu_usage/100) * avail_mem)

        return time_complexity
    
    # - Generate Ewald Sphere Pixel Space
    def create_pixel_space(self):
        # Create the 3D pixel space with size m x m x m
        self.pixel_space = np.zeros((self.m, self.m, self.m))

    # - Convert Cartesian (qx, qy, qz) Coordinates to Spherical Coordinates
    def cart_to_sph(self, qx, qy, qz):
        qr = np.sqrt(qx**2 + qy**2 + qz**2)
        qtheta = np.arctan2(np.sqrt(qx**2 + qy**2), qz)
        qphi = np.arctan2(qy, qx)
        return qr, qtheta, qphi

    # - Plot Generated Points in a 3-Panel
    def plot_panel(self):
        plt.figure(figsize=(15,5))

        plt.subplot(131)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qx, qz, color = 'blue')
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.subplot(132)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qy, qz, color = 'green')
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.subplot(133)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, color = 'red')
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.tight_layout()
        plt.show()

    def plot_qx_qz(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qx, qz, color='blue')
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qy_qz(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(qy, qz, color = 'green')
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qxy_qz(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, color = 'red')
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def gen_gaussian(self, q_vec):
        # Define the variance for each direction in 3D
        variance = [self.sigma_r**2, self.sigma_theta**2, self.sigma_phi**2]
        
        # Convert Cartesian coordinates to spherical
        qr, qtheta, qphi = self.cart_to_sph(*q_vec)
        
        # Create a grid of coordinates spanning the pixel space
        grid = np.indices((self.m, self.m, self.m)).reshape(3,-1).T

        # Generate a Gaussian centered at each peak
        gauss = multivariate_normal.pdf(grid, mean=[qr, qtheta, qphi], cov=variance)

        return gauss.reshape((self.m, self.m, self.m))

    # - Gaussian Convolution Methods + Plotting
    def gaussian(self, x, mu, sigma):
        return norm.pdf(x, mu, sigma)

    def convoluted_gaussian(self, q_vec):
        qx, qy, qz = q_vec
        qr, qtheta, qphi = self.cart_to_sph(qx, qy, qz)
        gauss_r = self.gaussian(qr, qr, self.sigma_r)
        gauss_theta = self.gaussian(qtheta, qtheta, self.sigma_theta)
        gauss_phi = self.gaussian(qphi, qphi, self.sigma_phi)
        return gauss_r * gauss_theta * gauss_phi
    
    def smear_peaks(self):
        for hkl, q_vec in self.hkl_coords:
            gauss = self.gen_gaussian(q_vec)
            self.pixel_space += gauss

    def plot_qx_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qx, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qy_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qy, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_qxy_qz_convolution(self):
        plt.figure()
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        plt.show()

    def plot_panel_convolution(self):
        plt.figure(figsize=(15,5))
        
        plt.subplot(131)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qx, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_x \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')
        
        plt.subplot(132)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(qy, qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_y \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.subplot(133)
        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            plt.scatter(np.sqrt(qx**2 + qy**2), qz, c=intensity)
        plt.grid()
        plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
        plt.ylabel(r'$q_z \, (\AA^{-1})$')

        plt.tight_layout()
        plt.show()

    def plot_3D(self):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        for hkl, q_vec in self.hkl_coords:
            qx, qy, qz = q_vec
            intensity = self.convoluted_gaussian(q_vec)
            ax.scatter(qx, qy, qz, c=intensity)

        ax.set_xlabel('Qx')
        ax.set_ylabel('Qy')
        ax.set_zlabel('Qz')
        plt.show()
    
    def plot_image(self, sigma):
        # Estimate the time complexity
        est_time_complexity = self.estimate_time_complexity()
        print(f"Estimated time complexity: {est_time_complexity}")

        # Ask the user to proceed
        proceed = input("Proceed ([y]/n)? ")
        if proceed.lower() != 'n':
            start_time = time.time()

            # Iterate through the points in the Cartesian coordinate system
            for hkl, q_vec in self.hkl_coords:
                gauss = self.gen_gaussian(q_vec)
                self.pixel_space += gauss

            # Define the threshold for plotting
            threshold = 0.5

            # Get the indices where the Gaussian is above the threshold
            ind = np.argwhere(self.pixel_space > threshold)

            # Plot the 3D pixel space
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            ax.scatter(ind[:,0], ind[:,1], ind[:,2], c='r', marker='o')
            plt.show()
        
            end_time = time.time()
            total_time = end_time - start_time
            print(f"Total computation time: {total_time} seconds")

    def plot_pixel_space(self):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        pixel_space_reshaped = self.pixel_space.reshape(self.m**3)
        x, y, z = np.indices((self.m, self.m, self.m)).reshape(3, -1)
        c = pixel_space_reshaped
        img = ax.scatter(x, y, z, c=c, cmap=plt.hot())
        fig.colorbar(img)
        plt.show()

class PlotGenCrystal(GenCrystal, PlotCrystal):
    def __init__(self, lattice_params, Mhkl, sigma_r, sigma_theta, sigma_phi, m):
        GenCrystal.__init__(self, lattice_params, Mhkl)
        PlotCrystal.__init__(self, Mhkl, sigma_r, sigma_theta, sigma_phi, m)

class ImagePlane:
    def __init__(self, plot_crystal, plane_normal=np.array([0,0,1]), plane_point=np.array([0,0,0]), plane_thickness=0.05):
        self.plot_crystal = plot_crystal
        self.plane_normal = plane_normal
        self.plane_point = plane_point
        self.plane_thickness = plane_thickness
        self.plane_d = -plane_point.dot(plane_normal)

    def get_intersection_points(self):
        """
        Get the points in the plot crystal that intersect with the plane
        """

        # Calculate the distance of each point in the pixel space from the plane
        distances = np.abs((self.plane_normal[0] * self.plot_crystal.pixel_space[:,:,0] + 
                            self.plane_normal[1] * self.plot_crystal.pixel_space[:,:,1] + 
                            self.plane_normal[2] * self.plot_crystal.pixel_space[:,:,2] + 
                            self.plane_d) / 
                           np.linalg.norm(self.plane_normal))

        # Get the points in the pixel space where the distance is less than the plane thickness
        intersection_points = np.where(distances <= self.plane_thickness)

        return intersection_points

    def plot_intersection(self):
        """
        Plot the intersection of the plane and the plot crystal
        """

        intersection_points = self.get_intersection_points()

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(intersection_points[0], intersection_points[1], intersection_points[2], c='r', marker='o')
        plt.show()


In [None]:
# - Old 7
class ImagePlane:
    def __init__(self, plot_crystal, plane_normal=np.array([0,0,1]), plane_point=np.array([0,0,0]), plane_thickness=0.05):
        self.plot_crystal = plot_crystal
        self.plane_normal = plane_normal
        self.plane_point = plane_point
        self.plane_thickness = plane_thickness
        self.plane_d = -plane_point.dot(plane_normal)

    def get_intersection_points(self):
        """
        Get the points in the plot crystal that intersect with the plane
        """

        # Calculate the distance of each point in the pixel space from the plane
        distances = np.abs((self.plane_normal[0] * self.plot_crystal.pixel_space[:,:,0] + 
                            self.plane_normal[1] * self.plot_crystal.pixel_space[:,:,1] + 
                            self.plane_normal[2] * self.plot_crystal.pixel_space[:,:,2] + 
                            self.plane_d) / 
                           np.linalg.norm(self.plane_normal))

        # Get the points in the pixel space where the distance is less than the plane thickness
        intersection_points = np.where(distances <= self.plane_thickness)

        return intersection_points

    def plot_intersection(self):
        """
        Plot the intersection of the plane and the plot crystal
        """

        intersection_points = self.get_intersection_points()

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(intersection_points[0], intersection_points[1], intersection_points[2], c='r', marker='o')
        plt.show()

In [None]:
    # - Plot Generated Points in a 3-Panel
    # def plot_panel(self):
    #     plt.figure(figsize=(15,5))

    #     plt.subplot(131)
    #     for hkl, q_vec in self.hkl_coords:
    #         qx, qy, qz = q_vec
    #         plt.scatter(qx, qz, color = 'blue')
    #     plt.grid()
    #     plt.xlabel(r'$q_x \, (\AA^{-1})$')
    #     plt.ylabel(r'$q_z \, (\AA^{-1})$')

    #     plt.subplot(132)
    #     for hkl, q_vec in self.hkl_coords:
    #         qx, qy, qz = q_vec
    #         plt.scatter(qy, qz, color = 'green')
    #     plt.grid()
    #     plt.xlabel(r'$q_y \, (\AA^{-1})$')
    #     plt.ylabel(r'$q_z \, (\AA^{-1})$')

    #     plt.subplot(133)
    #     for hkl, q_vec in self.hkl_coords:
    #         qx, qy, qz = q_vec
    #         plt.scatter(np.sqrt(qx**2 + qy**2), qz, color = 'red')
    #     plt.grid()
    #     plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
    #     plt.ylabel(r'$q_z \, (\AA^{-1})$')

    #     plt.tight_layout()
    #     plt.show()

    # def plot_qx_qz(self):
    #     plt.figure()
    #     for hkl, q_vec in self.hkl_coords:
    #         qx, qy, qz = q_vec
    #         plt.scatter(qx, qz, color='blue')
    #     plt.grid()
    #     plt.xlabel(r'$q_x \, (\AA^{-1})$')
    #     plt.ylabel(r'$q_z \, (\AA^{-1})$')
    #     plt.show()

    # def plot_qy_qz(self):
    #     plt.figure()
    #     for hkl, q_vec in self.hkl_coords:
    #         qx, qy, qz = q_vec
    #         plt.scatter(qy, qz, color = 'green')
    #     plt.grid()
    #     plt.xlabel(r'$q_y \, (\AA^{-1})$')
    #     plt.ylabel(r'$q_z \, (\AA^{-1})$')
    #     plt.show()

    # def plot_qxy_qz(self):
    #     plt.figure()
    #     for hkl, q_vec in self.hkl_coords:
    #         qx, qy, qz = q_vec
    #         plt.scatter(np.sqrt(qx**2 + qy**2), qz, color = 'red')
    #     plt.grid()
    #     plt.xlabel(r'$q_{xy} \, (\AA^{-1})$')
    #     plt.ylabel(r'$q_z \, (\AA^{-1})$')
    #     plt.show()


In [None]:

'''
# class ImagePlane:
#     def __init__(self, points):
#         self.points = points
#         self.fig_3d = None
#         self.fig_2d = None
#         self.scatter = None
#         self.plane = None
#         self.slider_a = None
#         self.slider_b = None
#         self.slider_c = None
#         self.slider_d = None
#         self.slider_thickness = None
#         self.btn_level_parallel = None
#         self.btn_level_perpendicular = None

#     def update_plane(self, a, b, c, d):
#         x = np.linspace(0, 1, 10)
#         y = np.linspace(0, 1, 10)
#         X, Y = np.meshgrid(x, y)
#         Z = (-a * X - b * Y - d) / c

#         return go.Surface(x=X, y=Y, z=Z, opacity=0.5, showscale=False)

#     def update(self, change):
#         a = self.slider_a.value
#         b = self.slider_b.value
#         c = self.slider_c.value
#         d = self.slider_d.value
#         plane_thickness = self.slider_thickness.value
#         new_plane = self.update_plane(a, b, c, d)
#         self.fig_3d.data[1].x = new_plane.x
#         self.fig_3d.data[1].y = new_plane.y
#         self.fig_3d.data[1].z = new_plane.z

#         # Calculate intersection points
#         if np.isclose(a, 0) and np.isclose(c, 0):  # This means the plane is parallel to the XZ plane
#             intersection_points = self.points[np.abs(self.points[:, 1] - (-d/b)) <= plane_thickness]
#         else:
#             plane_distance = np.abs(a * self.points[:, 0] + b * self.points[:, 1] + c * self.points[:, 2] + d) / np.sqrt(
#                 a ** 2 + b ** 2 + c ** 2)
#             intersection_points = self.points[plane_distance <= plane_thickness]

#         # Update the scatter plot with intersection points
#         self.fig_2d.data[0].x = intersection_points[:, 0]
#         self.fig_2d.data[0].y = intersection_points[:, 1]

#     def level_parallel(self, _):
#         # Level the plane parallel to the xy-plane about its center
#         self.slider_a.value = 0
#         self.slider_b.value = 0
#         self.slider_c.value = 1
#         self.slider_d.value = -np.mean(self.points[:, 2])

#     def level_perpendicular(self, _):
#         # Level the plane parallel to the xz-plane about its center
#         self.slider_a.value = 0
#         self.slider_b.value = 1
#         self.slider_c.value = 0
#         self.slider_d.value = -np.mean(self.points[:, 1])

#     def load_data(self, hkl_coords, a1, a2, a3):
#         # Convert hkl coordinates to Cartesian coordinates
#         self.points = hkl_coords[:,0]*a1 + hkl_coords[:,1]*a2 + hkl_coords[:,2]*a3
        
#         # Create a 3D scatter plot
#         self.scatter = go.Scatter3d(x=self.points[:, 0], y=self.points[:, 1], z=self.points[:, 2], mode='markers',
#                                     marker=dict(size=5))

#         # Initial plane parameters
#         a, b, c, d = 1, -1, 1, 0

#         self.plane = self.update_plane(a, b, c, d)

#         # Create the initial 3D plot
#         self.fig_3d = go.FigureWidget(data=[self.scatter, self.plane])
#         self.fig_3d.layout.title = "3D Plot with Plane"
#         self.fig_3d.layout.width = 800
#         self.fig_3d.layout.height = 600

#         # Create the 2D scatter plot for intersection points
#         self.fig_2d = go.FigureWidget(data=[go.Scatter(x=[], y=[], mode='markers')])
#         self.fig_2d.layout.title = "Intersection Points"
#         self.fig_2d.layout.xaxis.title = 'X'
#         self.fig_2d.layout.yaxis.title = 'Y'

#         # Create sliders for plane parameters
#         self.slider_a = widgets.FloatSlider(min=-1, max=1, step=0.01, value=a, description='a (X coefficient)')
#         self.slider_b = widgets.FloatSlider(min=-1, max=1, step=0.01, value=b, description='b (Y coefficient)')
#         self.slider_c = widgets.FloatSlider(min=-1, max=1, step=0.01, value=c, description='c (Z coefficient)')
#         self.slider_d = widgets.FloatSlider(min=-1, max=1, step=0.01, value=d, description='d (Constant)')
#         self.slider_thickness = widgets.FloatSlider(min=0.01, max=0.5, step=0.01, value=0.05, description='Thickness')

#         # Create spring-action buttons for leveling the plane
#         self.btn_level_parallel = widgets.Button(description="Level Parallel to XY-Plane")
#         self.btn_level_perpendicular = widgets.Button(description="Level Perpendicular to XY-Plane")

#         # Add event handlers to the buttons
#         self.btn_level_parallel.on_click(self.level_parallel)
#         self.btn_level_perpendicular.on_click(self.level_perpendicular)

#         # Add the observer to the sliders
#         self.slider_a.observe(self.update, names='value')
#         self.slider_b.observe(self.update, names='value')
#         self.slider_c.observe(self.update, names='value')
#         self.slider_d.observe(self.update, names='value')
#         self.slider_thickness.observe(self.update, names='value')

#         # Display the interactive plot and sliders
#         display(widgets.HBox([self.fig_3d, self.fig_2d]))
#         display(widgets.VBox([widgets.Label('Plane Parameters:'), self.slider_a, self.slider_b, self.slider_c, self.slider_d,
#                               self.slider_thickness, self.btn_level_parallel, self.btn_level_perpendicular]))
'''

In [None]:

# class ImagePlane:
#     def __init__(self, hkl_coords=None):
#         self.points = hkl_coords
#         self.fig_3d = None
#         self.fig_2d = None
#         self.scatter = None
#         self.plane = None
#         self.slider_a = None
#         self.slider_b = None
#         self.slider_c = None
#         self.slider_d = None
#         self.slider_thickness = None
#         self.btn_level_parallel = None
#         self.btn_level_perpendicular = None

#     def update_plane(self, a, b, c, d):
#         x = np.linspace(0, 1, 10)
#         y = np.linspace(0, 1, 10)
#         X, Y = np.meshgrid(x, y)
#         Z = (-a * X - b * Y - d) / c

#         return go.Surface(x=X, y=Y, z=Z, opacity=0.5, showscale=False)

#     def update(self, change):
#         a = self.slider_a.value
#         b = self.slider_b.value
#         c = self.slider_c.value
#         d = self.slider_d.value
#         plane_thickness = self.slider_thickness.value
#         new_plane = self.update_plane(a, b, c, d)
#         self.fig_3d.data[1].x = new_plane.x
#         self.fig_3d.data[1].y = new_plane.y
#         self.fig_3d.data[1].z = new_plane.z

#         # Calculate intersection points
#         if np.isclose(a, 0) and np.isclose(c, 0):  # This means the plane is parallel to the XZ plane
#             intersection_points = self.points[np.abs(self.points[:, 1] - (-d/b)) <= plane_thickness]
#         else:
#             plane_distance = np.abs(a * self.points[:, 0] + b * self.points[:, 1] + c * self.points[:, 2] + d) / np.sqrt(
#                 a ** 2 + b ** 2 + c ** 2)
#             intersection_points = self.points[plane_distance <= plane_thickness]

#         # Calculate the thickness threshold from the origin plane
#         origin_plane_distance = np.abs(a * self.points[:, 0] + b * self.points[:, 1] + c * self.points[:, 2] + d) / np.sqrt(
#             a ** 2 + b ** 2 + c ** 2)
#         origin_plane_thickness = np.max(origin_plane_distance)

#         # Update the scatter plot with intersection points
#         self.fig_2d.data[0].x = intersection_points[:, 0]
#         self.fig_2d.data[0].y = intersection_points[:, 1]

#         # Generate or delete visual plane objects based on the thickness
#         existing_planes = len(self.fig_3d.data) - 1  # Exclude the scatter plot
#         thickness_range = self.slider_thickness.max - self.slider_thickness.min
#         thickness_increment = thickness_range * 0.1  # Generate new plane objects every 10% of the thickness slider range

#         target_planes = min(existing_planes, int(np.ceil((origin_plane_thickness - self.slider_thickness.min) / thickness_increment)))

#         if target_planes > existing_planes:  # Generate additional plane objects
#             offset_factor = (self.slider_thickness.value - self.slider_thickness.min) / thickness_range  # Offset factor based on the thickness slider value
#             for _ in range(target_planes - existing_planes):
#                 offset = offset_factor * origin_plane_thickness
#                 offset_plane = self.update_plane(a, b, c, d + offset)
#                 self.fig_3d.add_trace(go.Surface(x=offset_plane.x, y=offset_plane.y, z=offset_plane.z, opacity=0.5, showscale=False))

#         elif target_planes < existing_planes:  # Delete excess plane objects
#             for _ in range(existing_planes - target_planes):
#                 self.fig_3d.data.pop()
                
#     # def update(self, change):
#     #     a = self.slider_a.value
#     #     b = self.slider_b.value
#     #     c = self.slider_c.value
#     #     d = self.slider_d.value
#     #     plane_thickness = self.slider_thickness.value
#     #     new_plane = self.update_plane(a, b, c, d)
#     #     self.fig_3d.data[1].x = new_plane.x
#     #     self.fig_3d.data[1].y = new_plane.y
#     #     self.fig_3d.data[1].z = new_plane.z

#     #     # Calculate intersection points
#     #     if np.isclose(a, 0) and np.isclose(c, 0):  # This means the plane is parallel to the XZ plane
#     #         intersection_points = self.points[np.abs(self.points[:, 1] - (-d/b)) <= plane_thickness]
#     #     else:
#     #         plane_distance = np.abs(a * self.points[:, 0] + b * self.points[:, 1] + c * self.points[:, 2] + d) / np.sqrt(
#     #             a ** 2 + b ** 2 + c ** 2)
#     #         intersection_points = self.points[plane_distance <= plane_thickness]

#     #     # Update the scatter plot with intersection points
#     #     self.fig_2d.data[0].x = intersection_points[:, 0]
#     #     self.fig_2d.data[0].y = intersection_points[:, 1]

#     def level_parallel(self, _):
#         # Level the plane parallel to the xy-plane about its center
#         self.slider_a.value = 0
#         self.slider_b.value = 0
#         self.slider_c.value = 1
#         self.slider_d.value = -np.mean(self.points[:, 2])

#     def level_perpendicular(self, _):
#         # Level the plane parallel to the xz-plane about its center
#         self.slider_a.value = 0
#         self.slider_b.value = 1
#         self.slider_c.value = 0
#         self.slider_d.value = -np.mean(self.points[:, 1])

#     def load_data(self, hkl_coords, a1, a2, a3):
#         hkl_coords = np.asarray(hkl_coords)
#         a1 = np.asarray(a1)
#         a2 = np.asarray(a2)
#         a3 = np.asarray(a3)

#         # Check if the shapes are compatible
#         if hkl_coords.shape[1] != 3 or a1.shape != (3,) or a2.shape != (3,) or a3.shape != (3,):
#             raise ValueError("Invalid shapes of input arrays")

#         # Convert hkl coordinates to Cartesian coordinates
#         self.points = hkl_coords[:, 0][:, np.newaxis] * a1 + hkl_coords[:, 1][:, np.newaxis] * a2 + hkl_coords[:, 2][:, np.newaxis] * a3

#         # Rescale the plane to match the data range
#         self.scale_plane()

#         # Calculate the range of the data
#         data_range = np.max(self.points, axis=0) - np.min(self.points, axis=0)

#         # Create a 3D scatter plot
#         self.scatter = go.Scatter3d(x=self.points[:, 0], y=self.points[:, 1], z=self.points[:, 2], mode='markers',
#                                     marker=dict(size=5))

#         # Initial plane parameters
#         a, b, c, d = 1, -1, 1, 0

#         self.plane = self.update_plane(a, b, c, d)

#         # Create the initial 3D plot
#         self.fig_3d = go.FigureWidget(data=[self.scatter, self.plane])
#         self.fig_3d.layout.title = "3D Plot with Plane"
#         self.fig_3d.layout.width = 800
#         self.fig_3d.layout.height = 600

#         # Create the 2D scatter plot for intersection points
#         self.fig_2d = go.FigureWidget(data=[go.Scatter(x=[], y=[], mode='markers')])
#         self.fig_2d.layout.title = "Intersection Points"
#         self.fig_2d.layout.xaxis.title = 'X'
#         self.fig_2d.layout.yaxis.title = 'Y'

#         # Create sliders for plane parameters
#         self.slider_a = widgets.FloatSlider(min=-1, max=1, step=0.01, value=a, description='a (X coefficient)')
#         self.slider_b = widgets.FloatSlider(min=-1, max=1, step=0.01, value=b, description='b (Y coefficient)')
#         self.slider_c = widgets.FloatSlider(min=-1, max=1, step=0.01, value=c, description='c (Z coefficient)')
#         self.slider_d = widgets.FloatSlider(min=-1, max=1, step=0.01, value=d, description='d (Constant)')

#         # Update the thickness slider range
#         self.slider_thickness = widgets.FloatSlider(min=0.01, max=np.max(data_range), step=0.01, value=0.05, description='Thickness')

#         # Create spring-action buttons for leveling the plane
#         self.btn_level_parallel = widgets.Button(description="Level Parallel to XY-Plane")
#         self.btn_level_perpendicular = widgets.Button(description="Level Perpendicular to XY-Plane")

#         # Add event handlers to the buttons
#         self.btn_level_parallel.on_click(self.level_parallel)
#         self.btn_level_perpendicular.on_click(self.level_perpendicular)

#         # Add the observer to the sliders
#         self.slider_a.observe(self.update, names='value')
#         self.slider_b.observe(self.update, names='value')
#         self.slider_c.observe(self.update, names='value')
#         self.slider_d.observe(self.update, names='value')
#         self.slider_thickness.observe(self.update, names='value')

#         # Display the interactive plot and sliders
#         display(widgets.HBox([self.fig_3d, self.fig_2d]))
#         display(widgets.VBox([widgets.Label('Plane Parameters:'), self.slider_a, self.slider_b, self.slider_c, self.slider_d,
#                               self.slider_thickness, self.btn_level_parallel, self.btn_level_perpendicular]))

    
#     def scale_plane(self):
#         min_vals = np.min(self.points, axis=0)
#         max_vals = np.max(self.points, axis=0)
#         self.points = (self.points - min_vals) / (max_vals - min_vals)