# PIBERT Ablation Study

This notebook replicates the ablation study from the paper using a GTX 3060 GPU.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pibert import PIBERT
from pibert.utils import load_dataset, plot_results
from tqdm import tqdm

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

In [None]:
# Load CFDBench dataset
dataset = load_dataset("cylinder_wake")
train_data = dataset["train"]
test_data = dataset["test"]

print(f"Training samples: {len(train_data['x'])}")
print(f"Test samples: {len(test_data['x'])}")

In [None]:
# Define model variants for ablation study
variants = {
    "PIBERT-full": {
        "fourier_features": True,
        "wavelet_features": True,
        "physics_attention": True
    },
    "Fourier-only": {
        "fourier_features": True,
        "wavelet_features": False,
        "physics_attention": True
    },
    "Wavelet-only": {
        "fourier_features": False,
        "wavelet_features": True,
        "physics_attention": True
    },
    "Standard-attention": {
        "fourier_features": True,
        "wavelet_features": True,
        "physics_attention": False
    }
}

In [None]:
# Train and evaluate each variant
results = {}
for name, config in tqdm(variants.items()):
    print(f"
Training {name}...")
    
    # Initialize model with variant configuration
    model = PIBERT(
        input_dim=2,  # u, v components
        hidden_dim=128,
        num_layers=4,
        num_heads=8,
        **config
    ).to(device)
    
    # Train model (simplified training loop)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    # Move data to device
    x_train = train_data["x"].to(device)
    coords_train = train_data["coords"].to(device)
    
    # Training loop
    model.train()
    for epoch in range(50):  # Reduced for demonstration
        optimizer.zero_grad()
        pred = model.predict(x_train, coords_train)
        loss = torch.mean((pred - x_train) ** 2)
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/50, Loss: {loss.item():.6f}")
    
    # Evaluate on test set
    model.eval()
    with torch.no_grad():
        x_test = test_data["x"].to(device)
        coords_test = test_data["coords"].to(device)
        pred_test = model.predict(x_test, coords_test)
        test_mse = torch.mean((pred_test - x_test) ** 2).item()
        
        # Calculate NMSE
        test_nmse = test_mse / torch.mean(x_test ** 2).item()
    
    results[name] = {
        "test_mse": test_mse,
        "test_nmse": test_nmse,
        "model": model
    }
    
    print(f"{name} - Test MSE: {test_mse:.6f}, NMSE: {test_nmse:.6f}")

In [None]:
# Visualize results
plt.figure(figsize=(10, 6))
variants_list = list(results.keys())
mse_values = [results[v]["test_mse"] for v in variants_list]
nmse_values = [results[v]["test_nmse"] for v in variants_list]

x = np.arange(len(variants_list))
width = 0.35

plt.bar(x - width/2, mse_values, width, label='MSE')
plt.bar(x + width/2, nmse_values, width, label='NMSE')

plt.xlabel('Model Variant')
plt.ylabel('Error')
plt.title('Ablation Study Results')
plt.xticks(x, variants_list, rotation=15)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
# Visualize predictions for sample 0
sample_idx = 0
component = 0  # u-velocity

plt.figure(figsize=(15, 10))
for i, (name, res) in enumerate(results.items()):
    model = res["model"]
    x_test = test_data["x"].to(device)
    coords_test = test_data["coords"].to(device)
    
    with torch.no_grad():
        pred = model.predict(x_test[sample_idx:sample_idx+1], 
                            coords_test[sample_idx:sample_idx+1])
    
    # Plot true vs predicted
    plt.subplot(2, 2, i+1)
    plt.imshow(pred[0, :, :, component].cpu(), cmap='viridis')
    plt.title(f'{name} Prediction')
    plt.colorbar()
    
    # Plot error
    plt.subplot(2, 2, i+3)
    error = pred[0, :, :, component].cpu() - x_test[sample_idx, :, :, component].cpu()
    plt.imshow(error, cmap='coolwarm', vmin=-0.5, vmax=0.5)
    plt.title(f'{name} Error')
    plt.colorbar()

plt.tight_layout()
plt.suptitle('PIBERT Ablation Study: Velocity Field Predictions', fontsize=16)
plt.show()