In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import os


##########################################################################
# File path for saving plots
# A folder will be cerate in the home directory with "PINN_ASDD" name
##########################################################################


dir_path = os.getcwd() 

arr = dir_path.split("\\")
dir_path ='/'.join(arr) + "/"

saving_path= dir_path # Update this path


savename = input("Enter the name to Save the graphs (without .jpg/png extension): ")

desktop_path = os.path.join(os.path.expanduser('~'), saving_path, 'PINN_ASDD')  # Update './' to required directory if needed

# Ensure the directory exists
if not os.path.exists(desktop_path):
    os.makedirs(desktop_path)


########################################################
# Define the input for potential energy function
########################################################

a = b = c = d = L = None

while True:
    potential_type = input("Enter potential type (harmonic, double_well, asymmetric_double_well): ").strip()

    if potential_type == "harmonic":
        a = 1
        break
    elif potential_type == "double_well":
        a = float(input("Enter value for a (e.g., 0.1): "))
        b = float(input("Enter value for b (e.g., 1): "))
        break
    elif potential_type == "asymmetric_double_well":
        a = float(input("Enter value for a (e.g., 0.1): "))
        b = float(input("Enter value for b (e.g., 1): "))
        c = float(input("Enter value for c (e.g., 1): "))
        d = float(input("Enter value for d (e.g., 0): "))
        break
    else:
        print("Unsupported potential type. Please choose from 'harmonic', 'double_well', 'asymmetric_double_well'. Try again.")


########################################################
# Constants Calculations
########################################################

h = tf.constant(6.62607015e-34, dtype=tf.float32)   # Planck's constant in J·s
pi = tf.constant(np.pi, dtype=tf.float32)           # π
hbar = h / (2.0 * pi)                               # Reduced Planck's constant in J·s
m_e = tf.constant(9.10938356e-31, dtype=tf.float32) # electron mass (kg) 
omega = tf.constant(1.0e15, dtype=tf.float32)       # angular frequency (rad/s)


C_1 = hbar**2 / (2.0 * m_e)
C_2 = (m_e * omega**2)/2.0

########################################################
# Wavefunction and its derivative calculations 
########################################################

##################################
# ψ(x) using a neural network
##################################
class Wavefunction(tf.keras.Model):
    def __init__(self, num_hidden):
        super(Wavefunction, self).__init__()
        self.layer1 = tf.keras.layers.Dense(num_hidden, activation='tanh')
        self.layer2 = tf.keras.layers.Dense(num_hidden/2, activation='tanh')
        self.layer3 = tf.keras.layers.Dense(num_hidden/4, activation='tanh')
        self.layer4 = tf.keras.layers.Dense(1, activation=None)

    def call(self, x):
        h1 = self.layer1(x)
        h2 = self.layer2(h1)
        h3 = self.layer3(h2)
        output = self.layer4(h3)
        return output 

wavefunction = Wavefunction(num_hidden=128)

##################################
# Wavefunction ψ(x)
##################################

def compute_psi(x, wavefunction):
    psi = tf.exp(-tf.square(x)) * wavefunction(x) # Exponent is added to ensure ψ(x)→0 as 𝑥 → ±∞.
    return psi

##################################
# Normalization constraint: 
# ∫ |ψ(x)|^2 dx = 1 
# (implemented as penalty)
##################################
def normalization_penalty(psi):
    norm = tf.reduce_mean(tf.square(psi))
    return tf.square(norm - 1.0)

##################################
# First derivative of ψ(x)
##################################
@tf.function
def compute_d_psi(x, wavefunction):
    with tf.GradientTape() as tape:
        tape.watch(x)
        psi = compute_psi(x, wavefunction)
    dpsi = tape.gradient(psi, x)
    return dpsi

##################################
# Second derivative of ψ(x)
##################################
@tf.function
def compute_d2_psi(x, wavefunction):
    with tf.GradientTape() as tape:
        tape.watch(x)
        dpsi = compute_d_psi(x, wavefunction)
    d2psi = tape.gradient(dpsi, x)
    return d2psi

########################################################
# Kinetic Energy Operator
########################################################

@tf.function
def T(x, wavefunction):
    laplacian = compute_d2_psi(x, wavefunction)
    T_psi = - C_1 * laplacian
    return T_psi

########################################################
# Define Potential Energy Operator
########################################################

def V(x, potential_type="harmonic"):
#Returns the potential energy V(x) based on the selected potential type.
#Parameters:
    # x (tf.Tensor): Input position(s)
    # potential_type (str): Type of potential ('infinite_well', 'harmonic', 'double_well', 'asymmetric_double_well')
    #Returns:
        #tf.Tensor: Potential energy values
    if potential_type == "harmonic":
        return C_2 * tf.square(x)
    
    elif potential_type == "double_well":
        return a * x**4 - b * tf.square(x)
    
    elif potential_type == "asymmetric_double_well":
        return a * (tf.square(x) - b ** 2) ** 2 + c * (x - d)
    
    else:
        raise ValueError("Unsupported potential type")



    
########################################################
# Hamiltonian operator
########################################################

