In [1]:
%matplotlib notebook
import numpy as np
import torch
import torch.nn as nn
from torchsummary import summary
from torch.autograd import Variable
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as colors
from collections import Counter
from sklearn.metrics import f1_score
import torch.nn.functional as F
import time
import h5py
import seaborn as sns
plt.rc('font',family='Times New Roman', size=12)

The prescribed value selected was $1.1 g/cm^3$

To reduce the contingency of the inversion results, we cut the gravity anomaly map (43 × 36) from left to right and from top to bottom into 60 gravity anomaly maps with sizes of 32 × 32 and step lengths of 50 m.

In [2]:
patience = 20
epochs = 160
tra_num = 20000
val_num = 2000
batch_size = 20
density = 1100
num_cell = 32
num_obs = 32
length = 50
real_num = 60
start_fm = 64
noise_levels = [0.02, 0.04]
xrange = [-2700, -600]
yrange = [-1150, 600]
x_num = int((xrange[1]-xrange[0])/length + 1)  # 43
y_num = int((yrange[1]-yrange[0])/length + 1)  # 36
k=8

learning_rate = 4e-4
threshold = 1e-4
total_num = tra_num + val_num

dataFile = './tra&val/data{}.mat'
real_dataFile = './real/data_slice{}.mat'
realdata_ = h5py.File('./real/data.mat','r')
pointdata = h5py.File('./real/pointdata.mat', 'r')['pointdata']
model_File = './epoch_model_vignn1.pth'

In [3]:
realdata = np.reshape(np.nan_to_num(realdata_['d'][0]), (x_num, y_num))
realdata = np.where(realdata>0, realdata, 0).T
fig = plt.figure(figsize=(6.3, 5.4))
color = ('#00008F', '#0030FF', '#10FFEF', '#DFFF20', '#FF5000', '#800000')
my_levels = np.linspace(0., 2.3, 7)
cmap = colors.ListedColormap(color)
cf = plt.contourf(realdata, my_levels, cmap=cmap)
plt.contour(cf, my_levels, colors='k', linewidths=0.7)
x = []
y = []
for i in range(len(pointdata[0])):
    if -2650 < pointdata[0, :][i] < -601 and -1150 < pointdata[1, :][i] < 500:
        x.append(pointdata[0, :][i])
        y.append(pointdata[1, :][i])
x = (np.array(x)-xrange[0])/length
y = (np.array(y)-yrange[0])/length
plt.scatter(x, y, s=4, c = 'k')
plt.plot([20, 20], [0, 35])
plt.plot([0, 42], [15, 15])
plt.xticks(np.arange(0, x_num), ['-2700', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-2300', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1900', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1500', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1100', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-700', ' ', ' '])
plt.xlabel('Easting (m)')
plt.yticks(np.arange(0, y_num), ['-1150', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-750', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-350', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '50', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '450', ' ', ' ', ' '])
plt.ylabel('Northing (m)')
plt.tight_layout()
plt.tick_params(bottom=False, top=False, left=False, right=False)

plt.savefig("aa.png", dpi=300, bbox_inches='tight')
plt.close()

<IPython.core.display.Javascript object>

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,images,train=True, masks=None):
        self.train = train
        self.images = images
        if self.train:
            self.masks = masks
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = None
        if self.train:
            mask = self.masks[idx]
        return (image, mask)
    
class double_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True),
                    nn.BatchNorm2d(out_channels),
                    nn.ELU(inplace=True),
                    nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True),
                    nn.BatchNorm2d(out_channels),
                    nn.ELU(inplace=True))
        
    def forward(self, x):
        x = self.conv(x)
        return x     



In [6]:
#############################
# Define basic modules
#############################

# Standard double convolution block (used for local feature extraction)
class double_conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(double_conv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.ELU = nn.ELU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)

    def forward(self, x):
        x = self.ELU(self.conv1(x))
        x = self.conv2(x)
        return x

##############################################
# Graph Convolution with Multi-Head Updates
##############################################

