In [1]:
import os
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from datetime import datetime
from tqdm import tqdm

from models import *

In [36]:
state = np.load(r"C:\Users\Alex\Desktop\data-processing\cropped_and_ready\states_processed_cropped.npy")[:120]
myu = np.load(r"C:\Users\Alex\Desktop\data-processing\cropped_and_ready\myus_binarized_processed_cropped.npy")[:120]

In [42]:
myu_original = np.load(r"O:\Data-New\Data-dmd-11-03\myu_cropped.npy")[:120, 0::2, 0::2] 

In [43]:
print("State shape:",           state.shape,        state.dtype)
print("Myu shape:  ",           myu.shape,          myu.dtype)
print("Myu original shape:  ",  myu_original.shape, myu_original.dtype)

State shape: (120, 360, 637) complex64
Myu shape:   (120, 360, 637) uint8
Myu original shape:   (120, 371, 678) uint8


In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


In [24]:
A_r_data = state.real
A_i_data = state.imag

Nt, Nx, Ny = state.shape
dt, dx, dy = 0.05, 0.3, 0.3  # Example values
Nx_down, Ny_down = 18, 20
degrade_x = Nx // Nx_down  # 530//10=53
degrade_y = Ny // Ny_down  # 880//10=88

In [25]:
n_data = 20000
idx_t = np.random.randint(0, Nt, size=n_data)
idx_x = np.random.randint(0, Nx, size=n_data)
idx_y = np.random.randint(0, Ny, size=n_data)

t_vals = np.arange(Nt) * dt
x_vals = np.arange(Nx) * dx
y_vals = np.arange(Ny) * dy

t_data_np = t_vals[idx_t]
x_data_np = x_vals[idx_x]
y_data_np = y_vals[idx_y]

Ar_data_np = A_r_data[idx_t, idx_x, idx_y]
Ai_data_np = A_i_data[idx_t, idx_x, idx_y]

x_data_t = torch.tensor(x_data_np, dtype=torch.float32, device=device).view(-1, 1)
y_data_t = torch.tensor(y_data_np, dtype=torch.float32, device=device).view(-1, 1)
t_data_t = torch.tensor(t_data_np, dtype=torch.float32, device=device).view(-1, 1)
Ar_data_t = torch.tensor(Ar_data_np, dtype=torch.float32, device=device).view(-1, 1)
Ai_data_t = torch.tensor(Ai_data_np, dtype=torch.float32, device=device).view(-1, 1)

n_coll = 20000
t_eqs_np = np.random.uniform(0, t_vals[-1], size=n_coll)
x_eqs_np = np.random.uniform(0, x_vals[-1], size=n_coll)
y_eqs_np = np.random.uniform(0, y_vals[-1], size=n_coll)

x_eqs_t = torch.tensor(x_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)
y_eqs_t = torch.tensor(y_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)
t_eqs_t = torch.tensor(t_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)

In [29]:
model_5 = NPINN_PRO_MAX_TIMEBLOCK_V2(
    layers=[3, 128, 256, 256, 128, 2],
    Nt=Nt, Nx=Nx, Ny=Ny,
    Nx_down=Nx_down, Ny_down=Ny_down,
    dt=dt, dx=dx, dy=dy,
    degrade_x=degrade_x, degrade_y=degrade_y,
    delta=0.01,
    weight_pde=0.1,
    device=device,
    degrade_t=10
).to(device)

model_path = r"C:\Users\Alex\Desktop\gl-pinn-new\TimeBlockerV2_Test_2_final_2000.pt"

model_state = torch.load(model_path, map_location=device, weights_only=True)
model_5.load_state_dict(model_state)
model_5.eval()