def H_operator(x, psi):
    kinetic = T(x, wavefunction)
    potential = V(x, potential_type = potential_type) * psi
    return kinetic + potential
    
########################################################    
# Generate training data
########################################################

num_samples = 10000
# Uniform grid points
x_samples = np.random.uniform(low=-5.0, high=5.0, size=(num_samples, 1)).astype(np.float32)


energy_history = []

########################################################
# Create the wavefunction and optimizer
########################################################

wavefunction = Wavefunction(num_hidden=128)

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-3,
    decay_steps=1000,
    decay_rate=0.90
)
optimizer = tfa.optimizers.AdamW(weight_decay=1e-4, learning_rate=lr_schedule)

########################################################
# Train the wavefunction
########################################################
# for accurate results increase the num_epoch value upto 100 and minimize the learning rate
num_epochs = 500
batch_size = 256
tolerance = 1e-5

# Dataset
dataset = tf.data.Dataset.from_tensor_slices(x_samples).shuffle(buffer_size=10000).batch(batch_size)

########################################################
# Main Ground State Energy (E) Optimization Loop
########################################################
prev_energy = None

for epoch in range(num_epochs):
    
    total_energy = 0.0
    total_batches = 0
    
    for step, x_batch in enumerate(dataset):
        
        with tf.GradientTape() as tape:
            # Energy expectation value: <ψ|H|ψ> / <ψ|ψ>
            psi = compute_psi(x_batch, wavefunction)
            
            H_psi = H_operator(x_batch, psi)
            
            numerator = tf.reduce_mean(H_psi * psi)
            denominator = tf.reduce_mean(tf.square(psi)) 
                        
            norm_penalty = normalization_penalty(psi)

            energy = numerator / (denominator + 1e-10) + 0.1 * norm_penalty
        
        # Apply gradients
        grads = tape.gradient(energy, wavefunction.trainable_variables)
        optimizer.apply_gradients(zip(grads, wavefunction.trainable_variables))

        total_energy += energy.numpy()
        total_batches += 1
  
    avg_energy = total_energy / total_batches
    energy_history.append(avg_energy)
    print(f"Epoch {epoch+1:3d}/{num_epochs}, Energy: {avg_energy:.6e} J")
    
    print("Mean Wavefunction =",tf.reduce_mean(psi).numpy())
    
    # Check stopping condition
    if prev_energy is not None and abs(avg_energy - prev_energy) < tolerance:
        print(f"Stopping early at epoch {epoch+1} due to energy convergence.")
        break

    prev_energy = avg_energy


########################################################
# Ploting Graphs
########################################################
##################################
# Before trainign (x)
##################################
subset_ratio = 0.04
subset_size = int(num_samples * subset_ratio)
indices = np.random.choice(num_samples, size=subset_size, replace=False)
x_subset = x_samples[indices]
plt.figure(figsize=(6, 2))
plt.scatter(x_subset, np.zeros_like(x_subset), alpha=0.5, color='black', marker='|', s=100)
plt.xlabel(r'$x$', fontsize=1, font="Arial")
plt.yticks([])              
plt.xticks(fontsize=12)
plt.title('Random Sampling of Training Points (Monte Carlo)', fontsize=16)
plt.grid(axis='x', linestyle='--', alpha=0.3)
plt.box(False)              
plt.legend(loc='upper right', fontsize=10)
plt.tight_layout()
plt.savefig(os.path.join(desktop_path, savename + '_input_samples.jpg'), dpi=400)

##################################
# After trainign (x)
##################################
plt.figure(figsize=(10, 4))
ax = plt.gca() 
plt.plot(range(1, len(energy_history)+1), energy_history,linestyle='-', color="darkblue")
#plt.axvline(len(energy_history), color='r', linestyle='--', label='Early Stop')
plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Energy (J)', fontsize=16)
plt.title('Energy Convergence During Training', fontsize=18)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
ax.tick_params(axis='x', direction='in', length=6, width=1.2, bottom=True, top=False)
ax.tick_params(axis='y', direction='in', length=6, width=1.2, left=True, right=False)
plt.grid(False)
plt.tight_layout()
plt.savefig(os.path.join(desktop_path, savename + '_energy_convergence.jpg'), dpi=400)    

x_plot = np.linspace(-4, 4, 500).reshape(-1, 1).astype(np.float32)
psi_plot = compute_psi(tf.constant(x_plot), wavefunction)
prob_density = tf.square(psi_plot)
V_plot = V(tf.constant(x_plot), potential_type=potential_type)


plt.figure(figsize=(8, 5))
ax = plt.gca()
plt.plot(x_plot, psi_plot.numpy(),color='royalblue', label=fr'Learned $\Psi(x)$')
plt.xlabel(fr'$x$', fontsize=16)
plt.ylabel(fr'$\Psi(x)$', fontsize=16)
plt.title('Wavefunction after Energy Minimization', fontsize=18)
plt.legend()
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
ax.tick_params(axis='x', direction='in', length=6, width=1.2, bottom=True, top=False)
ax.tick_params(axis='y', direction='in', length=6, width=1.2, left=True, right=False)
plt.grid(False)
plt.tight_layout()
plt.savefig(os.path.join(desktop_path, savename + '_wavefunction.jpg'), dpi=400)   


