# Import and Mount

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.autograd import Variable, grad
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from torch.utils import data
from torch.distributions import MultivariateNormal
from torch.nn.utils import weight_norm
from torchvision import models
import torchvision.utils as vutils

try:
    from torchinfo import summary
except ImportError:
    !pip install torchinfo
    from torchinfo import summary

try:
    import mat73
except ImportError:
    !pip install mat73

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import sys
import os
import time
import math
from collections import defaultdict
from timeit import default_timer
import random

from google.colab import drive
drive.mount('/content/gdrive')

work_dir = './RGA-FNO-HT-INV'

os.chdir(work_dir)
!pwd

from plot_utils.plotslib import *
from FNO.utils import _get_act, add_padding2, remove_padding2
from FNO.utilities3 import UnitGaussianNormalizer, LpLoss
from FNO.basics import SpectralConv2d


In [None]:
# Set random seed for reproducibility
manualSeed = 999
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)

def try_gpu(i=0):
  """Return gpu(i) if exists, otherwise return cpu()."""
  if torch.cuda.device_count() >= i + 1:
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    return torch.device(f'cuda:{i}')
  return torch.device('cpu')

device0 = try_gpu()
print(device0)

# Load data

In [None]:
data_dir = work_dir + '/Data/'

model_dir = work_dir + '/models/'
fig_dir = work_dir + '/figs/'

try:
    os.makedirs(model_dir)
    os.makedirs(fig_dir)
except: Exception

In [None]:
fname = "Exponential_steady_state"
fname = "Exponential_steady_state"
filename = data_dir + fname + "_fields.npz"
with np.load(filename) as npzfile:
    realizations = npzfile['realizations']
    alpha = npzfile['alpha']
    Z = npzfile['Z']
    Q = npzfile['Q'][0]
    pump_id_list = npzfile['pump_id_list']
    for k in ['lx', 'ly', 'sigma2']:
        if k in npzfile.keys(): print(k, npzfile[k])
filename = data_dir + fname + "_heads.npy"
all_heads = np.load(filename, mmap_mode='r')
print(Q)
print(pump_id_list)
print(alpha.shape)
print(Z.shape)

print(realizations.shape)
print(all_heads.shape)

NR=realizations.shape[0]


# Experimental Domain
## Unit, Zero-centered

In [None]:
############### define domain with (0,0) at center ######
nx = ny = int(math.sqrt(all_heads.shape[1]))
dx_real = 5.0/(nx/64.0)
dt_real = 0.1

Lox, Loy = 1, 1
dx, dy = Lox/nx, Loy/ny

x = np.arange((-Lox/2+dx/2),(Lox/2),dx)
y = np.arange((-Lox/2+dx/2),(Lox/2),dy)

Xm, Ym = np.meshgrid(x,y)

X_star = np.hstack((Xm.flatten()[:,None], Ym.flatten()[:,None]))


# Fourier Neural Operator

