## Mscale FNO for wave prediction 

This notebook implement the multi-task leanring for Mscale FNO 3d model. 

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='2'
from neuralop.models import TFNO,FNO
import torch
import torch.nn.functional as F
import numpy as np
torch.manual_seed(0)
np.random.seed(0)
import logging
import datetime
import glob
from utilities3 import *

In [None]:
def read_data(file_paths, inlen=16, outlen=16):
    eta = []
    scatter = []
    len = inlen + outlen
    for f in file_paths:
        reader = MatReader(f, to_torch=False)
        eta_tmp = reader.read_field('eta')
        scatter_tmp = reader.read_field('scatter')
        eta_tmp = eta_tmp[:,:,:,117:437]
        scatter_tmp = scatter_tmp[:,:,:,117:437]
        #print(eta_tmp.shape,scatter_tmp.shape)
        slice_eta = []
        slice_scatter = []
        for i in range(10):
            slice_eta.append(eta_tmp[:,:,:,32*i:32*(i+1)])
            slice_scatter.append(scatter_tmp[:,:,:,32*i:32*(i+1)])
        eta_tmp = np.concatenate(slice_eta,axis=0)
        scatter_tmp = np.concatenate(slice_scatter,axis=0)
        #print(eta_tmp.shape,scatter_tmp.shape)
        x = reader.read_field('x').flatten()
        y = reader.read_field('y').flatten()
        t = reader.read_field('t').flatten()
        t = t[0:32] 
        if eta == []:
            eta = eta_tmp
            scatter = scatter_tmp
        else:
            eta = np.concatenate([eta,eta_tmp])
            scatter = np.concatenate([scatter,scatter_tmp])
    return eta, scatter, x, y, t

def split_data(eta, scatter, train_test_split=[0.6,0.2,0.2], inlen=16, outlen=16, sub=2, to_torch=False):
    train_size = int(eta.shape[0]*train_test_split[0])
    test_size = int(eta.shape[0]*train_test_split[2])
    train_a = scatter[:train_size,::sub,::sub,:inlen]
    train_e = eta[:train_size,::sub,::sub,:inlen]
    train_u = eta[:train_size,::sub,::sub,inlen:inlen+outlen]
    val_a = scatter[train_size:train_size+test_size,::sub,::sub,:inlen]
    val_e = eta[train_size:train_size+test_size,::sub,::sub,:inlen]
    val_u = eta[train_size:train_size+test_size,::sub,::sub,inlen:inlen+outlen]
    test_a = scatter[train_size+test_size:,::sub,::sub,:inlen]
    test_e = eta[train_size+test_size:,::sub,::sub,:inlen]
    test_u = eta[train_size+test_size:,::sub,::sub,inlen:inlen+outlen]
    if to_torch:
        train_a = torch.from_numpy(train_a)
        train_e = torch.from_numpy(train_e)
        train_u = torch.from_numpy(train_u)
        val_a = torch.from_numpy(val_a)
        val_e = torch.from_numpy(val_e)
        val_u = torch.from_numpy(val_u)
        test_a = torch.from_numpy(test_a)
        test_e = torch.from_numpy(test_e)
        test_u = torch.from_numpy(test_u)
    return train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u

# read and split data, for each datafile, spilit it in to train,val and test then concatenate them, add hs and spread tag in the filename, length the same as the test array
def read_and_split(file_paths, train_test_split=[0.6,0.2,0.2], inlen=16, outlen=16, sub=2, to_torch=False):
    eta, scatter, x, y, t = read_data([file_paths[0]], inlen=inlen, outlen=outlen)
    train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u = split_data(eta, scatter, train_test_split=train_test_split, inlen=inlen, outlen=outlen, sub=sub, to_torch=to_torch)
    hs = float(file_paths[0].split('angle')[1].split('h')[1].split('.mat')[0]) * np.ones(test_a.shape[0])   
    spread = float(file_paths[0].split('angle')[1].split('h')[0]) * np.ones(test_a.shape[0]) 

    for i,f in enumerate(file_paths):
        if i>0:
            eta, scatter, x, y, t = read_data([f], inlen=inlen, outlen=outlen)
            train_a_tmp, train_e_tmp, train_u_tmp, val_a_tmp, val_e_tmp, val_u_tmp, test_a_tmp, test_e_tmp, test_u_tmp = split_data(eta, scatter, train_test_split=train_test_split, inlen=inlen, outlen=outlen, sub=sub, to_torch=to_torch)
            train_a = np.concatenate([train_a,train_a_tmp])
            train_e = np.concatenate([train_e,train_e_tmp])
            train_u = np.concatenate([train_u,train_u_tmp])
            val_a = np.concatenate([val_a,val_a_tmp])
            val_e = np.concatenate([val_e,val_e_tmp])
            val_u = np.concatenate([val_u,val_u_tmp])
            test_a = np.concatenate([test_a,test_a_tmp])
            test_e = np.concatenate([test_e,test_e_tmp])
            test_u = np.concatenate([test_u,test_u_tmp])
            hs = np.concatenate([hs,float(f.split('angle')[1].split('h')[1].split('.mat')[0]) * np.ones(test_a_tmp.shape[0])])
            spread = np.concatenate([spread,float(f.split('angle')[1].split('h')[0]) * np.ones(test_a_tmp.shape[0])])
    if to_torch:
        train_a = torch.from_numpy(train_a)
        train_e = torch.from_numpy(train_e)
        train_u = torch.from_numpy(train_u)
        val_a = torch.from_numpy(val_a)
        val_e = torch.from_numpy(val_e)
        val_u = torch.from_numpy(val_u)
        test_a = torch.from_numpy(test_a)
        test_e = torch.from_numpy(test_e)
        test_u = torch.from_numpy(test_u)
        hs = torch.from_numpy(hs)
        spread = torch.from_numpy(spread)

    return train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u, x, y, t, hs, spread

