# Visualization of the 1D Schrödinger Equation PINN SolutionThis notebook visualizes the results of the Physics-Informed Neural Network (PINN) solution for the 1D time-dependent Schrödinger equation.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.lines import Line2D
from IPython.display import HTML
import torch

# Enable matplotlib widget
%matplotlib widget

# Import the Schrödinger equation PDE
from examples.equations.schrodinger_1d.pde import reference_solution
from measure_uq.utilities import cartesian_product_of_rows, to_numpy

# Set plot style
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12

## Load the Trained Model and PDE

In [None]:
# Load the trained model and PDE
from measure_uq.models import PINN
from measure_uq.pde import PDE

model = PINN.load('data/model_schrodinger_pinn.pickle')
pde = PDE.load('data/pde_schrodinger_pinn.pickle')

# Load the trainer to get the loss history
from measure_uq.trainers.trainer import Trainer
trainer = Trainer.load('data/trainer_schrodinger_pinn.pickle')

## Define Evaluation Grid

In [None]:
# Define the evaluation grid
T = 2.0         # Maximum time
X_min = -10.0   # Minimum spatial coordinate
X_max = 10.0    # Maximum spatial coordinate

# Create a fine grid for evaluation
Nt = 100  # Number of time points
Nx = 200  # Number of spatial points

t = np.linspace(0, T, Nt)
x = np.linspace(X_min, X_max, Nx)

# Create a meshgrid for visualization
T_mesh, X_mesh = np.meshgrid(t, x)

# Create input tensor for the model
device = next(model.parameters()).device

def evaluate_model(parameter_idx=0):
    # Get parameters for evaluation
    params = pde.parameters_test.values[parameter_idx].detach().cpu().numpy()
    hbar, m, k0, sigma = params
    
    # Create input points
    t_tensor = torch.tensor(t.reshape(-1, 1), dtype=torch.float32)
    x_tensor = torch.tensor(x.reshape(-1, 1), dtype=torch.float32)
    
    # Create parameter tensors (repeated for each point)
    hbar_tensor = torch.full_like(t_tensor, hbar)
    m_tensor = torch.full_like(t_tensor, m)
    k0_tensor = torch.full_like(t_tensor, k0)
    sigma_tensor = torch.full_like(t_tensor, sigma)
    
    # Evaluate PINN solution
    inputs = cartesian_product_of_rows(
        t_tensor, x_tensor, hbar_tensor, m_tensor, k0_tensor, sigma_tensor
    ).to(device)
    
    with torch.no_grad():
        outputs = model(inputs)
    
    # Convert to numpy arrays
    outputs_np = to_numpy(outputs)
    
    # Reshape outputs to match the grid
    u_pinn = outputs_np[:, 0].reshape(Nt, Nx).T  # Real part
    v_pinn = outputs_np[:, 1].reshape(Nt, Nx).T  # Imaginary part
    
    # Calculate probability density
    psi_squared_pinn = u_pinn**2 + v_pinn**2
    
    # Calculate reference solution
    u_ref, v_ref = reference_solution(t, x, params)
    psi_squared_ref = u_ref**2 + v_ref**2
    
    return {
        'u_pinn': u_pinn,
        'v_pinn': v_pinn,
        'psi_squared_pinn': psi_squared_pinn,
        'u_ref': u_ref,
        'v_ref': v_ref,
        'psi_squared_ref': psi_squared_ref,
        'params': params
    }

## Visualize the Solution

In [None]:
# Evaluate the model for a specific parameter set
param_idx = 0  # Change this to visualize different parameter sets
results = evaluate_model(param_idx)

# Extract results
u_pinn = results['u_pinn']
v_pinn = results['v_pinn']
psi_squared_pinn = results['psi_squared_pinn']
u_ref = results['u_ref']
v_ref = results['v_ref']
psi_squared_ref = results['psi_squared_ref']
params = results['params']

# Create figure for visualization
fig, axes = plt.subplots(3, 1, figsize=(12, 15))

# Plot real part
ax1 = axes[0]
im1 = ax1.pcolormesh(T_mesh, X_mesh, u_pinn, cmap='RdBu', shading='auto')
ax1.set_title(f'Real Part of Wave Function (PINN)')
ax1.set_xlabel('Time')
ax1.set_ylabel('Position')
plt.colorbar(im1, ax=ax1)

# Plot imaginary part
ax2 = axes[1]
im2 = ax2.pcolormesh(T_mesh, X_mesh, v_pinn, cmap='RdBu', shading='auto')
ax2.set_title(f'Imaginary Part of Wave Function (PINN)')
ax2.set_xlabel('Time')
ax2.set_ylabel('Position')
plt.colorbar(im2, ax=ax2)

