##  Import Python modules

In [None]:
import time
import torch
import imageio
import numpy as np
from scipy import linalg
import scipy.io
import scipy.sparse as sp
from scipy.sparse import diags
from scipy.sparse.linalg import splu,spilu
from scipy.sparse import coo_matrix
import matplotlib.pyplot as plt
from torch.nn import functional as F
import torch.nn as nn
from scipy.sparse.linalg import spsolve
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import warnings

warnings.filterwarnings('ignore')

## Impedance matrix and source term generating function

In [None]:
def matrix_ofd4(nx, nz, h, f, delta, c_vec):
    """
    Construct the impedance matrix for 2D acoustic wave equation using a 4th-order 
    optimized finite difference (OFD) method with perfectly matched layer (PML) boundary conditions.

    Parameters
    ----------
    nx : int
        Number of grid nodes in the x-direction (including boundaries).
    nz : int
        Number of grid nodes in the z-direction (including boundaries).
    h : float
        Spatial grid spacing (uniform in both directions).
    f : float
        Frequency (Hz).
    delta : int
        Width (in grid points) of the PML absorbing layer.
    c_vec : ndarray
        Flattened 1D array of wave speeds at each grid point in the computational domain
        (excluding boundaries), in row-major order.

    Returns
    -------
    A : scipy.sparse.csr_matrix
        The sparse complex impedance matrix of size (N, N), where 
        N = (nx - 2) * (nz - 2), is the problem scale.

    Notes
    -----
    - This method assumes a rectangular 2D domain with a uniform grid.
    - The PML is applied to all four boundaries.
    - A 4th-order finite difference stencil is used for discretization.
    - The resulting impedance matrix can be used in frequency-domain wave simulations.
    """

    # imaginary unit
    ii = 1j

    # problem scale
    N = (nx - 2) * (nz - 2)

    # nonzero elements of impedance matrix
    spn = 9 * (nx - 6) * (nz - 6) + 16 * (nx + nz - 12) + 14 * (nx + nz - 12) + 96

    # sparse store: vector space
    ai = np.zeros(spn, dtype=int)
    aj = np.zeros(spn, dtype=int)
    as_ = np.zeros(spn, dtype=complex)
    pa = 0

    # set angular frequency
    omega = 2 * np.pi * f

    # PML parameter: the ratio of reflection
    R = 1e-3

    for k in range(N):
        # grid coordinate conversion: row rule
        j = k // (nx - 2)
        i = k - (nx - 2) * j

        # set velocity: row rule
        c = c_vec[k]

        # set PML attenuation function
        if i < delta - 1:
            dx = (
                    -3 * c / (2 * delta * h) * np.log(R) * ((delta - i - 1) / delta) ** 2
            )
            dxp = -3 * c / (delta * h) ** 2 * np.log(R) * (-(delta - i - 1) / delta)
        elif i > nx - delta - 2:
            dx = (
                    -3 * c / (2 * delta * h) * np.log(R) * ((i - nx + delta + 2) / delta) ** 2
            )
            dxp = -3 * c / (delta * h) ** 2 * np.log(R) * ((i - nx + delta + 2) / delta)
        else:
            dx = 0
            dxp = 0

        if j < delta - 1:
            dz = (
                    -3 * c / (2 * delta * h) * np.log(R) * ((delta - j - 1) / delta) ** 2
            )
            dzp = -3 * c / (delta * h) ** 2 * np.log(R) * (-(delta - j - 1) / delta)
        elif j > nz - delta - 2:
            dz = (
                    -3 * c / (2 * delta * h) * np.log(R) * ((j - nz + delta + 2) / delta) ** 2
            )
            dzp = -3 * c / (delta * h) ** 2 * np.log(R) * ((j - nz + delta + 2) / delta)
        else:
            dz = 0
            dzp = 0

        tx = 1 - ii * dx / omega
        tz = 1 - ii * dz / omega

        # matlab rule: index start from 1
        kt = k + 1

        # left1
        if i != 0:
            ai[pa] = kt
            aj[pa] = kt - 1
            as_[pa] = 4 / (3 * tx ** 2) - 2 * ii * dxp * h / (3 * omega * tx ** 3)
            pa += 1

        # left2
        if i > 1:
            ai[pa] = kt
            aj[pa] = kt - 2
            as_[pa] = -1 / (12 * tx ** 2) + ii * dxp * h / (12 * omega * tx ** 3)
            pa += 1

        # right1
        if i != nx - 3:
            ai[pa] = kt
            aj[pa] = kt + 1
            as_[pa] = 4 / (3 * tx ** 2) + 2 * ii * dxp * h / (3 * omega * tx ** 3)
            pa += 1

        # right2
        if i < nx - 4:
            ai[pa] = kt
            aj[pa] = kt + 2
            as_[pa] = -1 / (12 * tx ** 2) - ii * dxp * h / (12 * omega * tx ** 3)
            pa += 1

        # up1
        if j != 0:
            ai[pa] = kt
            aj[pa] = kt - nx + 2
            as_[pa] = 4 / (3 * tz ** 2) - 2 * ii * dzp * h / (3 * omega * tz ** 3)
            pa += 1

        # up2
        if j > 1:
            ai[pa] = kt
            aj[pa] = kt - 2 * nx + 4
            as_[pa] = -1 / (12 * tz ** 2) + ii * dzp * h / (12 * omega * tz ** 3)
            pa += 1

        # down1
        if j != nz - 3:
            ai[pa] = kt
            aj[pa] = kt + nx - 2
            as_[pa] = 4 / (3 * tz ** 2) + 2 * ii * dzp * h / (3 * omega * tz ** 3)
            pa += 1

        # down2
        if j < nz - 4:
            ai[pa] = kt
            aj[pa] = kt + 2 * nx - 4
            as_[pa] = -1 / (12 * tz ** 2) - ii * dzp * h / (12 * omega * tz ** 3)
            pa += 1

        # inner
        ai[pa] = kt
        aj[pa] = kt
        as_[pa] = (omega * h / c) ** 2 - 5 / 2 * (1 / tx ** 2 + 1 / tz ** 2)
        pa += 1

    A = coo_matrix((as_, (ai - 1, aj - 1)), shape=(N, N)).tocsr()

    return A


