In [None]:
import torch
from geqtrain.data import AtomicDataDict
from hamiltonian import get_irreps, compute_hamiltonian, find_homo_lumo_iterative

# --- Define System and Basis Set Here ---
torch.set_default_dtype(torch.float64)

# Define the molecule by a list of atomic numbers (Z)
node_types = torch.tensor([8, 1, 1]) # Example: O-O-H
atom_positions = torch.tensor([ # Example positions for water
        [0.0000, 0.0000, 0.1173],
        [0.0000, 0.7572, -0.4692],
        [0.0000, -0.7572, -0.4692]
    ])
N_nodes = len(node_types)

basis_def = {
    1: {'1s': 1},
    8: {'1s': 1, '2s': 1, '2p': 1},
    # Add other elements as needed, e.g. 6: {'1s': 1, ...} for Carbon
}
node_irreps, edge_irreps, _, _ = get_irreps(basis_def)

edge_index = torch.combinations(torch.arange(N_nodes), r=2).t()
N_edges = edge_index.shape[1]

data = {
    AtomicDataDict.NODE_TYPE_KEY: node_types,
    AtomicDataDict.NODE_FEATURES_KEY: torch.randn(N_nodes, node_irreps.dim),
    AtomicDataDict.EDGE_FEATURES_KEY: torch.randn(N_edges, edge_irreps.dim),
    AtomicDataDict.EDGE_INDEX_KEY: edge_index,
}

H_original = compute_hamiltonian(data)

# # --- Compute Properties ---
# properties = compute_properties(H, node_types, atom_positions, basis_def)

# print("\n--- Computed Electronic Properties ---")
# if properties:
#     print(f"  HOMO Energy:      {properties['HOMO_energy']:.4f} a.u.")
#     print(f"  LUMO Energy:      {properties['LUMO_energy']:.4f} a.u.")
#     print(f"  HOMO-LUMO Gap:    {properties['HOMO_LUMO_gap']:.4f} a.u.")
#     mu_vec = properties['dipole_moment_vector']
#     print(f"  Dipole Vector (μ): [{mu_vec[0]:.3f}, {mu_vec[1]:.3f}, {mu_vec[2]:.3f}] a.u.")
#     print(f"  Total Dipole (||μ||): {properties['dipole_moment_total']:.3f} Debye") # Note: 1 a.u. = 2.54 Debye

print("\n--- Hamiltonian Assembled ---")
print(f"Shape: {H_original.shape}")

# --- Test Iterative Eigensolver ---
print("\n--- Testing Iterative Eigensolver ---")

# For testing, we find the "true" HOMO/LUMO via full diagonalization
true_eigenvalues = torch.linalg.eigvalsh(H_original)
num_electrons = node_types.sum().item()
num_occupied = num_electrons // 2
true_homo = true_eigenvalues[num_occupied - 1]
true_lumo = true_eigenvalues[num_occupied]
true_gap = true_lumo - true_homo

print(f"Ground Truth (from full diagonalization):")
print(f"  HOMO: {true_homo:.4f}, LUMO: {true_lumo:.4f}, Gap: {true_gap:.4f}")

# The GNN would predict this shift. We'll use the true midpoint as a perfect guess.
energy_shift_guess = torch.tensor((true_homo + true_lumo) / 2.0)
print(f"Using energy shift guess: {energy_shift_guess:.4f}")

# Call the new iterative function
iterative_results = find_homo_lumo_iterative(
    H=H_original,
    num_electrons=num_electrons,
    energy_shift_guess=energy_shift_guess,
    k=2, # Find the 6 eigenvalues closest to the shift
    num_iterations=10 # Number of Lanczos steps
)

print("\nIterative Solver Results:")
if iterative_results.get('HOMO_energy') is not None:
    iter_homo = iterative_results['HOMO_energy']
    iter_lumo = iterative_results['LUMO_energy']
    iter_gap = iterative_results['HOMO_LUMO_gap']
    print(f"  HOMO: {iter_homo:.4f}, LUMO: {iter_lumo:.4f}, Gap: {iter_gap:.4f}")
    
    homo_error = torch.abs(true_homo - iter_homo)
    lumo_error = torch.abs(true_lumo - iter_lumo)
    print(f"  Error vs Ground Truth -> HOMO: {homo_error:.2e}, LUMO: {lumo_error:.2e}")
    
