In [None]:
import numpy as np
from enum import Enum
from matplotlib import pyplot as plt
from tqdm import tqdm
import random


In [None]:
class PlotType(Enum):
    REAL       = "real"
    IMAG       = "imag"
    ABS        = "abs"
    ANGLE      = "angle"
    REAL_IMAG  = "real_imag"
    UNCHANGED  = "unchanged"
    ABS_SQUARE = "abs_square"

class PlotBuilder:
    def __init__(self):
        self.fig, self.axes = None, None
        self.plots = []

    def set_grid(self, rows, cols):
        self.fig, self.axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4))
        self.axes = self.axes.flatten() if rows * cols > 1 else [self.axes]
        return self

    def add_plot(self, index, data, plot_type=PlotType.UNCHANGED, **kwargs):
        if self.fig is None or self.axes is None:
            raise ValueError("Grid is not set. Use set_grid(rows, cols) first.")

        if index >= len(self.axes):
            raise IndexError("Index exceeds the number of grid cells.")

        ax = self.axes[index]
        image_kwargs = {k: v for k, v in kwargs.items() if k not in ["xlabel", "ylabel", "title", "colorbar", "annotations"]}
        img = self._generate_image(ax, data, plot_type, **image_kwargs)

        self._apply_axes_settings(ax, **kwargs)
        self._apply_special_settings(ax, img, **kwargs)

        self.plots.append((index, data, plot_type, kwargs))
        return self

    def _generate_image(self, ax, data, plot_type, **kwargs):
        plot_actions = {
            PlotType.UNCHANGED:  lambda: ax.imshow(data, **kwargs),
            PlotType.REAL:       lambda: ax.imshow(data.real, **kwargs),
            PlotType.IMAG:       lambda: ax.imshow(data.imag, **kwargs),
            PlotType.ABS:        lambda: ax.imshow(np.abs(data), **kwargs),
            PlotType.ANGLE:      lambda: ax.imshow(np.angle(data), **kwargs),
            PlotType.REAL_IMAG:  lambda: ax.imshow(data.imag * data.real, **kwargs),
            PlotType.ABS_SQUARE: lambda: ax.imshow(np.abs(data) ** 2, **kwargs)
        }

        if plot_type not in plot_actions:
            valid_types = ", ".join([pt.value for pt in PlotType])
            raise ValueError(f"Unsupported plot_type. Use PlotType enum values: {valid_types}.")

        return plot_actions[plot_type]()

    def _apply_axes_settings(self, ax, **kwargs):
        for key, value in kwargs.items():
            if key in ["xlabel", "ylabel", "title"]:
                getattr(ax, f"set_{key}")(value)

    def _apply_special_settings(self, ax, img, **kwargs):
        if "colorbar" in kwargs and kwargs["colorbar"]:
            self.fig.colorbar(img, ax=ax)

        if "annotations" in kwargs and isinstance(kwargs["annotations"], list):
            for annotation in kwargs["annotations"]:
                ax.annotate(
                    annotation.get("text", ""),
                    xy=annotation.get("xy", (0, 0)),
                    xytext=annotation.get("xytext", None),
                    arrowprops=annotation.get("arrowprops", None),
                    **annotation.get("kwargs", {})
                )

    def build(self):
        empty_axes = [ax for ax in self.axes if not ax.has_data()]
        for ax in empty_axes:
            ax.remove()

        self.axes = [ax for ax in self.axes if ax in self.fig.axes]
        plt.tight_layout()
        return self.fig, self.axes

def plot_state_and_myu(state, myu, index):
    builder = PlotBuilder()
    (_, _) = (
        builder
        .set_grid(2, 3)
        .add_plot(
                    index = 0,
                    data = state[index],
                    plot_type=PlotType.IMAG,
                    title=r"$\Im\{A\}$",
                    xlabel="X-axis",
                    ylabel="Y-axis",
                    cmap="viridis",
                    colorbar=True
        )
        .add_plot(
                    index = 1,
                    data = state[index],
                    plot_type=PlotType.REAL,
                    title=r"$\Re\{A\}$",
                    xlabel="X-axis",
                    ylabel="Y-axis",
                    cmap="viridis",
                    colorbar=True
        )
        .add_plot(
                    index = 2,
                    data = state[index],
                    plot_type=PlotType.REAL_IMAG,
                    title=r"$\Re\{A\} \times \Im\{A\}$",
                    xlabel="X-axis",
                    ylabel="Y-axis",
                    cmap="viridis",
                    colorbar=True
        )
        .add_plot(
                    index = 3,
                    data = state[index],
                    plot_type=PlotType.ABS_SQUARE,
                    title=r"$|A|^2$",
                    xlabel="X-axis",
                    ylabel="Y-axis",
                    cmap="viridis",
                    colorbar=True
        )
        .add_plot(
                    index = 4,
                    data = state[index],
                    plot_type=PlotType.ANGLE,
                    title=r"$\phi$",
                    xlabel="X-axis",
                    ylabel="Y-axis",
                    cmap="viridis",
                    colorbar=True
        )
        .add_plot(
                    index = 5,
                    data = myu[index],
                    plot_type=PlotType.UNCHANGED,
                    title=r"$\mu$",
                    xlabel="X-axis",
                    ylabel="Y-axis",
                    cmap="viridis",
                    colorbar=True
        )
        .build()
    )

    # plt.show()