class GraphConv(nn.Module):
    """
    Graph convolution that:
      1. Computes K-nearest neighbors based on Euclidean distance.
      2. Aggregates neighbor features via an elementwise maximum.
      3. Concatenates the original node feature with the aggregated neighbor feature.
      4. Applies multi-head linear updates.
    """
    def __init__(self, in_dim, out_dim, num_heads=4, k=8):
        super(GraphConv, self).__init__()
        self.num_heads = num_heads
        self.k = k
        assert out_dim % num_heads == 0, "out_dim must be divisible by num_heads"
        self.head_dim = out_dim // num_heads
        # Create a separate linear layer for each head.
        self.linears = nn.ModuleList([nn.Linear(2 * in_dim, self.head_dim) for _ in range(num_heads)])

    def forward(self, x):
        # x: (B, N, D)
        B, N, D = x.shape

        # Compute pairwise Euclidean distances: shape (B, N, N)
        dist = torch.cdist(x, x, p=2)
        # Mask self-distance by setting diagonal entries to infinity
        diag = torch.eye(N, device=x.device).bool().unsqueeze(0)
        dist.masked_fill_(diag, float('inf'))
        # Select effective k to avoid selecting more neighbors than available
        effective_k = self.k if self.k < N else N - 1
        # Get indices of k nearest neighbors (smallest distances)
        knn_indices = torch.topk(-dist, k=effective_k, dim=-1).indices  # shape: (B, N, effective_k)

        # Gather neighbor features using advanced indexing
        batch_indices = torch.arange(B, device=x.device).view(B, 1, 1).expand(B, N, effective_k)
        neighbors = x[batch_indices, knn_indices]  # shape: (B, N, effective_k, D)

        # Aggregate neighbors with elementwise maximum across the k dimension
        agg, _ = torch.max(neighbors, dim=2)  # shape: (B, N, D)

        # Concatenate the original feature with the aggregated neighbor feature
        concat_feat = torch.cat([x, agg], dim=-1)  # (B, N, 2*D)

        # Multi-head update: apply a separate linear projection per head and concatenate
        head_outputs = []
        for linear in self.linears:
            head_outputs.append(linear(concat_feat))  # each: (B, N, head_dim)
        out = torch.cat(head_outputs, dim=-1)  # (B, N, out_dim)
        return out

##############################################
# ViG Block: Graph-level processing with enhancement
##############################################

class ViGBlock(nn.Module):
    """
    ViG block that applies:
      - A graph convolution (with pre- and post- linear projections and nonlinear activations)
      - A feed-forward network (FFN) to further refine node features.
      
    The block follows:
         Y = (GraphConv(X * W_in)) * W_out + X
         Z = FFN(Y) + Y
    with normalization (LayerNorm) and dropout to help prevent over-smoothing.
    """
    def __init__(self, in_dim, out_dim, num_heads=4, k=8, ff_hidden_dim=None, dropout=0.1):
        super(ViGBlock, self).__init__()
        ff_hidden_dim = ff_hidden_dim or in_dim * 2
        self.proj_in = nn.Linear(in_dim, in_dim)
        self.graph_conv = GraphConv(in_dim, out_dim, num_heads=num_heads, k=k)
        self.proj_out = nn.Linear(out_dim, out_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(out_dim)
        # Feed-forward network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(out_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, out_dim)
        )
        self.norm2 = nn.LayerNorm(out_dim)

    def forward(self, x):
        # x: (B, N, D)
        x_proj = self.proj_in(x)  # project into same domain: (B, N, D)
        gc = self.graph_conv(x_proj)  # graph convolution: (B, N, out_dim)
        gc = self.proj_out(gc)
        gc = self.activation(gc)
        gc = self.dropout(gc)
        y = x + gc  # residual connection
        y = self.norm1(y)
        # FFN to further refine features
        ffn_out = self.ffn(y)
        ffn_out = self.dropout(ffn_out)
        out = y + ffn_out  # residual connection
        out = self.norm2(out)
        return out

