# Surrogate Model Surface pressure field

## DeepONet Implementation

### Dataset class

In [26]:
import torch
from torch.utils.data import Dataset
import numpy as np
import pyvista as pv
import os

class DrivAerNetDeepONetDataset(Dataset):
    def __init__(self, npz_dir, vtk_dir, global_stats_path=None, num_trunk_samples=4096, normalize_pressure=False):
        """
        Args:
            npz_dir: Folder containing the preprocessed .npz files (Branch data)
            vtk_dir: Folder containing the raw .vtk/.vtm files (Trunk data)
            global_stats_path: Path to 'pressure_stats.npz' (mean/std)
            num_trunk_samples: Number of points to sample for TrunkNet
            normalize_pressure: Boolean flag to toggle pressure normalization
        """
        self.npz_files = sorted([f for f in os.listdir(npz_dir) if f.endswith('.npz')])
        self.npz_dir = npz_dir
        self.vtk_dir = vtk_dir
        self.num_trunk_samples = num_trunk_samples
        self.normalize_pressure = normalize_pressure
        
        # Load Global Statistics for Pressure
        if self.normalize_pressure:
            stats = np.load(global_stats_path)
            self.global_mean = stats['mean']
            self.global_std = stats['std']

    def __len__(self):
        return len(self.npz_files)

    def __getitem__(self, idx):
        # 1. Load Preprocessed Branch Data
        npz_path = os.path.join(self.npz_dir, self.npz_files[idx])
        npz_data = np.load(npz_path)
        
        branch_points = npz_data['points']      # Already (4096, 3) and Normalized
        pc_centroid = npz_data['centroid']      # To sync with Mesh
        pc_scale = npz_data['scale']            # To sync with Mesh
        
        # 2. Load Raw Trunk Data (Mesh)
        # Assuming filename is the same, e.g., 'car_001.npz' -> 'car_001.vtk'
        vtk_filename = self.npz_files[idx].replace('.npz', '.vtk') 
        vtk_path = os.path.join(self.vtk_dir, vtk_filename)
        
        mesh = pv.read(vtk_path)
        raw_coords = mesh.points
        raw_pressure = mesh.point_data['p'] # Ensure key matches your VTK ('p' or 'pressure')

        # 3. Random Sampling for Trunk
        indices = np.random.choice(len(raw_coords), self.num_trunk_samples, replace=False)
        sampled_coords = raw_coords[indices]
        sampled_pressure = raw_pressure[indices]

        # 4. Sync Trunk Coordinates with Branch Geometry
        # We apply the EXACT same transform used for the Point Cloud
        trunk_coords = (sampled_coords - pc_centroid) / pc_scale

        # 5. Pressure Normalization (Toggleable)
        if self.normalize_pressure:
            target_pressure = (sampled_pressure - self.global_mean) / self.global_std
        else:
            target_pressure = sampled_pressure

        # 6. Convert to Tensors
        return {
            "branch_input": torch.FloatTensor(branch_points),      # (4096, 3)
            "trunk_input": torch.FloatTensor(trunk_coords),        # (2048, 3)
            "target": torch.FloatTensor(target_pressure).unsqueeze(-1) # (2048, 1)
        }

## Split the data and Dataloader

In [27]:
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

def get_dataloaders(npz_dir, vtk_dir, stats_path, batch_size=4, split_ratio=(0.8, 0.1, 0.1)):
    """
    Splits the dataset and returns Train, Val, and Test DataLoaders.
    """
    # 1. Initialize the full dataset
    full_dataset = DrivAerNetDeepONetDataset(
        npz_dir=npz_dir,
        vtk_dir=vtk_dir,
        global_stats_path=stats_path,
        normalize_pressure=True
    )

    # 2. Generate indices for splitting
    dataset_size = len(full_dataset)
    indices = list(range(dataset_size))
    
    train_size, val_size, test_size = split_ratio
    
    # Split: Train and (Val + Test)
    train_idx, tmp_idx = train_test_split(
        indices, train_size=train_size, random_state=42, shuffle=True
    )
    
    # Split: Val and Test
    relative_val_size = val_size / (val_size + test_size)
    val_idx, test_idx = train_test_split(
        tmp_idx, train_size=relative_val_size, random_state=42, shuffle=True
    )

    # 3. Create Subsets
    train_dataset = Subset(full_dataset, train_idx)
    val_dataset = Subset(full_dataset, val_idx)
    test_dataset = Subset(full_dataset, test_idx)

    # 4. Create DataLoaders
    # Note: num_workers > 0 allows CPU to load next batch while GPU processes current one
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader

