In [1]:
import torch
import torch.nn as nn
import xarray as xr
import numpy as np
from typing import Tuple, Dict
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from pathlib import Path
import time

class AdvectionDiffusionPINN(nn.Module):
    """
    Physics-Informed Neural Network for solving the 2D advection-diffusion equation:
    ∂C/∂t = D(∂²C/∂x² + ∂²C/∂y²) - u∂C/∂x - v∂C/∂y + S(x,y,t)
    """
    def __init__(self, D: float, hidden_layers: int = 10, neurons_per_layer: int = 64):
        super().__init__()
        self.D = D
        
        # Network architecture
        layers = []
        # Input layer (x, y, t) -> first hidden layer
        layers.append(nn.Linear(3, neurons_per_layer))
        layers.append(nn.Tanh())
        
        # Hidden layers
        for _ in range(hidden_layers - 1):
            layers.append(nn.Linear(neurons_per_layer, neurons_per_layer))
            layers.append(nn.Tanh())
            
        # Output layer (concentration C)
        layers.append(nn.Linear(neurons_per_layer, 1))
        
        self.network = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor, y: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network."""
        inputs = torch.stack([x, y, t], dim=-1)
        return self.network(inputs)
    
    def compute_derivatives(self, x: torch.Tensor, y: torch.Tensor, t: torch.Tensor,
                          u: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """Compute the derivatives needed for the PDE residual."""
        # Create variables requiring gradient
        x_var = x.clone().detach().requires_grad_(True)
        y_var = y.clone().detach().requires_grad_(True)
        t_var = t.clone().detach().requires_grad_(True)
        
        # Forward pass
        C = self.forward(x_var, y_var, t_var)
        
        # First derivatives
        dC_dt = torch.autograd.grad(C.sum(), t_var, create_graph=True)[0]
        dC_dx = torch.autograd.grad(C.sum(), x_var, create_graph=True)[0]
        dC_dy = torch.autograd.grad(C.sum(), y_var, create_graph=True)[0]
        
        # Second derivatives
        d2C_dx2 = torch.autograd.grad(dC_dx.sum(), x_var, create_graph=True)[0]
        d2C_dy2 = torch.autograd.grad(dC_dy.sum(), y_var, create_graph=True)[0]
        
        return dC_dt, dC_dx, dC_dy, d2C_dx2, d2C_dy2
    
    def pde_residual(self, x: torch.Tensor, y: torch.Tensor, t: torch.Tensor,
                    u: torch.Tensor, v: torch.Tensor, S: torch.Tensor) -> torch.Tensor:
        """
        Compute the PDE residual:
        R = ∂C/∂t - D(∂²C/∂x² + ∂²C/∂y²) + u∂C/∂x + v∂C/∂y - S
        """
        dC_dt, dC_dx, dC_dy, d2C_dx2, d2C_dy2 = self.compute_derivatives(x, y, t, u, v)
        
        # Compute diffusion term
        diffusion_term = self.D * (d2C_dx2 + d2C_dy2)
        
        # Compute advection term
        advection_term = u * dC_dx + v * dC_dy
        
        # Full residual
        residual = dC_dt - diffusion_term + advection_term - S
        
        return residual, dC_dt, dC_dx, dC_dy, d2C_dx2, d2C_dy2

def prepare_temporal_split_data(ds: xr.Dataset, train_days: int = 70, 
                           spatial_subsample: int = 4,
                           temporal_subsample: int = 4) -> Dict[str, Dict[str, torch.Tensor]]:
    """
    Prepare data with temporal train-test split.
    First 70 days for training, last 30 days for testing.
    """
    # Calculate split point in timesteps (24 timesteps per day)
    hours_per_day = 24
    train_timesteps = train_days * hours_per_day
    
    # Subsample coordinates first
    x_coords = ds.xC.values[::spatial_subsample]
    y_coords = ds.yC.values[::spatial_subsample]
    
    # Handle temporal subsampling and splitting
    all_timesteps = np.arange(len(ds.time))
    subsampled_timesteps = all_timesteps[::temporal_subsample]
    t_coords = subsampled_timesteps * 3600  # Convert to seconds
    
    # Create temporal mask based on subsampled timesteps
    time_train_mask = subsampled_timesteps <= train_timesteps
    time_test_mask = ~time_train_mask
    
    # Print information about subsampling
    print("\nSubsampled grid sizes:")
    print(f"X points: {len(x_coords)} (original: {len(ds.xC.values)})")
    print(f"Y points: {len(y_coords)} (original: {len(ds.yC.values)})")
    print(f"Time points: {len(t_coords)} (original: {len(ds.time)})")
    
    # Get the data arrays
    u_array = ds.u.values[::temporal_subsample, :, ::spatial_subsample, ::spatial_subsample]
    v_array = ds.v.values[::temporal_subsample, :, ::spatial_subsample, ::spatial_subsample]
    c_array = ds.c.values[::temporal_subsample, :, ::spatial_subsample, ::spatial_subsample]
    
    # Create coordinate grids for training data
    train_times = t_coords[time_train_mask]
    test_times = t_coords[time_test_mask]
    
    # Create meshgrids
    X_train, Y_train, T_train = np.meshgrid(x_coords, y_coords, train_times, indexing='ij')
    X_test, Y_test, T_test = np.meshgrid(x_coords, y_coords, test_times, indexing='ij')
    
    # Print debug information
    print("\nData shapes:")
    print(f"Original u shape: {ds.u.values.shape}")
    print(f"Subsampled u shape: {u_array.shape}")
    print(f"Time mask shape: {time_train_mask.shape}")
    print(f"X_train shape: {X_train.shape}")
    print(f"Training data shape: {u_array[time_train_mask].shape}")
    print(f"Number of training timesteps: {sum(time_train_mask)}")
    print(f"Number of test timesteps: {sum(time_test_mask)}")
    
    # Reshape velocity data for training
    u_train = u_array[time_train_mask].reshape(-1)
    v_train = v_array[time_train_mask].reshape(-1)
    c_train = c_array[time_train_mask].reshape(-1)
    
    # Reshape velocity data for testing
    u_test = u_array[time_test_mask].reshape(-1)
    v_test = v_array[time_test_mask].reshape(-1)
    c_test = c_array[time_test_mask].reshape(-1)
    
    # Create training tensors
    train_data = {
        'x': torch.tensor(X_train.flatten(), dtype=torch.float32),
        'y': torch.tensor(Y_train.flatten(), dtype=torch.float32),
        't': torch.tensor(T_train.flatten(), dtype=torch.float32),
        'u': torch.tensor(u_train, dtype=torch.float32),
        'v': torch.tensor(v_train, dtype=torch.float32),
        'c': torch.tensor(c_train, dtype=torch.float32),
    }
    
    # Create test tensors
    test_data = {
        'x': torch.tensor(X_test.flatten(), dtype=torch.float32),
        'y': torch.tensor(Y_test.flatten(), dtype=torch.float32),
        't': torch.tensor(T_test.flatten(), dtype=torch.float32),
        'u': torch.tensor(u_test, dtype=torch.float32),
        'v': torch.tensor(v_test, dtype=torch.float32),
        'c': torch.tensor(c_test, dtype=torch.float32),
    }
    
    # Handle source term if present
    if 'S' in ds:
        s_array = ds.S.values[::temporal_subsample, :, ::spatial_subsample, ::spatial_subsample]
        train_data['S'] = torch.tensor(s_array[time_train_mask].reshape(-1), dtype=torch.float32)
        test_data['S'] = torch.tensor(s_array[time_test_mask].reshape(-1), dtype=torch.float32)
    else:
        train_data['S'] = torch.zeros_like(train_data['x'])
        test_data['S'] = torch.zeros_like(test_data['x'])
    
    # Print final dataset sizes
    print("\nFinal dataset sizes:")
    print(f"Training samples: {len(train_data['x'])}")
    print(f"Test samples: {len(test_data['x'])}")
    
    return {'train': train_data, 'test': test_data}

def train_pinn(model: AdvectionDiffusionPINN, 
               data: Dict[str, torch.Tensor],
               num_epochs: int = 100, 
               learning_rate: float = 0.01,
               batch_size: int = 1000,
               save_dir: str = "models") -> Dict[str, list]:
    """
    Train the PINN model using mini-batches.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    n_samples = len(data['x'])
    history = {'loss': [], 'pde_loss': [], 'data_loss': []}
    
    # Create save directory if it doesn't exist
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Starting training with {n_samples} samples...")
    start_time = time.time()

    # Weighted loss
    pde_weight = 1.0
    data_weight = 

    for epoch in range(num_epochs):
        # Random permutation for batching
        perm = torch.randperm(n_samples)
        
        running_loss = 0.0
        running_pde_loss = 0.0
        running_data_loss = 0.0
        
        for i in range(0, n_samples, batch_size):
            optimizer.zero_grad()
            
            # Get batch indices
            idx = perm[i:i + batch_size]
            
            # Get batch data
            x_batch = data['x'][idx]
            y_batch = data['y'][idx]
            t_batch = data['t'][idx]
            u_batch = data['u'][idx]
            v_batch = data['v'][idx]
            c_batch = data['c'][idx]
            S_batch = data['S'][idx]
            
            # Compute PDE residual
            residual, dC_dt, dC_dx, dC_dy, d2C_dx2, d2C_dy2 = model.pde_residual(
                x_batch, y_batch, t_batch, u_batch, v_batch, S_batch)

            pde_loss = torch.mean(residual**2)
            
            # Compute predicted concentration
            c_pred = model(x_batch, y_batch, t_batch)
            data_loss = torch.mean((c_pred - c_batch)**2)

            # Combined loss without weights
            #loss = pde_loss + data_loss

            # Combined loss with weights
            loss = pde_weight * pde_loss + data_weight * data_loss
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            running_pde_loss += pde_loss.item()
            running_data_loss += data_loss.item()

            #if i % 1000 == 0:
             #   print(f"Batch i={i}, PDE={pde_loss.item()}, Data={data_loss.item()}") 
            
        # Record losses
        avg_loss = running_loss / (n_samples / batch_size)
        avg_pde_loss = running_pde_loss / (n_samples / batch_size)
        avg_data_loss = running_data_loss / (n_samples / batch_size)
        
        history['loss'].append(avg_loss)
        history['pde_loss'].append(avg_pde_loss)
        history['data_loss'].append(avg_data_loss)
        
        if epoch % 10 == 0:
            elapsed = time.time() - start_time
            print(f'-----Epoch {epoch}, Loss: {avg_loss:.6f}, PDE Loss: {avg_pde_loss:.6f}, '
                  f'Data Loss: {avg_data_loss:.6f}, Time: {elapsed:.2f}s')
            
            # Save model checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, save_dir / f'pinn_checkpoint_epoch_{epoch}.pt')
    
    return history

