In [1]:
import torch
import os
import cv2
import datetime
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


def prepare_ndarray_frame(data, vmin, vmax, cmap='viridis', title=""):
    """Prepares a frame from a numpy array for video by plotting it and returning the image as an ndarray."""
    fig, ax = plt.subplots(figsize=(5, 5))
    im = ax.imshow(data, cmap=cmap, origin='lower', aspect='auto', vmin=vmin, vmax=vmax)
    ax.set_title(title, fontsize=12)
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype='uint8').reshape(height, width, 4)
    plt.close(fig)
    return image[:, :, :3]

def create_combined_frame(pinn_prod, sim_prod, mu_pred, mu_full, vmin_pinn, vmax_pinn, vmin_mu, vmax_mu):
    """Creates a combined frame showing predicted and original states and mu-fields."""
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    axes[0].imshow(pinn_prod, cmap="viridis", origin="lower", vmin=vmin_pinn, vmax=vmax_pinn)
    axes[0].set_title("PINN: Real × Imag")
    axes[0].axis("off")

    axes[1].imshow(sim_prod, cmap="viridis", origin="lower", vmin=vmin_pinn, vmax=vmax_pinn)
    axes[1].set_title("Sim: Real × Imag")
    axes[1].axis("off")

    axes[2].imshow(mu_pred, cmap="viridis", origin="lower", vmin=vmin_mu, vmax=vmax_mu)
    axes[2].set_title("Predicted μ")
    axes[2].axis("off")

    axes[3].imshow(mu_full, cmap="viridis", origin="lower", vmin=vmin_mu, vmax=vmax_mu)
    axes[3].set_title("Original μ")
    axes[3].axis("off")

    fig.tight_layout()
    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype='uint8').reshape(height, width, 4)
    plt.close(fig)
    return image[:, :, :3]  # Return only RGB


def create_video(output_path, pinn_prod_frames, sim_prod_frames, mu_pred_frames, mu_full_frames, fps=30):
    """Creates a video combining frames."""
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    video_path = os.path.join(output_path, f"output_video_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.mp4")
    vmin_pinn, vmax_pinn = np.min(pinn_prod_frames), np.max(pinn_prod_frames)
    vmin_mu, vmax_mu = np.min(mu_pred_frames), np.max(mu_pred_frames)

    first_frame = create_combined_frame(
        pinn_prod_frames[0], sim_prod_frames[0], mu_pred_frames[0], mu_full_frames[0],
        vmin_pinn, vmax_pinn, vmin_mu, vmax_mu
    )
    height, width, _ = first_frame.shape

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(video_path, fourcc, fps, (width, height))

    try:
        for i in tqdm(range(len(pinn_prod_frames)), desc="Creating video"):
            combined_frame = create_combined_frame(
                pinn_prod_frames[i], sim_prod_frames[i], mu_pred_frames[i], mu_full_frames[i],
                vmin_pinn, vmax_pinn, vmin_mu, vmax_mu
            )
            video_writer.write(cv2.cvtColor(combined_frame, cv2.COLOR_RGB2BGR))
        video_writer.release()
    except Exception as e:
        video_writer.release()
        raise RuntimeError(f"Failed to create video: {e}")

    print(f"Video saved at: {video_path}")