dirs = os.listdir("./mixed_data")
dirs = [os.path.join("./mixed_data",d) for d in dirs]
# add another dir
dirs2 = os.listdir("./mixed_data2")
dirs2 = [os.path.join("./mixed_data2",d) for d in dirs2]
dirs = dirs + dirs2
train_a, train_e, train_u, val_a, val_e, val_u, test_a, test_e, test_u,x_domain,y_domain,t_domain,hs, spread = read_and_split(dirs,inlen=16,outlen=16,sub=2,to_torch=True)

**Change your config here**

In [None]:
# parameter and logging setup
modesxy = 32 # modes in x,y dimension
modest = 8 # modes in t dimension
width = 32 # width in FNO layer
n_net = 8 # number of sub-network
multiscale = True # is multi-scale
cnnfusion = True # is using conv-block
learning_rate = 0.001 
epochs = 500
batch_size = 20
iterations = epochs*(train_u.shape[0]//batch_size)
prefix = "conv_kernel3_layer3" # experiment name prefix
path = 'exp/'+prefix+ '_m' + str(modesxy) + str(modest) + '_w' + str(width)+ 'multi'+str(multiscale) + '_n' + str(n_net) + '_c' + str(cnnfusion) 
if not os.path.exists(path):
    os.mkdir(path)
path_model = path+'/model'
path_train_err = path+'/train.txt'
path_test_err = path+'/test.txt'
path_image = path+'/image'
if not os.path.exists(path_image):
    os.mkdir(path_image)
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
log_file = path+'/train'+ datetime.datetime.now().strftime('%Y%m%d_%H%M%S')+'.log'
fh = logging.FileHandler(log_file,encoding="UTF-8")
formator = logging.Formatter(fmt = '%(asctime)s : %(message)s')
sh.setFormatter(formator)
fh.setFormatter(formator)
logger.addHandler(sh)
logger.addHandler(fh)

In [4]:
sub = 2  
S = 128 // sub
T_in = 16
T = 16 

ntrain = train_a.shape[0]
nval = val_a.shape[0]
ntest = test_a.shape[0]
train_a = train_a.reshape(ntrain,S,S,1,T_in).repeat([1,1,1,T,1])
val_a = val_a.reshape(nval,S,S,1,T_in).repeat([1,1,1,T,1])
test_a = test_a.reshape(ntest,S,S,1,T_in).repeat([1,1,1,T,1])

In [None]:
# not used here, if you are interested, you can add compact activation to your experiment
class CompactActivation(nn.Module):
    def __init__(self):
        super(CompactActivation, self).__init__()
        self.relu = nn.ReLU()

    def forward(self, x):
        # (relu(x))^2 - 3(relu(x-1))^2 + 3(relu(x-2))^2 - (relu(x-3))^2
        return (self.relu(x))**2 - \
               3*(self.relu(x-1))**2 + \
               3*(self.relu(x-2))**2 - \
               (self.relu(x-3))**2


# vallina multi-scale FNO which is a combination of N FNO, parallel forward seems not work
class multiscaleFNO(nn.Module):
    def __init__(self,modes1,modes2,modes3,width,n_net):
        super(multiscaleFNO, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.n_net = n_net
        self.fno_subnets = nn.ModuleList([FNO(n_modes=(self.modes1,self.modes2,self.modes3),
                          n_layers=4,hidden_channels=self.width, 
                          lifting_channels = 2*self.width, 
                          projection_channels=2*self.width,in_channels=T_in+3, out_channels=1) for _ in range(n_net)])
        #self.cs = nn.Parameter(torch.randn(n_net), requires_grad=True)
        # give initial value to cs
        self.cs = torch.nn.Parameter(torch.pow(2,torch.arange(-1,n_net-1,dtype=torch.float32)),requires_grad=True) 
        #cs as predefined tensor,gamma are equal weight 1/n_net
        self.gammas = torch.nn.Parameter(torch.tensor([1.0/n_net]*n_net),requires_grad=True)

    def forward(self,x):
        results = []
        for i in range(self.n_net):
            scaled_x = self.cs[i] * x
            subnet_output = self.fno_subnets[i](scaled_x)
            weighted_output = self.gammas[i] * subnet_output
            results.append(weighted_output)
        u = sum(results)
        return u
    
    def forward2(self,x):
        scaled_x = self.cs.view(self.n_net, 1, 1, 1, 1, 1) * x.unsqueeze(0)  # [n_net, batch, channel, x, y, t]
        subnet_outputs = torch.stack([
            self.fno_subnets[i](scaled_x[i]) for i in range(self.n_net)
        ], dim=0)  # [n_net, batch, out_channels, x, y, t]

        weighted_outputs = self.gammas.view(self.n_net, 1, 1, 1, 1, 1) * subnet_outputs
        return weighted_outputs.sum(dim=0)  

    def forward3(self,x):
        futures = []
        for i in range(self.n_net):
            scaled_x = self.cs[i] * x
            fut = torch.jit.fork(self.fno_subnets[i], scaled_x)
            futures.append((fut, i))
    
        results = []
        for fut, i in futures:
            subnet_output = torch.jit.wait(fut)  
            weighted_output = self.gammas[i] * subnet_output
            results.append(weighted_output)

        u = sum(results)
        return u

    # jit
    def forward4(self, x):
        # jit.fork
        futures = [torch.jit.fork(subnet, x * c) 
                for subnet, c in zip(self.fno_subnets, self.cs)]

        outputs = [torch.jit.wait(fut) for fut in futures]

        return sum(g * out for g, out in zip(self.gammas, outputs))

# testcase, input shape [batch, chanel, S, S, T_in]
#input = torch.randn(10, 19, 32, 32, 16)
#model = multiscaleFNO(8, 8, 8, 32, n_net=4)
#output = model(input)
#output = model.forward4(input)
#print("input shape: ", input.shape)
#print("output shape: ", output.shape)

In [None]:
class Sin(nn.Module):
    def __init__(self):
        super(Sin, self).__init__()
        
    def forward(self, x):
        return torch.sin(x)

class res_conv_block(nn.Module):
    def __init__(self,n_net):
        super(res_conv_block,self).__init__()
        self.conv0 = nn.Conv3d(in_channels=n_net,out_channels=64,kernel_size=1,padding=0)
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv3d(64, 64, kernel_size=3, padding=1),
                nn.BatchNorm3d(64)
                ) for _ in range (3)
        ])
        self.conv1 = nn.Conv3d(in_channels=64,out_channels=1,kernel_size=1,padding=0)
    def forward(self,x):
        x = self.conv0(x)
        for block in self.blocks:
            residual =x
            x = block(x)
            x += residual
            x = nn.ReLU(inplace=True)(x)
        x = self.conv1(x)
        #x  = nn.ReLU(inplace=True)(x)
        return x

