In [None]:
import sys
import subprocess
import pkg_resources

required_packages = [
    'numpy', 'pandas', 'pymatgen', 'plotly', 'pyexcel_ods3', 'dash', 'cifkit'
]

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

for package in required_packages:
    try:
        pkg_resources.require(package)
    except pkg_resources.DistributionNotFound:
        print(f"{package} not found. Installing...")
        install(package)
    except pkg_resources.VersionConflict:
        print(f"Updating {package}...")
        install(f"--upgrade {package}")

print("All required packages are installed and up to date.")

import os
import json
import numpy as np
import pandas as pd
from math import sin, radians, asin, degrees, pi, cos
from pymatgen.core import Structure, Lattice
from pymatgen.io.cif import CifParser
from pymatgen.analysis.diffraction.core import AbstractDiffractionPatternCalculator, DiffractionPattern, get_unique_families
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
import plotly.graph_objects as go
from pyexcel_ods3 import get_data
from dash import Dash, dcc, html
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from cifkit import CifEnsemble

# XRD wavelengths in angstroms
WAVELENGTHS = {
    "CuKa": 1.54184, "CuKa2": 1.54439, "CuKa1": 1.54056, "CuKb1": 1.39222,
    "MoKa": 0.71073, "MoKa2": 0.71359, "MoKa1": 0.70930, "MoKb1": 0.63229,
    "CrKa": 2.29100, "CrKa2": 2.29361, "CrKa1": 2.28970, "CrKb1": 2.08487,
    "FeKa": 1.93735, "FeKa2": 1.93998, "FeKa1": 1.93604, "FeKb1": 1.75661,
    "CoKa": 1.79026, "CoKa2": 1.79285, "CoKa1": 1.78896, "CoKb1": 1.63079,
    "AgKa": 0.560885, "AgKa2": 0.563813, "AgKa1": 0.559421, "AgKb1": 0.497082,
}

# Load atomic scattering parameters
atomic_scattering_params_path = "atomic_scattering_params.json"
if not os.path.exists(atomic_scattering_params_path):
    raise FileNotFoundError(f"Required file 'atomic_scattering_params.json' not found in directory: {os.path.dirname(__file__)}")
with open(atomic_scattering_params_path) as file:
    ATOMIC_SCATTERING_PARAMS = json.load(file)

