### GraviGNN â€“ Physics-Informed Gravity Inversion

* Loads gravity anomaly data and corresponding 3D density models
* Preprocesses and normalizes the dataset
* Defines a graph-enhanced GraviGNN architecture
* Performs 2D gravity â†’ 3D density reconstruction
* Ensures physical consistency through forward modeling
* Evaluates performance using error and RÂ² metrics


### Library Imports & Environment Setup

In [None]:

import numpy as np
import torch
import torch.nn as nn
from torchsummary import summary
from torch.autograd import Variable
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as colors
from collections import Counter
from sklearn.metrics import f1_score
import torch.nn.functional as F
import time
import h5py
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
plt.rc('font',family='Times New Roman', size=12)

In [31]:
import torch

torch.cuda.is_available()

True

### Hyperparameters & Physical Parameters

In [None]:
# Deep network parameters
patience = 20
epochs = 130
tra_num = 20000
val_num = 2000
part_num = 100
category = 6
batch_size = 16
num_cell = 32
learning_rate = 4e-4
threshold = 1e-4
realdata_num = 1
start_fm = 32
syn_num = part_num*category
total_num = tra_num + val_num


# Physical parameters
density = 1000


### Dataset File Paths

In [33]:
# Data folders
dataFile = './data/tra&val/data{}.mat'
syn_dataFile = './data/syn/data{}.mat'


### Data Loading & Preprocessing

In [34]:
x = []
y = []
for i in range(total_num):
    data = h5py.File(dataFile.format(i), 'r')
    m = data['m'][0] / density
    d = data['d'][0]
    d = np.nan_to_num(d)
    x.append(d.reshape(1, num_cell, num_cell))
    y.append(m.reshape(16, num_cell, num_cell))

syn_x = []
syn_y = []
for i in range(syn_num):
    data = h5py.File(syn_dataFile.format(i), 'r')
    m = data['m'][0] / density
    d = data['d'][0]
    d = np.nan_to_num(d)
    syn_x.append(d.reshape(1, num_cell, num_cell))
    syn_y.append(m.reshape(16, num_cell, num_cell))

tra_x = x[:tra_num]
tra_y = y[:tra_num]
val_x = x[-val_num:]
val_y = y[-val_num:]

tra_idxs = list(range(len(tra_x)))
val_idxs = list(range(len(val_x)))
# np.random.shuffle(tra_idxs)
# np.random.shuffle(val_idxs)
syn_idxs = list(range(len(syn_x)))

In [35]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, train=True, masks=None):
        self.train = train
        self.images = images
        if self.train:
            self.masks = masks

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = None
        if self.train:
            mask = self.masks[idx]
        return (image, mask)


tra = Dataset(np.array(tra_x).astype(np.float32)[tra_idxs], train=True,
              masks=np.array(tra_y).astype(np.float32)[tra_idxs])
val = Dataset(np.array(val_x).astype(np.float32)[val_idxs], train=True,
              masks=np.array(val_y).astype(np.float32)[val_idxs])
syn = Dataset(np.array(syn_x).astype(np.float32)[syn_idxs], train=True,
              masks=np.array(syn_y).astype(np.float32)[syn_idxs])

tra_loader = torch.utils.data.DataLoader(dataset=tra, batch_size=batch_size, shuffle=False, pin_memory=False)
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=batch_size, shuffle=False, pin_memory=False)
syn_loader = torch.utils.data.DataLoader(dataset=syn, batch_size=batch_size, shuffle=False, pin_memory=False)

### GraviGNN with Forward-fitting and Smoothness Regularizers.
* Implements **GraviGNN**
* Uses `double_conv` blocks for spatial feature extraction
* Applies **GraphConv with k-NN aggregation** for global feature learning
* Outputs a **16-layer 3D density model** from 2D gravity input
* Defines a **Physics-Informed Loss (PINN)** combining:

  * Dice loss (supervised reconstruction)
  * Physics/data fidelity loss using forward matrix **G**
* Enforces gravity consistency: ( d_{pred} = G m_{pred} )
* Trains using Adam optimizer with validation and synthetic testing
* Saves trained model weights after each epoch


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import h5py
import time

##############################################
# 1. Basic Modules & ViG Structure
##############################################

class double_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(double_conv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.ELU = nn.ELU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)

    def forward(self, x):
        x = self.ELU(self.conv1(x))
        x = self.conv2(x)
        return x