def source_ofd(f, f0, N, h, c_vec, s_loc):
    """
    Generate the right-hand side (source term) for a frequency-domain 2D acoustic wave equation.

    Parameters
    ----------
    f : float
        Frequency (Hz) at which to evaluate the source.
    f0 : float
        Dominant frequency of the Ricker wavelet source.
    N : int
        Number of problem scale
    h : float
        Grid spacing (km).
    c_vec : ndarray
        1D array of wave speeds in the medium, in row-major order.
    s_loc : int
        Index of the source location in the flattened grid (0-based indexing).

    Returns
    -------
    s : scipy.sparse.coo_matrix
        A sparse column vector of shape (N, 1) representing the complex-valued
        source term at frequency f.

    Notes
    -----
    - The source is implemented using a frequency-domain Ricker wavelet.
    - Only one non-zero element is present, located at the source index.
    - The amplitude and phase are frequency-dependent.
    """

    # initial unit
    ii = 1j

    # source location (count from 0)
    s0 = s_loc

    # source velocity
    sc = c_vec[s0][0]

    # Amplitude and phase
    t0 = 0.12
    Amp = 1e+5

    s = np.sqrt(2) * Amp / (np.pi * f0) * (f / f0) ** 2 * np.exp(-(f / f0) ** 2) * \
        sp.coo_matrix(([-h ** 2 / sc ** 2 * np.exp(-ii * 2 * np.pi * f * t0)], ([s0], [0])), shape=(N, 1))

    return s

# Basic parameter for generate impedance matrix and source

In [None]:
# Number of grid points in the x-direction (including boundaries)
nx = 306

# Number of grid points in the z-direction (including boundaries)
nz = 306

 # Number of problem scale (interior grid points)
N = (nx - 2) * (nz - 2)

# Spatial step size (uniform in both directions), in kilometers
h = 0.025               

# Source frequency (Hz)
f = 40

# Dominant frequency of the Ricker wavelet
f0 = 20        

# Number of PML layers on each boundary
delta = 10                  

# ==================================================
# Velocity model 1: homogeneous model (306 × 306)
# ==================================================
c_velocity = 4 * np.ones((nx - 2, nz - 2), dtype=float)
# Flatten to a 1D array in row-major order
c_vec = c_velocity.reshape(N, 1)

