In [None]:
#THis notebook asks you to investigate a photoinitiator of your choice

In [None]:
#start up

from pyscf import gto, scf, dft, grad, mp, cc, grad, hessian
from ase import Atoms, optimize
from ase.calculators.calculator import Calculator, all_changes
import numpy as np
import matplotlib.pyplot as plt
from ase.build import molecule
from ase.visualize import view
from ase.optimize import BFGS

from weas_widget import WeasWidget

# Define a custom PySCF ASE calculator

class PySCFCalculator(Calculator):
    implemented_properties = ['energy', 'forces']

    def __init__(self, basis='sto-3g', charge=0, spin=0, method='RHF', xc='pbe', frequencies=False, **kwargs):
        super().__init__(**kwargs)
        self.basis = basis
        self.charge = charge
        self.spin = spin
        self.method = method.upper()
        self.xc = xc
        self.frequencies = frequencies
        self._mol = None
        self._mf = None

    def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes):
        super().calculate(atoms, properties, system_changes)

        # Atom string from ASE
        atom_str = "; ".join(
            f"{s} {x} {y} {z}"
            for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions())
        )

        # Build molecule
        mol = gto.Mole()
        mol.atom = atom_str
        mol.basis = self.basis
        mol.charge = self.charge
        mol.spin = self.spin
        mol.unit = 'Angstrom'
        mol.build()
        self._mol = mol

        # --- Mean-field step ---
        mf = self._get_scf_method(mol)
        mf = mf.run()
        self._mf = mf  # cache for post-HF

        # --- Post-HF method ---
        if self.method == 'MP2':
            from pyscf.mp.grad import mp2 as mp2_grad
            post = mp.MP2(mf).run()
            energy = post.e_tot
            forces = -mp2_grad.Gradients(post).kernel()
        elif self.method == 'CCSD':
            from pyscf.cc import ccsd 
            from pyscf.cc.grad import ccsd as ccsd_grad
            post = cc.CCSD(mf).run()
            energy = post.e_tot
            forces = -ccsd_grad.Gradients(post).kernel()
        else:
            energy = mf.energy_tot()
            forces = -self._get_grad_method(mf).kernel()


        # Store results
        self.results = {
            'energy': energy,
            'forces': forces,
        }

                # --- Frequencies (optional) ---
        if self.frequencies and self.method == 'RKS':
            from pyscf.hessian import rks as hess_rks
            hcalc = hess_rks.Hessian(mf)
            hess_raw = hcalc.kernel()  # shape (natoms, natoms, 3, 3)
            natoms = len(atoms)
            hessian_flat = hess_raw.transpose(0, 2, 1, 3).reshape((3 * natoms, 3 * natoms))
            self.results['hessian'] = hessian_flat

        
        if self.frequencies:
            hess_calc = self._get_hessian_method(mf)
            hess_raw = hess_calc.kernel()
            natoms = len(atoms)
            hess_result = hess_raw.transpose(0, 2, 1, 3).reshape(3 * natoms, 3 * natoms)
            self.results['hessian'] = hess_result  # optional

    def _get_scf_method(self, mol):
        """Returns appropriate SCF/DFT object."""
        m = self.method
        if m == 'RHF':
            return scf.RHF(mol)
        elif m == 'ROHF':
            return scf.ROHF(mol)
        elif m == 'UHF':
            return scf.UHF(mol)
        elif m == 'RKS':
            return dft.RKS(mol).set(xc=self.xc)
        elif m == 'UKS':
            return dft.UKS(mol).set(xc=self.xc)
        elif m in ['MP2', 'CCSD']:
            return scf.RHF(mol)  # base method for post-HF
        else:
            raise ValueError(f"Unsupported method: {m}")

    def _get_grad_method(self, mf):
        m = self.method
        if m == 'RHF': return grad.RHF(mf)
        if m == 'ROHF': return grad.ROHF(mf)
        if m == 'UHF': return grad.UHF(mf)
        if m == 'RKS': return grad.RKS(mf)
        if m == 'UKS': return grad.UKS(mf)
        raise ValueError(f"No grad available for {m}")

    def _get_hessian_method(self, mf):
        m = self.method
        if m == 'RHF':
            from pyscf.hessian import rhf as hess_rhf
            return hess_rhf.Hessian(mf)
        if m == 'RKS':
            from pyscf.hessian import rks as hess_rks
            return hess_rks.Hessian(mf)
        raise ValueError(f"Hessian not implemented for method: {m}")

    def get_potential_energy(self, atoms=None, **kwargs):
        if 'energy' not in self.results:
            self.calculate(atoms)
        return self.results['energy']

    def get_forces(self, atoms=None, **kwargs):
        if 'forces' not in self.results:
            self.calculate(atoms)
        return self.results['forces']