def generate_video(state, mu_full, model, x_vals, y_vals, t_vals, device, output_path):
    """
    Generates a comparison video of predicted and actual values for states and mu fields.
    """
    pinn_prod_frames, sim_prod_frames, mu_pred_frames, mu_full_frames = [], [], [], []

    # 1) Expand the predicted mu to full shape once, shape (Nt, Nx, Ny)
    mu_expanded = model.expand_myu_full(do_binarize=True, scale_255=True)

    # 2) Loop over each time index
    for i, t_val in enumerate(tqdm(t_vals, desc="Generating frames")):
        # Build a grid for the entire domain
        X, Y = np.meshgrid(x_vals, y_vals)
        XX, YY = X.ravel(), Y.ravel()
        TT = np.full_like(XX, t_val)

        x_test_t = torch.tensor(XX, dtype=torch.float32).view(-1, 1).to(device)
        y_test_t = torch.tensor(YY, dtype=torch.float32).view(-1, 1).to(device)
        t_test_t = torch.tensor(TT, dtype=torch.float32).view(-1, 1).to(device)

        # Predict A_r and A_i
        A_r_pred, A_i_pred = model.predict(x_test_t, y_test_t, t_test_t)
        A_r_pred_2d = A_r_pred.reshape(X.shape)
        A_i_pred_2d = A_i_pred.reshape(X.shape)

        # Calculate predicted product and true product
        pinn_prod = A_r_pred_2d * A_i_pred_2d
        sim_prod = state[i].real * state[i].imag

        # Get the predicted and true mu values for this time slice
        mu_pred_t = mu_expanded[i]  # Expanded predicted mu
        mu_full_t = mu_full[i]  # Ground-truth mu

        # Append frames
        pinn_prod_frames.append(pinn_prod)
        sim_prod_frames.append(sim_prod)
        mu_pred_frames.append(mu_pred_t)
        mu_full_frames.append(mu_full_t)

    # Convert to NumPy arrays
    pinn_prod_frames = np.array(pinn_prod_frames)
    sim_prod_frames = np.array(sim_prod_frames)
    mu_pred_frames = np.array(mu_pred_frames)
    mu_full_frames = np.array(mu_full_frames)

    # Create the video
    create_video(output_path, pinn_prod_frames, sim_prod_frames, mu_pred_frames, mu_full_frames)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class DNN(nn.Module):
    def __init__(self, layers):
        super().__init__()
        modules = []
        for i in range(len(layers)-1):
            modules.append(nn.Linear(layers[i], layers[i+1]))
            if i < len(layers)-2:
                modules.append(nn.Softplus())
        self.net = nn.Sequential(*modules)

    def forward(self, x):
        return self.net(x)