In [3]:
# Configuration
NPZ_PATH = "02_dataset/car_pc_vtk"
VTK_PATH = "02_dataset/car_pressure_field/F_D_WM_WW_1"
STATS_FILE = "02_dataset/pressure_stats.npz"
BATCH_SIZE = 8 # Adjust based on your GPU VRAM

# Get Loaders
train_loader, val_loader, test_loader = get_dataloaders(
    NPZ_PATH, VTK_PATH, STATS_FILE, batch_size=BATCH_SIZE
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

Train batches: 50
Val batches: 7
Test batches: 7


## Model Architecture: 

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def index_points(points, idx):
    """Helper to index points based on FPS/Ball Query indices."""
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

def farthest_point_sample(xyz, npoint):
    """Iterative Farthest Point Sampling to pick cluster centers."""
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

def query_ball_point(radius, nsample, xyz, new_xyz):
    """Grouping points within a fixed radius."""
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat(B, S, 1)
    sqdist = torch.sum((new_xyz.view(B, S, 1, 3) - xyz.view(B, 1, N, 3)) ** 2, -1)
    group_idx[sqdist > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat(1, 1, nsample)
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

In [29]:
class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
            
    def forward(self, xyz, points):
        """
        xyz: input points position data, [B, 3, N]
        points: input points data, [B, D, N]
        """
        # Switch to [B, N, 3] for sampling/grouping logic
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        # 1. Sample and Group
        new_xyz_idx = farthest_point_sample(xyz, self.npoint)
        new_xyz = index_points(xyz, new_xyz_idx)
        idx = query_ball_point(self.radius, self.nsample, xyz, new_xyz)
        grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, 3]
        
        # 2. Make coordinates local to the cluster center
        grouped_xyz_norm = grouped_xyz - new_xyz.view(xyz.shape[0], self.npoint, 1, 3)

        # 3. Concatenate features
        if points is not None:
            grouped_points = index_points(points, idx)
            # Combine local relative coords with existing features
            new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, 3+D]
        else:
            # For SA1: Combine local relative coords with global absolute coords
            # This is where the "6 channels" come from (3 local + 3 global)
            new_points = torch.cat([grouped_xyz_norm, grouped_xyz], dim=-1) # [B, npoint, nsample, 6]

        # 4. Prepare for Conv2d: [B, Channels, nsample, npoint]
        new_points = new_points.permute(0, 3, 2, 1) 
        
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))

        # 5. Max Pool over the samples
        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        
        return new_xyz, new_points
    

### DeepONet Architecture 

In [30]:
class PointDeepONet(nn.Module):
    def __init__(self, latent_dim=128):
        super(PointDeepONet, self).__init__()
        
        # --- BRANCH NET (PointNet++) ---
        # SA1: 4096 -> 1024 points
        self.sa1 = PointNetSetAbstraction(npoint=1024, radius=0.1, nsample=32, in_channel=3+3, mlp=[64, 64, 128])
        # SA2: 1024 -> 256 points
        self.sa2 = PointNetSetAbstraction(npoint=256, radius=0.2, nsample=32, in_channel=128+3, mlp=[128, 128, 256])
        # SA3: 256 -> 64 points
        self.sa3 = PointNetSetAbstraction(npoint=64, radius=0.4, nsample=32, in_channel=256+3, mlp=[256, 512, 1024])
        
        self.branch_fc = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )

        # --- TRUNK NET (MLP) ---
        self.trunk_net = nn.Sequential(
            nn.Linear(3, 128),
            nn.SiLU(),
            nn.Linear(128, 256),
            nn.SiLU(),
            nn.Linear(256, 128),
            nn.SiLU(),
            nn.Linear(128, latent_dim)
        )
        
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, branch_xyz, trunk_xyz):
        """
        branch_xyz: (B, 3, 4096) -> The sampled point cloud
        trunk_xyz: (B, M, 3) -> The mesh query points
        """
        B = branch_xyz.shape[0]
        
        # 1. Branch Pass
        l1_xyz, l1_points = self.sa1(branch_xyz, None)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        
        # Global Pooling over the remaining 64 points
        branch_features = torch.max(l3_points, 2)[0] 
        branch_v = self.branch_fc(branch_features) # (B, latent_dim)

        # 2. Trunk Pass
        M = trunk_xyz.shape[1]
        trunk_v = self.trunk_net(trunk_xyz.view(-1, 3))
        trunk_v = trunk_v.view(B, M, -1) # (B, M, latent_dim)

        # 3. Dot Product Fusion
        # bg: Batch, Latent | bmg: Batch, Points, Latent -> bm: Batch, Points
        out = torch.einsum('bg,bmg->bm', branch_v, trunk_v)
        
        return out.unsqueeze(-1) + self.bias # (B, M, 1)