class multiscaleFNOcnnFusion(nn.Module):
    def __init__(self, modes1, modes2, modes3, width, n_net):
        super(multiscaleFNOcnnFusion, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        self.n_net = n_net
        
        self.fno_subnets = nn.ModuleList([
            FNO(n_modes=(self.modes1, self.modes2, self.modes3),
                n_layers=4,
                hidden_channels=self.width, 
                lifting_channels=2*self.width,
                projection_channels=2*self.width,
                in_channels=T_in+3,
                out_channels=1
                ) 
            for _ in range(n_net)
        ])
        
        self.cs = torch.nn.Parameter(
            torch.pow(2, torch.arange(-2, n_net-2, dtype=torch.float32)),
            requires_grad=True
        )
        #self.fusion_net = res_conv_block(n_net)
        # 3layer 
        self.fusion_net = nn.Sequential(
            nn.Conv3d(n_net, 32, kernel_size=5, padding=2),
            nn.BatchNorm3d(32),
            nn.ReLU(),
 
            nn.Conv3d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.Conv3d(64, 32, kernel_size=5, padding=2),
            nn.BatchNorm3d(32),
            nn.ReLU(),

            nn.Conv3d(32, 1, kernel_size=1),
        )

    def forward(self, x):
        subnet_outputs = []
        for i in range(self.n_net):
            scaled_x = self.cs[i] * x
            subnet_output = self.fno_subnets[i](scaled_x)
            subnet_outputs.append(subnet_output)
        
        #  [batch, n_net, x, y, t]
        stacked_outputs = torch.stack(subnet_outputs, dim=1).squeeze(2)
        # stacked_outputs.shape = [batch, n_net, x, y, t]
        #print("stacked_outputs shape: ", stacked_outputs.shape)
        fused_output = self.fusion_net(stacked_outputs)
        
        return fused_output

# testcase, input shape [batch, chanel, S, S, T_in]
#input = torch.randn(10, 19, 64, 64, 16)
#model = multiscaleFNOcnnFusion(8, 8, 8, 32, n_net=8)
#output = model(input)
#print("input shape: ", input.shape)
#print("output shape: ", output.shape)

In [None]:
class FNO_multi(nn.Module):
    def __init__(self, modes1, modes2, modes3, width, n_net=8, multiscale=False,cnnfusion=False):
        super(FNO_multi, self).__init__()
        self.modes1 = modes1
        self.modes2 = modes2
        self.modes3 = modes3
        self.width = width
        '''
        self.fno_r = TFNO(n_modes=(self.modes1, self.modes2, self.modes3), hidden_channels=self.width,
                        in_channels=19,
                        out_channels=1,
                        factorization='tucker',
                        implementation='factorized',
                        n_layers=4,
                        rank=0.05)
        self.fno_p = TFNO(n_modes=(self.modes1, self.modes2, self.modes3), hidden_channels=self.width,
                        in_channels=19,
                        out_channels=1,
                        factorization='tucker',
                        implementation='factorized',
                        n_layers=4,
                        rank=0.05)
        '''
        if multiscale:
            if cnnfusion:
                self.fno_r = multiscaleFNOcnnFusion(self.modes1, self.modes2, self.modes3, self.width, n_net=n_net)
                self.fno_p = multiscaleFNOcnnFusion(self.modes1, self.modes2, self.modes3, self.width, n_net=n_net)
            else:
                self.fno_r = multiscaleFNO(self.modes1, self.modes2, self.modes3, self.width, n_net=n_net)
                self.fno_p = multiscaleFNO(self.modes1, self.modes2, self.modes3, self.width, n_net=n_net)
        else:
            self.fno_r = FNO(n_modes=(self.modes1,self.modes2,self.modes3), n_layers=4,hidden_channels=self.width, lifting_channels = 2*self.width, projection_channels=2*self.width,in_channels=19, out_channels=1)
            self.fno_p = FNO(n_modes=(self.modes1,self.modes2,self.modes3), n_layers=4,hidden_channels=self.width, lifting_channels = 2*self.width, projection_channels=2*self.width,in_channels=19, out_channels=1)
    
    def forward(self,x):
        x = self.get_grid(x)
        x = self.fno_r(x)
        y = x.permute(0,2,3,1,4).repeat([1,1,1,16,1])
        y = self.get_grid(y)
        y = self.fno_p(y)
        return x,y     
    def get_grid(self,x):
        # x[batch,x,y,t,c]
        # output[batch,c+3,x,y,t]
        shape,device = x.shape,x.device
        batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3]
        gridx = torch.tensor(np.linspace(0, x_domain[-1]-x_domain[0], size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1])
        gridy = torch.tensor(np.linspace(0, y_domain[-1]-y_domain[0], size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1])
        gridz = torch.tensor(np.linspace(0, t_domain[-1]-t_domain[0], size_z), dtype=torch.float)
        gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1])
        grid = torch.cat((gridx, gridy, gridz), dim=-1).to(device)
        x = torch.cat((x,grid),dim=-1)
        x = x.permute(0,4,1,2,3)
        return x

