In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import get_laplacian
from torch_scatter import scatter_add
from torch_geometric.data import Data

In [2]:

# Define the Physics-Informed Graph Convolutional Network (GCN)
class PhysicsInformedGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super(PhysicsInformedGCN, self).__init__()
        self.convs = nn.ModuleList()
        # First layer: input to hidden
        self.convs.append(GCNConv(in_channels, hidden_channels))
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        # Final layer: hidden to output (e.g., predicted field u)
        self.convs.append(GCNConv(hidden_channels, out_channels))
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
        x = self.convs[-1](x, edge_index)
        return x

In [3]:
# Define a physics-informed loss function for a static PDE, e.g., Laplace's equation Δu = 0.
def physics_loss(data, predictions):
    """
    Compute the PDE residual loss based on the graph Laplacian.
    We assume that our PDE is Δu = 0, where the Laplacian is approximated on the graph.
    """
    num_nodes = data.num_nodes
    # Compute the graph Laplacian (symmetric normalization)
    lap_edge_index, lap_edge_weight = get_laplacian(data.edge_index, normalization='sym', num_nodes=num_nodes)
    row, col = lap_edge_index  # row: target nodes, col: source nodes
    
    # Compute the differences along the edges: u_j - u_i for each edge (i,j)
    diff = predictions[col] - predictions[row]
    # Weight differences by the Laplacian weights
    diff_weighted = lap_edge_weight.unsqueeze(1) * diff
    # Aggregate differences for each node using scatter
    laplacian_pred = scatter_add(diff_weighted, row, dim=0, dim_size=num_nodes)
    
    # The PDE residual at each node should be close to zero; use MSE loss
    loss_phys = torch.mean(laplacian_pred**2)
    return loss_phys

In [4]:

# For demonstration, create a dummy static graph with 100 nodes.
num_nodes = 100
# Example: use 2D coordinates as features (you can add more features if needed)
x = torch.randn((num_nodes, 2))

# Create a simple connectivity (e.g., a k-nearest neighbor graph or any mesh-based connectivity)
# Here, we create a dummy edge_index tensor (shape [2, num_edges])
# Note: In practice, edge_index should reflect the connectivity of your mesh.
edge_index = torch.randint(0, num_nodes, (2, 500))

# Create a PyG Data object
data = Data(x=x, edge_index=edge_index, num_nodes=num_nodes)

# Initialize the network and optimizer
model = PhysicsInformedGCN(in_channels=2, hidden_channels=32, out_channels=1, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Dummy ground-truth (if available) for a data-driven loss (here, we use zeros for demonstration)
# In a real scenario, you might have sparse measurements or boundary condition targets.
# target = torch.zeros((num_nodes, 1))

# Training loop (static case)
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass: predict the field u at each node
    pred = model(data)
    
    # Data-driven loss (MSE between prediction and target)
    # loss_data = F.mse_loss(pred, target)
    # Physics-informed loss enforcing the PDE Δu = 0
    loss_phys = physics_loss(data, pred)
    
    # Total loss: you may balance these terms with a weighting parameter (lambda)
    lambda_phys = 1.0  # Adjust this weight as needed
    loss = lambda_phys * loss_phys
    
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch:3d}, Loss: {loss.item():.6f}, Physics Loss: {loss_phys.item():.6f}")


Epoch   0, Loss: 0.000112, Physics Loss: 0.000112
Epoch  20, Loss: 0.000007, Physics Loss: 0.000007
Epoch  40, Loss: 0.000001, Physics Loss: 0.000001
Epoch  60, Loss: 0.000000, Physics Loss: 0.000000
Epoch  80, Loss: 0.000000, Physics Loss: 0.000000
Epoch 100, Loss: 0.000000, Physics Loss: 0.000000
Epoch 120, Loss: 0.000000, Physics Loss: 0.000000
Epoch 140, Loss: 0.000000, Physics Loss: 0.000000
Epoch 160, Loss: 0.000000, Physics Loss: 0.000000
Epoch 180, Loss: 0.000000, Physics Loss: 0.000000