##############################################
# ViG-based UNet Architecture (Pyramid Architecture)
##############################################

class ViGUNet(nn.Module):
    """
    This network combines CNN-based patch processing with graph neural network modules.
    
    Unified Flow Summary:
       1. Input Processing:
          - An image is divided into patches.
          - Each patch is converted into a feature vector to form matrix X.
       2. Graph Construction:
          - A graph G is built where each node represents a patch.
          - Edges are created by connecting each node to its K nearest neighbors.
       3. Graph Convolution and Multi-Head Updates:
          - Graph convolution aggregates neighbor information via a max-relative function.
          - Features are updated via a multi-head mechanism.
       4. ViG Block Enhancement:
          - A Grapher module (ViGBlock) uses pre- and post- projections, nonlinear activations, and residual connections.
          - An FFN further refines node features.
       5. Network Architectures:
          - The ViG blocks are stacked to form a UNet-like encoder-decoder (pyramid) architecture.
    
    This complete flow leverages graph neural network principles with modern network design strategies.
    """
    def __init__(self, start_fm=32, num_heads=4, ff_hidden_dim=None, dropout=0.1, k=8):
        super(ViGUNet, self).__init__()
        self.start_fm = start_fm

        # Encoder Stage 1
        self.enc1_conv = double_conv(1, start_fm)
        self.enc1_vig = ViGBlock(start_fm, start_fm, num_heads=num_heads, k=k, dropout=dropout)
        self.pool1 = nn.MaxPool2d(2)

        # Encoder Stage 2
        self.enc2_conv = double_conv(start_fm, start_fm * 2)
        self.enc2_vig = ViGBlock(start_fm * 2, start_fm * 2, num_heads=num_heads, k=k, dropout=dropout)
        self.pool2 = nn.MaxPool2d(2)

        # Encoder Stage 3
        self.enc3_conv = double_conv(start_fm * 2, start_fm * 4)
        self.enc3_vig = ViGBlock(start_fm * 4, start_fm * 4, num_heads=num_heads, k=k, dropout=dropout)
        self.pool3 = nn.MaxPool2d(2)

        # Encoder Stage 4
        self.enc4_conv = double_conv(start_fm * 4, start_fm * 8)
        self.enc4_vig = ViGBlock(start_fm * 8, start_fm * 8, num_heads=num_heads, k=k, dropout=dropout)
        self.pool4 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck_conv = double_conv(start_fm * 8, start_fm * 16)
        self.bottleneck_vig = ViGBlock(start_fm * 16, start_fm * 16, num_heads=num_heads, k=k, dropout=dropout)

        # Decoder Stage 4
        self.up4 = nn.ConvTranspose2d(start_fm * 16, start_fm * 8, kernel_size=2, stride=2)
        self.dec4_conv = double_conv(start_fm * 16, start_fm * 8)
        self.dec4_vig = ViGBlock(start_fm * 8, start_fm * 8, num_heads=num_heads, k=k, dropout=dropout)

        # Decoder Stage 3
        self.up3 = nn.ConvTranspose2d(start_fm * 8, start_fm * 4, kernel_size=2, stride=2)
        self.dec3_conv = double_conv(start_fm * 8, start_fm * 4)
        self.dec3_vig = ViGBlock(start_fm * 4, start_fm * 4, num_heads=num_heads, k=k, dropout=dropout)

        # Decoder Stage 2
        self.up2 = nn.ConvTranspose2d(start_fm * 4, start_fm * 2, kernel_size=2, stride=2)
        self.dec2_conv = double_conv(start_fm * 4, start_fm * 2)
        self.dec2_vig = ViGBlock(start_fm * 2, start_fm * 2, num_heads=num_heads, k=k, dropout=dropout)

        # Decoder Stage 1
        self.up1 = nn.ConvTranspose2d(start_fm * 2, start_fm, kernel_size=2, stride=2)
        self.dec1_conv = double_conv(start_fm * 2, start_fm)
        self.dec1_vig = ViGBlock(start_fm, start_fm, num_heads=num_heads, k=k, dropout=dropout)

        self.final_conv = nn.Conv2d(start_fm, 16, kernel_size=1)
        self.final_bn = nn.BatchNorm2d(16)
        self.final_act = nn.Sigmoid()

    def forward(self, inputs):
        # Encoder Stage 1
        enc1 = self.enc1_conv(inputs)  # (B, start_fm, H, W)
        B, C, H, W = enc1.shape
        enc1_flat = enc1.view(B, C, H * W).permute(0, 2, 1)  # (B, N, C)
        enc1_vig = self.enc1_vig(enc1_flat)
        enc1 = enc1_vig.permute(0, 2, 1).view(B, C, H, W)
        pool1 = self.pool1(enc1)

        # Encoder Stage 2
        enc2 = self.enc2_conv(pool1)
        B, C, H, W = enc2.shape
        enc2_flat = enc2.view(B, C, H * W).permute(0, 2, 1)
        enc2_vig = self.enc2_vig(enc2_flat)
        enc2 = enc2_vig.permute(0, 2, 1).view(B, C, H, W)
        pool2 = self.pool2(enc2)

        # Encoder Stage 3
        enc3 = self.enc3_conv(pool2)
        B, C, H, W = enc3.shape
        enc3_flat = enc3.view(B, C, H * W).permute(0, 2, 1)
        enc3_vig = self.enc3_vig(enc3_flat)
        enc3 = enc3_vig.permute(0, 2, 1).view(B, C, H, W)
        pool3 = self.pool3(enc3)

        # Encoder Stage 4
        enc4 = self.enc4_conv(pool3)
        B, C, H, W = enc4.shape
        enc4_flat = enc4.view(B, C, H * W).permute(0, 2, 1)
        enc4_vig = self.enc4_vig(enc4_flat)
        enc4 = enc4_vig.permute(0, 2, 1).view(B, C, H, W)
        pool4 = self.pool4(enc4)

        # Bottleneck
        bottleneck = self.bottleneck_conv(pool4)
        B, C, H, W = bottleneck.shape
        bottleneck_flat = bottleneck.view(B, C, H * W).permute(0, 2, 1)
        bottleneck_vig = self.bottleneck_vig(bottleneck_flat)
        bottleneck = bottleneck_vig.permute(0, 2, 1).view(B, C, H, W)

        # Decoder Stage 4
        up4 = self.up4(bottleneck)
        cat4 = torch.cat([up4, enc4], dim=1)
        dec4 = self.dec4_conv(cat4)
        B, C, H, W = dec4.shape
        dec4_flat = dec4.view(B, C, H * W).permute(0, 2, 1)
        dec4_vig = self.dec4_vig(dec4_flat)
        dec4 = dec4_vig.permute(0, 2, 1).view(B, C, H, W)

        # Decoder Stage 3
        up3 = self.up3(dec4)
        cat3 = torch.cat([up3, enc3], dim=1)
        dec3 = self.dec3_conv(cat3)
        B, C, H, W = dec3.shape
        dec3_flat = dec3.view(B, C, H * W).permute(0, 2, 1)
        dec3_vig = self.dec3_vig(dec3_flat)
        dec3 = dec3_vig.permute(0, 2, 1).view(B, C, H, W)

        # Decoder Stage 2
        up2 = self.up2(dec3)
        cat2 = torch.cat([up2, enc2], dim=1)
        dec2 = self.dec2_conv(cat2)
        B, C, H, W = dec2.shape
        dec2_flat = dec2.view(B, C, H * W).permute(0, 2, 1)
        dec2_vig = self.dec2_vig(dec2_flat)
        dec2 = dec2_vig.permute(0, 2, 1).view(B, C, H, W)

        # Decoder Stage 1
        up1 = self.up1(dec2)
        cat1 = torch.cat([up1, enc1], dim=1)
        dec1 = self.dec1_conv(cat1)
        B, C, H, W = dec1.shape
        dec1_flat = dec1.view(B, C, H * W).permute(0, 2, 1)
        dec1_vig = self.dec1_vig(dec1_flat)
        dec1 = dec1_vig.permute(0, 2, 1).view(B, C, H, W)

        out = self.final_conv(dec1)
        out = self.final_bn(out)
        out = self.final_act(out)
        return out


