In [None]:
import os
import cv2
import datetime
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


def prepare_ndarray_frame(data, vmin, vmax, cmap='viridis', title=""):
    """Prepares a frame from a numpy array for video by plotting it and returning the image as an ndarray."""
    fig, ax = plt.subplots(figsize=(5, 5))
    im = ax.imshow(data, cmap=cmap, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    ax.set_title(title, fontsize=12)
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype='uint8').reshape(height, width, 4)
    plt.close(fig)
    return image[:, :, :3]

def create_combined_frame(pinn_prod, sim_prod, mu_pred, mu_full, vmin_pinn, vmax_pinn, vmin_mu, vmax_mu):
    """Creates a combined frame showing predicted and original states and mu-fields."""
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    axes[0].imshow(pinn_prod, cmap="viridis", origin="lower", vmin=vmin_pinn, vmax=vmax_pinn)
    axes[0].set_title("PINN: Real × Imag")
    axes[0].axis("off")

    axes[1].imshow(sim_prod, cmap="viridis", origin="lower", vmin=vmin_pinn, vmax=vmax_pinn)
    axes[1].set_title("Sim: Real × Imag")
    axes[1].axis("off")

    axes[2].imshow(mu_pred, cmap="viridis", origin="lower", vmin=vmin_mu, vmax=vmax_mu)
    axes[2].set_title("Predicted μ")
    axes[2].axis("off")

    axes[3].imshow(mu_full, cmap="viridis", origin="lower", vmin=vmin_mu, vmax=vmax_mu)
    axes[3].set_title("Original μ")
    axes[3].axis("off")

    fig.tight_layout()
    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype='uint8').reshape(height, width, 4)
    plt.close(fig)
    return image[:, :, :3]  # Return only RGB


def create_video(output_path, pinn_prod_frames, sim_prod_frames, mu_pred_frames, mu_full_frames, fps=30):
    """Creates a video combining frames."""
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    video_path = os.path.join(output_path, f"output_video_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4")
    vmin_pinn, vmax_pinn = np.min(pinn_prod_frames), np.max(pinn_prod_frames)
    vmin_mu, vmax_mu = np.min(mu_pred_frames), np.max(mu_pred_frames)

    first_frame = create_combined_frame(
        pinn_prod_frames[0], sim_prod_frames[0], mu_pred_frames[0], mu_full_frames[0],
        vmin_pinn, vmax_pinn, vmin_mu, vmax_mu
    )
    height, width, _ = first_frame.shape

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(video_path, fourcc, fps, (width, height))

    try:
        for i in tqdm(range(len(pinn_prod_frames)), desc="Creating video"):
            combined_frame = create_combined_frame(
                pinn_prod_frames[i], sim_prod_frames[i], mu_pred_frames[i], mu_full_frames[i],
                vmin_pinn, vmax_pinn, vmin_mu, vmax_mu
            )
            video_writer.write(cv2.cvtColor(combined_frame, cv2.COLOR_RGB2BGR))
        video_writer.release()
    except Exception as e:
        video_writer.release()
        raise RuntimeError(f"Failed to create video: {e}")

    print(f"Video saved at: {video_path}")


def generate_video(state, mu_full, model, x_vals, y_vals, t_vals, device, output_path):
    """
    Generates a comparison video of predicted and actual values for states and mu fields.
    """
    pinn_prod_frames, sim_prod_frames, mu_pred_frames, mu_full_frames = [], [], [], []

    # 1) Expand the predicted mu to full shape once, shape (Nt, Nx, Ny)
    mu_expanded = model.expand_myu_full(do_binarize=True, scale_255=True)

    # 2) Loop over each time index
    for i, t_val in enumerate(tqdm(t_vals, desc="Generating frames")):
        # Build a grid for the entire domain
        X, Y = np.meshgrid(x_vals, y_vals)
        XX, YY = X.ravel(), Y.ravel()
        TT = np.full_like(XX, t_val)

        x_test_t = torch.tensor(XX, dtype=torch.float32).view(-1, 1).to(device)
        y_test_t = torch.tensor(YY, dtype=torch.float32).view(-1, 1).to(device)
        t_test_t = torch.tensor(TT, dtype=torch.float32).view(-1, 1).to(device)

        # Predict A_r and A_i
        A_r_pred, A_i_pred = model.predict(x_test_t, y_test_t, t_test_t)
        A_r_pred_2d = A_r_pred.reshape(X.shape)
        A_i_pred_2d = A_i_pred.reshape(X.shape)

        # Calculate predicted product and true product
        pinn_prod = A_r_pred_2d * A_i_pred_2d
        sim_prod = state[i].real * state[i].imag

        # Get the predicted and true mu values for this time slice
        mu_pred_t = mu_expanded[i]  # Expanded predicted mu
        mu_full_t = mu_full[i]  # Ground-truth mu

        # Append frames
        pinn_prod_frames.append(pinn_prod)
        sim_prod_frames.append(sim_prod)
        mu_pred_frames.append(mu_pred_t)
        mu_full_frames.append(mu_full_t)

    # Convert to NumPy arrays
    pinn_prod_frames = np.array(pinn_prod_frames)
    sim_prod_frames = np.array(sim_prod_frames)
    mu_pred_frames = np.array(mu_pred_frames)
    mu_full_frames = np.array(mu_full_frames)

    # Create the video
    create_video(output_path, pinn_prod_frames, sim_prod_frames, mu_pred_frames, mu_full_frames)