In [None]:
def ssp_loss(pred, true, ax=(1, 2)):
    fpred = torch.fft.fftn(pred, dim=ax)
    ftrue = torch.fft.fftn(true, dim=ax)
    norm_error = torch.norm(fpred-ftrue, dim=ax)
    norm_pred = torch.norm(fpred, dim=ax)
    norm_true = torch.norm(ftrue, dim=ax)
    ssps = norm_error / (norm_pred + norm_true)
    ssp = torch.sum(torch.mean(ssps, dim=-1))
    return ssp

device = torch.device('cuda')
model = FNO_multi(modesxy,modesxy,modest,width,n_net,multiscale,cnnfusion).cuda() 
#print(count_params(model))
logger.info("models params: %d" % count_params(model))
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)
myloss = LpLoss(size_average=False)
#myloss = F.mse_loss(reduction='mean')
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u,train_e), batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(val_a, val_u,val_e), batch_size=batch_size, shuffle=False)

2025-10-18 04:38:51,585 : models params: 134915490


In [None]:
from timeit import default_timer
gamma = 1.0
best_err = 1e10
for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_mse_e = 0
    train_mse_u = 0
    train_l2_e = 0 
    train_l2_u = 0
    train_ssp_e = 0
    train_ssp_u = 0
    for x, y ,e in train_loader:
        x, y ,e = x.cuda(), y.cuda(), e.cuda()
        optimizer.zero_grad()
        outeta,out = model(x)
        out = out.view(batch_size, S, S, T)
        outeta = outeta.view(batch_size,S,S,T)

        mse_u = F.mse_loss(out, y, reduction='mean')
        mse_e = F.mse_loss(outeta, e, reduction='mean')

        l2_u = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        l2_e = myloss(outeta.view(batch_size, -1),e.view(batch_size, -1))
        ssp_e = ssp_loss(outeta,e)
        ssp_u = ssp_loss(out,y)
        l2 = l2_u+gamma* l2_e#+ssp_e+gamma*ssp_u
        #l2 = mse_u + gamma*mse_e
        l2.backward()

        optimizer.step()
        scheduler.step()
        train_mse_e += mse_e.item()
        train_mse_u += mse_u.item()
        train_l2_e += l2_e.item()
        train_l2_u += l2_u.item()
        train_ssp_e += ssp_e.item()
        train_ssp_u += ssp_u.item()

    model.eval()
    test_l2 = 0.0
    test_l2_e = 0.0
    test_l2_u = 0.0
    with torch.no_grad():
        for x, y, e in val_loader:
            x, y ,e = x.cuda(), y.cuda(), e.cuda()

            outeta,out = model(x)
            out = out.view(batch_size, S, S, T)
            outeta = outeta.view(batch_size,S,S,T)