class GraphConv(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads=4, k=8):
        super(GraphConv, self).__init__()
        self.num_heads = num_heads
        self.k = k
        assert out_dim % num_heads == 0
        self.head_dim = out_dim // num_heads
        self.linears = nn.ModuleList(
            [nn.Linear(2 * in_dim, self.head_dim) for _ in range(num_heads)]
        )

    def forward(self, x):
        B, N, D = x.shape
        dist = torch.cdist(x, x, p=2)
        diag = torch.eye(N, device=x.device).bool().unsqueeze(0)
        dist.masked_fill_(diag, float('inf'))

        effective_k = self.k if self.k < N else N - 1
        knn_indices = torch.topk(-dist, k=effective_k, dim=-1).indices

        batch_indices = torch.arange(B, device=x.device).view(B, 1, 1).expand(B, N, effective_k)
        neighbors = x[batch_indices, knn_indices]

        agg, _ = torch.max(neighbors, dim=2)
        concat_feat = torch.cat([x, agg], dim=-1)

        return torch.cat([linear(concat_feat) for linear in self.linears], dim=-1)


class ViGBlock(nn.Module):
    def __init__(self, in_dim, out_dim, num_heads=4, k=8, ff_hidden_dim=None, dropout=0.1):
        super(ViGBlock, self).__init__()
        ff_hidden_dim = ff_hidden_dim or in_dim * 2

        self.proj_in = nn.Linear(in_dim, in_dim)
        self.graph_conv = GraphConv(in_dim, out_dim, num_heads, k)
        self.proj_out = nn.Linear(out_dim, out_dim)

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        self.norm1 = nn.LayerNorm(out_dim)

        self.ffn = nn.Sequential(
            nn.Linear(out_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, out_dim)
        )

        self.norm2 = nn.LayerNorm(out_dim)

    def forward(self, x):
        x_proj = self.proj_in(x)
        gc = self.dropout(self.activation(self.proj_out(self.graph_conv(x_proj))))
        y = self.norm1(x + gc)
        return self.norm2(y + self.dropout(self.ffn(y)))


class ViGUNet(nn.Module):
    def __init__(self, start_fm=32, num_heads=4, dropout=0.1, k=8):
        super(ViGUNet, self).__init__()

        self.enc1_conv = double_conv(1, start_fm)
        self.enc1_vig = ViGBlock(start_fm, start_fm, num_heads, k, dropout=dropout)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2_conv = double_conv(start_fm, start_fm * 2)
        self.enc2_vig = ViGBlock(start_fm * 2, start_fm * 2, num_heads, k, dropout=dropout)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3_conv = double_conv(start_fm * 2, start_fm * 4)
        self.enc3_vig = ViGBlock(start_fm * 4, start_fm * 4, num_heads, k, dropout=dropout)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4_conv = double_conv(start_fm * 4, start_fm * 8)
        self.enc4_vig = ViGBlock(start_fm * 8, start_fm * 8, num_heads, k, dropout=dropout)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck_conv = double_conv(start_fm * 8, start_fm * 16)
        self.bottleneck_vig = ViGBlock(start_fm * 16, start_fm * 16, num_heads, k, dropout=dropout)

        self.up4 = nn.ConvTranspose2d(start_fm * 16, start_fm * 8, 2, 2)
        self.dec4_conv = double_conv(start_fm * 16, start_fm * 8)
        self.dec4_vig = ViGBlock(start_fm * 8, start_fm * 8, num_heads, k, dropout=dropout)

        self.up3 = nn.ConvTranspose2d(start_fm * 8, start_fm * 4, 2, 2)
        self.dec3_conv = double_conv(start_fm * 8, start_fm * 4)
        self.dec3_vig = ViGBlock(start_fm * 4, start_fm * 4, num_heads, k, dropout=dropout)

        self.up2 = nn.ConvTranspose2d(start_fm * 4, start_fm * 2, 2, 2)
        self.dec2_conv = double_conv(start_fm * 4, start_fm * 2)
        self.dec2_vig = ViGBlock(start_fm * 2, start_fm * 2, num_heads, k, dropout=dropout)

        self.up1 = nn.ConvTranspose2d(start_fm * 2, start_fm, 2, 2)
        self.dec1_conv = double_conv(start_fm * 2, start_fm)
        self.dec1_vig = ViGBlock(start_fm, start_fm, num_heads, k, dropout=dropout)

        self.final_conv = nn.Conv2d(start_fm, 16, kernel_size=1)
        self.final_bn = nn.BatchNorm2d(16)
        self.final_act = nn.Sigmoid()

    def _apply_vig(self, x, vig_module):
        B, C, H, W = x.shape
        x_flat = x.reshape(B, C, H * W).permute(0, 2, 1)
        x_vig = vig_module(x_flat)
        return x_vig.permute(0, 2, 1).reshape(B, C, H, W)

    def forward(self, inputs):
        e1 = self._apply_vig(self.enc1_conv(inputs), self.enc1_vig)
        p1 = self.pool1(e1)

        e2 = self._apply_vig(self.enc2_conv(p1), self.enc2_vig)
        p2 = self.pool2(e2)

        e3 = self._apply_vig(self.enc3_conv(p2), self.enc3_vig)
        p3 = self.pool3(e3)

        e4 = self._apply_vig(self.enc4_conv(p3), self.enc4_vig)
        p4 = self.pool4(e4)

        b = self._apply_vig(self.bottleneck_conv(p4), self.bottleneck_vig)

        d4 = self._apply_vig(self.dec4_conv(torch.cat([self.up4(b), e4], 1)), self.dec4_vig)
        d3 = self._apply_vig(self.dec3_conv(torch.cat([self.up3(d4), e3], 1)), self.dec3_vig)
        d2 = self._apply_vig(self.dec2_conv(torch.cat([self.up2(d3), e2], 1)), self.dec2_vig)
        d1 = self._apply_vig(self.dec1_conv(torch.cat([self.up1(d2), e1], 1)), self.dec1_vig)

        return self.final_act(self.final_bn(self.final_conv(d1)))