In [31]:
import torch

def model_sanity_check():
    # 1. Setup Parameters
    batch_size = 4
    num_branch_points = 4096
    num_trunk_query_points = 2048
    latent_dim = 128
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"--- Starting Sanity Check on {device} ---")

    # 2. Initialize Model
    model = PointDeepONet(latent_dim=latent_dim).to(device)
    model.train()

    # 3. Create Dummy Inputs (Mimicking our DataLoader)
    # Branch: (B, 3, 4096) - Note the 3 is the second dim for SA layers
    dummy_branch_pc = torch.randn(batch_size, 3, num_branch_points).to(device)
    # Trunk: (B, 2048, 3)
    dummy_trunk_coords = torch.randn(batch_size, num_trunk_query_points, 3).to(device)
    # Target: (B, 2048, 1)
    dummy_target = torch.randn(batch_size, num_trunk_query_points, 1).to(device)

    print(f"Input Shapes: Branch {dummy_branch_pc.shape}, Trunk {dummy_trunk_coords.shape}")

    # 4. Forward Pass
    try:
        output = model(dummy_branch_pc, dummy_trunk_coords)
        print(f"Output Shape: {output.shape}")
        
        # Check 1: Output Shape Correctness
        assert output.shape == (batch_size, num_trunk_query_points, 1), "❌ Output shape mismatch!"
        print("✅ Forward Pass: Shape Correct.")
    except Exception as e:
        print(f"❌ Forward Pass Failed: {e}")
        return

    # 5. Backward Pass (Gradient Flow Check)
    try:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        loss_fn = torch.nn.MSELoss()
        
        loss = loss_fn(output, dummy_target)
        loss.backward()
        
        # Check if gradients exist in both nets
        branch_grad = next(model.sa1.mlp_convs[0].parameters()).grad
        trunk_grad = next(model.trunk_net[0].parameters()).grad
        
        assert branch_grad is not None, "❌ No gradients in Branch Net!"
        assert trunk_grad is not None, "❌ No gradients in Trunk Net!"
        
        print("✅ Backward Pass: Gradient flow confirmed in both networks.")
    except Exception as e:
        print(f"❌ Backward Pass Failed: {e}")
        return

    # 6. Permutation Invariance Check
    # If we shuffle the 4096 points, the branch output (and thus final output) should be the same
    with torch.no_grad():
        model.eval()
        idx = torch.randperm(num_branch_points)
        shuffled_branch_pc = dummy_branch_pc[:, :, idx]
        
        original_out = model(dummy_branch_pc, dummy_trunk_coords)
        shuffled_out = model(shuffled_branch_pc, dummy_trunk_coords)
        
        # Check if the difference is negligible
        diff = torch.abs(original_out - shuffled_out).max().item()
        if diff < 1e-4:
            print(f"✅ Permutation Invariance: Confirmed (Max Diff: {diff:.2e})")
        else:
            print(f"⚠️ Permutation Invariance Warning: Max Diff: {diff:.2e}")

    print("\n--- All Sanity Checks Passed! ---")

if __name__ == "__main__":
    # Ensure all previous code (Utilities, SA Module, PointDeepONet) is in the same file or imported
    model_sanity_check()

--- Starting Sanity Check on cuda ---
Input Shapes: Branch torch.Size([4, 3, 4096]), Trunk torch.Size([4, 2048, 3])
Output Shape: torch.Size([4, 2048, 1])
✅ Forward Pass: Shape Correct.
✅ Backward Pass: Gradient flow confirmed in both networks.
✅ Permutation Invariance: Confirmed (Max Diff: 1.96e-05)