#            out = y_normalizer.decode(out)
            test_l2_u += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
            test_l2_e += myloss(outeta.view(batch_size, -1),e.view(batch_size, -1)).item()

    train_mse_e /= len(train_loader)
    train_mse_u /= len(train_loader)
    train_l2_e /= ntrain
    train_l2_u /= ntrain
    train_ssp_e /= ntrain
    train_ssp_u /= ntrain
    test_l2_e /= ntest
    test_l2_u /= ntest

    t2 = default_timer()
    #print(ep, t2-t1, train_mse_e,train_mse_u, train_l2_e,train_l2_u,test_l2_e, test_l2_u)
    logger.info('epoch:%d, time:%f, train_mse_e:%f, train_mse_u:%f, train_l2_e:%f, train_l2_u:%f, train_ssp_e:%f, train_ssp_u:%f,test_l2_e:%f, test_l2_u:%f'%
               (ep, t2-t1, train_mse_e,train_mse_u, train_l2_e,train_l2_u,train_ssp_e,train_ssp_u,test_l2_e, test_l2_u))
    if test_l2_u<best_err:
        best_err = test_l2_u
        torch.save(model, path_model)
        logger.info('best model saved')


2025-10-18 04:39:27,250 : epoch:0, time:35.643630, train_mse_e:1.963805, train_mse_u:0.814142, train_l2_e:2.046303, train_l2_u:1.266463, train_ssp_e:0.792451, train_ssp_u:0.841687,test_l2_e:1.048204, test_l2_u:1.015782
2025-10-18 04:39:28,050 : best model saved
2025-10-18 04:40:01,723 : epoch:1, time:33.666396, train_mse_e:0.466719, train_mse_u:0.610112, train_l2_e:0.840851, train_l2_u:1.020655, train_ssp_e:0.623787, train_ssp_u:0.903694,test_l2_e:0.636064, test_l2_u:1.001139
2025-10-18 04:40:02,377 : best model saved
2025-10-18 04:40:36,563 : epoch:2, time:34.182002, train_mse_e:0.183929, train_mse_u:0.607055, train_l2_e:0.531026, train_l2_u:1.005810, train_ssp_e:0.291867, train_ssp_u:0.942897,test_l2_e:0.493439, test_l2_u:1.001190
2025-10-18 04:41:09,954 : epoch:3, time:33.386558, train_mse_e:0.098052, train_mse_u:0.603849, train_l2_e:0.397610, train_l2_u:0.998841, train_ssp_e:0.206889, train_ssp_u:0.957697,test_l2_e:0.357780, test_l2_u:0.992328
2025-10-18 04:41:10,616 : best model s

In [None]:
pred = torch.zeros(test_u.shape)
prede = torch.zeros(test_e.shape)
index = 0
model = torch.load(path_model).cuda()
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u,test_e), batch_size=1, shuffle=False)
with torch.no_grad():
    for x, y,e in test_loader:
        test_l2 = 0
        x, y ,e = x.cuda(), y.cuda(), e.cuda()

        outeta,out = model(x)
        out = out.view(1, S, S, T)
        outeta = outeta.view(1,S,S,T)
#        out = y_normalizer.decode(out)
        pred[index] = out
        prede[index] = outeta

        test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
        test_l2 += myloss(outeta.view(1,-1),e.view(1,-1)).item()
        #print(index, test_l2)
        logger.info('%d    %f' % (index,test_l2))
        index = index + 1

scipy.io.savemat(path+'/pred.mat', mdict={'pred': pred.cpu().numpy(), 'prede': prede.cpu().numpy(), 'test_u': test_u.cpu().numpy(), 'test_e': test_e.cpu().numpy()})

2025-10-18 20:28:42,247 : 0    0.411797
2025-10-18 20:28:42,338 : 1    0.420941
2025-10-18 20:28:42,430 : 2    0.501765
2025-10-18 20:28:42,520 : 3    0.554314
2025-10-18 20:28:42,611 : 4    0.458553
2025-10-18 20:28:42,701 : 5    0.424976
2025-10-18 20:28:42,791 : 6    0.744665
2025-10-18 20:28:42,882 : 7    0.519301
2025-10-18 20:28:42,973 : 8    0.471453
2025-10-18 20:28:43,063 : 9    0.512453
2025-10-18 20:28:43,153 : 10    0.662857
2025-10-18 20:28:43,242 : 11    0.456347
2025-10-18 20:28:43,330 : 12    0.525291
2025-10-18 20:28:43,420 : 13    0.510111
2025-10-18 20:28:43,509 : 14    0.646572
2025-10-18 20:28:43,598 : 15    0.635722
2025-10-18 20:28:43,689 : 16    0.582651
2025-10-18 20:28:43,778 : 17    0.575729
2025-10-18 20:28:43,868 : 18    0.711782
2025-10-18 20:28:43,960 : 19    0.559714
2025-10-18 20:28:44,049 : 20    0.402115
2025-10-18 20:28:44,138 : 21    0.448920
2025-10-18 20:28:44,228 : 22    0.548589
2025-10-18 20:28:44,317 : 23    0.422418
2025-10-18 20:28:44,406 : 

