In [None]:
# %% [markdown]
# # Joint Reconstruction of IP and SOS
# Test reconstruction with various constraints

# %%
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy.sparse.linalg import LinearOperator, cg
import warnings
warnings.filterwarnings('ignore')

# Load data
data = np.load('data/simulated_data.npz')
p0_true = data['p0_true']
c_true = data['c_true']
mask = data['mask']
measurements = data['measurements']
dt = float(data['dt'])
dx = float(data['dx'])

print(f"Data loaded. Shape: {p0_true.shape}")
print(f"Measurements shape: {measurements.shape}")

# %%
class JointReconstructor:
    def __init__(self, grid_size, dx, c0=1540.0):
        self.grid_size = grid_size
        self.nx, self.ny = grid_size
        self.dx = dx
        self.c0 = c0
        
        # Parameters for constraints
        self.support_mask = None
        self.p0_bounds = (0, 1)
        self.c_bounds = (1400, 1700)
        self.tv_weight = 0.01
    
    def set_support(self, mask):
        self.support_mask = mask
    
    def set_bounds(self, p0_bounds, c_bounds):
        self.p0_bounds = p0_bounds
        self.c_bounds = c_bounds
    
    def forward_model(self, p0, c, n_steps=200):
        """Simplified forward model"""
        nx, ny = p0.shape
        n_sensors = measurements.shape[1]
        
        radius = 0.072
        angles = np.linspace(0, 2*np.pi, n_sensors, endpoint=False)
        sensors = np.column_stack([radius * np.cos(angles), radius * np.sin(angles)])
        
        p = p0.copy()
        p_prev = p.copy()
        y_sim = np.zeros((n_steps, n_sensors))
        
        kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]]) / (self.dx**2)
        
        for t in range(n_steps):
            laplacian = np.zeros_like(p)
            # Compute laplacian with boundary conditions
            laplacian[1:-1, 1:-1] = (
                p[2:, 1:-1] + p[:-2, 1:-1] + 
                p[1:-1, 2:] + p[1:-1, :-2] - 4*p[1:-1, 1:-1]
            ) / (self.dx**2)
            
            if t == 0:
                p = p + 0.5 * (c**2) * (dt**2) * laplacian
            else:
                p_new = 2*p - p_prev + (c**2) * (dt**2) * laplacian
                p_prev, p = p, p_new
            
            # Sample
            for i, (x, y) in enumerate(sensors):
                xi = int((x / self.dx) + nx // 2)
                yi = int((y / self.dx) + ny // 2)
                if 0 <= xi < nx and 0 <= yi < ny:
                    y_sim[t, i] = p[xi, yi]
        
        return y_sim
    
    def tv_norm(self, x):
        """Total variation norm"""
        dx = np.roll(x, -1, axis=0) - x
        dy = np.roll(x, -1, axis=1) - x
        return np.sum(np.sqrt(dx**2 + dy**2 + 1e-8))
    
    def loss_function(self, params, measurements):
        """Combined loss function with constraints"""
        nx, ny = self.grid_size
        p0 = params[:nx*ny].reshape(nx, ny)
        c = params[nx*ny:].reshape(nx, ny)
        
        # Data fidelity
        y_sim = self.forward_model(p0, c)
        data_loss = 0.5 * np.sum((y_sim - measurements)**2)
        
        # TV regularization
        tv_loss = self.tv_weight * (self.tv_norm(p0) + self.tv_norm(c))
        
        # Support constraint penalty
        support_loss = 0
        if self.support_mask is not None:
            support_loss = 1000 * np.sum(p0[~self.support_mask]**2) + \
                         1000 * np.sum((c[~self.support_mask] - self.c0)**2)
        
        # Bound constraint penalty
        bound_loss = 0
        p0_min, p0_max = self.p0_bounds
        c_min, c_max = self.c_bounds
        bound_loss += 1000 * np.sum(np.maximum(0, p0_min - p0)**2)
        bound_loss += 1000 * np.sum(np.maximum(0, p0 - p0_max)**2)
        bound_loss += 1000 * np.sum(np.maximum(0, c_min - c)**2)
        bound_loss += 1000 * np.sum(np.maximum(0, c - c_max)**2)
        
        total_loss = data_loss + tv_loss + support_loss + bound_loss
        return total_loss
    
    def reconstruct(self, measurements, n_iter=50, method='L-BFGS-B'):
        """Perform joint reconstruction"""
        nx, ny = self.grid_size
        
        # Initial guess
        p0_init = np.zeros((nx, ny))
        c_init = np.ones((nx, ny)) * self.c0
        params_init = np.concatenate([p0_init.flatten(), c_init.flatten()])
        
        # Bounds
        p0_min, p0_max = self.p0_bounds
        c_min, c_max = self.c_bounds
        bounds = [(p0_min, p0_max)] * (nx*ny) + [(c_min, c_max)] * (nx*ny)
        
        print(f"Starting optimization with {method}...")
        
        # Optimize
        result = minimize(
            self.loss_function,
            params_init,
            args=(measurements,),
            method=method,
            bounds=bounds,
            options={'maxiter': n_iter, 'disp': True}
        )
        
        # Extract results
        params_opt = result.x
        p0_est = params_opt[:nx*ny].reshape(nx, ny)
        c_est = params_opt[nx*ny:].reshape(nx, ny)
        
        return p0_est, c_est, result.fun

# %%
# Initialize reconstructor
reconstructor = JointReconstructor(grid_size=p0_true.shape, dx=dx)
reconstructor.set_support(mask)
reconstructor.set_bounds(p0_bounds=(0, 1), c_bounds=(1400, 1700))

# Run reconstruction
print("="*60)
print("Joint Reconstruction with Constraints")
print("="*60)

p0_est, c_est, final_loss = reconstructor.reconstruct(
    measurements, 
    n_iter=30,  # Reduced for speed
    method='L-BFGS-B'
)

# %%
# Calculate metrics
def calculate_metrics(true, est, mask=None):
    if mask is not None:
        true_masked = true[mask]
        est_masked = est[mask]
    else:
        true_masked = true.flatten()
        est_masked = est.flatten()
    
    mse = np.mean((true_masked - est_masked)**2)
    nrmse = np.sqrt(mse) / (np.max(true_masked) - np.min(true_masked) + 1e-8)
    correlation = np.corrcoef(true_masked, est_masked)[0, 1]
    
    return {
        'MSE': mse,
        'NRMSE': nrmse,
        'Correlation': correlation
    }

p0_metrics = calculate_metrics(p0_true, p0_est, mask)
c_metrics = calculate_metrics(c_true, c_est, mask)

print("\n" + "="*60)
print("RECONSTRUCTION RESULTS")
print("="*60)
print(f"Final loss: {final_loss:.6e}")
print("\nInitial Pressure:")
print(f"  MSE: {p0_metrics['MSE']:.6e}")
print(f"  NRMSE: {p0_metrics['NRMSE']:.4f}")
print(f"  Correlation: {p0_metrics['Correlation']:.4f}")

print("\nSpeed of Sound:")
print(f"  MSE: {c_metrics['MSE']:.2f}")
print(f"  NRMSE: {c_metrics['NRMSE']:.4f}")
print(f"  Correlation: {c_metrics['Correlation']:.4f}")

# %%
# Visualize results
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# True distributions
im1 = axes[0, 0].imshow(p0_true.T, cmap='hot', origin='lower')
axes[0, 0].set_title('True Initial Pressure')
plt.colorbar(im1, ax=axes[0, 0])

im2 = axes[0, 1].imshow(c_true.T, cmap='viridis', origin='lower')
axes[0, 1].set_title('True Speed of Sound (m/s)')
plt.colorbar(im2, ax=axes[0, 1])

# Estimated distributions
im3 = axes[1, 0].imshow(p0_est.T, cmap='hot', origin='lower')
axes[1, 0].set_title('Estimated Initial Pressure')
plt.colorbar(im3, ax=axes[1, 0])

im4 = axes[1, 1].imshow(c_est.T, cmap='viridis', origin='lower')
axes[1, 1].set_title('Estimated Speed of Sound (m/s)')
plt.colorbar(im4, ax=axes[1, 1])

# Error maps
p0_error = np.abs(p0_true - p0_est)
im5 = axes[0, 2].imshow(p0_error.T, cmap='Reds', origin='lower')
axes[0, 2].set_title('Pressure Error')
plt.colorbar(im5, ax=axes[0, 2])

c_error = np.abs(c_true - c_est)
im6 = axes[1, 2].imshow(c_error.T, cmap='Reds', origin='lower')
axes[1, 2].set_title('SOS Error')
plt.colorbar(im6, ax=axes[1, 2])

# Profile plots
x_center = p0_true.shape[0] // 2
axes[0, 3].plot(p0_true[x_center, :], 'b-', label='True', linewidth=2)
axes[0, 3].plot(p0_est[x_center, :], 'r--', label='Estimated', linewidth=2)
axes[0, 3].set_title('Horizontal Profile')
axes[0, 3].set_xlabel('Pixel')
axes[0, 3].set_ylabel('Pressure')
axes[0, 3].legend()
axes[0, 3].grid(True, alpha=0.3)

axes[1, 3].plot(c_true[x_center, :], 'b-', label='True', linewidth=2)
axes[1, 3].plot(c_est[x_center, :], 'r--', label='Estimated', linewidth=2)
axes[1, 3].set_title('Horizontal Profile')
axes[1, 3].set_xlabel('Pixel')
axes[1, 3].set_ylabel('Speed of Sound (m/s)')
axes[1, 3].legend()
axes[1, 3].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('data/reconstruction_results.png', dpi=150, bbox_inches='tight')
plt.show()