--- All Sanity Checks Passed! ---


## Training Loops

In [38]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import os
import time
from tqdm import tqdm

# Utility to log progress to both console and a text file
def log_message(message, log_file='logs/deepONet_03_not_normalized/train_log.txt'):
    print(message)
    with open(log_file, "a") as f:
        f.write(f"[{time.ctime()}] {message}\n")

In [39]:
# Configuration
CHECKPOINT_DIR = '03_checkpoints/deepONet_03_not_normalized'
LOG_FILE = 'logs/deepONet_03_not_normalized/train_log.txt'
LATEST_CHECKPOINT = os.path.join(CHECKPOINT_DIR, 'latest_checkpoint.pth')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def initialize_training(model, optimizer, checkpoint_path):
    start_epoch = 0
    best_val_loss = float('inf')
    
    if os.path.exists(checkpoint_path):
        log_message(f"Checking for checkpoint at {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        log_message(f"Successfully resumed from Epoch {start_epoch}")
    else:
        log_message("No checkpoint found. Starting training from scratch.")
        
    return start_epoch, best_val_loss

In [40]:
def run_epoch(model, loader, optimizer, criterion, device, is_training=True):
    if is_training:
        model.train()
    else:
        model.eval()
        
    total_loss = 0
    desc = "Train" if is_training else "Val"
    bar = tqdm(loader, desc=desc)

    for batch in bar:
        # Prepare inputs (Permute Branch for PointNet++ channels)
        branch_in = batch['branch_input'].permute(0, 2, 1).to(device)
        trunk_in = batch['trunk_input'].to(device)
        targets = batch['target'].to(device)

        with torch.set_grad_enabled(is_training):
            preds = model(branch_in, trunk_in)
            loss = criterion(preds, targets)

            if is_training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        total_loss += loss.item()
        bar.set_postfix(mse=loss.item())

    return total_loss / len(loader)

In [41]:
# Initialize Model, Optimizer, and Scheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PointDeepONet(latent_dim=256).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
criterion = nn.MSELoss()

# Get state
start_epoch, best_val_loss = initialize_training(model, optimizer, LATEST_CHECKPOINT)

# Main Loop
NUM_EPOCHS = 100
for epoch in range(start_epoch, NUM_EPOCHS):
    # 1. Train and Validate
    avg_train_loss = run_epoch(model, train_loader, optimizer, criterion, device, is_training=True)
    avg_val_loss = run_epoch(model, val_loader, optimizer, criterion, device, is_training=False)
    
    scheduler.step(avg_val_loss)
    
    # 2. Log Results
    log_msg = f"Epoch {epoch+1}: Train {avg_train_loss:.6f} | Val {avg_val_loss:.6f} | LR {optimizer.param_groups[0]['lr']:.2e}"
    log_message(log_msg, LOG_FILE)

    # 3. Save "Latest" Checkpoint (Every Epoch for safety)
    checkpoint_state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss,
    }
    torch.save(checkpoint_state, LATEST_CHECKPOINT)

    # 4. Save Periodic Checkpoint (Every 10 Epochs)
    if (epoch + 1) % 50 == 0:
        periodic_path = os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save(checkpoint_state, periodic_path)
        log_message(f"System: 10-Epoch Periodic Checkpoint Saved.")

    # 5. Save "Best" Model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_model.pth'))
        log_message("System: Best Model updated.")

# Final Save after all epochs
torch.save(model.state_dict(), "final_point_deeponet.pth")
log_message("Training Complete. Final model saved.")

No checkpoint found. Starting training from scratch.


Train:   0%|          | 0/50 [00:00<?, ?it/s]

Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.993]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.876]


Epoch 1: Train 8.996874 | Val 0.986015 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.918]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.905]


Epoch 2: Train 0.953619 | Val 0.928972 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.929]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.87] 


Epoch 3: Train 0.918172 | Val 0.904115 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.891]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.867]


Epoch 4: Train 0.891813 | Val 0.882201 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.851]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.82s/it, mse=0.83] 


Epoch 5: Train 0.875503 | Val 0.862747 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, mse=0.862]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.788]


Epoch 6: Train 0.856197 | Val 0.835485 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.828]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.738]