In [None]:
class FNO2d(nn.Module):
    def __init__(self, modes1, modes2,
                 width=64, fc_dim=128,
                 layers=None,
                 in_dim=3, out_dim=1,
                 act='gelu',
                 pad_ratio=[0., 0.]):
        super(FNO2d, self).__init__()
        if isinstance(pad_ratio, float):
            pad_ratio = [pad_ratio, pad_ratio]
        else:
            assert len(pad_ratio) == 2, 'Cannot add padding in more than 2 directions'
        self.modes1 = modes1
        self.modes2 = modes2

        self.pad_ratio = pad_ratio
        if layers is None:
            self.layers = [width] * (len(modes1) + 1)
        else:
            self.layers = layers
        self.fc0 = nn.Linear(in_dim, layers[0])

        self.sp_convs = nn.ModuleList([SpectralConv2d(
            in_size, out_size, mode1_num, mode2_num)
            for in_size, out_size, mode1_num, mode2_num
            in zip(self.layers, self.layers[1:], self.modes1, self.modes2)])

        self.ws = nn.ModuleList([nn.Conv1d(in_size, out_size, 1)
                                 for in_size, out_size in zip(self.layers, self.layers[1:])])

        self.fc1 = nn.Linear(layers[-1], fc_dim)
        self.fc2 = nn.Linear(fc_dim, layers[-1])
        self.fc3 = nn.Linear(layers[-1], out_dim)
        self.act = _get_act(act)

    def forward(self, x):
        size_1, size_2 = x.shape[1], x.shape[2]
        if max(self.pad_ratio) > 0:
            num_pad1 = [round(i * size_1) for i in self.pad_ratio]
            num_pad2 = [round(i * size_2) for i in self.pad_ratio]
        else:
            num_pad1 = num_pad2 = [0.]

        length = len(self.ws)
        batchsize = x.shape[0]
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 3, 1, 2)   # B, C, X, Y
        x = add_padding2(x, num_pad1, num_pad2)
        size_x, size_y = x.shape[-2], x.shape[-1]

        for i, (speconv, w) in enumerate(zip(self.sp_convs, self.ws)):
            x1 = speconv(x)
            x2 = w(x.view(batchsize, self.layers[i], -1)).view(batchsize, self.layers[i+1], size_x, size_y)
            x = x1 + x2
            if i != length - 1:
                x = self.act(x)
        x = remove_padding2(x, num_pad1, num_pad2)
        x = x.permute(0, 2, 3, 1)
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.act(x)
        x = self.fc3(x)
        return x

    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).to(device)

##FNO Forward Model


In [None]:
config = defaultdict(dict)
config['model'] = {
    'modes1':[20]*4,
    'modes2':[20]*4,
    'fc_dim':128,
    'layers':[64]*5,
    'pad_ratio':[1,1],
    'out_dim':5,
    'act':'gelu'
}
FONet = FNO2d(
    modes1=config['model']['modes1'],
    modes2=config['model']['modes2'],
    fc_dim=config['model']['fc_dim'],
    layers=config['model']['layers'],
    pad_ratio=config['model']['pad_ratio'],
    out_dim=config['model']['out_dim'],
    act=config['model']['act']
).to(device0)


*load state variables from trained model*

In [None]:
FONet.load_state_dict(torch.load(model_dir+fname+'_FNO_HT.pt', map_location=device0))

for param in FONet.parameters():
  param.requires_grad = False
summary(FONet, (1,32,32,1))


## RGA Decoder


In [None]:
class FC_block(nn.Module):
  ''' encoder with CNN '''
  def __init__(self, in_feat, out_feat, params=None, normalize=False, act=nn.Tanh(), exp=True):
      super(FC_block, self).__init__()
      self.layer = [nn.Linear(in_feat, out_feat)]
      if params is not None:
          # Normalization occurs inside the model
          self.register_buffer('W', torch.tensor(params['Z'],dtype=torch.float32))
          self.register_buffer('b', torch.tensor(params['mu'],dtype=torch.float32))
          with torch.no_grad():
              self.layer[0].weight = nn.Parameter(self.W)
              self.layer[0].bias = nn.Parameter(self.b)

      if normalize:
          self.layer.append(nn.BatchNorm1d(out_feat, 0.8))
      if act:
          self.layer.append(act)

      self.module = nn.Sequential(*self.layer)
      self.exp = exp
  def forward(self, x):
      return self.module(x).exp() if self.exp else self.module(x)

In [None]:
in_feat = 50
out_feat = nx*ny
params = {
    "Z":Z.T,
    "mu":np.zeros((1,))
}
KNet = FC_block(in_feat, out_feat, params,act=None,exp=True).to(device0)
for param in KNet.parameters():
    param.requires_grad = False


# Training data
*sample & normalizaton*

In [None]:
heads = all_heads[:,:,-1,:].reshape((NR,nx,ny,-1))

hmean = np.mean(heads,0)
hstd = np.std(heads,0)

Kmean = np.mean(np.exp(realizations),0)
Kstd = np.std(np.exp(realizations),0)

