In [1]:
import numpy as np
import torch, scipy

torch.set_default_dtype(torch.float32)
torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
import torch.nn as nn

class Gradient(nn.Module):
    '''A network block that output the gradient by the displace u
        The input is (batch, 2, N*N), the channels are for displacement [ux, uy] 
        The output is (batch, 4, Ngrad):[ux_1, uy_1, ux_2, uy_2]
    '''
    def __init__(self, Gradx, Grady):
        super().__init__()
        self.Gradx = Gradx
        self.Grady = Grady
    def forward(self, x):
        gx = torch.einsum('lk, ijk -> ijl', self.Gradx, x)
        gy = torch.einsum('lk, ijk -> ijl', self.Grady, x)
        return torch.cat((gx, gy), dim=1)

class InnerEnergy(nn.Module):
    '''A network block that output the inner energy
        The input is (batch, 4, Ngrad) and (batch, N*N), the first argument is for displacement gradient, 
        the last one is for  Young's modulus 
        The output is (batch,)
    '''
    def __init__(self, CG2DG, Vitg):
        super().__init__()
        self.CG2DG = CG2DG
        self.Vitg = Vitg
    def forward(self, x, E):
        Em = torch.einsum('ij, kj -> ki', self.CG2DG, E)
        nu = 0.3
        nu1, nu2 = 1.0/(2*(1 + nu)), nu/((1 + nu)*(1 - 2*nu))
        mu, lmbda = Em*nu1, Em*nu2
        J = (1+x[:,0,:])*(1+x[:,3,:]) - x[:,1,:]*x[:,2,:]
        Ic = (1+x[:,0,:])**2 + (1+x[:,3,:])**2 + x[:,1,:]**2 + x[:,2,:]**2
        psi = 0.5*mu*(Ic - 2) - mu*torch.log(J) + 0.5*lmbda*((torch.log(J))**2)
        Psi = torch.einsum('j, ij -> i', self.Vitg, psi)
        return Psi
    
class OuterEnergy(nn.Module):
    '''A network block that output the inner energy
        The input is (batch, 2, N*N), the channels are for displacement [ux, uy] 
        The output is (batch,)
    '''
    def __init__(self, S_X, S_Y):
        super().__init__()
        self.S_X = S_X
        self.S_Y = S_Y
    def forward(self, x):
        xx = x[:,0,...].reshape((x.shape[0],-1))
        xy = x[:,1,...].reshape((x.shape[0],-1))
        ox = torch.einsum('j, ij -> i', self.S_X, xx)
        oy = torch.einsum('j, ij -> i', self.S_Y, xy)
        return ox + oy

class SED(nn.Module):
    '''A network block that output the strain energy density by the displace u and Young's modulus
        The input is (batch, 2, N, N) and (batch, N, N), the first argument is for displacement, 
        the last one is for  Young's modulus
    '''
    def __init__(self, Gradx, Grady, CG2DG, Vitg, S_X, S_Y):
        super().__init__()
        self.CG2DG = CG2DG
        self.Vitg = Vitg
        self.S_X = S_X
        self.S_Y = S_Y
        self.Grad = Gradient(Gradx, Grady)   
        self.Inner = InnerEnergy(CG2DG, Vitg)
        self.Outer = OuterEnergy(S_X, S_Y)
    def forward(self, x, E):
        xr = x.reshape((x.shape[0], x.shape[1], -1))
        Em = (95.0*E + 5.0).reshape((E.shape[0], -1))
        xg = self.Grad(xr)
        IPsi = self.Inner(xg, Em)
        OPsi = self.Outer(xr)
        return IPsi + OPsi
        
class Dirichlet(nn.Module):
    '''A network block that output the boundary sum of the Dirichlet boundary
        The input is (batch, 2, N, N)
        The output is (batch,)
    '''      
    def __init__(self, Mfix):
        super().__init__()
        self.Mfix = Mfix
    def forward(self, x):
        xr = x.reshape((x.shape[0], x.shape[1], -1))
        diri = torch.einsum('ij, klj -> kli', self.Mfix, xr)
        return torch.sum(torch.square(diri), (1,2))
    
