In [1]:
import os
import torch
import warnings
import numpy as np
import gradio as gr
from data import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from dataclasses import dataclass, asdict
from model_io import load_model, save_model

# We we define all our model here:
from new_model import Encoder, Decoder, Propagator_concat as Propagator, Model, loss_function

In [2]:
warnings.filterwarnings("ignore")
data_path = "./data"
model_path = "./checkpoints/FlexiPropagator_2D_2025-01-30-12-11-01_0aee8fb0_best.pt"
dataset_train, dataset_val, alpha_interval_split, tau_interval_split = load_from_path(data_path)
state_dict = torch.load(model_path)

In [3]:
def get_model(latent_dim):
    # Instantiate encoder, decoder, and model
    encoder = Encoder(latent_dim)
    decoder  = Decoder(latent_dim)  # Decoder for x(t)
    propagator = Propagator(latent_dim) # z(t) --> z(t+tau)
    model = Model(encoder, decoder, propagator)
    return model

In [4]:
model = get_model(3)
model, alpha_interval_split, tau_interval_split, config = load_model(model_path, model)

In [5]:
def create_interface(model, exact_solution):
    def generate_3d_visualization(Re, t_0, tau):
        dt = 2 / 500
        t = t_0 + tau * dt
    
        # Generate initial and evolved states
        U_initial = exact_solution(Re, t_0)
        U_evolved = exact_solution(Re, t)
    
        # Ensure valid data
        if np.isnan(U_initial).any() or np.isnan(U_evolved).any():
            print("Warning: NaN values detected in solutions.")
            return None  # Avoid rendering empty plots
    
        fig3d = plt.figure(figsize=(12, 6))
        ax3d = fig3d.add_subplot(111, projection='3d')
    
        x_vals = np.linspace(-2, 2, U_initial.shape[1])
        y_vals = np.linspace(-2, 2, U_initial.shape[0])
        X, Y = np.meshgrid(x_vals, y_vals)
    
        surf1 = ax3d.plot_surface(X, Y, U_initial, cmap="viridis", alpha=0.6, label="Initial")
        surf2 = ax3d.plot_surface(X, Y, U_evolved, cmap="plasma", alpha=0.8, label="Evolved")
    
        ax3d.set_xlim(-3, 3)
        ax3d.set_xlabel("x")
        ax3d.set_ylabel("y")
        ax3d.set_zlabel("u(x,y,t)")
        ax3d.view_init(elev=25, azim=-45)
        ax3d.set_box_aspect((2,1,1))
    
        fig3d.colorbar(surf1, ax=ax3d, shrink=0.5, label="Initial")
        fig3d.colorbar(surf2, ax=ax3d, shrink=0.5, label="Evolved")
        ax3d.set_title(f"Solution Evolution\nInitial (t={t_0:.2f}) vs Evolved (t={t:.2f})")
    
        plt.tight_layout()
        plt.close(fig3d)  # Explicitly close figure
        return fig3d

    def process(Re, t_0, tau):
        dt = 2 / 500
        exact_initial = exact_solution(Re, t_0)
        exact_final = exact_solution(Re, t_0 + tau * dt)
    
        if np.isnan(exact_initial).any() or np.isnan(exact_final).any():
            print("Warning: NaN values in exact solutions.")
            return None  # Skip rendering if invalid
    
        x_in = torch.tensor(exact_initial, dtype=torch.float32)[None, None, :, :]
        Re_in = torch.tensor([[Re]], dtype=torch.float32)
        tau_in = torch.tensor([[tau]], dtype=torch.float32)
    
        with torch.no_grad():
            x_hat, x_hat_tau, *_ = model(x_in, tau_in, Re_in)
    
        # Ensure tensor shape compatibility
        pred = x_hat_tau.squeeze().numpy()
        if pred.shape != exact_final.shape:
            print(f"Warning: Shape mismatch {pred.shape} vs {exact_final.shape}")
            return None
    
        mse = np.square(pred - exact_final)
    
        fig, axs = plt.subplots(1, 3, figsize=(15, 4))
    
        for ax, (data, title) in zip(axs, [(pred, "Model Prediction"),
                                           (exact_final, "Exact Solution"),
                                           (mse, "MSE Error")]):
            if title == "MSE Error":
                im = ax.imshow(data, cmap="viridis", vmin=0, vmax=1e-2)
                plt.colorbar(im, ax=ax, fraction=0.075)
            else:
                im = ax.imshow(data, cmap="jet")
    
            ax.set_title(title)
            ax.axis("off")
    
        plt.tight_layout()
        plt.close(fig)  # Ensure figure is closed
        return fig

    def update_initial_plot(Re, t_0):
        exact_initial = exact_solution(Re, t_0)
        fig, ax = plt.subplots(figsize=(5, 5))
        im = ax.imshow(exact_initial, cmap='jet')
        plt.colorbar(im, ax=ax)
        ax.set_title('Initial State')
        return fig

    with gr.Blocks() as demo:
        gr.Markdown("## Dynamical System Visualizer")
        
        with gr.Row():
            with gr.Column(scale=1):
                Re_slider = gr.Slider(1, 10, value=9, step=0.1, label="Reynolds Number (Re)", info="Interpolation: [2.09, 2.99], Extrapolation Left: [1, 1.9], Extrapolation Right: [9.1, 10]")
                t0_input = gr.Number(value=0.45, label="Initial Time (t₀)", 
                                   info="Keep close to zero (0-0.1 recommended)")
                tau_slider = gr.Slider(150, 425, value=225, step=1, label="Tau (τ)", info="Interpolation: [364.7, 392.2], Extrapolation Left: [150, 177.5], Extrapolation Right: [397.5, 425]")
                initial_plot = gr.Plot(label="Initial State")
                
            with gr.Column(scale=3):
                three_d_plot = gr.Plot(label="3D Evolution")
                comparison_plots = gr.Plot(label="Model Comparison")

        inputs = [Re_slider, t0_input, tau_slider]

        def update_all(Re, t0, tau):
            return (
                generate_3d_visualization(Re, t0, tau),
                process(Re, t0, tau),
                update_initial_plot(Re, t0)
            )
        
        for component in inputs:
            component.change(
                update_all,
                inputs=inputs,
                outputs=[three_d_plot, comparison_plots, initial_plot]
            )

        demo.load(
            lambda: [generate_3d_visualization(4, 0.05, 225), 
                    process(4, 0.05, 225),
                    update_initial_plot(4, 0.05)],
            outputs=[three_d_plot, comparison_plots, initial_plot]
        )

    return demo

demo = create_interface(model, exact_solution)
demo.launch()


* Running on local URL:  http://127.0.0.1:7898

To create a public link, set `share=True` in `launch()`.