x_normalizer_recover = UnitGaussianNormalizer(torch.ones((0,))).to(device0)
x_normalizer_recover.mean = torch.tensor(Kmean[...],dtype=torch.float32)
x_normalizer_recover.std = torch.tensor(Kstd[...],dtype=torch.float32)
x_normalizer_recover.to(device0)
y_normalizer_recover = UnitGaussianNormalizer(torch.ones((0,)))
y_normalizer_recover.mean = torch.tensor(hmean[...],dtype=torch.float32)
y_normalizer_recover.std = torch.tensor(hstd[...],dtype=torch.float32)
y_normalizer_recover.to(device0)

*forward propagation and results of RGA-FNO*

In [None]:
shift = 10
inputs=torch.tensor(alpha[shift:shift+5,:],dtype=torch.float32,device=device0)
x_recover = KNet(inputs).view(-1,nx,ny).permute(0,2,1).to(device0)

x_recover = x_normalizer_recover.encode(x_recover)

outs = FONet(x_recover.unsqueeze(-1)).squeeze(-1)

outs = y_normalizer_recover.decode(outs)

ys = torch.tensor(heads[shift:shift+5,...], dtype=torch.float32)

preds = outs[...].detach().cpu().numpy()
refs = ys[...].detach().cpu().numpy()
fig = plot_forward_operator(realizations[shift:shift+10,...], preds[...,0], refs[...,0], cmaplnK='jet')


## Reference Data
1. measure HT data
2. add noise

In [None]:
#################  well network cell id ##################
print(int(nx/8))
xloc = np.arange(int(nx/4),int(nx*3/4+1),int(nx/8))
yloc = np.arange(int(ny/4),int(ny*3/4+1),int(ny/8))

xloc = np.tile(xloc,5)
yloc = np.repeat(yloc,5)

pump_id_list = [0,4,12,20,24]
pump_well_index = [xloc[pump_id_list], yloc[pump_id_list],[]]
print(pump_well_index)

dtb = int(195 / (1024/nx))
Hnum = 6
xloc = np.linspace(0+dtb,nx-dtb,Hnum ,dtype=int)
yloc = np.linspace(0+dtb,ny-dtb,Hnum ,dtype=int)
xloc = np.tile(xloc,Hnum)
yloc = np.repeat(yloc,Hnum)

HT_mask = torch.zeros(nx,ny,dtype=torch.bool)
HT_mask[xloc,yloc] = 1

HT_mask = HT_mask.unsqueeze(-1).repeat(1,1,1)

fid = 3

ytrue = torch.tensor(heads[fid,...],requires_grad=True, dtype=torch.float32)
ytrue_masked_real = torch.masked_select(ytrue,HT_mask)

ytrue_masked = ytrue_masked_real + torch.normal(mean=0.0,std=0.01*ytrue_masked_real.abs())


## Validation Data

In [None]:
dtb = int(50 / (1024/nx))
Knum = 6
xlin = np.linspace(0+dtb,nx-dtb,Knum ,dtype=int)
ylin = np.linspace(0+dtb,ny-dtb,Knum ,dtype=int)
xlin = np.tile(xlin,Knum)
ylin = np.repeat(ylin,Knum)

Ktrue = torch.tensor(np.exp(realizations[fid:fid+1]),dtype=torch.float32, device=device0)

K_mask = torch.zeros(1,nx,ny,dtype=torch.bool)
K_mask[0,xlin,ylin] = 1
K_masked = torch.masked_select(Ktrue,K_mask)


# Optimization

## Initial Guess

In [None]:
loc = torch.zeros((50,),dtype=torch.float32,requires_grad=True)
var = torch.ones((50,),dtype=torch.float32,requires_grad=True)

## Optimization Loop

In [None]:
torch.cuda.empty_cache()

In [None]:
batch_size=1
relu_layer = torch.nn.ReLU()
num_epoch = 5
p_intervals = 1