class XRDCalculator(AbstractDiffractionPatternCalculator):
    AVAILABLE_RADIATION = tuple(WAVELENGTHS)

    def __init__(self, wavelength="CuKa", symprec: float = 0, debye_waller_factors=None):
        if isinstance(wavelength, (float, int)):
            self.wavelength = wavelength
        elif isinstance(wavelength, str):
            self.radiation = wavelength
            self.wavelength = WAVELENGTHS[wavelength]
        else:
            raise TypeError(f"{type(wavelength)=} must be either float, int or str")
        self.symprec = symprec
        self.debye_waller_factors = debye_waller_factors or {}

    def get_pattern(self, structure: Structure, scaled=True, two_theta_range=(0, 90)):
        if self.symprec:
            finder = SpacegroupAnalyzer(structure, symprec=self.symprec)
            structure = finder.get_refined_structure()

        wavelength = self.wavelength
        lattice = structure.lattice
        is_hex = lattice.is_hexagonal()

        min_r, max_r = (
            (0, 2 / wavelength)
            if two_theta_range is None
            else [2 * sin(radians(t / 2)) / wavelength for t in two_theta_range]
        )

        recip_lattice = lattice.reciprocal_lattice_crystallographic
        recip_pts = recip_lattice.get_points_in_sphere([[0, 0, 0]], [0, 0, 0], max_r)
        if min_r:
            recip_pts = [pt for pt in recip_pts if pt[1] >= min_r]

        _zs, _coeffs, _frac_coords, _occus, _dw_factors = [], [], [], [], []

        for site in structure:
            for sp, occu in site.species.items():
                _zs.append(sp.Z)
                try:
                    c = ATOMIC_SCATTERING_PARAMS[sp.symbol]
                except KeyError:
                    raise ValueError(f"Unable to calculate XRD pattern as there is no scattering coefficients for {sp.symbol}.")
                _coeffs.append(c)
                _dw_factors.append(self.debye_waller_factors.get(sp.symbol, 0))
                _frac_coords.append(site.frac_coords)
                _occus.append(occu)

        zs = np.array(_zs)
        coeffs = np.array(_coeffs)
        frac_coords = np.array(_frac_coords)
        occus = np.array(_occus)
        dw_factors = np.array(_dw_factors)
        peaks = {}
        two_thetas = []

        for hkl, g_hkl, ind, _ in sorted(recip_pts, key=lambda i: (i[1], -i[0][0], -i[0][1], -i[0][2])):
            hkl = [int(round(i)) for i in hkl]
            if g_hkl != 0:
                theta = asin(wavelength * g_hkl / 2)
                s = g_hkl / 2
                s2 = s**2
                g_dot_r = np.dot(frac_coords, np.transpose([hkl])).T[0]

                fs = zs - 41.78214 * s2 * np.sum(
                    coeffs[:, :, 0] * np.exp(-coeffs[:, :, 1] * s2),
                    axis=1,  
                )

                dw_correction = np.exp(-dw_factors * s2)
                f_hkl = np.sum(fs * occus * np.exp(2j * pi * g_dot_r) * dw_correction)
                lorentz_factor = (1 + cos(2 * theta) ** 2) / (sin(theta) ** 2 * cos(theta))
                i_hkl = (f_hkl * f_hkl.conjugate()).real
                two_theta = degrees(2 * theta)

                if is_hex:
                    hkl = (hkl[0], hkl[1], -hkl[0] - hkl[1], hkl[2])

                ind = np.where(
                    np.abs(np.subtract(two_thetas, two_theta)) < AbstractDiffractionPatternCalculator.TWO_THETA_TOL
                )
                if len(ind[0]) > 0:
                    peaks[two_thetas[ind[0][0]]][0] += i_hkl * lorentz_factor
                    peaks[two_thetas[ind[0][0]]][1].append(tuple(hkl)) 
                else:
                    d_hkl = 1 / g_hkl
                    peaks[two_theta] = [i_hkl * lorentz_factor, [tuple(hkl)], d_hkl]
                    two_thetas.append(two_theta)

        max_intensity = max(v[0] for v in peaks.values())
        x = []
        y = []
        hkls = []
        d_hkls = []
        for k in sorted(peaks):
            v = peaks[k]
            fam = get_unique_families(v[1])
            if v[0] / max_intensity * 100 > AbstractDiffractionPatternCalculator.SCALED_INTENSITY_TOL: 
                x.append(k)
                y.append(v[0])
                hkls.append([{"hkl": hkl, "multiplicity": mult} for hkl, mult in fam.items()])
                d_hkls.append(v[2])
        xrd = DiffractionPattern(x, y, hkls, d_hkls)
        if scaled:
            xrd.normalize(mode="max", value=100)
        return xrd
def parse_xy_raw(file_path):
    data = pd.read_csv(file_path, sep='\s+', header=None)
    data.columns = ['2_theta', 'intensity']
    return data

def parse_dif(file_path):
    dif_data = get_data(file_path)
    data = pd.DataFrame(dif_data['Sheet1'][1:], columns=dif_data['Sheet1'][0])
    data.columns = ['2_theta', 'intensity']
    data = data.dropna().reset_index(drop=True)
    return data

def plot_xrd(patterns, titles, wavelength, experimental_data=None):
    fig = go.Figure()
    for pattern, title in zip(patterns, titles):
        fig.add_trace(go.Bar(x=pattern.x, y=pattern.y, name=title, width=0.2))
    
    if experimental_data is not None:
        fig.add_trace(go.Scatter(x=experimental_data['2_theta'], y=experimental_data['intensity'], mode='lines', name='Experimental Data'))
    fig.update_layout(
        title=f"XRD Patterns (Wavelength: {wavelength})",
        xaxis_title="2 Theta",
        yaxis_title="Intensity",
        template="plotly_white",
        barmode='overlay'
    )
    return fig