print("\nFound eigenvalues around the shift:")
print(iterative_results.get('found_eigenvalues').numpy())

In [None]:
import torch
import torch.nn as nn
import numpy as np
from scipy.linalg import eigh
import matplotlib.pyplot as plt

torch.set_default_dtype(torch.float32)

class MolecularOrbitalCalculator:
    """
    Simplified Hartree-Fock calculator for computing HOMO and LUMO energies
    using PyTorch for tensor operations and automatic differentiation.
    """
    
    def __init__(self, n_basis, n_electrons, device='cpu'):
        """
        Initialize the molecular orbital calculator.
        
        Args:
            n_basis (int): Number of basis functions
            n_electrons (int): Number of electrons (must be even for closed shell)
            device (str): PyTorch device ('cpu' or 'cuda')
        """
        self.n_basis = n_basis
        self.n_electrons = n_electrons
        self.n_occupied = n_electrons // 2  # Assuming closed shell
        self.device = device
        
        # Initialize matrices
        self.H_core = None
        self.S = None
        self.two_electron_integrals = None
        
    def initialize_matrices(self, H_core=None, S=None, eri=None):
        """
        Initialize or generate the core Hamiltonian, overlap matrix, and two-electron integrals.
        
        Args:
            H_core (torch.Tensor): Core Hamiltonian matrix
            S (torch.Tensor): Overlap matrix
            eri (torch.Tensor): Two-electron repulsion integrals
        """
        if H_core is None:
            # Generate a random symmetric core Hamiltonian for demonstration
            H_core = torch.randn(self.n_basis, self.n_basis, device=self.device)
            H_core = (H_core + H_core.T) / 2
            # Make it more realistic by scaling diagonal elements
            H_core.diagonal().add_(torch.abs(torch.randn(self.n_basis, device=self.device)) * 5)
        
        if S is None:
            # Generate overlap matrix (should be positive definite)
            S = torch.eye(self.n_basis, device=self.device)
            # Add some off-diagonal elements
            S += 0.1 * torch.randn(self.n_basis, self.n_basis, device=self.device)
            S = (S + S.T) / 2
            # Ensure positive definiteness
            eigenvals, eigenvecs = torch.linalg.eigh(S)
            eigenvals = torch.clamp(eigenvals, min=0.1)
            S = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
        
        if eri is None:
            # Generate two-electron integrals (4D tensor)
            eri = torch.randn(self.n_basis, self.n_basis, self.n_basis, self.n_basis, device=self.device)
            # Ensure proper symmetries: (μν|λσ) = (νμ|λσ) = (μν|σλ) = (λσ|μν)
            eri = (eri + eri.transpose(0, 1) + eri.transpose(2, 3) + eri.transpose(0, 2).transpose(1, 3)) / 4
            eri = torch.abs(eri)  # Make positive for stability
        
        self.H_core = H_core
        self.S = S
        self.two_electron_integrals = eri
    
    def build_density_matrix(self, C_occupied):
        """
        Build the density matrix from occupied molecular orbitals.
        
        Args:
            C_occupied (torch.Tensor): Occupied molecular orbital coefficients
            
        Returns:
            torch.Tensor: Density matrix P
        """
        # P_μν = 2 * Σᵢ^(occ) C_μᵢ * C_νᵢ
        P = 2.0 * torch.mm(C_occupied, C_occupied.T)
        return P
    
    def build_fock_matrix(self, P):
        """
        Build the Fock matrix from the density matrix.
        
        Args:
            P (torch.Tensor): Density matrix
            
        Returns:
            torch.Tensor: Fock matrix F
        """
        # F_μν = H_μν^core + Σ_λσ P_λσ * [(μν|λσ) - 0.5*(μλ|νσ)]
        F = self.H_core.clone()
        
        # Add two-electron contributions
        for mu in range(self.n_basis):
            for nu in range(self.n_basis):
                for lam in range(self.n_basis):
                    for sigma in range(self.n_basis):
                        coulomb = self.two_electron_integrals[mu, nu, lam, sigma]
                        exchange = self.two_electron_integrals[mu, lam, nu, sigma]
                        F[mu, nu] += P[lam, sigma] * (coulomb - 0.5 * exchange)
        
        return F
    
    def solve_roothaan_hall(self, F):
        """
        Solve the Roothaan-Hall equation FC = SCE.
        
        Args:
            F (torch.Tensor): Fock matrix
            
        Returns:
            tuple: (orbital_energies, molecular_orbitals)
        """
        # Convert to numpy for scipy's generalized eigenvalue solver
        F_np = F.detach().cpu().numpy()
        S_np = self.S.detach().cpu().numpy()
        
        # Solve generalized eigenvalue problem
        eigenvalues, eigenvectors = eigh(F_np, S_np)
        
        # Convert back to PyTorch tensors
        orbital_energies = torch.tensor(eigenvalues, device=self.device, dtype=torch.float32)
        molecular_orbitals = torch.tensor(eigenvectors, device=self.device, dtype=torch.float32)
        
        return orbital_energies, molecular_orbitals
    
    def scf_iteration(self, max_iterations=50, convergence_threshold=1e-6):
        """
        Perform self-consistent field iterations.
        
        Args:
            max_iterations (int): Maximum number of SCF iterations
            convergence_threshold (float): Energy convergence threshold
            
        Returns:
            tuple: (converged_energies, converged_orbitals, energy_history)
        """
        # Initial guess: diagonalize core Hamiltonian
        initial_energies, initial_orbitals = self.solve_roothaan_hall(self.H_core)
        
        # Take occupied orbitals for initial density
        C_occupied = initial_orbitals[:, :self.n_occupied]
        
        energy_history = []
        
        for iteration in range(max_iterations):
            # Build density matrix
            P = self.build_density_matrix(C_occupied)
            
            # Build Fock matrix
            F = self.build_fock_matrix(P)
            
            # Solve Roothaan-Hall equation
            orbital_energies, molecular_orbitals = self.solve_roothaan_hall(F)
            
            # Calculate total electronic energy
            electronic_energy = torch.trace(torch.mm(P, self.H_core + F)) / 2
            energy_history.append(electronic_energy.item())
            
            # Check for convergence
            if iteration > 0:
                energy_change = abs(energy_history[-1] - energy_history[-2])
                if energy_change < convergence_threshold:
                    print(f"SCF converged after {iteration + 1} iterations")
                    break
            
            # Update occupied orbitals
            C_occupied = molecular_orbitals[:, :self.n_occupied]
            
            if iteration % 10 == 0:
                print(f"Iteration {iteration}: Energy = {electronic_energy.item():.6f}")
        
        return orbital_energies, molecular_orbitals, energy_history
    
    def compute_homo_lumo(self):
        """
        Compute HOMO and LUMO energies and gap.
        
        Returns:
            dict: Dictionary containing HOMO energy, LUMO energy, and gap
        """
        # Perform SCF calculation
        orbital_energies, molecular_orbitals, energy_history = self.scf_iteration()
        
        # HOMO is the highest occupied orbital
        homo_energy = orbital_energies[self.n_occupied - 1]
        
        # LUMO is the lowest unoccupied orbital
        lumo_energy = orbital_energies[self.n_occupied]
        
        # HOMO-LUMO gap
        gap = lumo_energy - homo_energy
        
        results = {
            'homo_energy': homo_energy.item(),
            'lumo_energy': lumo_energy.item(),
            'gap': gap.item(),
            'all_energies': orbital_energies.detach().cpu().numpy(),
            'molecular_orbitals': molecular_orbitals.detach().cpu().numpy(),
            'energy_history': energy_history
        }
        
        return results
    
    def plot_orbital_energies(self, results):
        """
        Plot molecular orbital energy levels.
        
        Args:
            results (dict): Results from compute_homo_lumo()
        """
        energies = results['all_energies']
        
        plt.figure(figsize=(10, 8))
        
        # Plot occupied orbitals
        occupied_energies = energies[:self.n_occupied]
        plt.hlines(occupied_energies, 0, 1, colors='blue', linewidth=3, label='Occupied')
        
        # Plot unoccupied orbitals
        unoccupied_energies = energies[self.n_occupied:]
        plt.hlines(unoccupied_energies, 0, 1, colors='red', linewidth=3, label='Unoccupied')
        
        # Highlight HOMO and LUMO
        plt.hlines(results['homo_energy'], 0, 1, colors='darkblue', linewidth=5, label='HOMO')
        plt.hlines(results['lumo_energy'], 0, 1, colors='darkred', linewidth=5, label='LUMO')
        
        # Add gap annotation
        plt.annotate(f'Gap = {results["gap"]:.3f} eV', 
                    xy=(0.5, (results['homo_energy'] + results['lumo_energy'])/2),
                    xytext=(0.7, (results['homo_energy'] + results['lumo_energy'])/2),
                    arrowprops=dict(arrowstyle='<->', color='green', lw=2),
                    fontsize=12, ha='center')
        
        plt.xlim(-0.1, 1.1)
        plt.ylabel('Energy (eV)')
        plt.title('Molecular Orbital Energy Levels')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
    
    def plot_scf_convergence(self, results):
        """
        Plot SCF energy convergence.
        
        Args:
            results (dict): Results from compute_homo_lumo()
        """
        energy_history = results['energy_history']
        
        plt.figure(figsize=(10, 6))
        plt.plot(energy_history, 'bo-', linewidth=2, markersize=6)
        plt.xlabel('SCF Iteration')
        plt.ylabel('Electronic Energy')
        plt.title('SCF Energy Convergence')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()