loss_func = nn.MSELoss()

optimizer = torch.optim.SGD(
    [{'params': [var], 'lr':1e1, 'weight_decay':0 },
      {'params': [loc], 'lr':1e1, 'weight_decay':0 }],
    lr=1e-2,
    momentum=0.9
)
scheduler =  torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1, last_epoch=-1)

start_time = time.time()
for epoch in range(num_epoch):
    loss = 0.0
    optimizer.zero_grad()

    alpha_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc, torch.diag(relu_layer(var)+1e-6))
    initial_guess_alpha = alpha_dist.rsample((100,)).mean(axis=0).requires_grad_(True)

    Krecover = KNet(initial_guess_alpha.unsqueeze(0)).reshape(batch_size,nx,ny).permute(0,2,1)
    l1 = 1e0*loss_func(Krecover, Ktrue)

    Krecover = x_normalizer_recover.encode(Krecover)
    out = FONet(Krecover.unsqueeze(-1)).squeeze(-1)
    out = y_normalizer_recover.decode(out)
    loss = 1e3*loss_func(torch.masked_select(out[0],HT_mask), ytrue_masked)

    loss.backward(retain_graph=True)

    optimizer.step()
    if epoch < 100:
        scheduler.step()
    if epoch%p_intervals==0 or epoch==num_epoch-1:
        elapsed = time.time() - start_time
        print("Epoch#: %d" % epoch, end="\t")
        print("Loss1: %.4f" % l1.item(), end="\t")
        print("Loss2: %.4f" % loss.item(), end="\t")
        print("Time: %.4f" % elapsed, end="\t")
        print("LR1: %f" % optimizer.param_groups[0]['lr'], end="\n")

## Save Results

In [None]:

# torch.save([loc, var], model_dir+"/" +fname+"_"+str(fid)+"_loc_var_FNO.pt")

# loc, var = torch.load(model_dir+"/" +fname + "_"+str(fid)+"_loc_var_FNO.pt",map_location=device0)


# Plot all results

In [None]:
relu_layer = torch.nn.ReLU()
alpha_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc, torch.diag(relu_layer(var)+1e-6))

################### hydraulic conductivity colormap plot: 2D ##################
initial_guess_alpha = alpha_dist.rsample((100,)).detach().cpu().numpy()

logK_preds = np.matmul(initial_guess_alpha,Z)
logK_pred = np.mean(np.matmul(loc.detach().cpu().numpy()[None,:]\
                              ,Z),axis=0).reshape((nx,ny),order='F').flatten()

logK_true = np.matmul(alpha[fid,:],Z).reshape((nx,ny),order='F').flatten()

############################# metrics: accuracy with #############################
# threshold 10%
minlK, maxlK = np.min(logK_true), np.max(logK_true)
thres = 0.10
K_len = maxlK-minlK

acc = abs(logK_true-logK_pred)/K_len
acc = sum(acc<thres)/(nx*ny)
print(acc)

K_true = np.exp(logK_true).flatten()
K_pred = np.exp(logK_pred).flatten()
K_err = K_true - K_pred
error_K = np.linalg.norm(K_err,2)/np.linalg.norm(K_true,2)
print(error_K)


## Forward HT with Recovered Field

In [None]:
x_recover = KNet(loc.unsqueeze(0)).view(-1,nx,ny).permute(0,2,1)

x_recover = x_normalizer_recover.encode(x_recover)
outs = FONet(x_recover.unsqueeze(-1)).squeeze(-1)
outs = y_normalizer_recover.decode(outs)

preds = outs.detach().cpu().numpy()
refs = ytrue.reshape(1,nx,ny,-1).detach().cpu().numpy()

hmin, hmax = np.min(preds,axis=(1,2)), np.max(preds,axis=(1,2))

hmin *= 0.95
hmax *= 1.05


## Hydraulic Heads