In [7]:
##############################################
# Loss Functions
##############################################
def dice(pred, target):
    smooth = 1
    num = pred.size(0)
    m1 = pred.reshape(num, -1)
    m2 = target.reshape(num, -1)
    intersection = m1 * m2
    loss = (2. * intersection.sum(1) + smooth) / ((m1 * m1).sum(1) + (m2 * m2).sum(1) + smooth)
    return loss.sum() / num

def my_loss(pre_y, tru_y):
    loss = 1 - dice(pre_y, tru_y)
    return loss

In [None]:
def Model(m, w):
    L, W, H= m.shape
    c = ["#D1FEFE", "#D1FEFE", "#00FEF9", "#00FDFE", "#50FB7F", "#D3F821", "#FFDE00", "#FF9D00", "#F03A00", "#E10000"]
    x, y, z = np.indices((L, W, H))
    model = (x < 0) & (y < 0) & (z < 0)
    color = np.empty(m.shape, dtype=object)
    for i in range(L):
        for j in range(W):
            for k in range(H):
                if m[i][j][k] >= w:
                    cube = (x > i-1) & (x <= i)& (y > j-1) & (y <= j) & (z > k-1) & (z <= k)
                    color[cube] = c[int(round(10*m[i][j][k]))-1]
                    model = model | cube
    plt_model(model, color)