class NPINN_PRO_MAX_TIMEBLOCK(nn.Module):
    """
    A 'PINN' for the Complex Ginzburg–Landau equation, with both
    - time downsampling: degrade_t
    - x,y downsampling:  degrade_x, degrade_y
    - We interpret each mu_small_raw in shape (Nt_down, Nx_down, Ny_down).

    PDE: A_t = mu A + delta Lap(A) - |A|^2 A
    where mu is a BINARY field => 0 or 1, but stored as raw -> threshold in get_myu_collocation.

    'Time blocks': each coarse time index covers degrade_t frames in the full domain.
    """

    def __init__(
        self,
        layers,               # e.g. [3,64,64,2] for A_r, A_i
        Nt, Nx, Ny,           # full domain sizes
        Nx_down, Ny_down,     # smaller, downsampled domain for mu in x,y
        dt, dx, dy,
        degrade_x, degrade_y,
        degrade_t,            # <--- NEW: factor for time downsampling
        delta=0.01,
        weight_pde=1.0,
        device='cpu'
    ):
        super().__init__()
        self.device = device
        self.delta  = delta
        self.weight_pde = weight_pde

        # Full domain
        self.Nt, self.Nx, self.Ny = Nt, Nx, Ny
        self.Nx_down, self.Ny_down = Nx_down, Ny_down
        self.dt, self.dx, self.dy = dt, dx, dy
        self.degrade_x, self.degrade_y = degrade_x, degrade_y
        self.degrade_t = degrade_t

        # The reduced domain size in time
        # we assume Nt is divisible by degrade_t for simplicity
        self.Nt_down = Nt // degrade_t

        # 1) The neural net for A(x,y,t)
        self.dnn = DNN(layers).to(device)

        # 2) The trainable mu_small: shape (Nt_down, Nx_down, Ny_down).
        init = 0.3 * torch.randn(self.Nt_down, Nx_down, Ny_down)
        self.mu_small_raw = nn.Parameter(init.to(device))

    def forward(self, x, y, t):
        return self.net_A(x, y, t)

    def net_A(self, x, y, t):
        inp = torch.cat([x,y,t], dim=1)
        out = self.dnn(inp)
        A_r = out[:,0:1]
        A_i = out[:,1:2]
        return A_r, A_i

    def binarize_mu_small(self):
        """
        Hard threshold the entire mu_small_raw -> 0 or 1 in place.
        This is optional and breaks gradient flow.
        """
        with torch.no_grad():
            self.mu_small_raw.data = (self.mu_small_raw.data > 0.0).float()

    def get_myu_collocation(self, x, y, t):
        """
        (x,y,t) -> integer indices (i, j_down, k_down).
        But for time, we do i = floor(t/dt), then i_down = floor(i/degrade_t).
        Then threshold to 0/1.
        """
        # Convert t-> i in [0..Nt-1]
        i = (t[:,0] / self.dt).round().long().clamp(0, self.Nt-1)
        # Then the coarse time index
        i_down = (i // self.degrade_t).clamp(0, self.Nt_down-1)

        j_down = (x[:,0] / (self.dx*self.degrade_x)).floor().long()
        k_down = (y[:,0] / (self.dy*self.degrade_y)).floor().long()

        j_down = j_down.clamp(0, self.Nx_down-1)
        k_down = k_down.clamp(0, self.Ny_down-1)

        mu_vals_raw = self.mu_small_raw[i_down, j_down, k_down]
        # Binarize for PDE
        mu_bin = (mu_vals_raw > 0.0).float()  # shape (batch,)
        return mu_bin.view(-1,1)

    def pde_residual(self, x, y, t):
        A_r, A_i = self.net_A(x,y,t)
        mu_vals = self.get_myu_collocation(x,y,t)

        A_r_t = torch.autograd.grad(A_r, t,
            grad_outputs=torch.ones_like(A_r),
            create_graph=True, retain_graph=True)[0]
        A_i_t = torch.autograd.grad(A_i, t,
            grad_outputs=torch.ones_like(A_i),
            create_graph=True, retain_graph=True)[0]

        # wrt x
        A_r_x = torch.autograd.grad(A_r, x,
            grad_outputs=torch.ones_like(A_r),
            create_graph=True, retain_graph=True)[0]
        A_i_x = torch.autograd.grad(A_i, x,
            grad_outputs=torch.ones_like(A_i),
            create_graph=True, retain_graph=True)[0]

        # wrt y
        A_r_y = torch.autograd.grad(A_r, y,
            grad_outputs=torch.ones_like(A_r),
            create_graph=True, retain_graph=True)[0]
        A_i_y = torch.autograd.grad(A_i, y,
            grad_outputs=torch.ones_like(A_i),
            create_graph=True, retain_graph=True)[0]

        # second derivatives
        A_r_xx = torch.autograd.grad(A_r_x, x,
            grad_outputs=torch.ones_like(A_r_x),
            create_graph=True, retain_graph=True)[0]
        A_r_yy = torch.autograd.grad(A_r_y, y,
            grad_outputs=torch.ones_like(A_r_y),
            create_graph=True, retain_graph=True)[0]

        A_i_xx = torch.autograd.grad(A_i_x, x,
            grad_outputs=torch.ones_like(A_i_x),
            create_graph=True, retain_graph=True)[0]
        A_i_yy = torch.autograd.grad(A_i_y, y,
            grad_outputs=torch.ones_like(A_i_y),
            create_graph=True, retain_graph=True)[0]

        lapA_r = A_r_xx + A_r_yy
        lapA_i = A_i_xx + A_i_yy

        A_abs2 = A_r**2 + A_i**2

        f_r = A_r_t - mu_vals*A_r - self.delta*lapA_r + A_abs2*A_r
        f_i = A_i_t - mu_vals*A_i - self.delta*lapA_i + A_abs2*A_i
        return f_r, f_i

    def loss_pde(self, x_eqs, y_eqs, t_eqs):
        f_r, f_i = self.pde_residual(x_eqs, y_eqs, t_eqs)
        return torch.mean(f_r**2 + f_i**2)

    def loss_data(self, x_data, y_data, t_data, A_r_data, A_i_data):
        A_r_pred, A_i_pred = self.net_A(x_data, y_data, t_data)
        return torch.mean((A_r_pred - A_r_data)**2 + (A_i_pred - A_i_data)**2)

    def train_model(
        self,
        x_data, y_data, t_data, A_r_data, A_i_data,
        x_eqs, y_eqs, t_eqs,
        n_epochs=200000,
        lr=1e-3,
        model_name="MyModel",
        output_dir="./results",
        video_freq=10000,
        state_exp=None,
        myu_full_exp=None,
        x_vals=None,
        y_vals=None,
        t_vals=None,
        device="cpu"
    ):
        import os
        from datetime import datetime

        optimizer = optim.Adam(self.parameters(), lr=lr)
        model_folder = os.path.join(output_dir, model_name)
        os.makedirs(model_folder, exist_ok=True)

        for epoch in range(n_epochs):
            optimizer.zero_grad()
            pde_l  = self.loss_pde(x_eqs, y_eqs, t_eqs)
            data_l = self.loss_data(x_data, y_data, t_data, A_r_data, A_i_data)
            loss   = data_l + self.weight_pde*pde_l
            loss.backward()
            optimizer.step()

            if epoch % 500 == 0:
                print(f"Epoch={epoch}, total={loss.item():.4e}, data={data_l.item():.4e}, PDE={pde_l.item():.4e}")

            if (epoch % video_freq==0 and epoch>0):
                ckpt_path = os.path.join(model_folder, f"{model_name}_epoch_{epoch}.pt")
                torch.save(self.state_dict(), ckpt_path)
                print(f"Checkpoint saved at {ckpt_path}")

                # optional video
                if (state_exp is not None) and (myu_full_exp is not None) \
                   and (x_vals is not None) and (y_vals is not None) and (t_vals is not None):
                    vid_name = f"{model_name}_epoch_{epoch}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
                    video_folder = os.path.join(model_folder, "videos")
                    os.makedirs(video_folder, exist_ok=True)

                    video_path = os.path.join(video_folder, vid_name)
                    generate_video(
                        state_exp,
                        myu_full_exp,
                        self,
                        x_vals, y_vals, t_vals,
                        device=device,
                        output_path=video_path
                    )

        final_ckpt = os.path.join(model_folder, f"{model_name}_final_{n_epochs}.pt")
        torch.save(self.state_dict(), final_ckpt)
        print(f"Final checkpoint saved at {final_ckpt}\nTraining done.\n")

    def expand_myu_full(self, do_binarize=True, scale_255=False):
        """
        Expand mu_small_raw shape = (Nt_down, Nx_down, Ny_down)
        to full shape (Nt, Nx, Ny) by:
         1) repeat_interleave along time dim by degrade_t
         2) repeat_interleave along x,y dims by degrade_x, degrade_y
        """
        with torch.no_grad():
            mu_raw = self.mu_small_raw.detach()  # shape (Nt_down, Nx_down, Ny_down)

            if do_binarize:
                mu_bin = (mu_raw>0.0).float()
            else:
                mu_bin = mu_raw

            # time expansion
            mu_time = mu_bin.repeat_interleave(self.degrade_t, dim=0)
            # shape => (Nt_down*degrade_t, Nx_down, Ny_down) = (Nt, Nx_down, Ny_down)

            # expand in x,y
            mu_full_x = mu_time.repeat_interleave(self.degrade_x, dim=1)
            mu_full_xy = mu_full_x.repeat_interleave(self.degrade_y, dim=2)

            if scale_255:
                mu_full_xy = mu_full_xy * 255.0

            return mu_full_xy.cpu().numpy()

    def predict(self, x, y, t):
        self.eval()
        with torch.no_grad():
            A_r, A_i = self.net_A(x, y, t)
        return A_r.cpu().numpy(), A_i.cpu().numpy()


In [10]:
state = np.load("../data/test_new/states_processed_cropped.npy")
myu_full = np.load("../data/test_new/myus_binarized_processed_cropped.npy")

print("State shape:", state.shape, state.dtype)  # (350,530,880), complex128
print("Myu shape:  ", myu_full.shape, myu_full.dtype)  # (350,530,880), uint16

A_r_data = state.real
A_i_data = state.imag

Nt, Nx, Ny = state.shape
dt, dx, dy = 0.05, 0.3, 0.3  # Example values
Nx_down, Ny_down = 10, 10
degrade_x = Nx // Nx_down  # 530//10=53
degrade_y = Ny // Ny_down  # 880//10=88

State shape: (350, 530, 880) complex128
Myu shape:   (350, 530, 880) uint16


In [4]:
model = NPINN_PRO_MAX_TIMEBLOCK(
    layers=[3,64,64,2],
    Nt=Nt, Nx=Nx, Ny=Ny,
    Nx_down=Nx_down, Ny_down=Ny_down,
    dt=dt, dx=dx, dy=dy,
    degrade_x=degrade_x, degrade_y=degrade_y,
    delta=0.01,
    weight_pde=0.1,
    device='cuda',
    degrade_t=50
)
model.to('cuda')

NPINN_PRO_MAX_TIMEBLOCK(
  (dnn): DNN(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=64, bias=True)
      (1): Softplus(beta=1.0, threshold=20.0)
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Softplus(beta=1.0, threshold=20.0)
      (4): Linear(in_features=64, out_features=2, bias=True)
    )
  )
)