Epoch 7: Train 0.890080 | Val 0.837254 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.792]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.734]


Epoch 8: Train 0.819450 | Val 0.810749 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:50<00:00,  1.00s/it, mse=0.809]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.741]


Epoch 9: Train 0.780515 | Val 0.819619 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.788]
Val: 100%|██████████| 7/7 [00:13<00:00,  1.86s/it, mse=0.733]


Epoch 10: Train 0.811559 | Val 0.766092 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.722]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.649]


Epoch 11: Train 0.753347 | Val 0.740618 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.714]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.664]


Epoch 12: Train 0.743871 | Val 0.727699 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.737]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.686]


Epoch 13: Train 0.745075 | Val 0.718447 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.693]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.695]


Epoch 14: Train 0.724842 | Val 0.731485 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.744]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.678]


Epoch 15: Train 0.730750 | Val 0.738896 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.755]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.716]


Epoch 16: Train 0.721655 | Val 0.738542 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.944]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.682]


Epoch 17: Train 0.720398 | Val 0.700993 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.713]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.645]


Epoch 18: Train 0.711639 | Val 0.708048 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.695]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.85s/it, mse=0.664]


Epoch 19: Train 0.696576 | Val 0.701563 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.65] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.81s/it, mse=0.63] 


Epoch 20: Train 0.689007 | Val 0.695669 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.691]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.636]


Epoch 21: Train 0.697524 | Val 0.694683 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.633]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.644]


Epoch 22: Train 0.672189 | Val 0.692178 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.693]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.652]


Epoch 23: Train 0.684514 | Val 0.687056 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.736]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.613]


Epoch 24: Train 0.677194 | Val 0.665272 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, mse=0.658]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.583]


Epoch 25: Train 0.655818 | Val 0.649028 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.654]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.82s/it, mse=0.537]


Epoch 26: Train 0.647751 | Val 0.621878 | LR 1.00e-03
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it, mse=0.66] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.83s/it, mse=0.622]


Epoch 27: Train 0.640581 | Val 0.660213 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:51<00:00,  1.04s/it, mse=0.626]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.606]


Epoch 28: Train 0.630695 | Val 0.629014 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.646]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.634]


Epoch 29: Train 0.648768 | Val 0.651959 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.624]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.579]


Epoch 30: Train 0.632761 | Val 0.622234 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:49<00:00,  1.00it/s, mse=0.614]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.604]


Epoch 31: Train 0.616800 | Val 0.622568 | LR 1.00e-03


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.589]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.657]


Epoch 32: Train 0.624168 | Val 0.657315 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.62] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.78s/it, mse=0.577]


Epoch 33: Train 0.633051 | Val 0.618173 | LR 5.00e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.612]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.556]


Epoch 34: Train 0.614945 | Val 0.592830 | LR 5.00e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.57] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.85s/it, mse=0.513]


Epoch 35: Train 0.608651 | Val 0.591847 | LR 5.00e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.576]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.80s/it, mse=0.647]


Epoch 36: Train 0.613777 | Val 0.613670 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.619]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.608]


Epoch 37: Train 0.601845 | Val 0.603868 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.604]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.551]


Epoch 38: Train 0.606860 | Val 0.590269 | LR 5.00e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.594]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.54] 


Epoch 39: Train 0.613260 | Val 0.590655 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.599]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.606]


Epoch 40: Train 0.595897 | Val 0.597316 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.573]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.84s/it, mse=0.6]  


Epoch 41: Train 0.600016 | Val 0.606246 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.572]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.549]


Epoch 42: Train 0.598003 | Val 0.590732 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.551]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.563]


Epoch 43: Train 0.595637 | Val 0.599046 | LR 5.00e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.573]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.55] 


Epoch 44: Train 0.590508 | Val 0.590427 | LR 2.50e-04


Train: 100%|██████████| 50/50 [00:48<00:00,  1.04it/s, mse=0.583]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.624]


Epoch 45: Train 0.589068 | Val 0.597183 | LR 2.50e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.599]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.517]


Epoch 46: Train 0.600671 | Val 0.567949 | LR 2.50e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.568]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.548]


Epoch 47: Train 0.593540 | Val 0.580949 | LR 2.50e-04


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.608]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.564]


Epoch 48: Train 0.584052 | Val 0.581486 | LR 2.50e-04