# Plot probability density
ax3 = axes[2]
im3 = ax3.pcolormesh(T_mesh, X_mesh, psi_squared_pinn, cmap='viridis', shading='auto')
ax3.set_title(f'Probability Density |ψ|² (PINN)')
ax3.set_xlabel('Time')
ax3.set_ylabel('Position')
plt.colorbar(im3, ax=ax3)

plt.tight_layout()
plt.show()

## Compare PINN and Reference Solutions

In [None]:
# Create figure for comparison
fig, axes = plt.subplots(3, 2, figsize=(15, 15))

# Plot real part comparison
ax1 = axes[0, 0]
im1 = ax1.pcolormesh(T_mesh, X_mesh, u_pinn, cmap='RdBu', shading='auto')
ax1.set_title('Real Part (PINN)')
ax1.set_xlabel('Time')
ax1.set_ylabel('Position')
plt.colorbar(im1, ax=ax1)

ax2 = axes[0, 1]
im2 = ax2.pcolormesh(T_mesh, X_mesh, u_ref, cmap='RdBu', shading='auto')
ax2.set_title('Real Part (Reference)')
ax2.set_xlabel('Time')
ax2.set_ylabel('Position')
plt.colorbar(im2, ax=ax2)

# Plot imaginary part comparison
ax3 = axes[1, 0]
im3 = ax3.pcolormesh(T_mesh, X_mesh, v_pinn, cmap='RdBu', shading='auto')
ax3.set_title('Imaginary Part (PINN)')
ax3.set_xlabel('Time')
ax3.set_ylabel('Position')
plt.colorbar(im3, ax=ax3)

ax4 = axes[1, 1]
im4 = ax4.pcolormesh(T_mesh, X_mesh, v_ref, cmap='RdBu', shading='auto')
ax4.set_title('Imaginary Part (Reference)')
ax4.set_xlabel('Time')
ax4.set_ylabel('Position')
plt.colorbar(im4, ax=ax4)

# Plot probability density comparison
ax5 = axes[2, 0]
im5 = ax5.pcolormesh(T_mesh, X_mesh, psi_squared_pinn, cmap='viridis', shading='auto')
ax5.set_title('Probability Density |ψ|² (PINN)')
ax5.set_xlabel('Time')
ax5.set_ylabel('Position')
plt.colorbar(im5, ax=ax5)

ax6 = axes[2, 1]
im6 = ax6.pcolormesh(T_mesh, X_mesh, psi_squared_ref, cmap='viridis', shading='auto')
ax6.set_title('Probability Density |ψ|² (Reference)')
ax6.set_xlabel('Time')
ax6.set_ylabel('Position')
plt.colorbar(im6, ax=ax6)

# Add parameter information
param_text = f'Parameters: ħ={params[0]:.3f}, m={params[1]:.3f}, k₀={params[2]:.3f}, σ={params[3]:.3f}'
fig.suptitle(param_text, fontsize=14)

plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.show()

## Create Animation of Wave Packet Evolution

In [None]:
# Create animation of the wave packet evolution
fig, ax = plt.subplots(figsize=(10, 6))

# Plot initial state
line_pinn, = ax.plot(x, psi_squared_pinn[:, 0], 'b-', label='PINN')
line_ref, = ax.plot(x, psi_squared_ref[:, 0], 'r--', label='Reference')

# Set plot properties
ax.set_xlim(X_min, X_max)
ax.set_ylim(0, np.max(psi_squared_ref) * 1.1)
ax.set_xlabel('Position')
ax.set_ylabel('Probability Density |ψ|²')
ax.set_title('Wave Packet Evolution')
ax.legend()

# Add time indicator
time_text = ax.text(0.02, 0.95, '', transform=ax.transAxes)

def animate(i):
    line_pinn.set_ydata(psi_squared_pinn[:, i])
    line_ref.set_ydata(psi_squared_ref[:, i])
    time_text.set_text(f'Time: {t[i]:.2f}')
    return line_pinn, line_ref, time_text

# Create animation
anime = FuncAnimation(
    fig, animate, frames=Nt, interval=50, blit=True
)

# Display animation
HTML(anime.to_html5_video())

## Plot Training and Testing Losses

In [None]:
# Plot training and testing losses
fig, ax = plt.subplots(figsize=(10, 6))

# Extract loss data
train_loss = trainer.trainer_data.train_loss
test_loss = trainer.trainer_data.test_loss
iterations = np.arange(len(train_loss))
test_iterations = np.arange(0, len(train_loss), trainer.trainer_data.test_every)[:len(test_loss)]

# Plot losses
ax.semilogy(iterations, train_loss, 'b-', label='Training Loss')
ax.semilogy(test_iterations, test_loss, 'r-', label='Testing Loss')

# Set plot properties
ax.set_xlabel('Iteration')
ax.set_ylabel('Loss (log scale)')
ax.set_title('Training and Testing Losses')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.show()