# Create the plot
fig, ax1 = plt.subplots(figsize=(8, 5))  # Directly specify figure size in subplots

# Plot V(x)
line1, = ax1.plot(x_plot, V_plot.numpy(), color='red', linestyle="-", label=r'$V(x)$')
ax1.set_ylabel(r'$V(x)$', color='red', fontsize=16)
ax1.tick_params(axis='x', direction='in', length=6, width=1.2, bottom=True, top=False)
ax1.tick_params(axis='y', direction='in', length=6, width=1.2, left=True, right=False)

# Twin axis for probability density
ax2 = ax1.twinx()
line2, = ax2.plot(x_plot, prob_density.numpy(), color='royalblue', label=r'$|\Psi(x)|^2$')
ax2.set_ylabel(r'$|\Psi(x)|^2$', color='royalblue', fontsize=16)
ax2.tick_params(axis='x', direction='in', length=6, width=1.2, bottom=True, top=False)
ax2.tick_params(axis='y', direction='in', length=6, width=1.2, left=False, right=True)

# Set title and ticks
plt.title('Potential vs Probability Density', fontsize=18)
ax1.set_xticks(ax1.get_xticks())  # Ensure x-tick updates are passed to both axes
ax1.set_yticks(ax1.get_yticks())
ax2.set_yticks(ax2.get_yticks())
ax1.tick_params(labelsize=14)
ax2.tick_params(labelsize=14)

# Combine legends from both axes
lines = [line1, line2]
labels = [line.get_label() for line in lines]
ax1.legend(lines, labels, loc='upper right')

# Optional: Turn off grid if not desired
ax1.grid(False)

# Save the figure
plt.tight_layout()
plt.savefig(os.path.join(desktop_path, savename + '_potential_vs_PDF.jpg'), dpi=400)


plt.figure(figsize=(8, 5))
plt.plot(x_plot, psi_plot.numpy(), label=fr'Learned $\Psi(x)$')
plt.xlabel(fr'$x$', fontsize=16)
plt.ylabel(fr'$\Psi(x)$', fontsize=16)
plt.title('Wavefunction after Energy Minimization', fontsize=18)
plt.legend()
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.grid(False)
plt.tight_layout()
plt.savefig(os.path.join(desktop_path, savename + '_potential.jpg'), dpi=400)   

# Create a color gradient based on function values
colors = V(x_samples, potential_type = potential_type)

# Plot the data using plt.scatter
plt.scatter(x_samples, V(x_samples, potential_type = potential_type), c=colors, cmap='jet', marker='.')

# Add colorbar for reference
cbar = plt.colorbar()
cbar.set_label('Function Value')

# Set labels and title
plt.xlabel(fr'$x$')
plt.ylabel(fr'$V(x)$')
plt.title('Heatmap-like Plot of V(x)')

# Show the plot
plt.show()





# Analytical solution for harmonic oscillator ground state
# Dimensionless analytical ground state for harmonic oscillator
def analytical_psi(x_vals, n=0):
    xi = tf.sqrt((m_e * omega) / hbar) * x_vals
    H = hermite(n)
    H_vals = tf.convert_to_tensor(H(xi.numpy()), dtype=tf.float32)
    
    norm_const = ( (m_e * omega) / (np.pi * hbar) )**0.25 / tf.sqrt(2.0**n * factorial(n))
    psi = norm_const * H_vals * tf.exp(-0.5 * tf.square(xi))
    return psi

# Generate x values for plotting
# Generate x values for plotting
# Generate x values for plotting
x_plot = np.linspace(-4, 4, 500).reshape(-1, 1).astype(np.float32)
x_tensor = tf.constant(x_plot)

# Compute learned and analytical wavefunctions
psi_learned = compute_psi(x_tensor, wavefunction)
psi_analytical = analytical_psi_ground_scaled(x_tensor)

# Normalize both for fair comparison
psi_learned_norm = psi_learned / tf.sqrt(tf.reduce_sum(tf.square(psi_learned)))
psi_analytical_norm = psi_analytical / tf.sqrt(tf.reduce_sum(tf.square(psi_analytical)))

# Plot
plt.figure(figsize=(8, 5))
plt.plot(x_tensor, psi_learned_norm.numpy(), label=fr'Learned $\Psi(x)$', linewidth=2)
plt.plot(x_tensor, psi_analytical_norm.numpy(), label=fr'Analytical $\Psi_o(x)$', linestyle='--', linewidth=2)
plt.xlabel(fr'$(x)$')
plt.ylabel(fr'Normalized $\Psi(x)$')
plt.title('Comparison: Learned vs Analytical Ground State (Harmonic Oscillator)')
plt.legend()
plt.grid(True)
plt.show()

# Optional overlap
overlap = tf.reduce_sum(psi_learned_norm * psi_analytical_norm)
print(fr'Overlap ⟨$\psi_(learned)$|$\psi_(analytical)$⟩ = {overlap.numpy():.6f}')