In [None]:
from torch import optim
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class CNN(nn.Module):
    """
    A CNN that predicts (A_r, A_i) from spatial-temporal inputs.
    """
    def __init__(self, in_channels, out_channels, num_layers=4, base_filters=32):
        """
        Args:
            in_channels: Number of input channels (e.g., 1 for `t` or `mu` input).
            out_channels: Number of output channels (e.g., 2 for `A_r, A_i`).
            num_layers: Number of convolutional layers.
            base_filters: Number of filters in the first layer.
        """
        super().__init__()
        layers = []

        # Input layer
        layers.append(nn.Conv2d(in_channels, base_filters, kernel_size=3, padding=1))
        layers.append(nn.ReLU())

        # Intermediate layers
        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(base_filters, base_filters, kernel_size=3, padding=1))
            layers.append(nn.ReLU())

        # Output layer
        layers.append(nn.Conv2d(base_filters, out_channels, kernel_size=3, padding=1))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, in_channels, Nx, Ny).
        Returns:
            Tensor of shape (batch_size, out_channels, Nx, Ny).
        """
        return self.net(x)


import os
from datetime import datetime
from tqdm import tqdm

class NPINN_CNN_TIMEBLOCK(nn.Module):
    def __init__(
        self,
        in_channels,         # Input channels (e.g., 3 for (x, y, t))
        out_channels,        # Output channels (e.g., 2 for A_r, A_i)
        Nt, Nx, Ny,
        Nx_down, Ny_down,    # Downsampled domain sizes for mu
        dt, dx, dy,
        degrade_x, degrade_y,
        degrade_t,
        delta=0.01,
        weight_pde=1.0,
        device="cpu"
    ):
        super().__init__()
        self.device = device
        self.delta = delta
        self.weight_pde = weight_pde

        self.Nt, self.Nx, self.Ny = Nt, Nx, Ny
        self.Nx_down, self.Ny_down = Nx_down, Ny_down
        self.dt, self.dx, self.dy = dt, dx, dy
        self.degrade_x, self.degrade_y = degrade_x, degrade_y
        self.degrade_t = degrade_t
        self.Nt_down = Nt // degrade_t

        # CNN for predicting A(x, y, t)
        self.cnn = CNN(
            in_channels=in_channels,  # 3 for (x, y, t)
            out_channels=out_channels
        ).to(device)

        # Trainable mu_small (reduced domain in time and space)
        init = 0.3 * torch.randn(self.Nt_down, Nx_down, Ny_down)
        self.mu_small_raw = nn.Parameter(init.to(device))

    def forward(self, x, y, t):
        return self.net_A(x, y, t)

    def net_A(self, x, y, t):
        # Concatenate (x, y, t) along channel dimension and reshape for CNN
        batch_size = x.shape[0]
        inputs = torch.stack([x, y, t], dim=1)  # Shape: [batch_size, 3, 1, 1]
        inputs = inputs.view(batch_size, 3, 1, 1)  # Ensure 4D for CNN
        outputs = self.cnn(inputs)
        A_r, A_i = outputs[:, 0:1, 0, 0], outputs[:, 1:2, 0, 0]
        return A_r, A_i

    def get_myu_collocation(self, x, y, t):
        i = (t[:, 0] / self.dt).round().long().clamp(0, self.Nt - 1)
        i_down = (i // self.degrade_t).clamp(0, self.Nt_down - 1)
        j_down = (x[:, 0] / (self.dx * self.degrade_x)).floor().long().clamp(0, self.Nx_down - 1)
        k_down = (y[:, 0] / (self.dy * self.degrade_y)).floor().long().clamp(0, self.Ny_down - 1)
        mu_vals_raw = self.mu_small_raw[i_down, j_down, k_down]
        return mu_vals_raw.view(-1, 1)

    def pde_residual(self, x, y, t):
        A_r, A_i = self.net_A(x, y, t)
        mu_vals = self.get_myu_collocation(x, y, t)

        A_r_t = torch.autograd.grad(A_r, t, grad_outputs=torch.ones_like(A_r),
                                    create_graph=True, retain_graph=True)[0]
        A_i_t = torch.autograd.grad(A_i, t, grad_outputs=torch.ones_like(A_i),
                                    create_graph=True, retain_graph=True)[0]

        A_r_x = torch.autograd.grad(A_r, x, grad_outputs=torch.ones_like(A_r),
                                    create_graph=True, retain_graph=True)[0]
        A_i_x = torch.autograd.grad(A_i, x, grad_outputs=torch.ones_like(A_i),
                                    create_graph=True, retain_graph=True)[0]

        A_r_y = torch.autograd.grad(A_r, y, grad_outputs=torch.ones_like(A_r),
                                    create_graph=True, retain_graph=True)[0]
        A_i_y = torch.autograd.grad(A_i, y, grad_outputs=torch.ones_like(A_i),
                                    create_graph=True, retain_graph=True)[0]

        A_r_xx = torch.autograd.grad(A_r_x, x, grad_outputs=torch.ones_like(A_r_x),
                                    create_graph=True, retain_graph=True)[0]
        A_r_yy = torch.autograd.grad(A_r_y, y, grad_outputs=torch.ones_like(A_r_y),
                                    create_graph=True, retain_graph=True)[0]

        A_i_xx = torch.autograd.grad(A_i_x, x, grad_outputs=torch.ones_like(A_i_x),
                                    create_graph=True, retain_graph=True)[0]
        A_i_yy = torch.autograd.grad(A_i_y, y, grad_outputs=torch.ones_like(A_i_y),
                                    create_graph=True, retain_graph=True)[0]

        lapA_r = A_r_xx + A_r_yy
        lapA_i = A_i_xx + A_i_yy

        A_abs2 = A_r**2 + A_i**2

        f_r = A_r_t - mu_vals * A_r - self.delta * lapA_r + A_abs2 * A_r
        f_i = A_i_t - mu_vals * A_i - self.delta * lapA_i + A_abs2 * A_i
        return f_r, f_i

    def loss_pde(self, x_eqs, y_eqs, t_eqs):
        f_r, f_i = self.pde_residual(x_eqs, y_eqs, t_eqs)
        return torch.mean(f_r**2 + f_i**2)

    def loss_data(self, x_data, y_data, t_data, A_r_data, A_i_data):
        A_r_pred, A_i_pred = self.net_A(x_data, y_data, t_data)
        return torch.mean((A_r_pred - A_r_data)**2 + (A_i_pred - A_i_data)**2)

    def train_model(
        self,
        x_data, y_data, t_data, A_r_data, A_i_data,
        x_eqs, y_eqs, t_eqs,
        n_epochs=200000,
        lr=1e-3,
        model_name="MyModel",
        output_dir="./results",
        video_freq=10000,
        state_exp=None,
        myu_full_exp=None,
        x_vals=None,
        y_vals=None,
        t_vals=None,
        device="cpu"
    ):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        model_folder = os.path.join(output_dir, model_name)
        os.makedirs(model_folder, exist_ok=True)

        for epoch in tqdm(range(n_epochs)):
            optimizer.zero_grad()
            pde_loss = self.loss_pde(x_eqs, y_eqs, t_eqs)
            data_loss = self.loss_data(x_data, y_data, t_data, A_r_data, A_i_data)
            loss = data_loss + self.weight_pde * pde_loss
            loss.backward()
            optimizer.step()

            if epoch % 500 == 0:
                print(f"Epoch {epoch}: Total={loss.item():.4e}, Data={data_loss.item():.4e}, PDE={pde_loss.item():.4e}")

            if epoch % video_freq == 0 and epoch > 0:
                vid_path = os.path.join(model_folder, f"{model_name}_epoch_{epoch}_video.mp4")
                mdl_path = os.path.join(model_folder, f"{model_name}_epoch_{epoch}_trained.pt")
                generate_video(state_exp, myu_full_exp, self, x_vals, y_vals, t_vals, device, vid_path)
                torch.save(self.state_dict(), mdl_path)

        final_path = os.path.join(model_folder, f"{model_name}_final.pt")
        torch.save(self.state_dict(), final_path)
        print(f"Final model saved at {final_path}")

    def expand_myu_full(self, do_binarize=True, scale_255=False):
        """
        Expand mu_small_raw shape = (Nt_down, Nx_down, Ny_down)
        to full shape (Nt, Nx, Ny) by:
         1) repeat_interleave along time dim by degrade_t
         2) repeat_interleave along x,y dims by degrade_x, degrade_y
        """
        with torch.no_grad():
            mu_raw = self.mu_small_raw.detach()  # shape (Nt_down, Nx_down, Ny_down)

            if do_binarize:
                mu_bin = (mu_raw > 0.0).float()
            else:
                mu_bin = mu_raw

            # Expand time dimension
            mu_time = mu_bin.repeat_interleave(self.degrade_t, dim=0)
            # Shape: (Nt_down * degrade_t, Nx_down, Ny_down) = (Nt, Nx_down, Ny_down)

            # Expand spatial dimensions
            mu_full_x = mu_time.repeat_interleave(self.degrade_x, dim=1)
            mu_full_xy = mu_full_x.repeat_interleave(self.degrade_y, dim=2)
            # Shape: (Nt, Nx, Ny)

            if scale_255:
                mu_full_xy = mu_full_xy * 255.0

            return mu_full_xy.cpu().numpy()

    def predict(self, x, y, t):
        """
        Evaluate the neural net for A(x, y, t) -> (A_r, A_i) in NumPy form.
        """
        self.eval()
        with torch.no_grad():
            A_r, A_i = self.net_A(x, y, t)  # Forward pass to compute real and imaginary parts
        return A_r.cpu().numpy(), A_i.cpu().numpy()

In [None]:
import torch

state = np.load("../data/test_new/states_processed_cropped.npy")  # Complex (Nt, Nx, Ny)
myu_full = np.load("../data/test_new/myus_binarized_processed_cropped.npy")  # Binary (Nt, Nx, Ny)

print("State shape:", state.shape, state.dtype)  # (Nt, Nx, Ny)
print("Myu shape:  ", myu_full.shape, myu_full.dtype)

A_r_data = state.real  # Real part (Nt, Nx, Ny)
A_i_data = state.imag  # Imaginary part (Nt, Nx, Ny)

Nt, Nx, Ny = state.shape
dt, dx, dy = 0.05, 0.3, 0.3  # Temporal and spatial step sizes
Nx_down, Ny_down = 10, 10  # Downsampled sizes for `mu`
degrade_x, degrade_y = Nx // Nx_down, Ny // Ny_down
degrade_t = 50  # Each block of 50 time steps has the same `mu`
Nt_down = Nt // degrade_t

In [None]:
# Randomly sample training points for data and collocation (physics-based loss)
n_data = 20000  # Number of data points
idx_t = np.random.randint(0, Nt, size=n_data)
idx_x = np.random.randint(0, Nx, size=n_data)
idx_y = np.random.randint(0, Ny, size=n_data)

# Convert sampled indices to physical coordinates
t_vals = np.arange(Nt) * dt
x_vals = np.arange(Nx) * dx
y_vals = np.arange(Ny) * dy

t_data_np = t_vals[idx_t]
x_data_np = x_vals[idx_x]
y_data_np = y_vals[idx_y]

# Ground-truth A_r and A_i values
Ar_data_np = A_r_data[idx_t, idx_x, idx_y]
Ai_data_np = A_i_data[idx_t, idx_x, idx_y]

# Convert to tensors
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x_data_t = torch.tensor(x_data_np, dtype=torch.float32).view(-1, 1).to(device)
y_data_t = torch.tensor(y_data_np, dtype=torch.float32).view(-1, 1).to(device)
t_data_t = torch.tensor(t_data_np, dtype=torch.float32).view(-1, 1).to(device)
Ar_data_t = torch.tensor(Ar_data_np, dtype=torch.float32).view(-1, 1).to(device)
Ai_data_t = torch.tensor(Ai_data_np, dtype=torch.float32).view(-1, 1).to(device)

# Collocation points for PDE residuals
n_coll = 20000
t_eqs_np = np.random.uniform(0, t_vals[-1], size=n_coll)
x_eqs_np = np.random.uniform(0, x_vals[-1], size=n_coll)
y_eqs_np = np.random.uniform(0, y_vals[-1], size=n_coll)

x_eqs_t = torch.tensor(x_eqs_np, dtype=torch.float32, requires_grad=True).view(-1, 1).to(device)
y_eqs_t = torch.tensor(y_eqs_np, dtype=torch.float32, requires_grad=True).view(-1, 1).to(device)
t_eqs_t = torch.tensor(t_eqs_np, dtype=torch.float32, requires_grad=True).view(-1, 1).to(device)

In [None]:
# Initialize the model
model = NPINN_CNN_TIMEBLOCK(
    in_channels=3,         # Input channels (time, x, y grid)
    out_channels=2,        # Output channels (A_r, A_i)
    Nt=Nt, Nx=Nx, Ny=Ny,
    Nx_down=Nx_down, Ny_down=Ny_down,
    dt=dt, dx=dx, dy=dy,
    degrade_x=degrade_x, degrade_y=degrade_y,
    degrade_t=degrade_t,
    delta=0.01,
    weight_pde=0.1,
    device=device
).to(device)

In [None]:
# Training configuration
n_epochs = 200000
lr = 1e-3
video_freq = 10000  # Save comparison video every 1000 epochs
output_dir = "./results"
model_name = "CNN_PINN"

# Train the model
model.train_model(
    x_data_t, y_data_t, t_data_t,
    Ar_data_t, Ai_data_t,
    x_eqs_t, y_eqs_t, t_eqs_t,
    n_epochs=n_epochs,
    lr=lr,
    model_name=model_name,
    output_dir=output_dir,
    video_freq=video_freq,
    state_exp=state,  # Original state for video comparison
    myu_full_exp=myu_full,  # Original mu for video comparison
    x_vals=x_vals,
    y_vals=y_vals,
    t_vals=t_vals,
    device=device
)

In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class CNNWithDropout(nn.Module):
    """
    A CNN with added dropout layers for regularization.
    """
    def __init__(self, in_channels, out_channels, num_layers=6, base_filters=32, dropout_rate=0.3):
        """
        Args:
            in_channels: Number of input channels (e.g., 3 for x, y, t).
            out_channels: Number of output channels (e.g., 2 for A_r, A_i).
            num_layers: Number of convolutional layers.
            base_filters: Number of filters in the first layer.
            dropout_rate: Dropout rate for regularization.
        """
        super().__init__()
        layers = []

        # Input layer
        layers.append(nn.Conv2d(in_channels, base_filters, kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout2d(dropout_rate))

        # Intermediate layers
        for _ in range(num_layers - 2):
            layers.append(nn.Conv2d(base_filters, base_filters, kernel_size=3, padding=1))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout2d(dropout_rate))

        # Output layer
        layers.append(nn.Conv2d(base_filters, out_channels, kernel_size=3, padding=1))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, in_channels, Nx, Ny).
        Returns:
            Tensor of shape (batch_size, out_channels, Nx, Ny).
        """
        return self.net(x)