NPINN_PRO_MAX_TIMEBLOCK_V2(
  (dnn): ImprovedDNN(
    (net): Sequential(
      (0): Linear(in_features=3, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GELU(approximate='none')
      (3): ResBlock(
        (layers): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): GELU(approximate='none')
          (3): Linear(in_features=128, out_features=128, bias=True)
          (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (activation): GELU(approximate='none')
      )
      (4): ResBlock(
        (layers): Sequential(
          (0): Linear(in_features=128, out_features=128, bias=True)
          (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): GELU(approximate='none'

In [30]:
video_path = r'./videos_test'

In [47]:
def create_combined_frame(pinn_prod, original_prod, abs_diff, mu_pred, mu_full, mu_original):
    vmin = np.min(original_prod)
    vmax = np.max(original_prod)

    fig = plt.figure(figsize=(18, 12))
    spec = gridspec.GridSpec(ncols=2, nrows=3, figure=fig)

    ax1 = fig.add_subplot(spec[0, 0])
    ax2 = fig.add_subplot(spec[0, 1])
    ax3 = fig.add_subplot(spec[1, 0])
    ax4 = fig.add_subplot(spec[1, 1])
    ax5 = fig.add_subplot(spec[2, 0])
    ax6 = fig.add_subplot(spec[2, 1])

    im1 = ax1.imshow(original_prod, cmap="viridis", origin="lower", vmin=vmin, vmax=vmax)
    ax1.set_title("Original: Real x Imag")
    ax1.set_xlabel("X")
    ax1.set_ylabel("Y")

    im2 = ax2.imshow(pinn_prod, cmap="viridis", origin="lower", vmin=vmin, vmax=vmax)
    ax2.set_title("PINN: Real x Imag")
    ax2.set_xlabel("X")
    ax2.set_ylabel("Y")

    im3 = ax3.imshow(abs_diff, cmap="viridis", origin="lower")
    ax3.set_title("Absolute Difference")
    ax3.set_xlabel("X")
    ax3.set_ylabel("Y")

    im4 = ax4.imshow(mu_full, cmap="viridis", origin="lower")
    ax4.set_title("Original μ (Old)")
    ax4.set_xlabel("X")
    ax4.set_ylabel("Y")

    im5 = ax5.imshow(mu_original, cmap="viridis", origin="lower")
    ax5.set_title("Original μ (New)")
    ax5.set_xlabel("X")
    ax5.set_ylabel("Y")

    im6 = ax6.imshow(mu_pred, cmap="viridis", origin="lower")
    ax6.set_title("Predicted μ")
    ax6.set_xlabel("X")
    ax6.set_ylabel("Y")

    fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)
    fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
    fig.colorbar(im3, ax=ax3, fraction=0.046, pad=0.04)
    fig.colorbar(im4, ax=ax4, fraction=0.046, pad=0.04)
    fig.colorbar(im5, ax=ax5, fraction=0.046, pad=0.04)
    fig.colorbar(im6, ax=ax6, fraction=0.046, pad=0.04)

    fig.tight_layout()
    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()
    image = np.frombuffer(fig.canvas.buffer_rgba(), dtype='uint8').reshape(height, width, 4)
    plt.close(fig)
    return image[:, :, :3]


def create_video(output_path, pinn_prod_frames, original_prod_frames, abs_diff_frames, mu_pred_frames, mu_full_frames, mu_original_frames, fps=30, additional_title=""):
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    video_path = os.path.join(output_path, f"output_video_{datetime.now().strftime('%Y%m%d%H%M%S')}_{additional_title}.mp4")

    first_frame = create_combined_frame(
        pinn_prod_frames[0], original_prod_frames[0], abs_diff_frames[0], mu_pred_frames[0], mu_full_frames[0], mu_original_frames[0]
    )
    height, width, _ = first_frame.shape

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(video_path, fourcc, fps, (width, height))

    try:
        for i in tqdm(range(len(pinn_prod_frames)), desc="Creating video"):
            combined_frame = create_combined_frame(
                pinn_prod_frames[i], original_prod_frames[i], abs_diff_frames[i], mu_pred_frames[i], mu_full_frames[i], mu_original_frames[i]
            )
            video_writer.write(cv2.cvtColor(combined_frame, cv2.COLOR_RGB2BGR))
        video_writer.release()
    except Exception as e:
        video_writer.release()
        raise RuntimeError(f"Failed to create video: {e}")

    print(f"Video saved at: {video_path}")


def generate_video(state, mu_full, mu_full_original, model, x_vals, y_vals, t_vals, device, output_path):

    pinn_prod_frames, original_state_prod_frames, abs_diff_frames, mu_pred_frames, mu_full_frames, mu_original_frames = [], [], [], [], [], []
    mu_expanded = model.expand_myu_full(do_binarize=True, scale_255=True)
    
    for i, t_val in enumerate(tqdm(t_vals, desc="Generating frames")):
        X, Y = np.meshgrid(x_vals, y_vals)
        XX = X.ravel()
        YY = Y.ravel()
        TT = np.full_like(XX, t_val)

        x_test_t = torch.tensor(XX, dtype=torch.float32, device=device).view(-1, 1)
        y_test_t = torch.tensor(YY, dtype=torch.float32, device=device).view(-1, 1)
        t_test_t = torch.tensor(TT, dtype=torch.float32, device=device).view(-1, 1)

        A_r_pred, A_i_pred = model.predict(x_test_t, y_test_t, t_test_t)
        A_r_pred_2d = A_r_pred.reshape(X.shape)
        A_i_pred_2d = A_i_pred.reshape(X.shape)

        pinn_prod = A_r_pred_2d * A_i_pred_2d
        original_prod = state[i].real * state[i].imag
        
        if pinn_prod.shape != original_prod.shape:
            pinn_prod = pinn_prod.T
    
        abs_diff = np.abs(original_prod - pinn_prod)

        mu_pred_2d = mu_expanded[i]
        mu_full_2d = mu_full[i]
        mu_original_2d = mu_full_original[i]

        pinn_prod_frames.append(pinn_prod)
        original_state_prod_frames.append(original_prod)
        abs_diff_frames.append(abs_diff)
        mu_pred_frames.append(mu_pred_2d)
        mu_full_frames.append(mu_full_2d)
        mu_original_frames.append(mu_original_2d)

    pinn_prod_frames           = np.array(pinn_prod_frames)
    original_state_prod_frames = np.array(original_state_prod_frames)
    abs_diff_frames            = np.array(abs_diff_frames)
    mu_pred_frames             = np.array(mu_pred_frames)
    mu_full_frames             = np.array(mu_full_frames)
    mu_original_frames         = np.array(mu_full_original)

    create_video(output_path, pinn_prod_frames, original_state_prod_frames, abs_diff_frames, mu_pred_frames, mu_full_frames, mu_original_frames, fps=30)

In [48]:
generate_video(
    state=state,
    mu_full=myu,
    model=model_5, 
    x_vals=x_vals, 
    y_vals=y_vals,
    t_vals=t_vals,
    mu_full_original=myu_original, 
    device=device, 
    output_path=video_path
)

Generating frames: 100%|██████████| 120/120 [00:08<00:00, 14.22it/s]
Creating video: 100%|██████████| 120/120 [01:00<00:00,  1.97it/s]

Video saved at: ./videos_test\output_video_20250312123201_.mp4