def plt_model(model, facecolors='r'):
    fig = plt.figure(figsize=(8, 4))
    ax = plt.axes(projection='3d')  # Changed this line
    ax.voxels(model, facecolors=facecolors, edgecolors='w', linewidth=0.4)
    plt.xticks(np.arange(0, 44, 1), ['-2725', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-2325', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1925', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1525', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1125', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-725', ' ', ' ', ' '])
    ax.set_xlabel('Easting (m)', labelpad=3)
    plt.yticks(np.arange(0, 37, 1), ['-1175', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-775', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-375', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '25', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '425', ' ', ' ', ' ', ' '])
    ax.set_ylabel('Northing (m)', labelpad=2)
    ax.set_zticks(np.arange(0, 17, 1))
    ax.set_zticklabels(['0', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '400', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '800'])
    ax.set_zlabel('Depth (m)', labelpad=-2)
    ax.invert_zaxis()
    ax.xaxis.set_tick_params(pad=-2)
    ax.yaxis.set_tick_params(pad=-2)
    ax.zaxis.set_tick_params(pad=0)
    plt.show()

    
def colorma():
    cdict = ["#F2F2F2", "#D1FEFE", "#00FEF9", "#00FDFE", "#50FB7F", "#D3F821", "#FFDE00", "#FF9D00", "#F03A00", "#E10000"] 
    return colors.ListedColormap(cdict, 'indexed')

def plot_xoz(model, index):
    plt.imshow(model.swapaxes(1, 2).T[index].T, cmap=colorma())
    plt.xticks(np.arange(-0.5, 43.5, 1), ['-2725', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-2325', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1925', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1525', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1125', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-725', ' ', ' ', ' '])
    plt.xlabel('Easting (m)')
    plt.yticks(np.arange(-0.5, 16.5, 1), ['0', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '400', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '800'])
    plt.ylabel('Depth (m)')
    plt.grid()
    plt.tight_layout()
    np.save('real_xoz_data.npy', model.swapaxes(1, 2).T[index].T)

def plot_yoz(model, index):
    plt.imshow(model.T[index].T, cmap=colorma())
    plt.xticks(np.arange(-0.5, 36.5, 1), ['-1175', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-775', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-375', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '25', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '425', ' ', ' ', ' ', ' '])
    plt.xlabel('Northing (m)')
    plt.yticks(np.arange(-0.5, 16.5, 1), ['0', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '400', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '800'])
    plt.ylabel('Depth (m)')
    plt.grid()
    plt.tight_layout()
    np.save('real_yoz_data.npy', model.T[index].T)
    
def plot_xoy(model, index):
    ax = plt.gca()
    plt.imshow(model[index], cmap=colorma())
    ax.invert_yaxis()
    plt.xticks(np.arange(-0.5, 43.5, 1), ['-2725', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-2325', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1925', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1525', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-1125', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-725', ' ', ' ', ' '])
    plt.xlabel('Easting (m)')
    plt.yticks(np.arange(-0.5, 36.5, 1), ['-1175', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-775', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '-375', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '25', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '425', ' ', ' ', ' ', ' '])
    plt.ylabel('Northing (m)')
    plt.grid()
    plt.tight_layout()

In [9]:
x = []
y = []
for i in range(total_num):
    data = h5py.File(dataFile.format(i), 'r')
    m = data['m'][0]/density
    d = data['d'][0]
    d = np.nan_to_num(d)
    x.append(d.reshape(1, num_cell, num_cell))
    y.append(m.reshape(16, num_cell, num_cell))

real_x = []
real_y = []
for i in range(real_num):
    data = h5py.File(real_dataFile.format(i), 'r')
    m = data['m'][0]/density
    d = data['d_slice'][0]
    d = np.nan_to_num(d)
    real_x.append(d.reshape(1, num_cell, num_cell))
    real_y.append(m.reshape(16, num_cell, num_cell))
    
tra_x = x[:tra_num]
tra_y = y[:tra_num]
val_x = x[-val_num:]
val_y = y[-val_num:]

In [10]:
m.shape

(16384,)

In [11]:
for j in noise_levels:
    for i in range(tra_num):
        noise_x = tra_x[i] + j*tra_x[i].max()*np.random.normal(0, 1, (1, num_cell, num_cell))
        tra_x.append(noise_x)
        tra_y.append(tra_y[i])

    for i in range(val_num):
        noise_x = val_x[i] + j*val_x[i].max()*np.random.normal(0, 1, (1, num_cell, num_cell))
        val_x.append(noise_x)
        val_y.append(val_y[i])

In [12]:
print(np.shape(tra_x), np.shape(tra_y))
print(np.shape(val_x), np.shape(val_y))
print(np.shape(real_x))

(60000, 1, 32, 32) (60000, 16, 32, 32)
(6000, 1, 32, 32) (6000, 16, 32, 32)
(60, 1, 32, 32)


In [13]:
tra_idxs = list(range(len(tra_x)))
val_idxs = list(range(len(val_x)))
real_idxs = list(range(len(real_x)))
np.random.shuffle(tra_idxs)
np.random.shuffle(val_idxs)

In [14]:
tra = Dataset(np.array(tra_x).astype(np.float32)[tra_idxs], train=True, masks=np.array(tra_y).astype(np.float32)[tra_idxs])
val = Dataset(np.array(val_x).astype(np.float32)[val_idxs], train=True, masks=np.array(val_y).astype(np.float32)[val_idxs])
real = Dataset(np.array(real_x).astype(np.float32)[real_idxs], train=True, masks=np.array(real_y).astype(np.float32)[real_idxs])
tra_loader = torch.utils.data.DataLoader(dataset=tra, batch_size=batch_size, shuffle=False, pin_memory=False)
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=batch_size, shuffle=False, pin_memory=False)
real_loader = torch.utils.data.DataLoader(dataset=real, batch_size=batch_size, shuffle=False, pin_memory=False)

In [16]:
# Initialize the ViG-based UNet model
model = ViGUNet(start_fm=32, num_heads=4, dropout=0.3, k=8)
model.cuda()

# Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Summarize the model with a sample input size (1, 32, 32)
summary(model, input_size=(1, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             320
               ELU-2           [-1, 32, 32, 32]               0
            Conv2d-3           [-1, 32, 32, 32]           9,248
       double_conv-4           [-1, 32, 32, 32]               0
            Linear-5             [-1, 1024, 32]           1,056
            Linear-6              [-1, 1024, 8]             520
            Linear-7              [-1, 1024, 8]             520
            Linear-8              [-1, 1024, 8]             520
            Linear-9              [-1, 1024, 8]             520
        GraphConv-10             [-1, 1024, 32]               0
           Linear-11             [-1, 1024, 32]           1,056
             ReLU-12             [-1, 1024, 32]               0
          Dropout-13             [-1, 1024, 32]               0
        LayerNorm-14             [-1, 1

In [17]:
mean_tra_losses = []
mean_val_losses = []
real_data = []

start = time.time()
epoch = 0
while epoch <= epochs:
    tra_losses = []
    val_losses = []
    model.train()
    for images, masks in tra_loader:   
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        optimizer.zero_grad()
        outputs = model(images)      
        loss = my_loss(outputs, masks)
        loss.backward()
        optimizer.step()
        tra_losses.append(loss.data)
        
    model.eval()
    for images, masks in val_loader:
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        with torch.no_grad():
            outputs = model(images)
            val_loss = my_loss(outputs, masks)
            val_losses.append(val_loss.data)  
    
    for images, masks in real_loader:
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        with torch.no_grad():
            outputs = model(images)
            real_data.extend([[outputs, masks, images]])  

    torch.save(model.state_dict(), 'epoch_model_vignn_real.pth')  # Save at the end of every epoch
    print(f"Model saved at the end of epoch {epoch}.")
        
    epoch += 1
    mean_tra_loss = torch.mean(torch.stack(tra_losses)).item() if tra_losses else 0.0
    mean_val_loss = torch.mean(torch.stack(val_losses)).item() if val_losses else 0.0
    mean_tra_losses.append(mean_tra_loss)
    mean_val_losses.append(mean_val_loss)
    print(f'Epoch: {epoch}. Tra Loss: {mean_tra_loss:.4f}. Val Loss: {mean_val_loss:.4f}.')
        
end = time.time()
run_time = end - start
print(f"Total runtime: {run_time:.2f} seconds.")


Model saved at the end of epoch 0.
Epoch: 1. Tra Loss: 0.7019. Val Loss: 0.5946.
Model saved at the end of epoch 1.
Epoch: 2. Tra Loss: 0.4906. Val Loss: 0.3906.
Model saved at the end of epoch 2.
Epoch: 3. Tra Loss: 0.3694. Val Loss: 0.3523.
Model saved at the end of epoch 3.
Epoch: 4. Tra Loss: 0.3393. Val Loss: 0.3313.
Model saved at the end of epoch 4.
Epoch: 5. Tra Loss: 0.3243. Val Loss: 0.3223.
Model saved at the end of epoch 5.
Epoch: 6. Tra Loss: 0.3158. Val Loss: 0.3114.
Model saved at the end of epoch 6.
Epoch: 7. Tra Loss: 0.3097. Val Loss: 0.3047.
Model saved at the end of epoch 7.
Epoch: 8. Tra Loss: 0.3001. Val Loss: 0.2956.
Model saved at the end of epoch 8.
Epoch: 9. Tra Loss: 0.2933. Val Loss: 0.2908.
Model saved at the end of epoch 9.
Epoch: 10. Tra Loss: 0.2893. Val Loss: 0.2879.
Model saved at the end of epoch 10.
Epoch: 11. Tra Loss: 0.2863. Val Loss: 0.2864.
Model saved at the end of epoch 11.
Epoch: 12. Tra Loss: 0.2838. Val Loss: 0.2850.
Model saved at the end 

In [None]:
model_File = './epoch_model_vignn_real.pth'

In [None]:

# Instantiate model and load to GPU
model = ViGUNet(start_fm=32).cuda()
model.load_state_dict(torch.load(model_File))
model.eval()

# Store outputs, masks, inputs
real_data = []

# Inference loop
for images, masks in real_loader:
    with torch.no_grad():
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        outputs = model(images)
        real_data.append([outputs, masks, images])



In [88]:
model.load_state_dict(torch.load(model_File))
real_data = []
model.eval()
for images, masks in real_loader:
    with torch.no_grad():
        images = Variable(images.cuda())
        masks = Variable(masks.cuda())
        outputs = model(images)
        real_data.extend([[outputs, masks, images]])  

In [89]:
type(real_data[0][0])

torch.Tensor

In [None]:

real_data_np = []
real_predict = []
real_truth_d = []

for output, mask, image in real_data:
    # Move tensors to CPU and detach from computation graph
    output_np = output.cpu().detach().numpy()
    mask_np = mask.cpu().detach().numpy()
    image_np = image.cpu().detach().numpy()

    real_data_np.append([output_np, mask_np, image_np])
    real_predict.append(output_np)
    real_truth_d.append(mask_np)


In [91]:
print(real_truth_d[1].shape)

print(real_truth_d[0].shape)

(20, 16, 32, 32)
(20, 16, 32, 32)


In [92]:
print(real_predict[1].shape)
print(real_predict[0].shape)                                

(20, 16, 32, 32)
(20, 16, 32, 32)


In [93]:
print(m.shape)
print(m.size)
real_predict[0].shape
# len(y0)

(16384,)
16384


(20, 16, 32, 32)

In [96]:
import numpy as np

# Define x0 and y0 before using them
x0 = np.arange(0, xrange[1] - num_obs * length + 2 * length - xrange[0], length) / length
y0 = np.arange(0, yrange[1] - num_obs * length + 2 * length - yrange[0], length) / length

# Convert list to NumPy array first
real_predict = np.array(real_predict)  # Now shape should be (N, 16, 32, 32)

# Reshape the batch
real_predict = real_predict.reshape(-1, 16, 32, 32)

# Validate patch count
assert real_predict.shape[0] == len(x0) * len(y0),\
    f"Mismatch: {real_predict.shape[0]} predictions vs {len(x0) * len(y0)} patches"

# Initialize accumulators
m_m = np.zeros((16, y_num, x_num), dtype=float)
m_times = np.zeros((16, y_num, x_num))

n = 0
for i in range(len(x0)):
    for j in range(len(y0)):
        m = real_predict[n]  # shape: (16, 32, 32)
        n += 1

        start_x = int(x0[i])
        start_y = int(y0[j])

        for o in range(16):
            for p in range(num_cell):
                for q in range(num_cell):
                    val = float(m[o, p, q])
                    m_m[o, start_y + p, start_x + q] += val
                    m_times[o, start_y + p, start_x + q] += 1

# Final averaging
mask = m_times != 0
m_m[mask] = m_m[mask] / m_times[mask]


In [97]:
Model(m_m.T, 0.5)
plt.savefig("real_Newaaa1.png", dpi=300, bbox_inches='tight')
plt.close()

<IPython.core.display.Javascript object>

The reconstructed subsurface is shown in two cross-sections (Northing = -400 m and Easting = -1700 m)

In [98]:
index = int((-400-yrange[0])/length)
plt.figure(figsize = (10.5, 4))
plot_xoz(np.around(m_m, decimals=1), index=index)
plt.savefig("real_Newaaa2.png", dpi=300, bbox_inches='tight')
plt.close()

<IPython.core.display.Javascript object>

In [99]:
index = int((-1700-xrange[0])/length)
plt.figure(figsize = (8.75, 4))
plot_yoz(np.around(m_m, decimals=1), index=index)
plt.savefig("real_Newaaa3.png", dpi=300, bbox_inches='tight')
plt.close()

<IPython.core.display.Javascript object>

In [100]:
plt.figure(figsize = (7.2, 6))
plot_xoy(np.around(m_m, decimals=1), index=5)

<IPython.core.display.Javascript object>