def evaluate_model(model: nn.Module, data: Dict[str, torch.Tensor], 
                  batch_size: int = 1000) -> Dict[str, float]:
    """
    Evaluate model performance on given dataset.
    """
    model.eval()
    predictions = []
    actuals = []
    total_pde_loss = 0.0
    n_batches = 0
    
    # First collect predictions without gradients
    with torch.no_grad():
        for i in range(0, len(data['x']), batch_size):
            # Get batch
            x_batch = data['x'][i:i + batch_size]
            y_batch = data['y'][i:i + batch_size]
            t_batch = data['t'][i:i + batch_size]
            
            # Get predictions
            pred = model(x_batch, y_batch, t_batch)
            predictions.append(pred)
            actuals.append(data['c'][i:i + batch_size])
    
    # Then compute PDE residuals with gradients enabled
    for i in range(0, len(data['x']), batch_size):
        # Get batch
        x_batch = data['x'][i:i + batch_size]
        y_batch = data['y'][i:i + batch_size]
        t_batch = data['t'][i:i + batch_size]
        u_batch = data['u'][i:i + batch_size]
        v_batch = data['v'][i:i + batch_size]
        S_batch = data['S'][i:i + batch_size]
        
        # Compute PDE residual with gradients enabled
        residual, *_ = model.pde_residual(x_batch, y_batch, t_batch, 
                                    u_batch, v_batch, S_batch)
        total_pde_loss += torch.mean(residual**2).item()
        n_batches += 1
    
    # Concatenate predictions and actual values
    predictions = torch.cat(predictions, dim=0).numpy()
    actuals = torch.cat(actuals, dim=0).numpy()
    
    # Calculate metrics
    mse = mean_squared_error(actuals, predictions)
    rmse = np.sqrt(mse)
    r2 = r2_score(actuals, predictions)
    avg_pde_loss = total_pde_loss / n_batches
    
    return {
        'mse': mse,
        'rmse': rmse,
        'r2': r2,
        'pde_loss': avg_pde_loss
    }