In [None]:
u_pred_plot = np.array([0])
u_true_plot = np.array([0])
u_pred_plot_list = []
u_true_plot_list = []
for i in range(5):

    error_h = np.linalg.norm(refs[...,i].flatten()-preds[...,i].flatten(),2)/ \
                    np.linalg.norm(refs[...,i].flatten(),2)
    print(error_h)
    upred_masked = torch.masked_select(outs[...,i],HT_mask[...]).detach().cpu().numpy()
    utrue_masked = torch.masked_select(ytrue[...,i],HT_mask[...]).detach().cpu().numpy()
    u_pred_plot = np.hstack((u_pred_plot, upred_masked))
    u_true_plot = np.hstack((u_true_plot, utrue_masked))
    u_pred_plot_list.append(upred_masked)
    u_true_plot_list.append(utrue_masked)
u_pred_plot = u_pred_plot[1:]
u_true_plot = u_true_plot[1:]
var_model = np.sum((u_pred_plot - u_true_plot)**2)
var_data = np.sum((u_true_plot-np.mean(u_true_plot))**2)
r2_inverse_head = (var_data-var_model) / var_data
print(r2_inverse_head)

In [None]:
#set font size
axis_label_font_size = 15
axis_tick_font_size = 15
legend_fontszie = 15
colorbar_font_size = 15
title_size = 15
contour_size = 10

ti = 0

iter = 0

gridspec_kw=dict(wspace=0.55,hspace=0.5)
fig, axs = plt.subplots(3, 5,figsize=(25,13), gridspec_kw=gridspec_kw)
axs = axs.flatten()

ax = axs[0]
im = ax.pcolormesh(Xm,Ym,logK_true.reshape((nx,ny)),cmap='jet')
im.set_clim(minlK, maxlK)
ax.set_title('(A1). True $lnT$', fontsize=title_size)

ax = axs[1]
im = ax.pcolormesh(Xm,Ym,logK_pred.reshape((nx,ny)),cmap='jet')
im.set_clim(minlK, maxlK)
ax.set_title('(A2). Estimated $lnT$', fontsize=title_size)

for ax in axs[:2]:
    # Divide existing axes and create new axes
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="8%", pad=0.1)

    cbar = fig.colorbar(
        im, cax=cax,ticks=[round(x,1) for x in\
        np.arange(minlK*1.05, maxlK*1.05,(maxlK-minlK)/3)]
    )
    cbar.ax.tick_params(labelsize=colorbar_font_size)

ax = axs[2]
logK_var = np.var(logK_preds,axis=0)
im = ax.pcolormesh(Xm,Ym,logK_var.reshape((nx,ny)),cmap='jet')
ax.set_title('(D). Variance Map',fontsize=title_size)
# Divide existing axes and create new axes
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="8%", pad=0.1)
cbar = fig.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize=colorbar_font_size)

####################################################################################
####################################################################################
####################################################################################
####################################################################################
ax = axs[3]

im = ax.plot(np.arange(alpha[fid].shape[0]), alpha[fid], 'or', label='True')
im = ax.plot(
    np.arange(alpha[fid].shape[0]), loc.detach().cpu().numpy(),
    'b', label='Est'
)
ax.legend(
    loc='upper right',ncol=1,prop={'size': legend_fontszie}, framealpha=1,
    facecolor='none',borderpad=0.1,labelspacing=0.1,handletextpad=0.5,
    handlelength=0.2,columnspacing=0.01
)
ax.set_xlabel("index",fontsize=axis_label_font_size)
ax.set_ylabel(r"Latent variable $\alpha$",fontsize=axis_label_font_size)


ticks = [1, 10, 20, 30, 40, 50]
labels = [1, 10, 20, 30, 40, 50]
ax.set_xticks(ticks)
ax.set_xticklabels(labels,fontsize=axis_tick_font_size,ha='center')
ticks = [-3.0, -1.5, 0.0, 1.5, 3.0]
labels = ticks
ax.set_ylim([-3.5,3.5])
ax.set_yticks(ticks)
ax.set_yticklabels(labels,fontsize=axis_tick_font_size, ha='right', va='center')
ax.set_title('(E). true vs. estimates',fontsize=title_size)