Train: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, mse=0.566]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.611]


Epoch 49: Train 0.582484 | Val 0.587523 | LR 2.50e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.613]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.553]


Epoch 50: Train 0.577074 | Val 0.579551 | LR 2.50e-04
System: 10-Epoch Periodic Checkpoint Saved.


Train: 100%|██████████| 50/50 [00:50<00:00,  1.02s/it, mse=0.53] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.565]


Epoch 51: Train 0.574022 | Val 0.579881 | LR 2.50e-04


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.591]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.549]


Epoch 52: Train 0.575099 | Val 0.592815 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.567]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.581]


Epoch 53: Train 0.577680 | Val 0.576882 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.566]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.534]


Epoch 54: Train 0.575253 | Val 0.563757 | LR 1.25e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.575]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.568]


Epoch 55: Train 0.574096 | Val 0.575149 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, mse=0.548]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.583]


Epoch 56: Train 0.570441 | Val 0.571284 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.563]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.568]


Epoch 57: Train 0.571860 | Val 0.579755 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.555]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.79s/it, mse=0.531]


Epoch 58: Train 0.573088 | Val 0.565052 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.00it/s, mse=0.569]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.84s/it, mse=0.521]


Epoch 59: Train 0.566322 | Val 0.564190 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.00it/s, mse=0.556]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.542]


Epoch 60: Train 0.565907 | Val 0.555628 | LR 1.25e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.566]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.633]


Epoch 61: Train 0.572634 | Val 0.604843 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.562]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.637]


Epoch 62: Train 0.567773 | Val 0.576886 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.585]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.56] 


Epoch 63: Train 0.561403 | Val 0.568596 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:51<00:00,  1.04s/it, mse=0.583]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.578]


Epoch 64: Train 0.565123 | Val 0.568289 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.552]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.82s/it, mse=0.52] 


Epoch 65: Train 0.574481 | Val 0.553127 | LR 1.25e-04
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, mse=0.528]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.84s/it, mse=0.514]


Epoch 66: Train 0.562611 | Val 0.564228 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.00it/s, mse=0.574]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.83s/it, mse=0.499]


Epoch 67: Train 0.558535 | Val 0.558391 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.526]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.503]


Epoch 68: Train 0.556574 | Val 0.570761 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.562]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.514]


Epoch 69: Train 0.570189 | Val 0.558478 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:50<00:00,  1.00s/it, mse=0.547]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.521]


Epoch 70: Train 0.555550 | Val 0.562298 | LR 1.25e-04


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.557]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.83s/it, mse=0.581]


Epoch 71: Train 0.549735 | Val 0.567011 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.55] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.84s/it, mse=0.517]


Epoch 72: Train 0.551693 | Val 0.549307 | LR 6.25e-05
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.58] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.529]


Epoch 73: Train 0.560999 | Val 0.565347 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.537]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.84s/it, mse=0.522]


Epoch 74: Train 0.553264 | Val 0.550945 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.529]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.535]


Epoch 75: Train 0.557621 | Val 0.547094 | LR 6.25e-05
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.542]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.511]


Epoch 76: Train 0.552310 | Val 0.552901 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.573]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.78s/it, mse=0.529]


Epoch 77: Train 0.551961 | Val 0.549644 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.541]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.528]


Epoch 78: Train 0.553740 | Val 0.582191 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.533]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.562]


Epoch 79: Train 0.557437 | Val 0.547991 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.517]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.53] 


Epoch 80: Train 0.550582 | Val 0.545796 | LR 6.25e-05
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it, mse=0.506]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.85s/it, mse=0.521]


Epoch 81: Train 0.556645 | Val 0.553456 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, mse=0.527]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.58] 


Epoch 82: Train 0.552623 | Val 0.553970 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.534]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.562]


Epoch 83: Train 0.554173 | Val 0.562749 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.53] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.498]


Epoch 84: Train 0.545819 | Val 0.554585 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.566]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.553]


Epoch 85: Train 0.545905 | Val 0.545133 | LR 6.25e-05
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.518]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.508]


Epoch 86: Train 0.547404 | Val 0.548114 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.535]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.589]


Epoch 87: Train 0.545162 | Val 0.568415 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:50<00:00,  1.01s/it, mse=0.553]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.481]