def plot_results(model: nn.Module, data: Dict[str, torch.Tensor], 
                ds: xr.Dataset, timestep: int, save_dir: str = "figures") -> None:
    """
    Plot actual vs predicted concentration at a specific timestep.
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    model.eval()
    
    # Get spatial coordinates
    x_coords = ds.xC.values
    y_coords = ds.yC.values
    
    # Create meshgrid for single timestep
    X, Y = np.meshgrid(x_coords, y_coords)
    T = np.full_like(X, timestep * 3600)  # Convert to seconds
    
    # Get predictions
    with torch.no_grad():
        x_tensor = torch.tensor(X.flatten(), dtype=torch.float32)
        y_tensor = torch.tensor(Y.flatten(), dtype=torch.float32)
        t_tensor = torch.tensor(T.flatten(), dtype=torch.float32)
        predictions = model(x_tensor, y_tensor, t_tensor)
    
    # Reshape predictions
    pred_grid = predictions.numpy().reshape(X.shape)
    actual_grid = np.squeeze(ds.c.isel(time=timestep).values)

    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot actual concentration
    im1 = ax1.pcolormesh(X, Y, actual_grid, shading='auto')
    ax1.set_title('Actual Concentration')
    plt.colorbar(im1, ax=ax1)
    
    # Plot predicted concentration
    im2 = ax2.pcolormesh(X, Y, pred_grid, shading='auto')
    ax2.set_title('Predicted Concentration')
    plt.colorbar(im2, ax=ax2)
    
    plt.tight_layout()
    
    # Save figure
    plt.savefig(save_dir / f'comparison_timestep_{timestep}.png')
    plt.close()

def plot_training_history(history: Dict[str, list], save_dir: str = "figures") -> None:
    """
    Plot training history including total loss, PDE loss, and data loss.
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    plt.figure(figsize=(12, 6))
    plt.plot(history['loss'], label='Total Loss')
    plt.plot(history['pde_loss'], label='PDE Loss')
    plt.plot(history['data_loss'], label='Data Loss')
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training History')
    plt.legend()
    plt.grid(True)
    plt.savefig(save_dir / 'training_history.png')
    plt.close()



