In [None]:
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display, clear_output
from ase.build import mx2, graphene
from ase.visualize import view
from ase.dft.kpoints import sc_special_points as special_points, get_bandpath
import pyscf.pbc.tools.pyscf_ase as pyscf_ase
import pyscf.pbc.gto as pbcgto
import pyscf.pbc.dft as pbcdft
import matplotlib.pyplot as plt

# Function to create and visualize different materials
def create_material(material, lattice_a, vacuum, basis, pseudo):
    # Build the material structure
    if material == 'Graphene':
        cell = graphene(formula="C2", a=lattice_a, size=(1, 1, 1), vacuum=vacuum)
    elif material in ['MoS2', 'WS2', 'MoSe2', 'WSe2']:
        cell = mx2(formula=material,kind = '2H', a=lattice_a, thickness=3.0, vacuum=vacuum)
    else:
        raise ValueError("Unsupported material")
    
    print(f"{material} Cell Volume: {cell.get_volume()} Å³")

    # Convert to PySCF Cell
    pyscf_cell = pbcgto.Cell()
    pyscf_cell.atom = pyscf_ase.ase_atoms_to_pyscf(cell)
    pyscf_cell.a = cell.cell
    pyscf_cell.basis = basis
    pyscf_cell.pseudo = pseudo
    pyscf_cell.verbose = 3
    pyscf_cell.build(None, None)

    # Visualize the material structure
    view(cell)
    return pyscf_cell

# Function to calculate total energy
def calculate_energy(cell,xc):
    mf = pbcdft.RKS(cell)
    mf.xc = xc
    energy = mf.kernel()
    print(f"Total Energy: {energy:.6f} Hartree")
    return mf
    
def calculate_bandstructure(cell, mf):
    points = special_points['hexagonal']
    G = points['G']
    M = points['M']
    K = points['K']
    band_kpts, kpath, sp_points = get_bandpath([G, M, K, G], cell.a, npoints=50)
    band_kpts = cell.get_abs_kpts(band_kpts)

    # Compute band energies
    e_kn = mf.get_bands(band_kpts)[0]
    vbmax = max(en[cell.nelectron // 2 - 1] for en in e_kn)  # Valence Band Maximum
    e_kn = [en - vbmax for en in e_kn]  # Align to VBM

    # Plot band structure
    au2ev = 27.21139  # Hartree to eV conversion
    plt.figure(figsize=(6, 6))
    nbands = cell.nao_nr()
    for n in range(nbands):
        plt.plot(kpath, [e[n] * au2ev for e in e_kn], color='#4169E1')
    for p in sp_points:
        plt.axvline(p, color='k', linestyle='--', linewidth=0.5)
    plt.axhline(0, color='k', linestyle='-', linewidth=0.5)
    plt.xticks(sp_points, ['$%s$' % n for n in ['Γ', 'M', 'K', 'Γ']])
    plt.xlabel('Wave Vector')
    plt.ylabel('Energy (eV)')
    plt.title('Band Structure')
    plt.grid()
    plt.show()

# Interactive function
def interactive_pyscf(material, lattice_a, vacuum, basis, pseudo,xc):
    clear_output(wait=True)
    print(f"Creating {material} cell...")
    cell = create_material(material, lattice_a, vacuum, basis, pseudo)
    print("\nCalculating total energy...")
    mf = calculate_energy(cell,xc)
    print("\nCalculating band structure...")
    calculate_bandstructure(cell, mf)

# Widget components
lattice_a_slider = widgets.FloatSlider(value=2.46, min=2.0, max=4.0, step=0.01, description='Lattice (a):')
vacuum_slider = widgets.FloatSlider(value=10.0, min=5.0, max=20.0, step=0.1, description='Vacuum:')
basis_dropdown = widgets.Dropdown(
    options=['gth-szv', 'gth-dzvp', 'gth-tzvp', 'cc-pvdz', 'sto-3g'],
    value='gth-szv',
    description='Basis Set:',
)
pseudo_dropdown = widgets.Dropdown(
    options=['gth-pade', 'gth-blyp', 'gth-pbe', 'gth-hf'],
    value='gth-pade',
    description='Pseudopotential:',
)
xc_dropdown = widgets.Dropdown(
    options=['lda,vwn', 'pbe', 'b88,lyp', 'pbe0','b3lyp','tpss'],
    value='lda,vwn',
    description='Exchange correlation functional:',
)
material_dropdown = widgets.Dropdown(
    options=['Graphene'],
    value='Graphene',
    description='Material:',
)

# Interactive display
ui = widgets.VBox([material_dropdown,lattice_a_slider, vacuum_slider, basis_dropdown, pseudo_dropdown,xc_dropdown])
out = widgets.Output()

interactive_display = widgets.interactive_output(
    interactive_pyscf,
    {
        'material' : material_dropdown,
        'lattice_a': lattice_a_slider,
        'vacuum': vacuum_slider,
        'basis': basis_dropdown,
        'pseudo': pseudo_dropdown,
        'xc' : xc_dropdown
    }
)

# Display interface
display(ui, out, interactive_display)