##############################################
# 2. Updated PINN Loss (Dice + Physics + Smoothness)
##############################################

def dice_func(pred, target):
    smooth = 1.0
    num = pred.size(0)
    m1 = pred.reshape(num, -1)
    m2 = target.reshape(num, -1)
    intersection = (m1 * m2).sum(1)
    score = (2. * intersection + smooth) / (m1.pow(2).sum(1) + m2.pow(2).sum(1) + smooth)
    return score.mean()


def smoothness_loss(m):
    dx = m[:, :, :, 1:] - m[:, :, :, :-1]
    dy = m[:, :, 1:, :] - m[:, :, :-1, :]
    return dx.pow(2).mean() + dy.pow(2).mean()


def total_pinn_loss(outputs, masks, images, G_matrix,
                    lambda_physics=0.01,
                    lambda_smooth=0.001):

    d_loss = 1 - dice_func(outputs, masks)

    batch_size = outputs.size(0)
    m_flat = outputs.reshape(batch_size, -1)
    d_pred = torch.matmul(m_flat, G_matrix.t())
    d_obs = images.reshape(batch_size, -1)

    p_loss = F.mse_loss(d_pred, d_obs)
    s_loss = smoothness_loss(outputs)

    return d_loss + lambda_physics * p_loss + lambda_smooth * s_loss


##############################################
# 3. Training Setup
##############################################

lambda_phys = 0.1
lambda_smooth = 0.001

with h5py.File(name='./G.mat', mode='r') as f:
    G = torch.Tensor(np.nan_to_num(f['G'][:])).T.cuda()

model = ViGUNet(start_fm=32, num_heads=4, dropout=0.3, k=8).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

mean_tra_losses, mean_val_losses, mean_syn_losses = [], [], []
val_data, syn_data = [], []

start_time = time.time()
epoch = 0

while epoch <= epochs:
    tra_losses, val_losses, syn_losses = [], [], []

    model.train()
    for images, masks in tra_loader:
        images, masks = images.cuda(), masks.cuda()
        optimizer.zero_grad()
        outputs = model(images)

        loss = total_pinn_loss(
            outputs, masks, images, G,
            lambda_physics=lambda_phys,
            lambda_smooth=lambda_smooth
        )

        loss.backward()
        optimizer.step()
        tra_losses.append(loss.item())

    model.eval()
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.cuda(), masks.cuda()
            outputs = model(images)

            v_loss = total_pinn_loss(
                outputs, masks, images, G,
                lambda_physics=lambda_phys,
                lambda_smooth=lambda_smooth
            )
            val_losses.append(v_loss.item())

        for images, masks in syn_loader:
            images, masks = images.cuda(), masks.cuda()
            outputs = model(images)

            s_loss = total_pinn_loss(
                outputs, masks, images, G,
                lambda_physics=lambda_phys,
                lambda_smooth=lambda_smooth
            )
            syn_losses.append(s_loss.item())

    epoch += 1

    m_tra = np.mean(tra_losses)
    m_val = np.mean(val_losses)
    m_syn = np.mean(syn_losses)

    mean_tra_losses.append(m_tra)
    mean_val_losses.append(m_val)
    mean_syn_losses.append(m_syn)

    torch.save(model.state_dict(), 'pinn_vignn_phylossgeoloss.pth')

    print(f'Epoch: {epoch:03d} | Tra: {m_tra:.4f} | Val: {m_val:.4f} | Syn: {m_syn:.4f}')