In [2]:
# Main execution code

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Parameters
TRAIN_DAYS = 70
NUM_EPOCHS = 100
LEARNING_RATE = 0.001
BATCH_SIZE = 1000
DIFFUSION_COEFF = 1e-4
HIDDEN_LAYERS = 5
NEURONS_PER_LAYER = 32

# Create directories for saving results
MODEL_DIR = Path("models")
FIGURE_DIR = Path("figures")
MODEL_DIR.mkdir(parents=True, exist_ok=True)
FIGURE_DIR.mkdir(parents=True, exist_ok=True)

# Load data
print("Loading dataset...")
ds = xr.open_dataset("output-tracer-release_2025-02-16.nc")

# Prepare train-test split
print("Preparing temporal train-test split...")
data = prepare_temporal_split_data(ds, train_days=TRAIN_DAYS)

# Print dataset information
print(f"\nDataset information:")
print(f"Training samples: {len(data['train']['x'])}")
print(f"Test samples: {len(data['test']['x'])}")

# Initialize model
print("\nInitializing PINN model...")
model = AdvectionDiffusionPINN(
    D=DIFFUSION_COEFF,
    hidden_layers=HIDDEN_LAYERS,
    neurons_per_layer=NEURONS_PER_LAYER
)

# Train model
print("\nStarting training...")
history = train_pinn(
    model=model,
    data=data['train'],
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    save_dir=str(MODEL_DIR)
)

# Plot training history
print("\nPlotting training history...")
plot_training_history(history, save_dir=str(FIGURE_DIR))

# Evaluate on test set
print("\nEvaluating model on test set...")
test_metrics = evaluate_model(model, data['test'])
print("\nTest Set Metrics:")
print(f"MSE: {test_metrics['mse']:.6f}")
print(f"RMSE: {test_metrics['rmse']:.6f}")
print(f"R²: {test_metrics['r2']:.6f}")
print(f"PDE Loss: {test_metrics['pde_loss']:.6f}")

# Plot results for multiple timesteps in test period
print("\nGenerating visualization plots...")
test_timesteps = [1800, 1900, 2000, 2100]  # Example timesteps in test period
for timestep in test_timesteps:
    print(f"Plotting timestep {timestep}...")
    plot_results(model, data['test'], ds, timestep, save_dir=str(FIGURE_DIR))

print("\nTraining and evaluation complete!")
print(f"Models saved in: {MODEL_DIR}")
print(f"Figures saved in: {FIGURE_DIR}")


Loading dataset...
Preparing temporal train-test split...

Subsampled grid sizes:
X points: 25 (original: 100)
Y points: 25 (original: 100)
Time points: 601 (original: 2401)

Data shapes:
Original u shape: (2401, 1, 100, 100)
Subsampled u shape: (601, 1, 25, 25)
Time mask shape: (601,)
X_train shape: (25, 25, 421)
Training data shape: (421, 1, 25, 25)
Number of training timesteps: 421
Number of test timesteps: 180

Final dataset sizes:
Training samples: 263125
Test samples: 112500

Dataset information:
Training samples: 263125
Test samples: 112500

Initializing PINN model...

Starting training...
Starting training with 263125 samples...
-----Epoch 0, Loss: 1229.073644, PDE Loss: 1229.073634, Data Loss: 0.000014, Time: 3.38s
-----Epoch 10, Loss: 1229.073634, PDE Loss: 1229.073634, Data Loss: 0.000001, Time: 43.28s
-----Epoch 20, Loss: 1229.073634, PDE Loss: 1229.073634, Data Loss: 0.000000, Time: 80.21s
-----Epoch 30, Loss: 1229.073634, PDE Loss: 1229.073634, Data Loss: 0.000001, Time: 

In [None]:
ds.u.values.shape  # Likely (2401, 1, 100, 100) or similar