**Process the results**

To use the cmocean color map, you may need to install cmocean

In [None]:
# compute ssp for pred and ture with shape [num,nx,ny,nt]
def get_ssp(pred,true,ax=(1,2)):
    fpred = np.fft.fft2(pred,axes=ax)
    ftrue = np.fft.fft2(true,axes=ax)
    norm_error = np.linalg.norm(fpred-ftrue,axis=ax)
    norm_pred = np.linalg.norm(fpred,axis=ax)
    norm_true = np.linalg.norm(ftrue,axis=ax)
    ssps = norm_error/(norm_pred+norm_true)
    ssp = np.mean(ssps)
    return ssp,ssps

T_in = 16
T = 16
#prefix = 'exp/your target dir/'
path_image = prefix+'image'
if not os.path.exists(path_image):
    os.makedirs(path_image)
reader = MatReader(prefix+'pred.mat', to_torch=False)
true_e = reader.read_field('test_e')
true_u = reader.read_field('test_u')
pred_e = reader.read_field('prede')
pred_u = reader.read_field('pred')
# mean over axis 1,2,3 while keep axis 0, then normalized with hs
rmse_e_hs = np.mean(np.sqrt(np.mean((pred_e-true_e)**2,axis=(1,2,3)))/hs)
rmse_u_hs = np.mean(np.sqrt(np.mean((pred_u-true_u)**2,axis=(1,2,3)))/hs)
mae_e_hs = np.mean(np.mean(np.abs(pred_e-true_e),axis=(1,2,3))/hs)
mae_u_hs = np.mean(np.mean(np.abs(pred_u-true_u),axis=(1,2,3))/hs)
ssp_e,ssps_e = get_ssp(pred_e,true_e)
ssp_u,ssps_u = get_ssp(pred_u,true_u)
print('normalized by Hs->rmse_e:%f, rmse_u:%f, mae_e:%f, mae_u:%f, ssp_e:%f, ssp_u:%f' % (rmse_e_hs,rmse_u_hs,mae_e_hs,mae_u_hs,ssp_e,ssp_u)) 

mae_e_time = np.mean(np.mean(np.abs(pred_e-true_e),axis=(1,2))/hs.reshape(-1,1),axis=(0))
mae_u_time = np.mean(np.mean(np.abs(pred_u-true_u),axis=(1,2))/hs.reshape(-1,1),axis=(0))
mae_u_space = np.mean(np.mean(np.abs(pred_u-true_u),axis=(3))/hs.reshape(-1,1,1),axis=(0))
mae_e_space = np.mean(np.mean(np.abs(pred_e-true_e),axis=(3))/hs.reshape(-1,1,1),axis=(0))
mae_u_st = np.mean(np.abs(pred_u-true_u)/hs.reshape(-1,1,1,1),axis=(0))
mae_e_st = np.mean(np.abs(pred_e-true_e)/hs.reshape(-1,1,1,1),axis=(0))

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import cmocean

def plot_and_gif(tmp_eta, x, y, name,title,cmap,clim,label,save_fig=False,save_gif=False,grid=False,size=(8,6)):
    # Create figure and axes
    fig, ax = plt.subplots(figsize=size)
    levels = np.linspace(clim[0], clim[1], 100)  
    
    # Create a pcolormesh plot for the first frame
    if save_fig:
        #mesh = ax.pcolormesh(x, y, tmp_eta[:,:], cmap=cmap)
        mesh = ax.contourf(x, y, tmp_eta[:,:], cmap=cmap,levels=levels,extend='both') # levels=50
        mesh.set_edgecolor("face")
    else:      
        #mesh = ax.pcolormesh(x, y, tmp_eta[:,:,0], cmap=cmap)  # Use 'inferno' colormap for smooth visualization
        mesh = ax.contourf(x, y, tmp_eta[:,:,0], cmap=cmap,levels=levels,extend='both') # levels=50
    #mesh.set_clim(clim[0], clim[1])
    colorbar = plt.colorbar(mesh)
    colorbar.set_label(label, fontsize=16,fontstyle='italic')  # Colorbar label with LaTeX formatting
    ticks = np.linspace(clim[0], clim[1], 5)   
    colorbar.set_ticks(ticks)

    tick_labels = [f'{x:.2f}' for x in ticks]  
    colorbar.set_ticklabels(tick_labels)
       
    plt.title(title)
    ax.set_xlabel(r'$x$ (m)', fontsize=16, fontstyle='italic')
    ax.set_ylabel(r'$y$ (m)', fontsize=16, fontstyle='italic')
    plt.grid(grid)
    if save_fig:
        # Save figure as a jpg, high dpi, tight bounding box
        plt.savefig(name+'.jpg', dpi=600, bbox_inches='tight')
        plt.savefig(name+'.pdf', dpi=600, bbox_inches='tight')
        plt.savefig(name+'.eps', bbox_inches='tight')
    if not save_gif:
        return 
    plt.show()
    # Update function for animation
    def update(frame):
        ax.clear()
        #mesh = ax.pcolormesh(x, y, tmp_eta[:,:,frame], cmap=cmap)  # Use 'inferno' colormap for smooth visualization
        mesh = ax.contourf(x, y, tmp_eta[:,:,frame], cmap=cmap,levels=levels,extend='both') # levels=50
        ax.set_title(title+' '+f'Frame {frame+1}')
        ax.set_xlabel(r'$x$ (m)', fontsize=16, fontstyle='italic')
        ax.set_ylabel(r'$y$ (m)', fontsize=16, fontstyle='italic')
        plt.grid(grid)
        return mesh,

    # Create animation
    ani = animation.FuncAnimation(fig, update, frames=tmp_eta.shape[2], interval=400, blit=False)

    # Save animation as a GIF with LaTeX formatted axis labels and colorbar label
    ani.save(name+'.gif', writer='pillow', fps=5)  # Save as a GIF
    plt.show()
    return