print(f"Training completed in {time.time() - start_time:.2f} seconds.")



Epoch: 001 | Tra: 0.8416 | Val: 0.6685 | Syn: 0.6291
Epoch: 002 | Tra: 0.7146 | Val: 0.6211 | Syn: 0.6154
Epoch: 003 | Tra: 0.5638 | Val: 0.5133 | Syn: 0.4995
Epoch: 004 | Tra: 0.4595 | Val: 0.4222 | Syn: 0.4023
Epoch: 005 | Tra: 0.3976 | Val: 0.3683 | Syn: 0.3407
Epoch: 006 | Tra: 0.3622 | Val: 0.3438 | Syn: 0.3091
Epoch: 007 | Tra: 0.3406 | Val: 0.3250 | Syn: 0.2857
Epoch: 008 | Tra: 0.3258 | Val: 0.3123 | Syn: 0.2607
Epoch: 009 | Tra: 0.3157 | Val: 0.3012 | Syn: 0.2492
Epoch: 010 | Tra: 0.3068 | Val: 0.2982 | Syn: 0.2443
Epoch: 011 | Tra: 0.2996 | Val: 0.2874 | Syn: 0.2331
Epoch: 012 | Tra: 0.2938 | Val: 0.2905 | Syn: 0.2287
Epoch: 013 | Tra: 0.2886 | Val: 0.2793 | Syn: 0.2220
Epoch: 014 | Tra: 0.2842 | Val: 0.2757 | Syn: 0.2176
Epoch: 015 | Tra: 0.2804 | Val: 0.2745 | Syn: 0.2140
Epoch: 016 | Tra: 0.2773 | Val: 0.2712 | Syn: 0.2081
Epoch: 017 | Tra: 0.2732 | Val: 0.2715 | Syn: 0.2077
Epoch: 018 | Tra: 0.2693 | Val: 0.2673 | Syn: 0.2138
Epoch: 019 | Tra: 0.2620 | Val: 0.2616 | Syn: 

In [None]:
model_File = './pinn_vignn_reg.pth'

### Model Evaluation

* Loads trained model weights using `load_state_dict()`
* Sets model to evaluation mode using `model.eval()`
* Disables gradient computation with `torch.no_grad()`
* Performs forward pass on **validation dataset**
* Performs forward pass on **synthetic dataset**
* Stores predicted outputs, ground truth masks, and input images for further analysis


In [38]:
# Load the model for evaluation
val_data = []
syn_data = []
model.load_state_dict(torch.load(model_File),strict=False)
model.eval()
for images, masks in val_loader:
    with torch.no_grad():
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        outputs = model(images)
        val_data.extend([[outputs, masks, images]])

for images, masks in syn_loader:
    with torch.no_grad():
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        outputs = model(images)
        syn_data.extend([[outputs, masks, images]])

### Post-Processing of Evaluation Results

* Extracts predictions, ground truth density, and input gravity data from `val_data`

* Reshapes 3D density volumes (16Ã—32Ã—32) into 1D vectors (16384 elements)

* Reshapes 2D gravity images (32Ã—32) into 1D vectors (1024 elements)

* Converts tensors from GPU to CPU and NumPy format

* Stores processed validation results in:

  * `val_truth`
  * `val_predict`
  * `val_truth_d`

* Repeats the same processing for synthetic data

* Stores processed synthetic results in:

  * `syn_truth`
  * `syn_predict`
  * `syn_truth_d`

This prepares the outputs for metric computation and quantitative evaluation.


In [39]:
# val_data = np.array(val_data)
val_truth = []
val_predict = []
val_truth_d = []
for p, q in enumerate(val_data):
    for i, j in enumerate(q[1]):
        val_truth.append(j.reshape(1, 16384)[0].cpu().numpy())
    for i, j in enumerate(q[0]):
        val_predict.append(j.reshape(1, 16384)[0].cpu().numpy())
    for i, j in enumerate(q[2]):
        val_truth_d.append(j.reshape(1, 1024)[0].cpu().numpy())