In [11]:
n_data = 20000
idx_t = np.random.randint(0, Nt, size=n_data)
idx_x = np.random.randint(0, Nx, size=n_data)
idx_y = np.random.randint(0, Ny, size=n_data)

t_vals = np.arange(Nt) * dt
x_vals = np.arange(Nx) * dx
y_vals = np.arange(Ny) * dy

t_data_np = t_vals[idx_t]
x_data_np = x_vals[idx_x]
y_data_np = y_vals[idx_y]

Ar_data_np = A_r_data[idx_t, idx_x, idx_y]
Ai_data_np = A_i_data[idx_t, idx_x, idx_y]

import torch

device = 'cuda'

x_data_t = torch.tensor(x_data_np, dtype=torch.float32, device=device).view(-1, 1)
y_data_t = torch.tensor(y_data_np, dtype=torch.float32, device=device).view(-1, 1)
t_data_t = torch.tensor(t_data_np, dtype=torch.float32, device=device).view(-1, 1)
Ar_data_t = torch.tensor(Ar_data_np, dtype=torch.float32, device=device).view(-1, 1)
Ai_data_t = torch.tensor(Ai_data_np, dtype=torch.float32, device=device).view(-1, 1)

n_coll = 20000
t_eqs_np = np.random.uniform(0, t_vals[-1], size=n_coll)
x_eqs_np = np.random.uniform(0, x_vals[-1], size=n_coll)
y_eqs_np = np.random.uniform(0, y_vals[-1], size=n_coll)