In [None]:
def plot_space(tmp_eta, x, y, name,title,cmap,clim,label,save_fig=False,save_gif=False,grid=False,size=(6,4.5)):
    fig, ax = plt.subplots(figsize=size)

    # Create a pcolormesh plot for the first frame
    #mesh = ax.pcolormesh(x, y, tmp_eta[:,:], cmap=cmap)
    levels = np.linspace(clim[0], clim[1], 100)
    mesh = ax.contourf(x, y, tmp_eta, cmap=cmap,levels=levels,extend='both') #levels = 50
    mesh.set_edgecolor("face")
    mesh.set_clim(clim[0], clim[1])
    colorbar = plt.colorbar(mesh)
    colorbar.set_label(label, fontsize=16,fontstyle='italic')  # Colorbar label with LaTeX formatting
    plt.title(title)
    ax.set_xlabel(r'$x$ (m)', fontsize=16, fontstyle='italic')
    ax.set_ylabel(r'$y$ (m)', fontsize=16, fontstyle='italic')
    plt.grid(grid)
    if save_fig:
        # Save figure as a jpg, high dpi, tight bounding box
        plt.savefig(name+'.eps', bbox_inches='tight')
        plt.savefig(name+'.jpg', dpi=600, bbox_inches='tight')
        plt.savefig(name+'.pdf', dpi=600, bbox_inches='tight')
    if not save_gif:
        return 
    plt.show()
    return   

In [None]:
def show_video_line(data, ncols, vmax=0.6, vmin=0.0, cmap='gray',startnum = 0,norm=None, cbar=False, format='png', out_path=None, use_rgb=False):
    """generate images with a video sequence"""
    fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(3.25 * ncols, 3))
    plt.subplots_adjust(wspace=0.1, hspace=0)
    levels = np.linspace(vmin, vmax, 100)
    if len(data.shape) > 3:
        data = data.swapaxes(1,2).swapaxes(2,3)

    images = []
    if ncols == 1:
        im = axes.contourf(data[:,:,0], cmap=cmap, norm=norm,levels=levels)
        im.set_edgecolor("face")
        images.append(im)
        axes.axis('off')
        im.set_clim(vmin, vmax)
    else:
        for t, ax in enumerate(axes.flat):
            im = ax.contourf(data[:,:,t], cmap=cmap, norm=norm,levels=levels,extend='both')
            im.set_edgecolor("face")
            images.append(im)
            ax.axis('off')
            im.set_clim(vmin, vmax)
            ax.title.set_text(f'Frame {startnum+t+1}')
            ax.title.set_fontsize(25)

    if cbar and ncols > 1:
        cbaxes = fig.add_axes([0.91, 0.15, 0.04 / ncols, 0.7]) 
        cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.1, cax=cbaxes)

    ticks = np.linspace(vmin, vmax, 5)  
    cbar.set_ticks(ticks)
    #font size of cbar
    cbar.ax.tick_params(labelsize=20)  

    tick_labels = [f'{x:.2f}' for x in ticks]  
    cbar.set_ticklabels(tick_labels)

    plt.show()
    if out_path is not None:
        fig.savefig(out_path, format=format, pad_inches=0, bbox_inches='tight')
    plt.close()

In [None]:
T_in = 16
T = 16
#set if you want to check the results only, need to recover x_domain,y_domain,hs from the read data part.
#prefix = 'exp/your target dir/'  
path_image = prefix+'image'
if not os.path.exists(path_image):
    os.makedirs(path_image)