# # =================================================
# # # velocity model 2: two-layer model (258 × 258)
# # =================================================
# c_0, c_1 = 4 * np.ones(((nz - 2) // 2, nx - 2)), 5 * np.ones(((nz - 2) // 2, nx - 2))
# c_velocity = np.concatenate((c_0, c_1), axis=0)
# c_vec = c_velocity.reshape(N, 1)

# # ================================================
# # # velocity model 3: marmousi model (338 × 114)
# # ================================================
# marmousi = scipy.io.loadmat('/kaggle/input/wavefield-simulation-velocity-model/marmousi_vec.mat')
# c_velocity = marmousi['c_mat']
# c_vec = c_velocity.reshape(N, 1)

# # ================================================
# # # velocity model 4: bp2004 model (338 × 114)
# # ================================================
# bp2004 = scipy.io.loadmat('/kaggle/input/wavefield-simulation-velocity-model/bp2004_vec.mat')
# c_velocity = bp2004['c_mat']
# c_vec = c_velocity.reshape(N, 1)

# Source location index (central position)
s_loc = int((nx - 2) * (nz - 2) // 2 + (nx - 2) // 2)  

# Generate the impedance matrix using 4th-order OFD
A = matrix_ofd4(nx, nz, h, f, delta, c_vec)          

# Generate the right-hand side source vector and convert to dense
b = source_ofd(f, f0, N, h, c_vec, s_loc).todense()

## Trainging data (random version) generating function

In [None]:
def generate_random_data(num_samples, impedance_matrix, nx, nz, seed=1234):
    """
    Generate random complex-valued input-output training data for supervised learning.

    Parameters:
    -----------
    num_samples : int
        Number of training samples to generate.
    impedance_matrix : scipy.sparse matrix or ndarray
        The complex-valued system matrix A (e.g., impedance matrix) with shape (N, N),
        where N is the spatial grid size in one direction (assumes square grid).

    Returns:
    --------
    x_tensor : torch.Tensor
        Input tensor of shape (num_samples, 2, N, N), where the 2 channels represent
        the real and imaginary parts of a random complex input.
    y_tensor : torch.Tensor
        Output tensor of shape (num_samples, 2, N, N), representing the real and
        imaginary parts of the corresponding matrix-vector product A @ x.
    """
    x_data = np.zeros((num_samples, 2, nz-2, nx-2))    # Real and imaginary parts of input
    y_data = np.zeros((num_samples, 2, nz-2, nx-2))    # Real and imaginary parts of output
    np.random.seed(seed)
    for i in range(num_samples):
        # Generate random complex input field x with real and imaginary parts in [-1, 1]
        x_real = 2 * np.random.rand(nz-2, nx-2) - 1
        x_imag = 2 * np.random.rand(nz-2, nx-2) - 1
        x_data[i, 0, :, :] = x_real
        x_data[i, 1, :, :] = x_imag

        x_complex = (x_real + 1j * x_imag).reshape(-1, 1)  # Flatten for matrix-vector multiplication

        # Compute matrix-vector product: y = A @ x
        y_complex = impedance_matrix @ x_complex

        # Reshape result back to 2D grid and split into real/imaginary parts
        y_data[i, 0, :, :] = y_complex.real.reshape(nz-2, nx-2)
        y_data[i, 1, :, :] = y_complex.imag.reshape(nz-2, nx-2)

        # Normalize both x and y by maximum absolute value of y (to prevent scale explosion)
        max_val = np.max(np.abs(y_data[i]))
        x_data[i] /= max_val
        y_data[i] /= max_val

    # Convert to PyTorch tensors
    x_tensor = torch.tensor(x_data, dtype=torch.float32)
    y_tensor = torch.tensor(y_data, dtype=torch.float32)

    return x_tensor, y_tensor

## Complex media simulation requires integrating of generalized residuals

In [None]:
from scipy.ndimage import gaussian_filter

def generate_residual_data(num_samples, impedance_matrix, nx, nz, residuals, seed=1234):
    """
    Generate complex-valued training data from residuals with small perturbations,
    suitable for supervised learning.

    Parameters:
    -----------
    num_samples : int
        Number of training samples to generate.
    matrix : scipy.sparse matrix or ndarray
        The complex-valued system matrix A with shape (N, N).
    residuals : ndarray
        Array of precomputed residuals with shape (num_samples, N).
    nx : int
        Number of grid points in the x-direction (original grid).
    nz : int
        Number of grid points in the z-direction (original grid).
    seed : int, optional
        Random seed for reproducibility.

    Returns:
    --------
    x_tensor : torch.Tensor
        Input tensor of shape (num_samples, 2, nz-2, nx-2), representing real and
        imaginary parts of perturbed residual input.
    y_tensor : torch.Tensor
        Output tensor of shape (num_samples, 2, nz-2, nx-2), representing real and
        imaginary parts of matrix-vector product A @ x.

    Notes
    -----
    - Her the generalized residuals are obtained by running the ordinary bicgstab iterations.
    - This is a initial demonstration, more details are to be developed further！
    """

    np.random.seed(seed)
    N = (nx - 2) * (nz - 2)
    x_data = np.zeros((num_samples, 2, nz - 2, nx - 2))  # Real and imaginary input
    y_data = np.zeros((num_samples, 2, nz - 2, nx - 2))  # Real and imaginary output

    for i in range(num_samples):
        # Generate small smooth complex perturbation
        delta_real = gaussian_filter(2 * np.random.rand(N, 1) - 1, sigma=i/num_samples)
        delta_imag = gaussian_filter(2 * np.random.rand(N, 1) - 1, sigma=i/num_samples)
        X_delta = delta_real + 1j * delta_imag
        # Normalize residual and add perturbation
        residual_norm = residuals[i] / np.max(np.abs(residuals[i]))
        X = residual_norm.reshape(-1, 1) + X_delta

        # Compute output: Y = A @ X
        Y = impedance_matrix @ X

        # Reshape and separate real/imag parts
        x_data[i,0,:,:] = X.real.reshape(x_data[i,0,:,:].shape)
        x_data[i,1,:,:] = X.imag.reshape(x_data[i,1,:,:].shape)
        y_data[i,0,:,:] = Y.real.reshape(y_data[i,0,:,:].shape)
        y_data[i,1,:,:] = Y.imag.reshape(y_data[i,1,:,:].shape)

        # Normalize input and output by max abs value in output
        max_val = np.max(np.abs(y_data[i]))
        x_data[i] /= max_val
        y_data[i] /= max_val

    # Convert to PyTorch tensors
    x_tensor = torch.tensor(x_data, dtype=torch.float32)
    y_tensor = torch.tensor(y_data, dtype=torch.float32)

    return x_tensor, y_tensor

## Generating Training data and Validation data

In [None]:
# Generate training and validation datasets from the impedance matrix
trainX, trainY = generate_random_data(num_samples=200, impedance_matrix=A, nx=nx, nz=nz)
valX, valY = generate_random_data(num_samples=50, impedance_matrix=A, nx=nx, nz=nz)

# When the velocity model is complex, a mixed data set with residuals is required.
# residual = [] # Record the generalized residual vector in BiCGTSAB (pk and qk).......
# trainX_res, trainY_res = generate_residual_data(num_samples=len(residual), impedance_matrix=A, nx=nx, nz=nz, residuals=residual)
# valX_res, valY_res = generate_residual_data(num_samples=len(residual), impedance_matrix=A, nx=nx, nz=nz, residuals=residual)
# trainX, trainY = torch.cat([trainX, trainX_res], dim=0),  torch.cat([trainY, trainY_res], dim=0)
# valX, valX = torch.cat([valX, valX_res], dim=0),  torch.cat([valY, valY_res], dim=0)

# Print the shape of the generated datasets
print("Training input shape:", trainX.shape)   # Expected: (200, 2, N, N)
print("Training label shape:", trainY.shape)
print("Validation input shape:", valX.shape)   # Expected: (10, 2, N, N)
print("Validation label shape:", valY.shape)

# Check numerical accuracy of matrix-vector multiplication for a sample
sample_index = 96  # Index of the sample to validate
x_complex = (trainX[sample_index, 0, :, :] + 1j * trainX[sample_index, 1, :, :]).numpy()
y_complex = (trainY[sample_index, 0, :, :] + 1j * trainY[sample_index, 1, :, :]).numpy()

# Flatten for multiplication: A @ x should give y
x_flat = x_complex.reshape(-1, 1)
y_flat = y_complex.reshape(-1, 1)
error = A @ x_flat - y_flat

# Compute the norm of the error to verify consistency
print("Matrix-vector multiplication residual norm:", np.linalg.norm(error))

## Creating dataloader for batch training

In [None]:
# ================================
# Prepare PyTorch DataLoaders
# ================================

# Construct training dataset: input is the solution (y), output is the source (x)
train_dataset = TensorDataset(trainY, trainX)

# Construct validation dataset: same format as training
val_dataset = TensorDataset(valY, valX)

# Define training DataLoader with batch shuffling enabled
train_loader = DataLoader(
    dataset=train_dataset,   # Dataset containing training samples
    batch_size=16,           # Number of samples per batch
    shuffle=True             # Shuffle the data at every epoch
)

# Define validation DataLoader with no shuffling
val_loader = DataLoader(
    dataset=val_dataset,     # Dataset containing validation samples
    batch_size=16,           # Same batch size for consistency
    shuffle=False            # No need to shuffle validation data
)

## Training data visualization

In [None]:
# Select a sample index to visualize
i = 5

# Create subplots for input and output
fig, axs = plt.subplots(1, 2, figsize=(10, 4))

# Plot the real part of input vector x (channel 0)
im1 = axs[0].imshow(trainY[i, 0, :, :].numpy(), cmap='bwr')
axs[0].set_title('Real part of Input $Ax$', fontsize=12)
axs[0].axis('off')
fig.colorbar(im1, ax=axs[0], fraction=0.046, pad=0.04, shrink=0.3)

# Plot the imaginary part of output vector Ax (channel 1)
im2 = axs[1].imshow(trainX[i, 0, :, :].numpy(), cmap='bwr')
axs[1].set_title('Real part of label $x$', fontsize=12)
axs[1].axis('off')
fig.colorbar(im2, ax=axs[1], fraction=0.046, pad=0.04, shrink=0.3)
plt.show()

## Building UNet Architecture

In [None]:
# ========================================
# Basic Convolutional Block
# Two 5x5 convolutions with Tanh activations
# ========================================
class Conv(nn.Module):
    def __init__(self, C_in, C_out):
        super(Conv, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(C_in, C_out, kernel_size=5, stride=1, padding=2),
            nn.Tanh(),
            nn.Conv2d(C_out, C_out, kernel_size=5, stride=1, padding=2),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.layer(x)

# ========================================
# DownSampling Block
# Halves spatial resolution using strided convolution
# Channel size remains unchanged
# ========================================
class DownSampling(nn.Module):
    def __init__(self, C):
        super(DownSampling, self).__init__()
        self.Down = nn.Sequential(
            nn.Conv2d(C, C, kernel_size=5, stride=2, padding=2),
            nn.Tanh()
        )

    def forward(self, x):
        return self.Down(x)

# ========================================
# UpSampling Block
# Doubles spatial resolution via nearest interpolation
# Reduces channels by half via 1x1 convolution
# Concatenates with corresponding skip connection
# ========================================
class UpSampling(nn.Module):
    def __init__(self, C):
        super(UpSampling, self).__init__()
        self.Up = nn.Conv2d(C, C // 2, kernel_size=1, stride=1)

    def forward(self, x, r):
        up = F.interpolate(x, scale_factor=2, mode="nearest-exact")
        x = self.Up(up)
        return torch.cat((x, r), dim=1)  # Concatenate along channel dimension

# ========================================
# Full U-Net Architecture
# Symmetric encoder-decoder with skip connections
# Input/Output channel: 2 (e.g. real + imaginary parts)
# ========================================
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder path (Downsampling)
        self.C1 = Conv(2, 64)
        self.D1 = DownSampling(64)
        self.C2 = Conv(64, 128)
        self.D2 = DownSampling(128)
        self.C3 = Conv(128, 256)
        self.D3 = DownSampling(256)
        self.C4 = Conv(256, 512)
        self.D4 = DownSampling(512)
        self.C5 = Conv(512, 1024)

        # Decoder path (Upsampling)
        self.U1 = UpSampling(1024)
        self.C6 = Conv(1024, 512)
        self.U2 = UpSampling(512)
        self.C7 = Conv(512, 256)
        self.U3 = UpSampling(256)
        self.C8 = Conv(256, 128)
        self.U4 = UpSampling(128)
        self.C9 = Conv(128, 64)

        # Output layer
        self.pred = nn.Conv2d(64, 2, kernel_size=5, stride=1, padding=2)
        self.Th = nn.Tanh()

    def forward(self, x):
        # Encoder: extract hierarchical features
        R1 = self.C1(x)
        R2 = self.C2(self.D1(R1))
        R3 = self.C3(self.D2(R2))
        R4 = self.C4(self.D3(R3))
        Y1 = self.C5(self.D4(R4))

        # Decoder: reconstruct from features with skip connections
        O1 = self.C6(self.U1(Y1, R4))
        O2 = self.C7(self.U2(O1, R3))
        O3 = self.C8(self.U3(O2, R2))
        O4 = self.C9(self.U4(O3, R1))

        # Final output with Tanh activation
        return self.Th(self.pred(O4))

## Network training configuration

In [None]:
# Set computation device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate the U-Net model
model = UNet().to(device)

# Define Mean Squared Error loss function and Adam optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=4e-5)

# Learning rate scheduler: halve the LR every 100 epochs
scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

# Number of training epochs
num_epochs = 300

# Lists to store loss history
train_loss = []
valid_loss = []

# =========================
# Begin training loop
# =========================
for epoch in range(num_epochs):
    model.train()
    train_epoch_loss = 0.0

    # Training phase
    for batch_inputs, batch_targets in train_loader:
        batch_inputs = batch_inputs.float().to(device)
        batch_targets = batch_targets.float().to(device)

        # Forward pass
        outputs = model(batch_inputs)
        loss = criterion(outputs, batch_targets)

        # Backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate training loss
        train_epoch_loss += loss.item() * batch_inputs.size(0)

    # Step the scheduler
    scheduler.step()

    model.eval()
    valid_epoch_loss = 0.0

    # Validation phase
    with torch.no_grad():
        for batch_inputs, batch_targets in val_loader:
            batch_inputs = batch_inputs.float().to(device)
            batch_targets = batch_targets.float().to(device)

            # Forward pass
            outputs = model(batch_inputs)
            loss = criterion(outputs, batch_targets)

            # Accumulate validation loss
            valid_epoch_loss += loss.item() * batch_inputs.size(0)

    # Normalize by total number of samples
    train_epoch_loss /= len(train_loader.dataset)
    valid_epoch_loss /= len(val_loader.dataset)

    # Save loss values for plotting
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    # Print epoch summary
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_epoch_loss:.6f} | "
          f"Valid Loss: {valid_epoch_loss:.6f}")

## Loss function visualization

In [None]:
# Plot Training and Validation Loss
plt.figure(figsize=(8, 6))

# Use semilog-y scale to better visualize loss decay
plt.semilogy(train_loss, 'b-', label='Training Loss', linewidth=2)
plt.semilogy(valid_loss, 'r--',label='Validation Loss', linewidth=2)

# Axis labels and title
plt.xlabel('Epochs', fontsize=13)
plt.ylabel('MSE loss', fontsize=13)

# Add legend without frame
plt.legend(frameon=False, fontsize=12)

# Optional: Add grid for readability
plt.grid(True, which="both", linestyle='--', linewidth=0.5)

# Display the plot
plt.tight_layout()
plt.show()

## Save the trained network weights

In [None]:
# Save Trained Model Weights

# Define the path to save the model weights
model_save_path = 'unet_weights.pth'

# Save only the model parameters (recommended approach)
torch.save(model.state_dict(), model_save_path)

# Confirm successful saving
print(f'Model weights saved successfully to "{model_save_path}"')

# Testing efficiency of DL preconditioners in local PyCharm

## Taking homogeneous medium as an example

## Step 1: Load model weights

In [None]:
# Instantiate the model
model = UNet()

# Define the path to the saved weights
model_weights_path = 'unet_weights.pth'

# Load model weights (ensure map to correct device)
model.load_state_dict(torch.load(model_weights_path, map_location=device))

# Move the model to the target device (CPU or GPU)
model.to(device)

# Set the model to evaluation mode
model.eval()

# Confirmation message
print(f'Model weights loaded successfully from \"{model_weights_path}\"')

## Step 2: Add network predictions to bicgstab method

In [None]:
class NetPreconditioner:
    def __init__(self, model, device):
        self.model = model
        self.device = device

    def __call__(self, x):
        """
        x: numpy array of shape (N,), where N = 304 * 304
        Returns: numpy array of shape (N,)
        """
        # Ensure x is complex
        if np.iscomplexobj(x):
            p = x
        else:
            p = x.astype(np.complex128)

        # Convert to network input tensor
        tensor = np.zeros((1, 2, 304, 304), dtype=np.float32)
        tensor[0, 0] = p.real.reshape(304, 304)
        tensor[0, 1] = p.imag.reshape(304, 304)
        power = np.max(np.abs(tensor))
        tensor = tensor / power
        tensor = torch.from_numpy(tensor).float().to(self.device)

        # Forward through network
        with torch.no_grad():
            output = self.model(tensor).cpu().numpy()
        pred = output[0, 0] + 1j * output[0, 1]

        return (pred * power).reshape(-1)


def get_net_linear_operator(model, device, shape):
    preconditioner = NetPreconditioner(model, device)
    return LinearOperator(
        dtype=np.complex128,
        shape=shape,
        matvec=lambda x: preconditioner(x)
    )

In [None]:
def bicgstab(A, b, *, x0=None, tol=_NoValue, maxiter=None, M=None,
             callback=None, atol=0., rtol=1e-5):
    """Use BIConjugate Gradient STABilized iteration to solve ``Ax = b``.

    Parameters
    ----------
    A : {sparse matrix, ndarray, LinearOperator}
        The real or complex N-by-N matrix of the linear system.
        Alternatively, ``A`` can be a linear operator which can
        produce ``Ax`` and ``A^T x`` using, e.g.,
        ``scipy.sparse.linalg.LinearOperator``.
    b : ndarray
        Right hand side of the linear system. Has shape (N,) or (N,1).
    x0 : ndarray
        Starting guess for the solution.
    rtol, atol : float, optional
        Parameters for the convergence test. For convergence,
        ``norm(b - A @ x) <= max(rtol*norm(b), atol)`` should be satisfied.
        The default is ``atol=0.`` and ``rtol=1e-5``.
    maxiter : integer
        Maximum number of iterations.  Iteration will stop after maxiter
        steps even if the specified tolerance has not been achieved.
    M : {sparse matrix or Network, ndarray, LinearOperator}
        Preconditioner for A.  The preconditioner should approximate the
        inverse of A.  Effective preconditioning dramatically improves the
        rate of convergence, which implies that fewer iterations are needed
        to reach a given error tolerance.
    callback : function
        User-supplied function to call after each iteration.  It is called
        as callback(xk), where xk is the current solution vector.
    tol : float, optional, deprecated

        .. deprecated:: 1.12.0
           `bicgstab` keyword argument ``tol`` is deprecated in favor of
           ``rtol`` and will be removed in SciPy 1.14.0.

    Returns
    -------
    x : ndarray
        The converged solution.
    info : integer
        Provides convergence information:
            0  : successful exit
            >0 : convergence to tolerance not achieved, number of iterations
            <0 : parameter breakdown

    rel_residuals : ndarray
        Record the relative residual during the iterative process

    Examples
    --------
    >>> import numpy as np
    >>> from scipy.sparse import csc_matrix
    >>> from scipy.sparse.linalg import bicgstab
    >>> R = np.array([[4, 2, 0, 1],
    ...               [3, 0, 0, 2],
    ...               [0, 1, 1, 1],
    ...               [0, 2, 1, 0]])
    >>> A = csc_matrix(R)
    >>> b = np.array([-1, -0.5, -1, 2])
    >>> x, exit_code = bicgstab(A, b, atol=1e-5)
    >>> print(exit_code)  # 0 indicates successful convergence
    0
    >>> np.allclose(A.dot(x), b)
    True

    """
    A, M, x, b, postprocess = make_system(A, M, x0, b)
    bnrm2 = np.linalg.norm(b)
    rel_residuals = []
    atol, _ = _get_atol_rtol('bicgstab', bnrm2, tol, atol, rtol)

    if bnrm2 == 0:
        return postprocess(b), 0, [0.0]

    n = len(b)

    dotprod = np.vdot if np.iscomplexobj(x) else np.dot

    if maxiter is None:
        maxiter = n*10

    matvec = A.matvec
    psolve = M.matvec

    # These values make no sense but coming from original Fortran code
    # sqrt might have been meant instead.
    rhotol = np.finfo(x.dtype.char).eps**2
    omegatol = rhotol

    # Dummy values to initialize vars, silence linter warnings
    rho_prev, omega, alpha, p, v = None, None, None, None, None

    r = b - matvec(x) if x.any() else b.copy()
    rtilde = r.copy()
    rel_residuals.append(np.linalg.norm(r) / bnrm2)
    for iteration in range(maxiter):
        if np.linalg.norm(r) < atol:  # Are we done?
            return postprocess(x), 0, rel_residuals

        rho = dotprod(rtilde, r)
        # if np.abs(rho) < rhotol:  # rho breakdown
        #     return postprocess(x), -10, rel_residuals

        if iteration > 0:
            if np.abs(omega) < omegatol:  # omega breakdown
                return postprocess(x), -11, rel_residuals

            beta = (rho / rho_prev) * (alpha / omega)
            p -= omega*v
            p *= beta
            p += r
        else:  # First spin
            s = np.empty_like(r)
            p = r.copy()

        phat = psolve(p)
        v = matvec(phat)
        rv = dotprod(rtilde, v)
        if rv == 0:
            return postprocess(x), -11, rel_residuals
        alpha = rho / rv
        r -= alpha*v
        s[:] = r[:]

        if np.linalg.norm(s) < atol:
            x += alpha * phat
            rel_residuals.append(np.linalg.norm(s) / bnrm2)
            return postprocess(x), 0, rel_residuals

        shat = psolve(s)
        t = matvec(shat)
        omega = dotprod(t, s) / dotprod(t, t)
        x += alpha*phat
        x += omega*shat
        r -= omega*t
        rho_prev = rho

        rel_residuals.append(np.linalg.norm(r) / bnrm2)

        if callback:
            callback(x)

    else:  # for loop exhausted
        # Return incomplete progress
        return postprocess(x), maxiter, rel_residuals

## Step 3: Evaluate the computational time and number of iterations of UNet-BiCGSTAB method

In [None]:
# Convert trained UNet model to a linear operator to serve as preconditioner
M = get_net_linear_operator(model, device, (N, N))  # UNet Preconditioner

# Stopping criteria
max_iter = 5000  # Maximum number of BiCGSTAB iterations
tol_res = 1e-5  # Tolerance of BiCGSTAB iterations
# ------------------------------------------------------------
# BiCGSTAB with no preconditioner
# ------------------------------------------------------------
start_time = time.time()
x_none, info_none, residuals_none = bicgstab(A, b, maxiter=max_iter, rtol=tol_res)
elapsed_none = time.time() - start_time

print("BiCGSTAB without preconditioner")
print(f"Final residual: {residuals_none[-1]:.3e}")
print(f"Convergence info: {info_none}")
print(f"Iterative steps: {len(residuals_none)}")
print(f"Elapsed time: {elapsed_none:.2f} seconds")
print("---------------------------------------------------------------")

# ------------------------------------------------------------
# BiCGSTAB with UNet preconditioner
# ------------------------------------------------------------
start_time = time.time()
x_dl, info_dl, residuals_dl = bicgstab(A, b, M=M, maxiter=max_iter, rtol=tol_res)
elapsed_dl = time.time() - start_time

print("BiCGSTAB with UNet preconditioner")
print(f"Final residual: {residuals_dl[-1]:.3e}")
print(f"Convergence info: {info_dl}")
print(f"Iterative steps: {len(residuals_dl)}")
print(f"Elapsed time: {elapsed_dl:.2f} seconds")

# ------------------------------------------------------------
# Plotting convergence curves
# ------------------------------------------------------------
plt.figure(figsize=(8, 6))
plt.semilogy(residuals_none, marker='o', label='No preconditioner')
plt.semilogy(residuals_dl, marker='^', label='UNet preconditioner')
plt.xlabel("iterative step", fontsize=12)
plt.ylabel("relative residual", fontsize=12)
plt.legend()
plt.show()