x_eqs_t = torch.tensor(x_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)
y_eqs_t = torch.tensor(y_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)
t_eqs_t = torch.tensor(t_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class ResBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim),
            nn.GELU(),
            nn.Linear(dim, dim),
            nn.BatchNorm1d(dim),
        )
        self.activation = nn.GELU()

    def forward(self, x):
        return self.activation(x + self.layers(x))

class ImprovedDNN(nn.Module):
    def __init__(self, layers, n_res_blocks=3):
        super().__init__()

        modules = []
        # Input projection
        modules.append(nn.Linear(layers[0], layers[1]))
        modules.append(nn.BatchNorm1d(layers[1]))
        modules.append(nn.GELU())

        # Middle layers with residual blocks
        mid_dim = layers[1]
        for _ in range(n_res_blocks):
            modules.append(ResBlock(mid_dim))

        # Additional dense layers with increasing width
        for i in range(1, len(layers)-2):
            modules.append(nn.Linear(layers[i], layers[i+1]))
            modules.append(nn.BatchNorm1d(layers[i+1]))
            modules.append(nn.GELU())

        # Output projection
        modules.append(nn.Linear(layers[-2], layers[-1]))

        self.net = nn.Sequential(*modules)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.net(x)

In [13]:
class NPINN_PRO_MAX_TIMEBLOCK_V2(nn.Module):
    def __init__(
        self,
        layers,
        Nt, Nx, Ny,
        Nx_down, Ny_down,
        dt, dx, dy,
        degrade_x, degrade_y,
        degrade_t,
        delta=0.01,
        weight_pde=1.0,
        device='cpu'
    ):
        super().__init__()
        self.device = device
        self.delta = delta
        self.weight_pde = weight_pde

        # Domain parameters (same as original)
        self.Nt, self.Nx, self.Ny = Nt, Nx, Ny
        self.Nx_down, self.Ny_down = Nx_down, Ny_down
        self.dt, self.dx, self.dy = dt, dx, dy
        self.degrade_x, self.degrade_y = degrade_x, degrade_y
        self.degrade_t = degrade_t
        self.Nt_down = Nt // degrade_t

        # Improved neural network for A(x,y,t)
        self.dnn = ImprovedDNN(layers, n_res_blocks=3).to(device)

        # The trainable mu_small (same as original)
        init = 0.3 * torch.randn(self.Nt_down, Nx_down, Ny_down)
        self.mu_small_raw = nn.Parameter(init.to(device))

        # Spatial frequency encodings
        self.register_buffer('freq_x', torch.linspace(0, 10, layers[0]))
        self.register_buffer('freq_y', torch.linspace(0, 10, layers[0]))
        self.register_buffer('freq_t', torch.linspace(0, 10, layers[0]))

    def positional_encoding(self, x, y, t):
        """Add spatial frequency encodings to the input"""
        enc_x = torch.sin(x * self.freq_x[None, :])
        enc_y = torch.sin(y * self.freq_y[None, :])
        enc_t = torch.sin(t * self.freq_t[None, :])
        return (enc_x + enc_y + enc_t) / 3.0

    def net_A(self, x, y, t):
        # Concatenate inputs and add positional encoding
        inp_raw = torch.cat([x, y, t], dim=1)
        pos_enc = self.positional_encoding(x, y, t)
        inp = inp_raw + pos_enc

        out = self.dnn(inp)
        return out[:, 0:1], out[:, 1:2]

    def forward(self, x, y, t):
        return self.net_A(x, y, t)

    def binarize_mu_small(self):
        """
        Hard threshold the entire mu_small_raw -> 0 or 1 in place.
        This is optional and breaks gradient flow.
        """
        with torch.no_grad():
            self.mu_small_raw.data = (self.mu_small_raw.data > 0.0).float()

    def get_myu_collocation(self, x, y, t):
        """
        (x,y,t) -> integer indices (i, j_down, k_down).
        But for time, we do i = floor(t/dt), then i_down = floor(i/degrade_t).
        Then threshold to 0/1.
        """
        # Convert t-> i in [0..Nt-1]
        i = (t[:,0] / self.dt).round().long().clamp(0, self.Nt-1)
        # Then the coarse time index
        i_down = (i // self.degrade_t).clamp(0, self.Nt_down-1)

        j_down = (x[:,0] / (self.dx*self.degrade_x)).floor().long()
        k_down = (y[:,0] / (self.dy*self.degrade_y)).floor().long()

        j_down = j_down.clamp(0, self.Nx_down-1)
        k_down = k_down.clamp(0, self.Ny_down-1)

        mu_vals_raw = self.mu_small_raw[i_down, j_down, k_down]
        # Binarize for PDE
        mu_bin = (mu_vals_raw > 0.0).float()  # shape (batch,)
        return mu_bin.view(-1,1)

    def pde_residual(self, x, y, t):
        A_r, A_i = self.net_A(x,y,t)
        mu_vals = self.get_myu_collocation(x,y,t)

        A_r_t = torch.autograd.grad(A_r, t,
            grad_outputs=torch.ones_like(A_r),
            create_graph=True, retain_graph=True)[0]
        A_i_t = torch.autograd.grad(A_i, t,
            grad_outputs=torch.ones_like(A_i),
            create_graph=True, retain_graph=True)[0]

        # wrt x
        A_r_x = torch.autograd.grad(A_r, x,
            grad_outputs=torch.ones_like(A_r),
            create_graph=True, retain_graph=True)[0]
        A_i_x = torch.autograd.grad(A_i, x,
            grad_outputs=torch.ones_like(A_i),
            create_graph=True, retain_graph=True)[0]

        # wrt y
        A_r_y = torch.autograd.grad(A_r, y,
            grad_outputs=torch.ones_like(A_r),
            create_graph=True, retain_graph=True)[0]
        A_i_y = torch.autograd.grad(A_i, y,
            grad_outputs=torch.ones_like(A_i),
            create_graph=True, retain_graph=True)[0]

        # second derivatives
        A_r_xx = torch.autograd.grad(A_r_x, x,
            grad_outputs=torch.ones_like(A_r_x),
            create_graph=True, retain_graph=True)[0]
        A_r_yy = torch.autograd.grad(A_r_y, y,
            grad_outputs=torch.ones_like(A_r_y),
            create_graph=True, retain_graph=True)[0]

        A_i_xx = torch.autograd.grad(A_i_x, x,
            grad_outputs=torch.ones_like(A_i_x),
            create_graph=True, retain_graph=True)[0]
        A_i_yy = torch.autograd.grad(A_i_y, y,
            grad_outputs=torch.ones_like(A_i_y),
            create_graph=True, retain_graph=True)[0]

        lapA_r = A_r_xx + A_r_yy
        lapA_i = A_i_xx + A_i_yy

        A_abs2 = A_r**2 + A_i**2

        f_r = A_r_t - mu_vals*A_r - self.delta*lapA_r + A_abs2*A_r
        f_i = A_i_t - mu_vals*A_i - self.delta*lapA_i + A_abs2*A_i
        return f_r, f_i

    def loss_pde(self, x_eqs, y_eqs, t_eqs):
        f_r, f_i = self.pde_residual(x_eqs, y_eqs, t_eqs)
        return torch.mean(f_r**2 + f_i**2)

    def gradient_penalty(self, x, y, t):
        """Additional regularization for derivatives"""
        A_r, A_i = self.net_A(x, y, t)

        gradients_r = torch.autograd.grad(
            A_r.sum(), x, create_graph=True, retain_graph=True)[0]
        gradients_i = torch.autograd.grad(
            A_i.sum(), x, create_graph=True, retain_graph=True)[0]

        return (gradients_r.pow(2).sum() + gradients_i.pow(2).sum()) / x.shape[0]

    def loss_data(self, x_data, y_data, t_data, A_r_data, A_i_data):
        A_r_pred, A_i_pred = self.net_A(x_data, y_data, t_data)

        # L2 loss
        l2_loss = torch.mean((A_r_pred - A_r_data)**2 + (A_i_pred - A_i_data)**2)

        # Add L1 loss for better stability
        l1_loss = torch.mean(torch.abs(A_r_pred - A_r_data) + torch.abs(A_i_pred - A_i_data))

        return l2_loss + 0.1 * l1_loss

    def train_model(
        self,
        x_data, y_data, t_data, A_r_data, A_i_data,
        x_eqs, y_eqs, t_eqs,
        n_epochs=200000,
        lr=1e-3,
        batch_size=1024,
        model_name="MyModel",
        output_dir="./results",
        video_freq=10000,
        state_exp=None,
        myu_full_exp=None,
        x_vals=None,
        y_vals=None,
        t_vals=None,
        device="cpu"
    ):
        import os
        from datetime import datetime
        from torch.utils.data import TensorDataset, DataLoader

        # Create dataloaders
        train_dataset = TensorDataset(x_data, y_data, t_data, A_r_data, A_i_data)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        # Collocation points dataloader
        coll_dataset = TensorDataset(x_eqs, y_eqs, t_eqs)
        coll_loader = DataLoader(coll_dataset, batch_size=batch_size, shuffle=True)

        # Optimizer and scheduler
        optimizer = optim.AdamW(self.parameters(), lr=lr, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            epochs=n_epochs,
            steps_per_epoch=len(train_loader),
            pct_start=0.1,
            div_factor=25.0
        )

        # Gradient clipping
        max_grad_norm = 1.0

        model_folder = os.path.join(output_dir, model_name)
        os.makedirs(model_folder, exist_ok=True)

        best_loss = float('inf')
        patience_counter = 0
        patience = 10  # epochs for early stopping

        for epoch in range(n_epochs):
            total_loss = 0
            total_data_loss = 0
            total_pde_loss = 0

            for (x_d, y_d, t_d, ar_d, ai_d), (x_e, y_e, t_e) in zip(train_loader, coll_loader):
                optimizer.zero_grad()

                # Data loss
                data_loss = self.loss_data(x_d, y_d, t_d, ar_d, ai_d)

                # PDE loss
                pde_loss = self.loss_pde(x_e, y_e, t_e)

                # Gradient penalty
                grad_penalty = self.gradient_penalty(x_e, y_e, t_e)

                # Total loss
                loss = data_loss + self.weight_pde * pde_loss + 0.01 * grad_penalty

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_grad_norm)
                optimizer.step()
                scheduler.step()

                total_loss += loss.item()
                total_data_loss += data_loss.item()
                total_pde_loss += pde_loss.item()

            avg_loss = total_loss / len(train_loader)
            avg_data_loss = total_data_loss / len(train_loader)
            avg_pde_loss = total_pde_loss / len(train_loader)

            if epoch % 500 == 0:
                print(f"Epoch={epoch}, total={avg_loss:.4e}, data={avg_data_loss:.4e}, PDE={avg_pde_loss:.4e}")

                # Early stopping check
                if avg_loss < best_loss:
                    best_loss = avg_loss
                    patience_counter = 0
                    # Save best model
                    torch.save(self.state_dict(), os.path.join(model_folder, f"{model_name}_best.pt"))
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print("Early stopping triggered")
                        break

            # Rest of the training loop (checkpointing, video generation) remains the same
            if (epoch % video_freq == 0 and epoch > 0):
                ckpt_path = os.path.join(model_folder, f"{model_name}_epoch_{epoch}.pt")
                torch.save(self.state_dict(), ckpt_path)
                print(f"Checkpoint saved at {ckpt_path}")

                if all(x is not None for x in [state_exp, myu_full_exp, x_vals, y_vals, t_vals]):
                    vid_name = f"{model_name}_epoch_{epoch}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
                    video_folder = os.path.join(model_folder, "videos")
                    os.makedirs(video_folder, exist_ok=True)
                    video_path = os.path.join(video_folder, vid_name)
                    generate_video(state_exp, myu_full_exp, self, x_vals, y_vals, t_vals, device=device, output_path=video_path)

        final_ckpt = os.path.join(model_folder, f"{model_name}_final_{n_epochs}.pt")
        torch.save(self.state_dict(), final_ckpt)
        print(f"Final checkpoint saved at {final_ckpt}\nTraining done.\n")

    def expand_myu_full(self, do_binarize=True, scale_255=False):
        """
        Expand mu_small_raw shape = (Nt_down, Nx_down, Ny_down)
        to full shape (Nt, Nx, Ny) by:
         1) repeat_interleave along time dim by degrade_t
         2) repeat_interleave along x,y dims by degrade_x, degrade_y
        """
        with torch.no_grad():
            mu_raw = self.mu_small_raw.detach()  # shape (Nt_down, Nx_down, Ny_down)

            if do_binarize:
                mu_bin = (mu_raw>0.0).float()
            else:
                mu_bin = mu_raw

            # time expansion
            mu_time = mu_bin.repeat_interleave(self.degrade_t, dim=0)
            # shape => (Nt_down*degrade_t, Nx_down, Ny_down) = (Nt, Nx_down, Ny_down)

            # expand in x,y
            mu_full_x = mu_time.repeat_interleave(self.degrade_x, dim=1)
            mu_full_xy = mu_full_x.repeat_interleave(self.degrade_y, dim=2)

            if scale_255:
                mu_full_xy = mu_full_xy * 255.0

            return mu_full_xy.cpu().numpy()

    def predict(self, x, y, t):
        self.eval()
        with torch.no_grad():
            A_r, A_i = self.net_A(x, y, t)
        return A_r.cpu().numpy(), A_i.cpu().numpy()