class Observe(nn.Module):
    '''A network block that output the observation
        The input is (batch, 2, N, N)
        The output is (batch, 2, Nobs)
    '''      
    def __init__(self, Mobs):
        super().__init__()
        self.Mobs = Mobs
    def forward(self, x):
        xr = x.reshape((x.shape[0], x.shape[1], -1))
        obs = torch.einsum('ij, klj -> kli', self.Mobs, xr)
        return obs
        

In [3]:
import os, pickle

with open(os.path.join(os.path.abspath('.'), 'KnownData', 'Losscomp' + '.pickle'), 'rb') as file:
    A2N, N2A, Ma2n, Mn2a, Mobs, Mfix, CG2DG, Gradx, Grady, Vitg, S_X, S_Y = pickle.load(file)
    
with open(os.path.join(os.path.abspath('.'), 'KnownData', 'ParaSam' + '.pickle'), 'rb') as file:
    sam, Tsam = pickle.load(file)

with open(os.path.join(os.path.abspath('.'), 'KnownData', 'USam' + '.pickle'), 'rb') as file:
    u_sam, obs_sam = pickle.load(file)

with open(os.path.join(os.path.abspath('.'), 'KnownData', 'UTSam' + '.pickle'), 'rb') as file:
    u_Tsam, obs_Tsam = pickle.load(file)
    
with open(os.path.join(os.path.abspath('.'), 'KnownData', 'ParaSam_ve' + '.pickle'), 'rb') as file:
    sam_ve = pickle.load(file)

sam = torch.from_numpy(sam).float()
Tsam = torch.from_numpy(Tsam).float() 
u_sam = torch.from_numpy(u_sam).float()
obs_sam = torch.from_numpy(obs_sam).float()
u_Tsam = torch.from_numpy(u_Tsam).float()
obs_Tsam = torch.from_numpy(obs_Tsam).float()


Mobst = torch.from_numpy(Mobs.todense()).float().to(device)
Mfixt = torch.from_numpy(Mfix.todense()).float().to(device)
CG2DGt = torch.from_numpy(CG2DG).float().to(device)
Gradxt = torch.from_numpy(Gradx).float().to(device)
Gradyt = torch.from_numpy(Grady).float().to(device)
Vitgt = torch.from_numpy(Vitg).float().to(device)
S_Xt = torch.from_numpy(S_X).float().to(device)
S_Yt = torch.from_numpy(S_Y).float().to(device)

u = u_sam[112:114,...].clone()
uobs = obs_sam[112:114,...].clone()
E = sam[112:114,0,...].clone()

ut = u.to(device)
uobst = uobs.to(device)
Et = E.to(device)

  A2N, N2A, Ma2n, Mn2a, Mobs, Mfix, CG2DG, Gradx, Grady, Vitg, S_X, S_Y = pickle.load(file)
  A2N, N2A, Ma2n, Mn2a, Mobs, Mfix, CG2DG, Gradx, Grady, Vitg, S_X, S_Y = pickle.load(file)


In [4]:
sed = SED(Gradxt, Gradyt, CG2DGt, Vitgt, S_Xt, S_Yt).to(device)
diri = Dirichlet(Mfixt).to(device)
obs = Observe(Mobst).to(device)

In [5]:
# print(sed(ut, Et))
# print(diri(ut)*10000)
# print(obs(ut)*100)

In [6]:
def test(u, E):
    ux = u[0,...].reshape(-1)
    uy = u[1,...].reshape(-1)
    guxm = Gradx.dot(ux)
    guym = Grady.dot(ux)
    gvxm = Gradx.dot(uy)
    gvym = Grady.dot(uy)
    Em = E.reshape(-1)
    Em = Em*95.0 + 5.0
    Em = CG2DG.dot(Em)
    psi = np.zeros(Em.shape[0])
    for i in range(guxm.shape[0]):
        J = (1+guxm[i])*(1+gvym[i]) - guym[i]*gvxm[i]
        Ic = (1+guxm[i])**2 + (1+gvym[i])**2 + guym[i]**2 + gvxm[i]**2
        nu = 0.3
        mu, lmbda = Em[i]/(2*(1 + nu)), Em[i]*nu/((1 + nu)*(1 - 2*nu))
        psi[i] = (mu/2)*(Ic - 2) - mu*np.log(J) + (lmbda/2)*(np.log(J))**2
    Psi = np.dot(Vitg, psi) + np.dot(S_X, ux) + np.dot(S_Y, uy)
    dirux = Mfix.dot(ux)
    diruy = Mfix.dot(uy)
    dirv = np.square(dirux).sum() + np.square(diruy).sum()
    obsux = Mobs.dot(ux)
    obsuy = Mobs.dot(uy)
    print(Psi)
    print(dirv*10000)
    # print(obsux*100)
    # print(obsuy*100)