# Example usage and demonstration
if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Create a simple molecular system
    print("=== HOMO-LUMO Calculation Example ===\n")
    
    # Initialize calculator for a small molecule (6 basis functions, 8 electrons)
    calc = MolecularOrbitalCalculator(n_basis=6, n_electrons=8, device='cpu')
    
    # Initialize matrices (in practice, these come from quantum chemistry integrals)
    calc.initialize_matrices()
    
    print("System parameters:")
    print(f"Number of basis functions: {calc.n_basis}")
    print(f"Number of electrons: {calc.n_electrons}")
    print(f"Number of occupied orbitals: {calc.n_occupied}")
    print()
    
    # Compute HOMO and LUMO
    print("Starting SCF calculation...")
    results = calc.compute_homo_lumo()
    
    # Display results
    print(f"\n=== Results ===")
    print(f"HOMO energy: {results['homo_energy']:.6f} eV")
    print(f"LUMO energy: {results['lumo_energy']:.6f} eV")
    print(f"HOMO-LUMO gap: {results['gap']:.6f} eV")
    
    print(f"\nAll orbital energies:")
    for i, energy in enumerate(results['all_energies']):
        orbital_type = "occupied" if i < calc.n_occupied else "unoccupied"
        special = ""
        if i == calc.n_occupied - 1:
            special = " (HOMO)"
        elif i == calc.n_occupied:
            special = " (LUMO)"
        print(f"Orbital {i+1}: {energy:.6f} eV ({orbital_type}){special}")
    
    # Plot results
    print(f"\nGenerating plots...")
    calc.plot_orbital_energies(results)
    calc.plot_scf_convergence(results)
    
    # Demonstrate gradient computation for geometry optimization
    print(f"\n=== Gradient Computation Example ===")
    
    # Create a simple model where core Hamiltonian depends on a parameter
    class GeometryOptimizer(nn.Module):
        def __init__(self, calc):
            super().__init__()
            self.calc = calc
            # Parameter representing molecular geometry
            self.geometry_param = nn.Parameter(torch.tensor(1.0))
        
        def forward(self):
            # Modify core Hamiltonian based on geometry
            H_modified = self.calc.H_core * self.geometry_param
            self.calc.H_core = H_modified
            
            # Compute HOMO-LUMO gap
            results = self.calc.compute_homo_lumo()
            return torch.tensor(results['gap'], requires_grad=True)
    
    # Create optimizer
    optimizer_model = GeometryOptimizer(calc)
    
    # Compute gradient of HOMO-LUMO gap with respect to geometry
    gap = optimizer_model()
    gap.backward()
    
    print(f"HOMO-LUMO gap: {gap.item():.6f} eV")
    print(f"Gradient w.r.t. geometry parameter: {optimizer_model.geometry_param.grad.item():.6f}")
    print("\nThis gradient could be used for geometry optimization or molecular dynamics!")