def preclean_cifs(directory):
    ensemble = CifEnsemble(directory)
    print(f"Cleaning {ensemble.file_count} CIF files in {ensemble.dir_path}")
    for cif in ensemble.cifs:
        print(f"Cleaning: {cif.file_path}")
    return ensemble

# Global variables to store CIF data
cif_structures = {}
patterns = {}
experimental_data = None
selected_wavelength = "CuKa"

# Dash app
app = Dash(__name__)

app.layout = html.Div([
    html.H1("XRD Pattern Customizer"),
    
    dcc.Dropdown(
        id='cif-selector',
        options=[],
        value=None,
        placeholder="Select a CIF file"
    ),
    
    html.Div([
        html.Div([
            html.Label("a:"),
            dcc.Input(id='a-input', type='number', placeholder='a')
        ], style={'display': 'inline-block', 'marginRight': '10px'}),
        html.Div([
            html.Label("b:"),
            dcc.Input(id='b-input', type='number', placeholder='b')
        ], style={'display': 'inline-block', 'marginRight': '10px'}),
        html.Div([
            html.Label("c:"),
            dcc.Input(id='c-input', type='number', placeholder='c')
        ], style={'display': 'inline-block', 'marginRight': '10px'}),
        html.Div([
            html.Label("\u03B1:"),  # Unicode for alpha
            dcc.Input(id='alpha-input', type='number', placeholder='\u03B1')
        ], style={'display': 'inline-block', 'marginRight': '10px'}),
        html.Div([
            html.Label("\u03B2:"),  # Unicode for beta
            dcc.Input(id='beta-input', type='number', placeholder='\u03B2')
        ], style={'display': 'inline-block', 'marginRight': '10px'}),
        html.Div([
            html.Label("\u03B3:"),  # Unicode for gamma
            dcc.Input(id='gamma-input', type='number', placeholder='\u03B3')
        ], style={'display': 'inline-block'}),
    ]),
    
    html.Button('Update', id='update-button', n_clicks=0),
    html.Button('Show CIF Summary', id='show-summary-button', n_clicks=0),
    html.Div(id='cif-summary', style={'marginTop': '20px', 'whiteSpace': 'pre-wrap'}),
    
    dcc.Graph(id='xrd-plot')
])

@app.callback(
    [Output('cif-selector', 'options'),
     Output('cif-selector', 'value'),
     Output('a-input', 'value'),
     Output('b-input', 'value'),
     Output('c-input', 'value'),
     Output('alpha-input', 'value'),
     Output('beta-input', 'value'),
     Output('gamma-input', 'value')],
    [Input('cif-selector', 'value')]
)
def update_cif_selection(selected_cif):
    options = [{'label': cif, 'value': cif} for cif in cif_structures.keys()]
    
    if selected_cif:
        lattice = cif_structures[selected_cif].lattice
        return options, selected_cif, lattice.a, lattice.b, lattice.c, lattice.alpha, lattice.beta, lattice.gamma
    else:
        first_cif = next(iter(cif_structures.keys())) if cif_structures else None
        if first_cif:
            lattice = cif_structures[first_cif].lattice
            return options, first_cif, lattice.a, lattice.b, lattice.c, lattice.alpha, lattice.beta, lattice.gamma
        else:
            return options, None, None, None, None, None, None, None

@app.callback(
    Output('xrd-plot', 'figure'),
    [Input('update-button', 'n_clicks')],
    [State('cif-selector', 'value'),
     State('a-input', 'value'),
     State('b-input', 'value'),
     State('c-input', 'value'),
     State('alpha-input', 'value'),
     State('beta-input', 'value'),
     State('gamma-input', 'value')]
)
def update_xrd_plot(n_clicks, selected_cif, a, b, c, alpha, beta, gamma):
    if not selected_cif:
        raise PreventUpdate

    # Create a new structure with updated lattice parameters
    original_structure = cif_structures[selected_cif]
    new_lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma)
    new_structure = Structure(new_lattice, original_structure.species, original_structure.frac_coords)
    
    # Calculate XRD pattern
    calculator = XRDCalculator(wavelength=selected_wavelength)
    pattern = calculator.get_pattern(new_structure)
    
    # Update the patterns dictionary
    patterns[selected_cif] = pattern

    # Plot XRD patterns
    fig = plot_xrd(list(patterns.values()), list(patterns.keys()), selected_wavelength, experimental_data)
    return fig