class NPINN_CNN_TIMEBLOCK_WITH_DROPOUT(nn.Module):
    def __init__(
        self,
        in_channels,         # Input channels (e.g., 3 for (x, y, t))
        out_channels,        # Output channels (e.g., 2 for A_r, A_i)
        Nt, Nx, Ny,
        Nx_down, Ny_down,    # Downsampled domain sizes for mu
        dt, dx, dy,
        degrade_x, degrade_y,
        degrade_t,
        delta=0.01,
        weight_pde=1.0,
        device="cpu",
        dropout_rate=0.3,
    ):
        super().__init__()
        self.device = device
        self.delta = delta
        self.weight_pde = weight_pde

        self.Nt, self.Nx, self.Ny = Nt, Nx, Ny
        self.Nx_down, self.Ny_down = Nx_down, Ny_down
        self.dt, self.dx, self.dy = dt, dx, dy
        self.degrade_x, self.degrade_y = degrade_x, degrade_y
        self.degrade_t = degrade_t
        self.Nt_down = Nt // degrade_t

        # CNN with Dropout for predicting A(x, y, t)
        self.cnn = CNNWithDropout(
            in_channels=in_channels,
            out_channels=out_channels,
            num_layers=6,  # Increased number of layers
            base_filters=64,  # Increased base filters
            dropout_rate=dropout_rate
        ).to(device)

        # Trainable mu_small (reduced domain in time and space)
        init = 0.3 * torch.randn(self.Nt_down, Nx_down, Ny_down)
        self.mu_small_raw = nn.Parameter(init.to(device))

    def forward(self, x, y, t):
        return self.net_A(x, y, t)

    def net_A(self, x, y, t):
        # Concatenate (x, y, t) along channel dimension and reshape for CNN
        batch_size = x.shape[0]
        inputs = torch.stack([x, y, t], dim=1)  # Shape: [batch_size, 3, 1, 1]
        inputs = inputs.view(batch_size, 3, 1, 1)  # Ensure 4D for CNN
        outputs = self.cnn(inputs)
        A_r, A_i = outputs[:, 0:1, 0, 0], outputs[:, 1:2, 0, 0]
        return A_r, A_i

    def get_myu_collocation(self, x, y, t):
        i = (t[:, 0] / self.dt).round().long().clamp(0, self.Nt - 1)
        i_down = (i // self.degrade_t).clamp(0, self.Nt_down - 1)
        j_down = (x[:, 0] / (self.dx * self.degrade_x)).floor().long().clamp(0, self.Nx_down - 1)
        k_down = (y[:, 0] / (self.dy * self.degrade_y)).floor().long().clamp(0, self.Ny_down - 1)
        mu_vals_raw = self.mu_small_raw[i_down, j_down, k_down]
        return mu_vals_raw.view(-1, 1)

    def pde_residual(self, x, y, t):
        A_r, A_i = self.net_A(x, y, t)
        mu_vals = self.get_myu_collocation(x, y, t)

        A_r_t = torch.autograd.grad(A_r, t, grad_outputs=torch.ones_like(A_r),
                                    create_graph=True, retain_graph=True)[0]
        A_i_t = torch.autograd.grad(A_i, t, grad_outputs=torch.ones_like(A_i),
                                    create_graph=True, retain_graph=True)[0]

        A_r_x = torch.autograd.grad(A_r, x, grad_outputs=torch.ones_like(A_r),
                                    create_graph=True, retain_graph=True)[0]
        A_i_x = torch.autograd.grad(A_i, x, grad_outputs=torch.ones_like(A_i),
                                    create_graph=True, retain_graph=True)[0]

        A_r_y = torch.autograd.grad(A_r, y, grad_outputs=torch.ones_like(A_r),
                                    create_graph=True, retain_graph=True)[0]
        A_i_y = torch.autograd.grad(A_i, y, grad_outputs=torch.ones_like(A_i),
                                    create_graph=True, retain_graph=True)[0]

        A_r_xx = torch.autograd.grad(A_r_x, x, grad_outputs=torch.ones_like(A_r_x),
                                    create_graph=True, retain_graph=True)[0]
        A_r_yy = torch.autograd.grad(A_r_y, y, grad_outputs=torch.ones_like(A_r_y),
                                    create_graph=True, retain_graph=True)[0]

        A_i_xx = torch.autograd.grad(A_i_x, x, grad_outputs=torch.ones_like(A_i_x),
                                    create_graph=True, retain_graph=True)[0]
        A_i_yy = torch.autograd.grad(A_i_y, y, grad_outputs=torch.ones_like(A_i_y),
                                    create_graph=True, retain_graph=True)[0]

        lapA_r = A_r_xx + A_r_yy
        lapA_i = A_i_xx + A_i_yy

        A_abs2 = A_r**2 + A_i**2

        f_r = A_r_t - mu_vals * A_r - self.delta * lapA_r + A_abs2 * A_r
        f_i = A_i_t - mu_vals * A_i - self.delta * lapA_i + A_abs2 * A_i
        return f_r, f_i

    def loss_pde(self, x_eqs, y_eqs, t_eqs):
        f_r, f_i = self.pde_residual(x_eqs, y_eqs, t_eqs)
        return torch.mean(f_r**2 + f_i**2)

    def loss_data(self, x_data, y_data, t_data, A_r_data, A_i_data):
        A_r_pred, A_i_pred = self.net_A(x_data, y_data, t_data)
        return torch.mean((A_r_pred - A_r_data)**2 + (A_i_pred - A_i_data)**2)

    def train_model(
        self,
        x_data, y_data, t_data, A_r_data, A_i_data,
        x_eqs, y_eqs, t_eqs,
        n_epochs=200000,
        lr=1e-3,
        model_name="MyModel",
        output_dir="./results",
        video_freq=10000,
        state_exp=None,
        myu_full_exp=None,
        x_vals=None,
        y_vals=None,
        t_vals=None,
        device="cpu"
    ):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        model_folder = os.path.join(output_dir, model_name)
        os.makedirs(model_folder, exist_ok=True)

        for epoch in tqdm(range(n_epochs)):
            optimizer.zero_grad()
            pde_loss = self.loss_pde(x_eqs, y_eqs, t_eqs)
            data_loss = self.loss_data(x_data, y_data, t_data, A_r_data, A_i_data)
            loss = data_loss + self.weight_pde * pde_loss
            loss.backward()
            optimizer.step()

            if epoch % 500 == 0:
                print(f"Epoch {epoch}: Total={loss.item():.4e}, Data={data_loss.item():.4e}, PDE={pde_loss.item():.4e}")

            if epoch % video_freq == 0 and epoch > 0:
                vid_path = os.path.join(model_folder, f"{model_name}_epoch_{epoch}_video.mp4")
                mdl_path = os.path.join(model_folder, f"{model_name}_epoch_{epoch}_trained.pt")
                generate_video(state_exp, myu_full_exp, self, x_vals, y_vals, t_vals, device, vid_path)
                torch.save(self.state_dict(), mdl_path)

        final_path = os.path.join(model_folder, f"{model_name}_final.pt")
        torch.save(self.state_dict(), final_path)
        print(f"Final model saved at {final_path}")

    def expand_myu_full(self, do_binarize=True, scale_255=False):
        """
        Expand mu_small_raw shape = (Nt_down, Nx_down, Ny_down)
        to full shape (Nt, Nx, Ny) by:
         1) repeat_interleave along time dim by degrade_t
         2) repeat_interleave along x,y dims by degrade_x, degrade_y
        """
        with torch.no_grad():
            mu_raw = self.mu_small_raw.detach()  # shape (Nt_down, Nx_down, Ny_down)

            if do_binarize:
                mu_bin = (mu_raw > 0.0).float()
            else:
                mu_bin = mu_raw

            # Expand time dimension
            mu_time = mu_bin.repeat_interleave(self.degrade_t, dim=0)
            # Shape: (Nt_down * degrade_t, Nx_down, Ny_down) = (Nt, Nx_down, Ny_down)

            # Expand spatial dimensions
            mu_full_x = mu_time.repeat_interleave(self.degrade_x, dim=1)
            mu_full_xy = mu_full_x.repeat_interleave(self.degrade_y, dim=2)
            # Shape: (Nt, Nx, Ny)

            if scale_255:
                mu_full_xy = mu_full_xy * 255.0

            return mu_full_xy.cpu().numpy()

    def predict(self, x, y, t):
        """
        Evaluate the neural net for A(x, y, t) -> (A_r, A_i) in NumPy form.
        """
        self.eval()
        with torch.no_grad():
            A_r, A_i = self.net_A(x, y, t)  # Forward pass to compute real and imaginary parts
        return A_r.cpu().numpy(), A_i.cpu().numpy()


In [None]:
# Import necessary libraries
import numpy as np
import torch

# Load state and mu data
state = np.load("../data/test_new/states_processed_cropped.npy")  # Complex (Nt, Nx, Ny)
myu_full = np.load("../data/test_new/myus_binarized_processed_cropped.npy")  # Binary (Nt, Nx, Ny)

print("State shape:", state.shape, state.dtype)  # (Nt, Nx, Ny)
print("Mu shape:   ", myu_full.shape, myu_full.dtype)

# Real and imaginary parts of the state
A_r_data = state.real  # Real part
A_i_data = state.imag  # Imaginary part

# Extract domain sizes
Nt, Nx, Ny = state.shape
dt, dx, dy = 0.05, 0.3, 0.3  # Temporal and spatial step sizes
Nx_down, Ny_down = 10, 10  # Downsampled sizes for `mu`
degrade_x, degrade_y = Nx // Nx_down, Ny // Ny_down
degrade_t = 50  # Each block of 50 time steps has the same `mu`
Nt_down = Nt // degrade_t

# Prepare training and collocation points
n_data = 20000  # Number of data points
idx_t = np.random.randint(0, Nt, size=n_data)
idx_x = np.random.randint(0, Nx, size=n_data)
idx_y = np.random.randint(0, Ny, size=n_data)

# Convert indices to physical coordinates
t_vals = np.arange(Nt) * dt
x_vals = np.arange(Nx) * dx
y_vals = np.arange(Ny) * dy

t_data_np = t_vals[idx_t]
x_data_np = x_vals[idx_x]
y_data_np = y_vals[idx_y]

# Ground-truth real and imaginary parts
Ar_data_np = A_r_data[idx_t, idx_x, idx_y]
Ai_data_np = A_i_data[idx_t, idx_x, idx_y]

# Convert to tensors
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x_data_t = torch.tensor(x_data_np, dtype=torch.float32).view(-1, 1).to(device)
y_data_t = torch.tensor(y_data_np, dtype=torch.float32).view(-1, 1).to(device)
t_data_t = torch.tensor(t_data_np, dtype=torch.float32).view(-1, 1).to(device)
Ar_data_t = torch.tensor(Ar_data_np, dtype=torch.float32).view(-1, 1).to(device)
Ai_data_t = torch.tensor(Ai_data_np, dtype=torch.float32).view(-1, 1).to(device)

# Collocation points for PDE residuals
n_coll = 20000
t_eqs_np = np.random.uniform(0, t_vals[-1], size=n_coll)
x_eqs_np = np.random.uniform(0, x_vals[-1], size=n_coll)
y_eqs_np = np.random.uniform(0, y_vals[-1], size=n_coll)

x_eqs_t = torch.tensor(x_eqs_np, dtype=torch.float32, requires_grad=True).view(-1, 1).to(device)
y_eqs_t = torch.tensor(y_eqs_np, dtype=torch.float32, requires_grad=True).view(-1, 1).to(device)
t_eqs_t = torch.tensor(t_eqs_np, dtype=torch.float32, requires_grad=True).view(-1, 1).to(device)

# Initialize the model
model = NPINN_CNN_TIMEBLOCK_WITH_DROPOUT(
    in_channels=3,         # Input channels (time, x, y grid)
    out_channels=2,        # Output channels (A_r, A_i)
    Nt=Nt, Nx=Nx, Ny=Ny,
    Nx_down=Nx_down, Ny_down=Ny_down,
    dt=dt, dx=dx, dy=dy,
    degrade_x=degrade_x, degrade_y=degrade_y,
    degrade_t=degrade_t,
    delta=0.01,
    weight_pde=0.1,
    dropout_rate=0.5,      # Dropout probability
    device=device
).to(device)

# Training configuration
n_epochs = 200000
lr = 1e-3
video_freq = 10000  # Save comparison video every 10000 epochs
output_dir = "./results"
model_name = "CNN_PINN_Dropout"

print("Initial mu_small_raw:")
print(model.mu_small_raw)
# Train the model
model.train_model(
    x_data_t, y_data_t, t_data_t,
    Ar_data_t, Ai_data_t,
    x_eqs_t, y_eqs_t, t_eqs_t,
    n_epochs=n_epochs,
    lr=lr,
    model_name=model_name,
    output_dir=output_dir,
    video_freq=video_freq,
    state_exp=state,  # Original state for video comparison
    myu_full_exp=myu_full,  # Original mu for video comparison
    x_vals=x_vals,
    y_vals=y_vals,
    t_vals=t_vals,
    device=device
)

print("Updated mu_small_raw:")
print(model.mu_small_raw)
# Save the final trained model
final_model_path = f"{output_dir}/{model_name}_final.pt"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved at {final_model_path}")

State shape: (350, 530, 880) complex128
Mu shape:    (350, 530, 880) uint16
Initial mu_small_raw:
Parameter containing:
tensor([[[-0.1337, -0.0711, -0.4917, -0.0978,  0.3985, -0.1620, -0.3516,
          -0.1934, -0.0863,  0.2276],
         [ 0.1529, -0.3025, -0.4092, -0.1418,  0.2641, -0.3506, -0.2351,
          -0.0710,  0.0350, -0.0425],
         [-0.1450, -0.1569,  0.1331,  0.2254, -0.0472,  0.2859, -0.1885,
          -0.1547,  0.3316, -0.1803],
         [ 0.4727, -0.1520, -0.4075,  0.3553,  0.1617,  0.1352, -0.0836,
           0.2357, -0.1484, -0.2308],
         [ 0.3905, -0.1791,  0.6272,  0.2286,  0.0017,  0.2941,  0.5514,
          -0.1074, -0.1156,  0.4443],
         [ 0.1383,  0.2162, -0.3742,  0.3170,  0.1288, -0.0113, -0.0124,
           0.2511,  0.4652, -0.1399],
         [-0.0333,  0.1135,  0.1674,  0.3537,  0.0588, -0.6603,  0.5049,
          -0.2520, -0.1286,  0.2642],
         [ 0.4199, -0.0227,  0.1994,  0.2506, -0.2709, -0.1783,  0.1414,
           0.1434, -0.0756,  0

  0%|          | 2/200000 [00:00<10:57:25,  5.07it/s]

Epoch 0: Total=3.7807e-01, Data=3.7806e-01, PDE=4.4449e-05


  0%|          | 502/200000 [00:51<5:43:52,  9.67it/s]

Epoch 500: Total=2.2827e-01, Data=2.2803e-01, PDE=2.4031e-03


  1%|          | 1002/200000 [01:43<5:47:35,  9.54it/s]

Epoch 1000: Total=2.1490e-01, Data=2.1421e-01, PDE=6.8202e-03


  1%|          | 1502/200000 [02:35<5:42:52,  9.65it/s]

Epoch 1500: Total=2.1059e-01, Data=2.0960e-01, PDE=9.8863e-03


  1%|          | 2002/200000 [03:27<5:44:38,  9.58it/s]

Epoch 2000: Total=2.0876e-01, Data=2.0781e-01, PDE=9.5204e-03


  1%|▏         | 2502/200000 [04:19<5:37:23,  9.76it/s]

Epoch 2500: Total=2.0897e-01, Data=2.0792e-01, PDE=1.0490e-02


  2%|▏         | 3002/200000 [05:11<5:40:54,  9.63it/s]

Epoch 3000: Total=2.0850e-01, Data=2.0744e-01, PDE=1.0611e-02


  2%|▏         | 3469/200000 [05:59<5:36:50,  9.72it/s]