####################################################################################
####################################################################################
####################################################################################
####################################################################################
ax = axs[4]
min_u_true_plot, max_u_true_plot = min(u_true_plot), -0.01 #max(u_true_plot)
ax.plot([min_u_true_plot, max_u_true_plot], [min_u_true_plot, max_u_true_plot],color='k')

for iii in range(5):
    ax.scatter(u_pred_plot_list[iii],u_true_plot_list[iii], label='p'+str(iii+1))

ax.legend(
    loc='lower right',ncol=1,prop={'size': legend_fontszie}, framealpha=0,\
    facecolor='none',borderpad=0.01,labelspacing=0.001,handletextpad=0.5, \
    handlelength=0.2,columnspacing=0.02
)
ax.set_xlabel('predicted heads [m]',fontsize=axis_label_font_size)
ax.set_ylabel('true heads [m]',fontsize=axis_label_font_size)

ticks = [-0.16, -0.10, -0.02]
labels = ticks #[18, 10, 2]
ax.set_xticks(ticks)
ax.set_yticks(ticks)

ax.set_xticklabels(labels,fontsize=axis_tick_font_size,ha='center')
ax.set_yticklabels(labels,fontsize=axis_tick_font_size,rotation=90, ha='right', va='center')

ax.set_title('(F). data vs. prediction',fontsize=title_size)

ax.text(
    -0.16, -0.05, '$R^2$=%.4f'%(r2_inverse_head), fontsize=15,
    bbox={'edgecolor':'w','facecolor':'w'}
)

####################################################################################
####################################################################################
####################################################################################
####################################################################################

preds = outs[iter,...].detach().cpu().numpy()
refs = ys[iter,...].detach().cpu().numpy()

hmin, hmax = np.min(preds,axis=(0,1)), np.max(preds,axis=(0,1))

hmin *= 0.95
hmax *= 1.00
for iid in range(5,10):

    ax = axs[iid]
    im = ax.pcolormesh(Xm, Ym, preds[...,iid-5], vmin=hmin[iid-5], vmax=hmax[iid-5])

    ax.set_title('(B%d) FNO p%d' % (iid-4,iid-4), fontsize=title_size)


    ax = axs[iid+5]
    im = ax.pcolormesh(Xm, Ym, refs[...,iid-5],vmin=hmin[iid-5], vmax=hmax[iid-5])

    ax.set_title('(C%d) True p%d' % (iid-4,iid-4), fontsize=title_size)

    for ax in axs[[iid,iid+5]]:
        # Divide existing axes and create new axes
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="8%", pad=0.1)

        cbar = fig.colorbar(im, cax=cax,ticks=[round(x,1) for x in\
                                              np.arange(hmin[iid-5],hmax[iid-5]*0.95,(hmax[iid-5]-hmin[iid-5])/4)])
        cbar.ax.tick_params(labelsize=colorbar_font_size)

        lvls = np.linspace(hmin[iid-5], hmax[iid-5],7)
        CP = ax.contour(Xm, Ym, refs[...,iid-5],levels=lvls,cmap="coolwarm")
        ax.clabel(CP,fontsize=contour_size,inline=True,inline_spacing=1,fmt='%.2f')


for ax in axs[np.delete(np.arange(15),[3,4])]:
    ######### x-axis name, ticks and labels #########
    ticks = [-0.5, 0.0, 0.5]
    labels = [0, 0.5, 1]
    ax.set_xlabel('x',fontsize=axis_label_font_size)
    ax.set_xticks(ticks)
    ax.set_xticklabels(labels,fontsize=axis_tick_font_size,ha='center')

    ######### y-axis name, ticks and labels #########
    ax.set_ylabel('y',fontsize=axis_label_font_size)
    ax.set_yticks(ticks)
    ax.set_yticklabels(labels,fontsize=axis_tick_font_size,rotation=90, ha='right', va='center')