In [15]:
# First initialize the improved model
model = NPINN_PRO_MAX_TIMEBLOCK_V2(
    layers=[3, 128, 256, 256, 128, 2],  # Deeper and wider architecture
    Nt=Nt, Nx=Nx, Ny=Ny,
    Nx_down=Nx_down, Ny_down=Ny_down,
    dt=dt, dx=dx, dy=dy,
    degrade_x=degrade_x, degrade_y=degrade_y,
    delta=0.01,
    weight_pde=0.1,
    device='cuda',
    degrade_t=50
)
model.to('cuda')

NPINN_PRO_MAX_TIMEBLOCK_V2(
  (dnn): ImprovedDNN(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
      (3): ResBlock(
        (layers): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): GELU(approximate='none')
          (3): Linear(in_features=128, out_features=128, bias=True)
          (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (activation): GELU(approximate='none')
      )
      (4): ResBlock(
        (layers): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): GELU(approximate='none'

In [17]:
# Then train with these parameters
model.train_model(
    x_data=x_data_t,
    y_data=y_data_t,
    t_data=t_data_t,
    A_r_data=Ar_data_t,
    A_i_data=Ai_data_t,
    x_eqs=x_eqs_t,
    y_eqs=y_eqs_t,
    t_eqs=t_eqs_t,
    n_epochs=200001,
    lr=1e-3,
    batch_size=2048,  # Increased batch size for better stability
    model_name="TimeBlockerV2_Test",
    output_dir="./results",
    video_freq=1000,
    state_exp=state,  # for generating comparison video
    myu_full_exp=myu_full,
    x_vals=x_vals,
    y_vals=y_vals,
    t_vals=t_vals,
    device='cuda'
)

Epoch=0, total=3.4745e-01, data=3.4321e-01, PDE=4.1571e-02
Epoch=500, total=1.6011e-01, data=1.5217e-01, PDE=7.8240e-02
Epoch=1000, total=1.4463e-01, data=1.3657e-01, PDE=7.9273e-02
Checkpoint saved at ./results\TimeBlockerV2_Test\TimeBlockerV2_Test_epoch_1000.pt


Generating frames: 100%|██████████| 350/350 [00:48<00:00,  7.22it/s]
Creating video: 100%|██████████| 350/350 [01:37<00:00,  3.60it/s]


Video saved at: ./results\TimeBlockerV2_Test\videos\TimeBlockerV2_Test_epoch_1000_20250124023846\output_video_20250124023936.mp4


KeyboardInterrupt: 