# syn_data = np.array(syn_data)
syn_truth = []
syn_predict = []
syn_truth_d = []
for p, q in enumerate(syn_data):
    for i, j in enumerate(q[1]):
        syn_truth.append(j.reshape(1, 16384)[0].cpu().numpy())
    for i, j in enumerate(q[0]):
        syn_predict.append(j.reshape(1, 16384)[0].cpu().numpy())
    for i, j in enumerate(q[2]):
        syn_truth_d.append(j.reshape(1, 1024)[0].cpu().numpy())

In [40]:
print(val_truth[0].shape)
print(val_predict[0].shape)
print(val_truth_d[0].shape)

print(syn_truth[0].shape)
print(syn_predict[0].shape)
print(syn_truth_d[0].shape)

(16384,)
(16384,)
(1024,)
(16384,)
(16384,)
(1024,)


### Performance Evaluation & Metrics Computation

* Converts forward matrix **G** to NumPy format for metric calculation
* Defines `compute_metrics()` to evaluate model performance

#### ðŸ”¹ For Each Sample:

* Flattens predicted and true density models
* Computes **Relative Error (L2 norm)** between prediction and ground truth
* Reconstructs gravity data using forward modeling:
  [
  d_{pred} = G \cdot (\rho \cdot m_{pred})
  ]
* Computes **RÂ² score** to measure physical data fidelity

---

#### ðŸ”¹ Validation Evaluation

* Computes mean Relative Error and RÂ² for validation dataset
* Stores results in `E` and `R_2`

---

#### ðŸ”¹ Synthetic Evaluation

* Computes metrics for full synthetic dataset
* Segments results into predefined geological categories
* Calculates mean metrics for each category

---

#### ðŸ”¹ Final Output

* Displays tabulated results including:

  * Validation performance
  * Category-wise synthetic performance
* Reports:

  * Relative Error (E)
  * RÂ² Score (physical consistency)


In [None]:
import numpy as np
import torch

# Initialize lists to store mean results
E = []
R_2 = []

# Ensure G and density are accessible and in numpy format
G_arr = G.cpu().numpy() if torch.is_tensor(G) else np.array(G)

def compute_metrics(predict_list, truth_list, truth_d_list, G_mat, den):
    rel_list = []
    r2_list = []
    
    for i in range(len(predict_list)):
        # Convert to numpy and flatten
        # We use .detach().cpu() in case they are still torch tensors
        p = predict_list[i]
        t = truth_list[i]
        d_obs = truth_d_list[i]
        
        m_pre = p.detach().cpu().numpy().flatten() if torch.is_tensor(p) else np.array(p).flatten()
        m_tru = t.detach().cpu().numpy().flatten() if torch.is_tensor(t) else np.array(t).flatten()
        tru_d = d_obs.detach().cpu().numpy().flatten() if torch.is_tensor(d_obs) else np.array(d_obs).flatten()

        # 1. Relative Error (L2 Norm)
        rel = np.linalg.norm(m_pre - m_tru) / (np.linalg.norm(m_tru) + 1e-8)
        rel_list.append(rel)

        # 2. R^2 Score (Physical Fidelity)
        # Reconstruct data: d_pred = G * (density * m_pred)
        pre_d = G_mat @ (den * m_pre)
        
        res = tru_d - pre_d
        ss_res = np.sum(res**2)
        ss_tot = np.sum((tru_d - np.mean(tru_d))**2)
        
        r2 = 1 - (ss_res / (ss_tot + 1e-8))
        r2_list.append(r2)
        
    return rel_list, r2_list

# --- 1. Evaluate Validation Set ---
val_rel, val_R_2 = compute_metrics(val_predict, val_truth, val_truth_d, G_arr, density)
E.append(np.mean(val_rel))
R_2.append(np.mean(val_R_2))

# --- 2. Evaluate Synthesis Set (6 Categories) ---
syn_rel, syn_R_2 = compute_metrics(syn_predict, syn_truth, syn_truth_d, G_arr, density)

# --- 3. Segment Results by Category ---
for i in range(category):
    start = i * part_num
    end = (i + 1) * part_num
    
    E.append(np.mean(syn_rel[start:end]))
    R_2.append(np.mean(syn_R_2[start:end]))

# --- 4. Display Results ---
print(f"{'Data Group':<15} | {'Rel. Error (E)':<15} | {'R2 Score':<10}")
print("-" * 50)
print(f"{'Validation':<15} | {E[0]:<15.4f} | {R_2[0]:<10.4f}")
for i in range(1, category + 1):
    print(f"{'Category ' + str(i):<15} | {E[i]:<15.4f} | {R_2[i]:<10.4f}")