def ase_to_pyscf(atoms, basis='sto-3g'):
    mol = gto.Mole()
    mol.atom = "\n".join(f"{atom.symbol} {pos[0]} {pos[1]} {pos[2]}" for atom, pos in zip(atoms, atoms.positions))
    mol.basis = basis
    mol.build()
    return mol


# Convert PySCF molecule to ASE Atoms object for visualization
def pyscf_to_ase(mol, basis='sto-3g'):
    positions=mol.atom_coords()
    symbols = [mol.atom_symbol(i) for i in range(mol.natm)]
    return Atoms(symbols=symbols, positions=positions)



from scipy.linalg import eigh
from ase.units import _hplanck, _c, _amu, Bohr, Hartree

def hessian_to_frequencies(hessian, atoms):
    """Convert a PySCF Hessian (in Ha/Bohr^2) into vibrational frequencies in cm^-1"""
    natoms = len(atoms)
    masses = atoms.get_masses() * _amu  # in kg
    hessian = hessian.copy()

    # Convert Hessian to atomic units (Hartree/Bohr^2 → kg m^2/s^2/m^2)
    hessian *= Hartree / Bohr**2  # now in J/m^2

    # Build mass-weighted Hessian
    mw_hessian = np.zeros_like(hessian)
    for i in range(3 * natoms):
        for j in range(3 * natoms):
            mi = masses[i // 3]
            mj = masses[j // 3]
            mw_hessian[i, j] = hessian[i, j] / np.sqrt(mi * mj)

    # Diagonalize
    eigvals, eigvecs = eigh(mw_hessian)

    # Convert to frequencies in cm^-1
    factor = 1 / (2 * np.pi * _c)  # Hz → cm^-1
    freqs = np.sign(eigvals) * np.sqrt(np.abs(eigvals)) * factor

    return freqs, eigvecs


from pyscf.tools import cubegen
from ase.io.cube import read_cube_data
from weas_widget import WeasWidget

def print_orbitals(calc):
    """Print MO energies and occupancies from a PySCFCalculator."""
    mf = getattr(calc, "_mf", None)
    if mf is None:
        raise RuntimeError("Calculator has no stored SCF object. Run atoms.get_potential_energy() first.")
    
    print(" MO Index | Energy (Ha) | Occupancy ")
    print("----------|-------------|-----------")
    for i, (e, occ) in enumerate(zip(mf.mo_energy, mf.mo_occ)):
        print(f"{i:>9} | {e:>11.6f} | {occ:>9.2f}")


def generate_and_view_cube(atoms, calc, orbital_index, isovalue=0.01):
    """Generate cube for specified orbital and display it with weasWidget."""
    mol = getattr(calc, "_mol", None)
    mf = getattr(calc, "_mf", None)

    if mol is None or mf is None:
        raise RuntimeError("PySCF molecule or SCF not found. Did you run atoms.get_potential_energy()?")

    cube_filename = f"orbital_{orbital_index}.cube"

    cubegen.orbital(
        mol, cube_filename, mf.mo_coeff[:, orbital_index],
        nx=100, ny=100, nz=100,
        margin=10  # adds buffer to grid
    )

    volume, cube_atoms = read_cube_data(cube_filename)

    viewer = WeasWidget()
    viewer.from_ase(cube_atoms)
    viewer.avr.iso.volumetric_data = {"values": volume}
    viewer.avr.iso.settings = {
        "positive": {"isovalue": isovalue},
        "negative": {"isovalue": -isovalue, "color": "yellow"}
    }
    return viewer


from ase.units import Hartree, eV

def get_excitation_energies(calc, n=5):
    """Return the lowest n excitation energies (LUMO - HOMO) in eV from a PySCFCalculator."""
    mf = getattr(calc, "_mf", None)
    if mf is None:
        raise RuntimeError("Calculator must have an SCF object (_mf). Run get_potential_energy() first.")

    mo_energies = mf.mo_energy
    mo_occ = mf.mo_occ

    # Indices of occupied and virtual orbitals
    occ_indices = [i for i, occ in enumerate(mo_occ) if occ > 0]
    virt_indices = [i for i, occ in enumerate(mo_occ) if occ == 0]

    if not occ_indices or not virt_indices:
        raise ValueError("No occupied or virtual orbitals found.")

    homo = occ_indices[-1]
    lumo = virt_indices[0]

    # Compute excitation energies: LUMO - HOMO and higher
    excitations = []
    for i in range(min(n, len(virt_indices))):
        delta_e = mo_energies[virt_indices[i]] - mo_energies[homo]
        excitations.append(delta_e * Hartree / eV)

    return excitations



import matplotlib.pyplot as plt
from pyscf import gto, scf, dft
from ase import Atoms

def compare_basis_sets(atoms_or_str, basis_sets, method='RHF', xc='b3lyp', title=None):
    """
    Compare total energy across basis sets using the specified method.
    
    Parameters:
        atoms_or_str : ASE Atoms object or PySCF atom string
        basis_sets   : list of basis sets (e.g., ['sto-3g', '6-31g', 'cc-pvdz'])
        method       : 'RHF', 'UHF', 'ROHF', 'RKS', or 'UKS'
        xc           : XC functional (only used for DFT)
        title        : Optional plot title
        
    Returns:
        basis_sets   : list of basis set names
        energies     : list of total energies (in Hartree)
    """
    if isinstance(atoms_or_str, Atoms):
        atom_str = "; ".join(
            f"{s} {x} {y} {z}"
            for s, (x, y, z) in zip(atoms_or_str.get_chemical_symbols(), atoms_or_str.get_positions())
        )
    else:
        atom_str = atoms_or_str  # Already PySCF format

    energies = []

    for basis in basis_sets:
        mol = gto.M(
            atom=atom_str,
            basis=basis,
            unit='Angstrom'
        )
        mol.build()

        method = method.upper()
        if method == 'RHF':
            mf = scf.RHF(mol)
        elif method == 'ROHF':
            mf = scf.ROHF(mol)
        elif method == 'UHF':
            mf = scf.UHF(mol)
        elif method == 'RKS':
            mf = dft.RKS(mol)
            mf.xc = xc
        elif method == 'UKS':
            mf = dft.UKS(mol)
            mf.xc = xc
        else:
            raise ValueError(f"Unsupported method: {method}")

        energy = mf.kernel()
        energies.append(energy)

    # Plot
    plt.figure(figsize=(6, 4))
    plt.plot(basis_sets, energies, marker='o', linestyle='-')
    plt.xlabel('Basis Set')
    plt.ylabel('Total Energy (Ha)')
    plt.title(title or f'Effect of Basis Set on Energy ({method})')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    return basis_sets, energies



In [None]:
#Please build a photoinitiator of your choice

my_symbols =['C', 'C', 'C', 'C', 'C', 'C', 'H', 'H', 'H', 'H', 'H', 'H']
my_positions = np.array([
    [-8.65057836e-15 , 1.38678435e+00, -5.82155956e-16],
    [ 1.20099037e+00 , 6.93391791e-01, -2.98042781e-16],
    [ 1.20099037e+00 ,-6.93391791e-01,  3.41626044e-16],
    [ 7.48069161e-15 ,-1.38678435e+00, -1.16716527e-16],
    [-1.20099037e+00 ,-6.93391791e-01, -6.99526672e-16],
    [-1.20099037e+00 , 6.93391791e-01, -1.05679706e-15],
    [-1.96000728e-15 , 2.46927819e+00, -2.58718204e-16],
    [ 2.13845764e+00 , 1.23463887e+00,  4.15790469e-16],
    [ 2.13845764e+00 ,-1.23463887e+00,  6.27619087e-16],
    [ 7.38038621e-16 ,-2.46927819e+00,  3.68590689e-16],
    [-2.13845764e+00 ,-1.23463887e+00,  6.08598507e-16],
    [-2.13845764e+00 , 1.3463887e+00,  0.5]
])

test_system=Atoms(symbols=my_symbols, positions=my_positions)


In [None]:
# lets see what we've made

viewer = WeasWidget()
viewer.from_ase(test_system)
viewer




# del to delete
#g to move
#d to duplicate

In [None]:
#bring the molecule back into python

atoms = viewer.to_ase()

positions= atoms.get_positions()
symbols = atoms.get_chemical_symbols()
print(symbols)
print(positions)
for symbol, pos in zip(symbols, positions):
    print(f"{symbol:>2}  {pos[0]:>8.4f}  {pos[1]:>8.4f}  {pos[2]:>8.4f}")

In [None]:
#optimize the molecule


# Define molecule
atoms = viewer.to_ase()

# Attach positions to PySCF calculator
#method = RHF -> HF
#method = RKS -> DFT
# if DFT need to include xc='pbe'

#atoms.calc = PySCFCalculator(basis='sto-3g',method='RHF')
calc=PySCFCalculator(basis='sto-3g',method='RKS', xc='PBE')
atoms.calc= calc

# Run optimization
opt = BFGS(atoms, trajectory='opt.traj')
opt.run(fmax=0.05)

# Print results
print("Final energy:", atoms.get_potential_energy())
print("Final positions:\n", atoms.get_positions())

In [None]:
#make sure its doing what you think its doing
viewer = WeasWidget()
viewer.from_ase(atoms)
viewer

In [None]:
# Basis Set Comparison
basis_sets = ['sto-3g', '6-31g', '6-31++g**','6-311++g**', 'cc-pvdz','cc-pvtz']

compare_basis_sets(atoms, basis_sets, method='RKS', xc='PBE', title='PBE Energy vs Basis Set')


#what basis set should you use? go back and re-optimize at the correct basis set

In [None]:
#now lets look at IR spectra

atoms = viewer.to_ase()
calc.frequencies=True

#clears results to be sure
calc.results={}

#recalcutate
atoms.calc = calc
atoms.get_potential_energy()

#extract IR requencies
hess=atoms.calc.results['hessian']
freqs, modes = hessian_to_frequencies(hess, test_system)
print("Vibrational frequencies (cm⁻¹):")
for i, f in enumerate(freqs):
    print(f"{i:>2}: {f:10.2f}")

In [None]:
#lets look at the electronic energy of the orbitals and the excitation energies:

print_orbitals(atoms.calc)


excitation_energies = get_excitation_energies(atoms.calc)

print("Excitation energies (eV):")
for i, e in enumerate(excitation_energies):
    print(f"{i+1:2}: {e:.3f} eV")

In [None]:
#we now visulalize the Homo and Lumo orbitals

viewer = generate_and_view_cube(atoms, atoms.calc, orbital_index=20)
viewer