Epoch 88: Train 0.550530 | Val 0.537140 | LR 6.25e-05
System: Best Model updated.


Train: 100%|██████████| 50/50 [00:48<00:00,  1.03it/s, mse=0.548]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.529]


Epoch 89: Train 0.560065 | Val 0.549147 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.551]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.85s/it, mse=0.535]


Epoch 90: Train 0.540512 | Val 0.541663 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.553]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.78s/it, mse=0.531]


Epoch 91: Train 0.547253 | Val 0.543681 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.534]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.519]


Epoch 92: Train 0.547161 | Val 0.541004 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.514]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.78s/it, mse=0.546]


Epoch 93: Train 0.540965 | Val 0.548010 | LR 6.25e-05


Train: 100%|██████████| 50/50 [00:48<00:00,  1.02it/s, mse=0.558]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.568]


Epoch 94: Train 0.553375 | Val 0.560217 | LR 3.13e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.527]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.543]


Epoch 95: Train 0.557884 | Val 0.551792 | LR 3.13e-05


Train: 100%|██████████| 50/50 [00:51<00:00,  1.03s/it, mse=0.521]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, mse=0.542]


Epoch 96: Train 0.544411 | Val 0.540419 | LR 3.13e-05


Train: 100%|██████████| 50/50 [00:51<00:00,  1.02s/it, mse=0.525]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.78s/it, mse=0.517]


Epoch 97: Train 0.556600 | Val 0.542374 | LR 3.13e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.51] 
Val: 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, mse=0.512]


Epoch 98: Train 0.539373 | Val 0.540454 | LR 3.13e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.01it/s, mse=0.576]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, mse=0.557]


Epoch 99: Train 0.543801 | Val 0.540869 | LR 3.13e-05


Train: 100%|██████████| 50/50 [00:49<00:00,  1.02it/s, mse=0.562]
Val: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it, mse=0.544]

Epoch 100: Train 0.552631 | Val 0.543192 | LR 1.56e-05
System: 10-Epoch Periodic Checkpoint Saved.
Training Complete. Final model saved.





## Inference

In [42]:
import torch
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Re-initialize the model architecture
model = PointDeepONet(latent_dim=256).to(device)

# 2. Load the trained weights
MODEL_PATH = "03_checkpoints/deepONet_03/best_model.pth" # or "final_point_deeponet.pth"
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

# 3. Load Global Stats for Denormalization
stats = np.load("pressure_stats.npz")
global_mean = stats['mean']
global_std = stats['std']

print(f"Model loaded from {MODEL_PATH}")
print(f"Global Stats - Mean: {global_mean:.2f}, Std: {global_std:.2f}")

Model loaded from 03_checkpoints/deepONet_03/best_model.pth
Global Stats - Mean: -93.65, Std: 114.89


In [43]:
def predict_pressure(model, batch, device, mean, std):
    """
    Takes a batch from the DataLoader and returns denormalized predictions and targets.
    """
    # 1. Get inputs and move to device
    # (B, 4096, 3) -> (B, 3, 4096) for PointNet++
    branch_in = batch['branch_input'].permute(0, 2, 1).to(device)
    trunk_in = batch['trunk_input'].to(device)
    targets_norm = batch['target'].to(device)

    # 2. Model Prediction (Output is normalized)
    with torch.no_grad():
        preds_norm = model(branch_in, trunk_in)

    # 3. Denormalize: P_real = (P_norm * std) + mean
    preds_real = (preds_norm.cpu().numpy() * std) + mean
    targets_real = (targets_norm.cpu().numpy() * std) + mean

    return preds_real, targets_real

In [44]:
# Pick a batch from the test loader
test_batch = next(iter(test_loader))

# Run Inference
preds, targets = predict_pressure(model, test_batch, device, global_mean, global_std)

# Calculate Errors (MAE and Max Error in Pascals)
abs_error = np.abs(preds - targets)
mae_pascals = np.mean(abs_error)
max_error_pascals = np.max(abs_error)

print(f"--- Inference Results (Batch Size: {preds.shape[0]}) ---")
print(f"Mean Absolute Error: {mae_pascals:.2f} Pa")
print(f"Max Error in Batch: {max_error_pascals:.2f} Pa")