In [None]:
class Simulation:
    """
    A class to simulate a dynamical system involving nonlinear operations
    on a spatial grid using Fourier transforms.
    """

    def __init__(self, d=(0.01, 100/320, 100/320), N=(100, 320, 320),
                myu_size=(10, 2, 2), myu_mstd=(5.4, 0.8)):
        """
        Initializes the simulation parameters and precomputes necessary Fourier transform terms.

        Args:
            d (tuple): Time step and spatial step sizes (dt, dx, dy).
            N (tuple): Number of time steps and spatial grid points (Nt, Nx, Ny).
            myu_size (tuple): Dimensions of the small-scale initial condition matrix.
            myu_mstd (tuple): Mean and standard deviation for the noise distribution (mean, std).
        """
        self.dt, self.dx, self.dy = d
        self.Nt, self.Nx, self.Ny = N
        self.myu_size = myu_size
        self.myu_mstd = myu_mstd

        kx = np.fft.fftfreq(self.Nx, self.dx)
        ky = np.fft.fftfreq(self.Ny, self.dy)
        Kx, Ky = np.meshgrid(kx, ky, indexing='ij')
        q = 10**-6 - 4.0 * np.pi**2 * (Kx**2 + Ky**2)
        self.exponent = np.exp(q * self.dt)
        expm1 = np.expm1(q * self.dt)
        self.q = q
        self.step1 = expm1 / q
        self.step2 = (expm1 - self.dt*q) / (self.dt*(q**2))

    def non_linear_function(self, xx, yy):
        """Defines the nonlinear interaction in the system."""
        return xx * (yy - np.abs(xx)**2)

    def next_state(self, A, myu, order=2):
        """Computes the next state of the system using the nonlinear transformation."""
        A_hat = np.fft.fft2(A)
        N_hat = np.fft.fft2(self.non_linear_function(A, myu))

        if order == 1 or "N_hat_prev" not in dir(self):
            self.N_hat_prev = N_hat
            return np.fft.ifft2(A_hat*self.exponent + N_hat*self.step1)

        R = np.fft.ifft2(A_hat*self.exponent + N_hat*self.step1 - (N_hat - self.N_hat_prev)*self.step2*.01)
        self.N_hat_prev = N_hat
        return R

    def compute_myu(self):
        """Computes the mu parameter with grid-like structure."""
        myu_small = np.random.normal(*self.myu_mstd, size=self.myu_size)
        myu_small = (myu_small > 0).astype(np.float32) * 255.0

        scale = np.array((self.Nt, self.Nx, self.Ny)) // np.array(self.myu_size)
        myu = np.kron(myu_small, np.ones(scale))
        return myu

    def compute_state(self, myu):
        """Simulates the system over time, generating the state matrix."""
        A_0 = np.random.normal(size=(self.Nx, self.Ny)) * 0.01 + \
              np.random.normal(size=(self.Nx, self.Ny)) * 0.01j
        A = np.zeros([self.Nt, self.Nx, self.Ny], dtype=np.complex64)
        A[0] = A_0

        for i in tqdm(range(1, self.Nt), desc="Computing States"):
            A[i] = self.next_state(A[i - 1], myu[i - 1])
        return A

    def compute_state_for_frame(self, myu, state, idx, state_weight=0.5):
        A_0 = np.random.normal(size=(self.Nx, self.Ny)) * 0.01 + \
              np.random.normal(size=(self.Nx, self.Ny)) * 0.01j

        A_0 = state_weight * state[idx] + (1-state_weight) * A_0

        A = np.zeros([2, self.Nx, self.Ny], dtype=np.complex64)
        A[0] = A_0

        for i in tqdm(range(1, 2), desc="Computing States"):
            A[i] = self.next_state(A[i - 1], myu[i - 1])

        return A


    def compute(self):
        """Main method to compute the simulation."""
        myu = self.compute_myu()
        state = self.compute_state(myu)
        return state, myu

    def check_properties(self, A, myu):
        """Prints properties of the matrices A and myu to ensure correctness and stability."""
        print("Unique Myus count\t", np.count_nonzero(np.unique(myu)))
        unique_values, counts = np.unique(myu, return_counts=True)
        print("Max value of myu:\t", np.max(myu))
        print("Min value of myu:\t", np.min(myu))
        print("Unique values:", unique_values.tolist())
        print("Counts:\t\t", counts)
        print(f"A.shape={A.shape},\nMyu.shape={myu.shape},\n")
        print("Any NaN values in Myu\t\t", np.isnan(myu).any())
        print("Any NaN values in A\t\t", np.isnan(A).any())

In [None]:
state_pred = np.load("A_predictions.npy")
myu_pred = np.load("mu_predictions.npy")

print(state_pred.shape)
print(myu_pred.shape)

In [None]:
simulation = Simulation(
    d=(0.0005, 0.15, 0.15),  # dt, dx, dy
    N=(100, 60, 60),         # Nt, Nx, Ny
    myu_size=(10, 4, 4),     # Temporal and spatial downsampling for mu
    myu_mstd=(1.0, 10.5)     # Mean and std for normal distribution
)

In [None]:
for frame in range(state_pred.shape[0] - 1):
    print("Frame:", frame)
    state_simulated = simulation.compute_state_for_frame(myu_pred, state_pred, frame, 0.5)

    plot_state_and_myu(state_pred, myu_pred, frame + 1)
    plt.savefig(f"plots/predicted_{frame + 1}.png", dpi=300, bbox_inches='tight')
    plt.close()


    plot_state_and_myu(state_simulated, myu_pred[frame:], 1)
    plt.savefig(f"plots/simulated_{frame + 1}.png", dpi=300, bbox_inches='tight')
    plt.close()