@app.callback(
    Output('cif-summary', 'children'),
    [Input('show-summary-button', 'n_clicks')],
    [State('cif-selector', 'value')]
)
def show_cif_summary(n_clicks, selected_cif):
    if not selected_cif or n_clicks == 0:
        return ''
    
    structure = cif_structures[selected_cif]
    summary = f"Summary for {selected_cif}:\n\n"
    summary += f"Formula: {structure.composition.reduced_formula}\n"
    summary += f"Space Group: {structure.get_space_group_info()[0]} ({structure.get_space_group_info()[1]})\n"
    summary += f"Lattice Parameters:\n"
    summary += f"  a = {structure.lattice.a:.4f} Å\n"
    summary += f"  b = {structure.lattice.b:.4f} Å\n"
    summary += f"  c = {structure.lattice.c:.4f} Å\n"
    summary += f"  α = {structure.lattice.alpha:.4f}°\n"
    summary += f"  β = {structure.lattice.beta:.4f}°\n"
    summary += f"  γ = {structure.lattice.gamma:.4f}°\n"
    summary += f"Volume: {structure.volume:.4f} Å³\n"
    summary += f"Number of sites: {len(structure)}\n"
    
    return summary

if __name__ == "__main__":
    # Step 1: Enter the directory containing CIF files
    directory = input("Enter the directory containing CIF files: ")
    
    # Pre-clean CIF files
    cleaned_ensemble = preclean_cifs(directory)
    
    # Step 2: Select CIF files from the directory
    files = [f for f in os.listdir(directory) if f.endswith('.cif')]
    print("Available CIF files:")
    for idx, file in enumerate(files):
        print(f"{idx + 1}. {file}")
    selected_indices = input("Enter the indices of CIF files to select (comma separated): ")
    selected_files = [files[int(idx) - 1] for idx in selected_indices.split(',')]
    selected_files = [os.path.join(directory, file) for file in selected_files]
    
    # Step 3: Select the wavelength
    print("Available wavelengths:")
    for idx, wl in enumerate(WAVELENGTHS.keys()):
        print(f"{idx + 1}. {wl}")
    selected_wavelength_idx = int(input("Enter the index of the wavelength to select: "))
    selected_wavelength = list(WAVELENGTHS.keys())[selected_wavelength_idx - 1]

    # Step 4: Ask if you want to upload an experimental file
    upload_experimental = input("Do you want to upload an experimental file for data comparison? (yes/no): ").strip().lower() == 'yes'
    experimental_file_path = None
    if upload_experimental:
        experimental_file_path = input("Enter the path of the experimental file (XY RAW or DIF): ")
        if not os.path.exists(experimental_file_path):
            raise FileNotFoundError(f"File '{experimental_file_path}' not found.")

    if upload_experimental:
        if experimental_file_path.lower().endswith('.dif'):
            experimental_data = parse_dif(experimental_file_path)
        else:
            experimental_data = parse_xy_raw(experimental_file_path)
        # Normalize experimental data
        max_intensity = experimental_data['intensity'].max()
        experimental_data['intensity'] = (experimental_data['intensity'] / max_intensity) * 100

    # Calculate XRD patterns
    calculator = XRDCalculator(wavelength=selected_wavelength)
    for file in selected_files:
        structure = Structure.from_file(file)
        cif_structures[os.path.basename(file)] = structure
        pattern = calculator.get_pattern(structure)
        patterns[os.path.basename(file)] = pattern

    # Run the Dash app
    app.run_server(debug=True, port=8050)  

All required packages are installed and up to date.