# Check a single point comparison
sample_idx = 0
point_idx = 0
print(f"\nSample 0, Point 0 Comparison:")
print(f"Predicted: {preds[sample_idx, point_idx, 0]:.2f} Pa")
print(f"Ground Truth: {targets[sample_idx, point_idx, 0]:.2f} Pa")

--- Inference Results (Batch Size: 8) ---
Mean Absolute Error: 47.09 Pa
Max Error in Batch: 1053.46 Pa

Sample 0, Point 0 Comparison:
Predicted: -200.80 Pa
Ground Truth: -224.77 Pa


--- Inference Results (Batch Size: 8) ---

Mean Absolute Error: 52.13 Pa

Max Error in Batch: 842.20 Pa



Sample 0, Point 0 Comparison:

Predicted: -111.99 Pa

Ground Truth: -44.50 Pa

--- Inference Results (Batch Size: 8) Second time ---

Mean Absolute Error: 48.44 Pa

Max Error in Batch: 1136.48 Pa


Sample 0, Point 0 Comparison:
Predicted: -84.42 Pa

Ground Truth: -79.92 Pa

In [45]:
# Pick a batch from the test loader
test_batch = next(iter(test_loader))

# Run Inference
preds, targets = predict_pressure(model, test_batch, device, global_mean, global_std)

# Calculate Errors (MAE and Max Error in Pascals)
abs_error = np.abs(preds - targets)
mae_pascals = np.mean(abs_error)
max_error_pascals = np.max(abs_error)

print(f"--- Inference Results (Batch Size: {preds.shape[0]}) ---")
print(f"Mean Absolute Error: {mae_pascals:.2f} Pa")
print(f"Max Error in Batch: {max_error_pascals:.2f} Pa")

# Check a single point comparison
sample_idx = 0
point_idx = 0
print(f"\nSample 0, Point 0 Comparison:")
print(f"Predicted: {preds[sample_idx, point_idx, 0]:.2f} Pa")
print(f"Ground Truth: {targets[sample_idx, point_idx, 0]:.2f} Pa")

--- Inference Results (Batch Size: 8) ---
Mean Absolute Error: 48.18 Pa
Max Error in Batch: 1032.96 Pa

Sample 0, Point 0 Comparison:
Predicted: -192.34 Pa
Ground Truth: -119.14 Pa


--- Inference Results (Batch Size: 8) ---

Mean Absolute Error: 52.12 Pa

Max Error in Batch: 910.79 Pa

Sample 0, Point 0 Comparison:

Predicted: -131.98 Pa

Ground Truth: -64.42 Pa

In [46]:
import pyvista as pv

def save_prediction_vtk(coords, preds, targets, filename="prediction_output_deNorm_epoch.vtp"):
    """
    Saves the predicted points as a VTK file for ParaView.
    coords: (M, 3), preds: (M, 1), targets: (M, 1)
    """
    # Create a point cloud mesh
    point_cloud = pv.PolyData(coords)
    point_cloud["Predicted_Pressure"] = preds.flatten()
    point_cloud["Actual_Pressure"] = targets.flatten()
    point_cloud["Error"] = np.abs(preds - targets).flatten()
    
    point_cloud.save(filename)
    print(f"Visualization saved to {filename}")

# Save the first car in the batch
# Note: coords must be the original raw coordinates or the synced normalized ones
save_prediction_vtk(test_batch['trunk_input'][0].numpy(), preds[0], targets[0])

Visualization saved to prediction_output_deNorm_epoch.vtp


In [40]:
import pyvista as pv

def save_prediction_vtk(coords, preds, targets, filename="prediction_output_200_epoch_2.vtp"):
    """
    Saves the predicted points as a VTK file for ParaView.
    coords: (M, 3), preds: (M, 1), targets: (M, 1)
    """
    # Create a point cloud mesh
    point_cloud = pv.PolyData(coords)
    point_cloud["Predicted_Pressure"] = preds.flatten()
    point_cloud["Actual_Pressure"] = targets.flatten()
    point_cloud["Error"] = np.abs(preds - targets).flatten()
    
    point_cloud.save(filename)
    print(f"Visualization saved to {filename}")

# Save the first car in the batch
# Note: coords must be the original raw coordinates or the synced normalized ones
save_prediction_vtk(test_batch['trunk_input'][0].numpy(), preds[0], targets[0])

Visualization saved to prediction_output_200_epoch_2.vtp
