In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
import matplotlib.pyplot as plt
import os
from datetime import datetime
from models import *

# Load data
state = np.load("data/Data-dmd-11-03/states_processed_cropped.npy")
myu_full = np.load("data/Data-dmd-11-03/myus_binarized_processed_cropped.npy")
myu_original = np.load("data/Data-dmd-11-03/myu_cropped.npy")

print("State shape:", state.shape, state.dtype)  # (350,530,880), complex128
print("Myu shape:  ", myu_full.shape, myu_full.dtype)  # (350,530,880), uint16
print("Myu shape:  ", myu_original.shape, myu_original.dtype)  # (350,530,880), uint16

State shape: (1500, 360, 637) complex64
Myu shape:   (1500, 360, 637) uint8
Myu shape:   (1500, 742, 1356) uint8


In [2]:
# Extract real and imaginary parts
A_r_data = state.real
A_i_data = state.imag

# Configuration
Nt, Nx, Ny = state.shape
dt, dx, dy = 0.05, 0.3, 0.3
Nx_down, Ny_down = 22, 26
degrade_x = Nx // Nx_down
degrade_y = Ny // Ny_down

# Sample data points
n_data = 20000
idx_t = np.random.randint(0, Nt, size=n_data)
idx_x = np.random.randint(0, Nx, size=n_data)
idx_y = np.random.randint(0, Ny, size=n_data)

t_vals = np.arange(Nt) * dt
x_vals = np.arange(Nx) * dx
y_vals = np.arange(Ny) * dy

t_data_np = t_vals[idx_t]
x_data_np = x_vals[idx_x]
y_data_np = y_vals[idx_y]

Ar_data_np = A_r_data[idx_t, idx_x, idx_y]
Ai_data_np = A_i_data[idx_t, idx_x, idx_y]

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Convert to tensors
x_data_t = torch.tensor(x_data_np, dtype=torch.float32, device=device).view(-1, 1)
y_data_t = torch.tensor(y_data_np, dtype=torch.float32, device=device).view(-1, 1)
t_data_t = torch.tensor(t_data_np, dtype=torch.float32, device=device).view(-1, 1)
Ar_data_t = torch.tensor(Ar_data_np, dtype=torch.float32, device=device).view(-1, 1)
Ai_data_t = torch.tensor(Ai_data_np, dtype=torch.float32, device=device).view(-1, 1)

# Collocation points for PDE constraints
n_coll = 20000
t_eqs_np = np.random.uniform(0, t_vals[-1], size=n_coll)
x_eqs_np = np.random.uniform(0, x_vals[-1], size=n_coll)
y_eqs_np = np.random.uniform(0, y_vals[-1], size=n_coll)

x_eqs_t = torch.tensor(x_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)
y_eqs_t = torch.tensor(y_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)
t_eqs_t = torch.tensor(t_eqs_np, dtype=torch.float32, device=device, requires_grad=True).view(-1, 1)

Using device: cuda


In [3]:
model_name = "TimeBlockerV2_WithValidation"
output_dir = "./results"

model = NPINN_PRO_MAX_TIMEBLOCK_V2(
    layers=[3, 128, 256, 256, 128, 2],  # Deeper and wider architecture
    Nt=Nt, Nx=Nx, Ny=Ny,
    Nx_down=Nx_down, Ny_down=Ny_down,
    dt=dt, dx=dx, dy=dy,
    degrade_x=degrade_x, degrade_y=degrade_y,
    delta=0.01,
    weight_pde=0.1,
    device=device,
    degrade_t=150,
).to(device)

In [4]:
model.train_model(
    x_data=x_data_t,
    y_data=y_data_t,
    t_data=t_data_t,
    A_r_data=Ar_data_t,
    A_i_data=Ai_data_t,
    x_eqs=x_eqs_t,
    y_eqs=y_eqs_t,
    t_eqs=t_eqs_t,
    n_epochs=120,
    lr=1e-3,
    batch_size=2048,
    model_name=model_name,
    output_dir=output_dir,
    video_freq=120,
    state_exp=state,
    myu_full_exp=myu_full,
    x_vals=x_vals,
    y_vals=y_vals,
    t_vals=t_vals,
    device=device,
    validation_split=0.2,  # Use 20% of data for validation
    val_freq=50            # Validate every 50 epochs
)

# After training, we can analyze the results
model_folder = os.path.join(output_dir, model_name)
losses_df = pd.read_csv(os.path.join(model_folder, f"{model_name}_losses.csv"))

print("\nTraining completed!")
print(f"Final training loss: {losses_df['train_total_loss'].iloc[-1]:.6e}")
print(f"Final validation loss: {losses_df['val_total_loss'].iloc[-1]:.6e}")

# Load best model
best_model_path = os.path.join(model_folder, f"{model_name}_best.pt")
model.load_state_dict(torch.load(best_model_path))
print(f"Best model loaded from {best_model_path}")

# Optional: Generate final validation visualization
print("Generating final visualization...")
final_vid_path = os.path.join(model_folder, "videos", f"{model_name}_final_visualization")
generate_video(state, myu_full, model, x_vals, y_vals, t_vals, device=device, output_path=final_vid_path)
print(f"Final visualization saved to {final_vid_path}")

Starting training with 16000 training samples and 4000 validation samples
Epoch 0: New best model saved (val_loss=1.2974e+00)
Epoch 0: Train [total=2.8273e+00, data=1.9626e+00, PDE=8.6461e+00] | Val [total=1.2974e+00, data=1.2191e+00, PDE=7.8298e-01]


KeyboardInterrupt: 