reader = MatReader(prefix+'pred.mat', to_torch=False)
true_e = reader.read_field('test_e')
true_u = reader.read_field('test_u')
pred_e = reader.read_field('prede')
pred_u = reader.read_field('pred')
# mean over axis 1,2,3 while keep axis 0, then normalized with hs
rmse_e_hs = np.mean(np.sqrt(np.mean((pred_e-true_e)**2,axis=(1,2,3)))/hs)
rmse_u_hs = np.mean(np.sqrt(np.mean((pred_u-true_u)**2,axis=(1,2,3)))/hs)
mae_e_hs = np.mean(np.mean(np.abs(pred_e-true_e),axis=(1,2,3))/hs)
mae_u_hs = np.mean(np.mean(np.abs(pred_u-true_u),axis=(1,2,3))/hs)
ssp_e,ssps_e = get_ssp(pred_e,true_e)
ssp_u,ssps_u = get_ssp(pred_u,true_u)
print('normalized by Hs->rmse_e:%f, rmse_u:%f, mae_e:%f, mae_u:%f, ssp_e:%f, ssp_u:%f' % (rmse_e_hs,rmse_u_hs,mae_e_hs,mae_u_hs,ssp_e,ssp_u)) 

mae_e_time = np.mean(np.mean(np.abs(pred_e-true_e),axis=(1,2))/hs.reshape(-1,1),axis=(0))
mae_u_time = np.mean(np.mean(np.abs(pred_u-true_u),axis=(1,2))/hs.reshape(-1,1),axis=(0))
mae_u_space = np.mean(np.mean(np.abs(pred_u-true_u),axis=(3))/hs.reshape(-1,1,1),axis=(0))
mae_e_space = np.mean(np.mean(np.abs(pred_e-true_e),axis=(3))/hs.reshape(-1,1,1),axis=(0))
mae_u_st = np.mean(np.abs(pred_u-true_u)/hs.reshape(-1,1,1,1),axis=(0))
mae_e_st = np.mean(np.abs(pred_e-true_e)/hs.reshape(-1,1,1,1),axis=(0))

# plot mae with time and space
plt.figure(figsize=(6,4.5))
plt.title('MAE for prediction and reconstruction',fontsize=14, fontstyle='normal')
plt.plot(t_domain[T_in:T_in+T],mae_u_time,label='predicted elevation',marker='o')
plt.plot(t_domain[:T_in],mae_e_time,label='reconstructed elevation',marker='*')
plt.legend(fontsize=10,loc='upper left')
plt.xlabel('time(s)',fontsize=16, fontstyle='italic')
plt.ylabel('$MAE/H_s$',fontsize=16, fontstyle='italic')
plt.xlim([0,32])
#plt.ylim([0.02,0.035])
# save eps ,tight layout, high dpi
#plt.savefig(path_image+'/mae_time.jpg',dpi=300,bbox_inches = 'tight')
#plt.savefig(path_image+'/mae_time.pdf',dpi=600,bbox_inches = 'tight')
plt.show()

#plt.figure()
# plot mae in 2d meshgrid (x_domain,y_domain)
sub =2
x = x_domain[::sub]
y = y_domain[::sub]
plot_and_gif(mae_u_space,x,y,path_image+'/mae_u_space','MAE for prediction',cmap=cmocean.cm.balance,clim=[0,0.2],label=r'$MAE/H_s$',save_fig=True,save_gif=False,grid=False)
plot_and_gif(mae_e_space,x,y,path_image+'/mae_e_space','MAE for reconstrution',cmap=cmocean.cm.balance,clim=[0,0.08],label=r'$MAE/H_s$',save_fig=True,save_gif=False,grid=False)
#plot mae in 2d meshgrid with time 
show_video_line(mae_u_st[:,:,:8],ncols=int(mae_u_st.shape[2]/2),vmax=0.2,vmin=0.0,cmap=cmocean.cm.balance,startnum=0,cbar=True,format='pdf',out_path=path_image+'/mae_u_st1.pdf',use_rgb=False)
show_video_line(mae_u_st[:,:,8:],ncols=int(mae_u_st.shape[2]/2),vmax=0.2,vmin=0.0,cmap=cmocean.cm.balance,startnum=8,cbar=True,format='pdf',out_path=path_image+'/mae_u_st2.pdf',use_rgb=False)
show_video_line(mae_e_st[:,:,:8],ncols=int(mae_e_st.shape[2]/2),vmax=0.1,vmin=0.0,cmap=cmocean.cm.balance,startnum=0,cbar=True,format='pdf',out_path=path_image+'/mae_e_st1.pdf',use_rgb=False)
show_video_line(mae_e_st[:,:,8:],ncols=int(mae_e_st.shape[2]/2),vmax=0.1,vmin=0.0,cmap=cmocean.cm.balance,startnum=8,cbar=True,format='pdf',out_path=path_image+'/mae_e_st2.pdf',use_rgb=False)

plot_and_gif(mae_e_st,x,y,path_image+'/mae_r_st','MAE for reconstruction',cmap=cmocean.cm.balance,clim=[0,0.1],label=r'$MAE/H_s$',save_fig=False,save_gif=True,grid=False)
plot_and_gif(mae_u_st,x,y,path_image+'/mae_u_st','MAE for prediction',cmap=cmocean.cm.balance,clim=[0,0.2],label=r'$MAE/H_s$',save_fig=False,save_gif=True,grid=False)

normalized by Hs->rmse_e:0.045343, rmse_u:0.082862, mae_e:0.034548, mae_u:0.052706, ssp_e:0.092759, ssp_u:0.169024


NameError: name 'plt' is not defined