# test(u[0,...], E[0,...])
# test(u[1,...], E[1,...])

In [7]:
def loss_sed(u, E):
    Es = torch.squeeze(E, 1)
    ls = sed(u*0.01, Es)*100
    ld = diri(u)
    return torch.mean(ls) + torch.mean(ld)

def loss_sup(u, u_ref, u_obs):
    la = (u-u_ref)**2
    la = torch.mean(la.reshape(la.shape[0], -1), dim=-1)
    lobs = (obs(u)-u_obs)**2
    lobs = torch.mean(lobs.reshape(lobs.shape[0], -1), dim=-1)
    return torch.mean(la+lobs)

In [8]:
uran = torch.randn(ut.shape).to(device)*0.01-0.03
print(loss_sed(uran, Et))

tensor(0.1097, device='cuda:0')


In [9]:
class BiasLayer(torch.nn.Module):
    def __init__(self, shape) -> None:
        super().__init__()
        bias_value = torch.randn(shape)
        self.bias_layer = torch.nn.Parameter(bias_value)
    
    def forward(self, x):
        return x + self.bias_layer[None, ...]

class PICNN(nn.Module):
    '''A convolution neural network to approximate the forward model.'''
    def __init__(self, resolution=32, channels=[1, 1, 2, 2, 4, 4, 8, 8, 16, 16, 8, 8, 4, 4, 2, 2], *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.N = resolution
        self.Np = resolution*resolution
        self.channels = channels
        #self.fc1 = nn.Conv2d(channels[0], channels[0], kernel_size=3, stride=1,
        #                     padding=1, bias=True)
        self.fc1 = BiasLayer((channels[0], self.N, self.N))
        self.conv1 = nn.Conv2d(channels[0], channels[1], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.res1 = nn.Conv2d(channels[0], channels[2], kernel_size=1, stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(channels[2], channels[3], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.res2 = nn.Conv2d(channels[2], channels[4], kernel_size=1, stride=1)
        self.pool4 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(channels[4], channels[5], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.conv6 = nn.Conv2d(channels[5], channels[6], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.res3 = nn.Conv2d(channels[4], channels[6], kernel_size=1, stride=1)
        self.pool6 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv7 = nn.Conv2d(channels[6], channels[7], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.conv8 = nn.Conv2d(channels[7], channels[8], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.res4 = nn.Conv2d(channels[6], channels[8], kernel_size=1, stride=1)
        self.pool8 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv9 = nn.Conv2d(channels[8], channels[9], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.conv10 = nn.Conv2d(channels[9], channels[10], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.res5 = nn.Conv2d(channels[8], channels[10], kernel_size=1, stride=1)
        self.pool10 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv11 = nn.Conv2d(channels[10], channels[11], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.conv12 = nn.Conv2d(channels[11], channels[12], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.res6 = nn.Conv2d(channels[10], channels[12], kernel_size=1, stride=1)
        self.pool12 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv13 = nn.Conv2d(channels[12], channels[13], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.conv14 = nn.Conv2d(channels[13], channels[14], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.res7 = nn.Conv2d(channels[12], channels[14], kernel_size=1, stride=1)
        self.pool14 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv15 = nn.Conv2d(channels[14], channels[15], kernel_size=9, stride=1, 
                               padding=4, bias=True)
        self.fc16 = nn.Conv2d(channels[15], channels[15], kernel_size=1, stride=1,
                              padding=0, bias=True)
        #self.fc16 = BiasLayer((channels[15], self.N, self.N))
        self.act = nn.SiLU()
    def forward(self, x):
        y = self.act(self.fc1(x))
        y_ = self.conv2(self.act(self.conv1(y)))
        y = self.act(y_ + self.res1(y))
        # y = self.pool2(y)
        y_ = self.conv4(self.act(self.conv3(y)))
        y = self.act(y_ + self.res2(y))
        # y = self.pool4(y)
        y_ = self.conv6(self.act(self.conv5(y)))
        y = self.act(y_ + self.res3(y))
        # y = self.pool6(y)
        y_ = self.conv8(self.act(self.conv7(y)))
        y = self.act(y_ + self.res4(y))
        # y = self.pool8(y)
        y_ = self.conv10(self.act(self.conv9(y)))
        y = self.act(y_ + self.res5(y))
        # y = self.pool10(y)
        y_ = self.conv12(self.act(self.conv11(y)))
        y = self.act(y_ + self.res6(y))
        # y = self.pool12(y)
        y_ = self.conv14(self.act(self.conv13(y)))
        y = self.act(y_ + self.res7(y))
        # y = self.pool14(y)
        y = self.act(self.conv15(y))
        y= self.fc16(y)
        return y
        

In [10]:
configfw = {  
            'n_epochsup': 1500, # number of training epochs for supervised learning
            'n_epochuns': 250, # number of training epochs for unsupervised learning
            'n_epochsem': 400, # number of training epochs for semisupervised learning
            'batch_size': 32, # size of a mini-batch
            'learning_rate': 1.8e-3, # learning rate
            'learning_rate_sem': 5e-5, # learning rate for semisupervised learning
            'ema_decay': 0.999, # decay rate for Exponential Moving Average 
            'lr_decay': 0.9,
            'lr_threshold': 1e-5,
            'lr_min': 5e-5,
            'lr_min_sem': 5e-6
            }

In [11]:
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm.notebook import trange
class FM():
    '''The forward model'''
    def __init__(self, pdeloss, suploss, Mobst, config) -> None:
        self.network = PICNN().to(device)
        self.obs = Observe(Mobst).to(device)
        self.pdeloss = pdeloss
        self.suploss = suploss
        self.n_epochsup = config['n_epochsup']
        self.n_epochuns = config['n_epochuns']
        self.n_epochsem = config['n_epochsem']
        self.batch_size = config['batch_size']
        self.lr = config['learning_rate']
        self.lr_s = config['learning_rate_sem']
        self.config = config
    def load_para_sup(self):
        self.network.load_state_dict(torch.load(os.path.join(os.path.abspath('.'), 'NNfw_para_sup.pth')))
        self.network.eval()
    def load_para_uns(self):
        self.network.load_state_dict(torch.load(os.path.join(os.path.abspath('.'), 'NNfw_para_uns.pth')))
        self.network.eval()
    def load_para(self):
        self.network.load_state_dict(torch.load(os.path.join(os.path.abspath('.'), 'NNfw_para.pth')))
        self.network.eval()
    def evaluate(self, E):
        self.network.eval()
        if E.dim() == 2:
            Et  = E[None,None,...]
        elif E.dim() == 3:
            Et = E[:,None,...]
        else:
            Et = E
        return self.network(Et)*0.01
    def observe(self, E):
        u = self.evaluate(E)
        return self.obs(u)
    def derivative(self, E):
        self.network.eval()
        Et = E.reshape(-1)
        Et.requires_grad_(True)
        Er = Et.reshape((1,1,self.network.N, self.network.N))
        J = torch.empty(self.obs.Mobs.shape[0], self.network.Np)
        obs = self.obs(self.network(Er))
        obs.reshape(-1)
        for i in range(obs.shape[0]):
            obs[0,i].backward(retain_graph=True)
            t = Et.grad
            J[i,:] = t[0,:]
            Et.grad.data.zero_()
        return J
    def supervised_train(self, dataset):
        data_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

        optimizer = Adam(self.network.parameters(), lr=self.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=self.config['lr_decay'], patience=120, 
                                      threshold=self.config['lr_threshold'], threshold_mode='rel', 
                                      cooldown=200, min_lr=self.config['lr_min'])
        tqdm_epoch = trange(self.n_epochsup)
        
        self.network.train()
        for epoch in tqdm_epoch:
            avg_loss = 0.
            num_items = 0
            for x, y, z in data_loader:
                x = torch.tensor(x, device=device) # x.to(device) 
                y = torch.tensor(y, device=device)
                z = torch.tensor(z, device=device)
                yp = self.network(x)   
                loss = self.suploss(yp, y, z)
                optimizer.zero_grad()
                loss.backward()    
                optimizer.step()
                scheduler.step(loss)
                avg_loss += loss.item() * x.shape[0]
                num_items += x.shape[0]
            # Print the averaged training loss so far.
            tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items * 1000))
            # Update the checkpoint after each epoch of training.
            torch.save(self.network.state_dict(), 'NNfw_para_sup.pth')
        self.network.eval() 
    def unsupervised_train(self, dataset):
        data_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

        optimizer = Adam(self.network.parameters(), lr=self.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=self.config['lr_decay'], patience=120, 
                                      threshold=self.config['lr_threshold'], threshold_mode='rel', 
                                      cooldown=200, min_lr=self.config['lr_min'])
        tqdm_epoch = trange(self.n_epochuns)
        
        self.network.train()
        n = 0
        for epoch in tqdm_epoch:
            avg_loss = 0.
            num_items = 0
            for x, y in data_loader:
                n+=1
                x = torch.tensor(x, device=device) # x.to(device)   
                yp = self.network(x)    
                loss = self.pdeloss(yp, x)
                optimizer.zero_grad()
                loss.backward()    
                optimizer.step()
                scheduler.step(loss)
                avg_loss += loss.item() * x.shape[0]
                num_items += x.shape[0]
            # Print the averaged training loss so far.
            tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items * 10))
            # Update the checkpoint after each epoch of training.
            torch.save(self.network.state_dict(), 'NNfw_para_uns.pth')
        self.network.eval()   
    def semisupervised_train(self, dataset_sup, dataset_uns, if_pretrained = True):
        data_loader_sup = DataLoader(dataset_sup, batch_size=self.batch_size, shuffle=True, num_workers=4)
        data_loader_uns = DataLoader(dataset_uns, batch_size=self.batch_size, shuffle=True, num_workers=4)

        if if_pretrained:
            optimizer = Adam(self.network.parameters(), lr=self.lr_s)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=self.config['lr_decay'], patience=120, 
                                        threshold=self.config['lr_threshold'], threshold_mode='rel', 
                                        cooldown=200, min_lr=self.config['lr_min_sem'])
        else:
            optimizer = Adam(self.network.parameters(), lr=self.lr)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=self.config['lr_decay'], patience=120, 
                                        threshold=self.config['lr_threshold'], threshold_mode='rel', 
                                        cooldown=200, min_lr=self.config['lr_min'])
        tqdm_epoch = trange(self.n_epochsem)
        
        self.network.train()
        dataloader_iterator = iter(data_loader_sup)
        for epoch in tqdm_epoch:
            avg_loss = 0.
            avg_loss1 = 0.
            num_items = 0
            num_items1 = 0
            for x, y in data_loader_uns:
                try:
                    x1, y1, z1 = next(dataloader_iterator)
                except StopIteration:
                    dataloader_iterator = iter(data_loader_sup)
                    avg_loss1 = 0.
                    num_items1 = 0
                    x1, y1, z1 = next(dataloader_iterator)
                x = torch.tensor(x, device=device) # x.to(device) 
                x1 = torch.tensor(x1, device=device)   
                y1 = torch.tensor(y1, device=device)   
                z1 = torch.tensor(z1, device=device)  
                yp = self.network(x)  
                yp1 = self.network(x1)  
                loss = self.pdeloss(yp, x)
                loss1 = self.suploss(yp1, y1, z1)
                losst = 5*loss1 + 0.05*loss
                optimizer.zero_grad()
                losst.backward()    
                optimizer.step()
                scheduler.step(losst)
                avg_loss += loss.item() * x.shape[0]
                avg_loss1 += loss1.item() * x1.shape[0]
                num_items += x.shape[0]
                num_items1 += x1.shape[0]
            tqdm_epoch.set_description('Average Loss: {:5f} = {:5f} + {:5f}'.format(
                (avg_loss/num_items+avg_loss1/num_items1)*1000, avg_loss/num_items*10, avg_loss1/num_items1*1000))
            torch.save(self.network.state_dict(), 'NNfw_para.pth')
        self.network.eval() 
        

In [12]:
from torch.utils.data import TensorDataset

Forward_Model = FM(loss_sed, loss_sup, Mobst, configfw)
dataset_sup = TensorDataset(sam[:u_sam.shape[0],...], u_sam*100, obs_sam*100)
sam_uns = torch.cat((sam, sam_ve), 0)
sam_uns = sam_uns.clamp(0.0, 1.0)
dataset_uns = TensorDataset(sam_uns[:2**15], torch.empty(2**15))
#Forward_Model.supervised_train(dataset_sup)
#Forward_Model.load_para_sup()
#Forward_Model.unsupervised_train(dataset_uns)
#Forward_Model.load_para_sup()
#Forward_Model.semisupervised_train(dataset_sup, dataset_uns)

In [13]:
dataset_T = TensorDataset(Tsam, u_Tsam*100, obs_Tsam*100)
def test_FM(Forward_Model, dataset):
    data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
    Forward_Model.network.eval()
    avg_loss = 0.
    num_items = 0
    i = 0
    with torch.no_grad():
        for x, y, z in data_loader:
            x = torch.tensor(x, device=device) # x.to(device) 
            y = torch.tensor(y, device=device)
            z = torch.tensor(z, device=device)
            yp = Forward_Model.network(x)   
            loss = loss_sup(yp, y, z)
            avg_loss += loss.item() * x.shape[0]
            num_items += x.shape[0]
            op = obs(yp)
            if i == 0:
                diff = (op - z).cpu()
            else:
                dif = (op - z).cpu()
                diff = torch.cat((diff, dif), 0)
            i += 1
        qtl = torch.tensor([0.95, 0.5, 0.05])
        err_mean = torch.sqrt(torch.mean(diff**2, 0))
        err_std = torch.std(diff, 0)
        err_max = torch.sqrt(torch.amax(diff**2, 0))
        err_min = torch.sqrt(torch.amin(diff**2, 0))
        err_q = torch.sqrt(torch.quantile(diff**2, qtl, dim=0))
        print('Average Test Loss: {:5f}'.format(avg_loss/num_items * 1000))
        print(err_mean)
        print(err_std)
        print(err_max)
        print(err_min)
        print(err_q)
    return torch.mean(err_mean), torch.mean(err_std)

In [14]:
Forward_Model.load_para()
err_mean, err_std = test_FM(Forward_Model, dataset_T)
print(err_mean, err_std)

  x = torch.tensor(x, device=device) # x.to(device)
  y = torch.tensor(y, device=device)
  z = torch.tensor(z, device=device)


Average Test Loss: 3.567199
tensor([[0.0172, 0.0069, 0.0196, 0.0080, 0.0183, 0.0097, 0.0103, 0.0132, 0.0175,
         0.0139, 0.0138, 0.0133, 0.0131, 0.0148, 0.0135, 0.0156, 0.0147, 0.0161,
         0.0156, 0.0172, 0.0165, 0.0170, 0.0174, 0.0174, 0.0193, 0.0184, 0.0220,
         0.0194, 0.0255, 0.0210, 0.0301, 0.0230, 0.0345, 0.0263, 0.0380, 0.0310,
         0.0391, 0.0375, 0.0430, 0.0399, 0.0454, 0.0413, 0.0448, 0.0436, 0.0448,
         0.0460, 0.0469, 0.0503, 0.0506, 0.0538, 0.0538, 0.0588, 0.0576, 0.0655,
         0.0618, 0.0678, 0.0617, 0.0710, 0.0645, 0.0796, 0.0809, 0.0740, 0.0743,
         0.0722, 0.0661, 0.0647, 0.0643, 0.0628, 0.0619, 0.0600, 0.0597, 0.0582,
         0.0557, 0.0525, 0.0523, 0.0527, 0.0524, 0.0517, 0.0527, 0.0538, 0.0562,
         0.0591, 0.0600, 0.0598, 0.0609, 0.0616, 0.0619, 0.0671, 0.0707, 0.0740,
         0.0783, 0.0977],
        [0.0261, 0.0205, 0.0283, 0.0234, 0.0252, 0.0218, 0.0153, 0.0191, 0.0196,
         0.0177, 0.0153, 0.0175, 0.0173, 0.0159, 0.0192

In [15]:
Forward_Model.load_para_sup()
err_mean, err_std = test_FM(Forward_Model, dataset_T)
print(err_mean, err_std)

Average Test Loss: 5.315718
tensor([[0.0127, 0.0072, 0.0164, 0.0092, 0.0133, 0.0135, 0.0111, 0.0142, 0.0185,
         0.0136, 0.0144, 0.0216, 0.0162, 0.0175, 0.0152, 0.0167, 0.0162, 0.0178,
         0.0173, 0.0222, 0.0192, 0.0212, 0.0199, 0.0219, 0.0216, 0.0223, 0.0253,
         0.0267, 0.0317, 0.0289, 0.0376, 0.0308, 0.0413, 0.0343, 0.0445, 0.0390,
         0.0484, 0.0478, 0.0527, 0.0509, 0.0544, 0.0503, 0.0543, 0.0533, 0.0544,
         0.0560, 0.0561, 0.0607, 0.0597, 0.0650, 0.0597, 0.0690, 0.0646, 0.0790,
         0.0688, 0.0803, 0.0718, 0.0864, 0.0779, 0.0966, 0.0992, 0.0890, 0.0958,
         0.0866, 0.0810, 0.0768, 0.0759, 0.0752, 0.0720, 0.0727, 0.0688, 0.0713,
         0.0685, 0.0674, 0.0672, 0.0636, 0.0627, 0.0651, 0.0666, 0.0681, 0.0696,
         0.0749, 0.0750, 0.0770, 0.0792, 0.0789, 0.0784, 0.0869, 0.0848, 0.0899,
         0.0972, 0.1209],
        [0.0191, 0.0145, 0.0273, 0.0185, 0.0264, 0.0200, 0.0152, 0.0180, 0.0192,
         0.0187, 0.0167, 0.0194, 0.0182, 0.0185, 0.0200

  x = torch.tensor(x, device=device) # x.to(device)
  y = torch.tensor(y, device=device)
  z = torch.tensor(z, device=device)


In [16]:
Forward_Model.load_para()
samran = torch.randn(1, 1, 32, 32, device=device)
samran2 = torch.randn(1, 1, 32, 32, device=device)*10
yp1 = Forward_Model.network(samran)
yp2 = Forward_Model.network(samran2)
o1 = obs(yp1)
o2 = obs(yp2)
print(o1)
print(o2)

tensor([[[-1.6675e-01,  4.8080e-02, -2.3354e-01, -3.2537e-04, -2.2117e-01,
          -1.1583e-01, -1.4750e-01, -3.6622e-01, -2.6865e-02, -7.3794e-01,
           7.1534e-02, -1.2144e+00,  1.7536e-01, -1.8057e+00,  3.0486e-01,
          -2.3247e+00,  4.9611e-01, -2.6890e+00,  7.0414e-01, -3.1587e+00,
           8.9178e-01, -3.4436e+00,  1.0413e+00, -3.3143e+00,  1.1463e+00,
          -2.6680e+00,  1.2334e+00, -1.6627e+00,  1.3667e+00, -7.3424e-01,
           1.5868e+00, -3.6029e-01,  1.8900e+00, -2.4546e-01,  2.2503e+00,
          -2.7555e-01,  2.6248e+00, -2.6100e-01,  2.9327e+00, -1.7228e-01,
           3.1716e+00, -1.7399e-01,  3.3772e+00, -4.5975e-01,  3.5024e+00,
          -9.6460e-01,  3.4526e+00, -1.4400e+00,  3.1929e+00, -1.3165e+00,
           2.8034e+00, -5.8132e-01,  2.4184e+00,  2.2955e-01,  2.1507e+00,
           7.5141e-01,  2.0038e+00,  5.3017e-01,  2.0271e+00, -2.7533e-01,
           2.1479e+00,  2.2445e+00,  2.2200e+00,  2.1803e+00,  1.9143e+00,
           1.8649e+00,  1